|
|
import streamlit as st |
|
|
from pinecone import Pinecone |
|
|
import os |
|
|
from PIL import Image |
|
|
import requests |
|
|
from transformers import AutoProcessor, CLIPModel |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Image Search App", layout="wide", initial_sidebar_state="expanded") |
|
|
|
|
|
|
|
|
pc = Pinecone(api_key="pcsk_6r4DPn_4P9LckhZak3PhebvSebnEBKQZuzYFeJL2X93LtLxZVBxyJ93inBAktefa8usvJC") |
|
|
index_name = "unsplash-index" |
|
|
unsplash_index = pc.Index(index_name) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_clip_model(): |
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
return model, processor |
|
|
|
|
|
model, processor = load_clip_model() |
|
|
|
|
|
|
|
|
st.sidebar.title("π Search Options") |
|
|
top_k = st.sidebar.slider("π’ Number of Similar Images", 1, 20, 10) |
|
|
|
|
|
|
|
|
st.sidebar.subheader("π Search by Text") |
|
|
search_query = st.sidebar.text_input("Enter a description (e.g., 'a cute cat', 'a red car')") |
|
|
text_search_btn = st.sidebar.button("π Search by Text") |
|
|
|
|
|
|
|
|
st.sidebar.subheader("πΌοΈ Search by Image") |
|
|
uploaded_file = st.sidebar.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"]) |
|
|
image_search_btn = st.sidebar.button("π Search by Image") |
|
|
|
|
|
|
|
|
def get_text_embedding(text): |
|
|
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) |
|
|
with torch.no_grad(): |
|
|
text_features = model.get_text_features(**inputs) |
|
|
return text_features.detach().cpu().numpy().flatten().tolist() |
|
|
|
|
|
|
|
|
def get_image_embedding(image): |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
image_features = model.get_image_features(**inputs) |
|
|
return image_features.detach().cpu().numpy().flatten().tolist() |
|
|
|
|
|
|
|
|
def search_similar_images(embedding, top_k=10): |
|
|
results = unsplash_index.query( |
|
|
vector=embedding, |
|
|
top_k=top_k, |
|
|
include_metadata=True, |
|
|
namespace="image-search-dataset" |
|
|
) |
|
|
return results.get("matches", []) |
|
|
|
|
|
|
|
|
st.title("π Image & Text Search with CLIP & Pinecone") |
|
|
|
|
|
|
|
|
if search_query and text_search_btn: |
|
|
with st.spinner("Generating embedding..."): |
|
|
embedding = get_text_embedding(search_query) |
|
|
with st.spinner("Searching for similar images..."): |
|
|
matches = search_similar_images(embedding, top_k=top_k) |
|
|
|
|
|
st.subheader("π Top Similar Images") |
|
|
if matches: |
|
|
cols = st.columns(3) |
|
|
for i, match in enumerate(matches): |
|
|
cosine_distance = 1 - match.get("score", 0) |
|
|
photo_id = match.get("id", "Unknown ID") |
|
|
url = match.get("metadata", {}).get("url", None) |
|
|
|
|
|
with cols[i % 3]: |
|
|
st.write(f"π· **Photo ID**: {photo_id} | π **Cosine Distance**: {cosine_distance:.4f}") |
|
|
if url: |
|
|
st.image(url, caption=f"Photo ID: {photo_id}", use_container_width=True) |
|
|
else: |
|
|
st.warning(f"β οΈ Image URL not found for Photo ID: {photo_id}") |
|
|
else: |
|
|
st.warning("β οΈ No similar images found!") |
|
|
|
|
|
|
|
|
if uploaded_file and image_search_btn: |
|
|
image = Image.open(uploaded_file).convert("RGB") |
|
|
st.image(image, caption="Uploaded Image", use_container_width=True) |
|
|
|
|
|
with st.spinner("Generating embedding..."): |
|
|
embedding = get_image_embedding(image) |
|
|
with st.spinner("Searching for similar images..."): |
|
|
matches = search_similar_images(embedding, top_k=top_k) |
|
|
|
|
|
st.subheader("π Top Similar Images") |
|
|
if matches: |
|
|
cols = st.columns(3) |
|
|
for i, match in enumerate(matches): |
|
|
cosine_distance = 1 - match.get("score", 0) |
|
|
photo_id = match.get("id", "Unknown ID") |
|
|
url = match.get("metadata", {}).get("url", None) |
|
|
|
|
|
with cols[i % 3]: |
|
|
st.write(f"π· **Photo ID**: {photo_id} | π **Cosine Distance**: {cosine_distance:.4f}") |
|
|
if url: |
|
|
st.image(url, caption=f"Photo ID: {photo_id}", use_container_width=True) |
|
|
else: |
|
|
st.warning(f"β οΈ Image URL not found for Photo ID: {photo_id}") |
|
|
|
|
|
|