File size: 4,064 Bytes
8c65df5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import streamlit as st
from pinecone import Pinecone
from dotenv import load_dotenv
import os
from PIL import Image
import requests
from transformers import AutoProcessor, CLIPModel
import numpy as np

# Load environment variables
load_dotenv()

# Initialize Pinecone
pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY"))
index_name = "image-index-50000"
unsplash_index = pc.Index(index_name)

# Load CLIP model and processor
@st.cache_resource
def load_clip_model():
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return model, processor

model, processor = load_clip_model()

# Function to generate embedding from text
def get_text_embedding(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
    text_features = model.get_text_features(**inputs)
    embedding = text_features.detach().cpu().numpy().flatten().tolist()
    return embedding

# Function to generate embedding from image
def get_image_embedding(image):
    inputs = processor(images=image, return_tensors="pt")
    image_features = model.get_image_features(**inputs)
    embedding = image_features.detach().cpu().numpy().flatten().tolist()
    return embedding

# Function to query Pinecone and fetch similar images
def search_similar_images(embedding, top_k=10):
    results = unsplash_index.query(
        vector=embedding,
        top_k=top_k,
        include_metadata=True,
        namespace="image-search-dataset"
    )
    return results["matches"]

# Streamlit UI
st.title("🔍 Image Search App")

# Sidebar for search controls
with st.sidebar:
    st.header("Search Options")
    
    # Search type selection
    search_type = st.radio(
        "Select search type:",
        ("Text to Image", "Image to Image")
    )
    
    # Input based on search type
    if search_type == "Text to Image":
        search_query = st.text_input("Enter your search query (e.g. Flower)")
        uploaded_file = None
    else:  # Image to Image
        uploaded_file = st.file_uploader("Upload an image to search", 
                                       type=["jpg", "jpeg", "png"])
        search_query = None
    
    # Search button
    search_button = st.button("Search")

# Main content area for results
if search_button:
    if (search_type == "Text to Image" and search_query) or (search_type == "Image to Image" and uploaded_file):
        # Generate embedding based on search type
        with st.spinner("Generating embedding..."):
            if search_type == "Text to Image":
                embedding = get_text_embedding(search_query)
            else:  # Image to Image
                image = Image.open(uploaded_file).convert("RGB")
                embedding = get_image_embedding(image)
                # Display the uploaded image
                st.image(image, caption="Uploaded Image", use_container_width=True)

        # Search for similar images
        with st.spinner("Searching for similar images..."):
            matches = search_similar_images(embedding, top_k=10)

        # Display results
        st.subheader("Top Similar Images")
        for match in matches:
            score = match["score"]
            photo_id = match["id"]
            url = match["metadata"]["url"]
            st.write(f"**Photo ID**: {photo_id} | **Similarity Score**: {score:.4f}")
            try:
                # Fetch and display the image from the URL
                response = requests.get(url, stream=True)
                response.raw.decode_content = True
                img = Image.open(response.raw)
                st.image(img, caption=f"Photo ID: {photo_id}", use_container_width=True)
            except Exception as e:
                st.error(f"Could not load image from {url}: {e}")
    else:
        st.warning("Please provide a search query or upload an image!")

# Instructions
st.write("---")
st.write("Note: This app searches an Unsplash dataset indexed in Pinecone using CLIP embeddings based on your input.")