import os import sys src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src")) sys.path.append(src_directory) from transformers import AutoProcessor, CLIPModel import streamlit as st from utils import logger from database import pinecone_index from PIL import Image logger = logger.get_logger() model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") PINECONE_INDEX = pinecone_index.create_index() def search_by_text(query_text, index): inputs = processor(text=query_text, return_tensors="pt") text_features = model.get_text_features(**inputs) query_vector = text_features.detach().cpu().numpy().flatten().tolist() results = index.query(vector=query_vector, top_k=10, include_metadata=True, namespace="image-search-dataset") return results def search_by_image(image, index): inputs = processor(images=image, return_tensors="pt") image_features = model.get_image_features(**inputs) query_vector = image_features.detach().cpu().numpy().flatten().tolist() results = index.query(vector=query_vector, top_k=5, include_metadata=True, namespace="image-search-dataset") return results def main(): st.set_page_config(page_title="Clip Search", layout="wide") st.title("📸Image Search with Pinecone and CLIP") option = st.selectbox("Choose Input Type", ["Text", "Image Upload"]) if option == "Text": user_text = st.text_input("Enter your search text", placeholder = "for eg: dogs or cat etc..") if st.button("Search"): results = search_by_text(user_text, PINECONE_INDEX) columns = st.columns(2) for idx, match in enumerate(results['matches']): with columns[idx % 2]: st.image( match['metadata']['url'], caption=f"Match: {match['metadata']['photo_id']}", width=500 ) elif option == "Image Upload": uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image") if st.button("Search by Image"): results = search_by_image(image, PINECONE_INDEX) columns = st.columns(2) for idx, match in enumerate(results['matches']): with columns[idx % 2]: st.image( match['metadata']['url'], caption=f"Match: {match['metadata']['photo_id']}", width=500 ) if __name__ == "__main__": main()