Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import time | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from pinecone import Pinecone, ServerlessSpec | |
| from transformers import AutoProcessor, CLIPModel | |
| from PIL import Image | |
| import torch | |
| global processor, model | |
| processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| load_dotenv() | |
| # Connect to Pinecone | |
| pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY")) | |
| # Create an index if it does not exist | |
| index_name = "index-search" | |
| unsplash_index = None | |
| if not pc.has_index(index_name): | |
| pc.create_index(name=index_name, metric="cosine", | |
| dimension=512, | |
| spec=ServerlessSpec(cloud="aws", region="us-east-1")) | |
| # Wait for the index to be ready | |
| while True: | |
| index = pc.describe_index(index_name) | |
| if index.status.get("ready", False): | |
| unsplash_index = pc.Index(index_name) | |
| break | |
| print("Waiting for index to be ready...") | |
| time.sleep(1) | |
| else: | |
| unsplash_index = pc.Index(index_name) | |
| # Streamlit UI | |
| st.title("π CLIP-Powered Image Search") | |
| st.markdown("Search images using **text** or **image**!") | |
| # Search type selection | |
| search_type = st.radio("Select Search Type", ["Text Search", "Image Search"], horizontal=True) | |
| def get_text_embedding(query): | |
| inputs = processor(text=query, return_tensors="pt") | |
| text_features = model.get_text_features(**inputs) | |
| return text_features.detach().numpy().flatten().tolist() | |
| def get_image_embedding(image): | |
| image = image.convert("RGB").resize((224, 224)) | |
| inputs = processor(images=image, return_tensors="pt") | |
| image_features = model.get_image_features(**inputs) | |
| return image_features.detach().numpy().flatten().tolist() | |
| if search_type == "Text Search": | |
| search_query = st.text_input("Enter a search query (min 3 characters)") | |
| if len(search_query) >= 3: | |
| with st.spinner("Searching images..."): | |
| text_embedding = get_text_embedding(search_query) | |
| response = unsplash_index.query(top_k=10, vector=text_embedding, namespace="image-search-dataset", include_metadata=True) | |
| # Display images in two columns | |
| cols = st.columns(2) | |
| for i, result in enumerate(response.matches): | |
| with cols[i % 2]: | |
| st.image(result.metadata["url"], caption=f"Match {i+1}") | |
| elif search_type == "Image Search": | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image") | |
| with st.spinner("Searching for similar images..."): | |
| image_embedding = get_image_embedding(image) | |
| response = unsplash_index.query(top_k=10, vector=image_embedding, namespace="image-search-dataset", include_metadata=True) | |
| # Display images in two columns | |
| cols = st.columns(2) | |
| for i, result in enumerate(response.matches): | |
| with cols[i % 2]: | |
| st.image(result.metadata["url"], caption=f"Match {i+1}") | |