velmurugan1122 commited on
Commit
07d2580
·
1 Parent(s): 0bba07a

added app file

Browse files
Files changed (1) hide show
  1. src/app.py +91 -0
src/app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import streamlit as st
5
+ import requests
6
+ import torch
7
+ from dotenv import load_dotenv
8
+ from pinecone import Pinecone, ServerlessSpec
9
+ from transformers import AutoTokenizer, CLIPModel, AutoProcessor
10
+ from PIL import Image
11
+
12
+ # Logging setup
13
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Load environment variables
17
+ load_dotenv()
18
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
19
+ # HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
20
+
21
+ # # Ensure Hugging Face authentication
22
+ # from huggingface_hub import login
23
+ # login(HF_ACCESS_TOKEN)
24
+
25
+ # Load CLIP model and processor
26
+ tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
27
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
28
+ processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
29
+
30
+ # Connect to Pinecone
31
+ pc = Pinecone(api_key=PINECONE_API_KEY)
32
+
33
+ # Ensure the index exists
34
+ index_name = "index-search"
35
+ if not pc.has_index(index_name):
36
+ pc.create_index(name=index_name, metric="cosine",
37
+ dimension=512,
38
+ spec=ServerlessSpec(cloud="aws", region="us-east-1"))
39
+ time.sleep(5) # Wait for index to initialize
40
+
41
+ unsplash_index = pc.Index(index_name)
42
+
43
+ # Streamlit UI
44
+ st.title("Search Images by Text or Image")
45
+
46
+ search_mode = st.radio("Choose search mode:", ["Text Search", "Image Search"])
47
+
48
+ if search_mode == "Text Search":
49
+ search_query = st.text_input("Search (at least 3 characters)")
50
+ if len(search_query) >= 3:
51
+ with st.spinner("Searching images..."):
52
+ inputs = tokenizer([search_query], padding=True, return_tensors="pt")
53
+ text_features = model.get_text_features(**inputs)
54
+ text_embedding = text_features.detach().numpy().flatten().tolist()
55
+
56
+ response = unsplash_index.query(
57
+ top_k=10,
58
+ vector=text_embedding,
59
+ namespace="image-search-dataset",
60
+ include_metadata=True
61
+ )
62
+
63
+ # Display results
64
+ cols = st.columns(2)
65
+ for i, result in enumerate(response.matches):
66
+ with cols[i % 2]:
67
+ st.image(result.metadata["url"], caption=f"Score: {result.score:.4f}")
68
+
69
+ elif search_mode == "Image Search":
70
+ uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
71
+ if uploaded_file:
72
+ image = Image.open(uploaded_file).convert("RGB")
73
+ st.image(image, caption="Uploaded Image", use_column_width=True)
74
+
75
+ with st.spinner("Searching similar images..."):
76
+ inputs = processor(images=image, return_tensors="pt")
77
+ image_features = model.get_image_features(**inputs)
78
+ image_embedding = image_features.detach().numpy().flatten().tolist()
79
+
80
+ response = unsplash_index.query(
81
+ top_k=10,
82
+ vector=image_embedding,
83
+ namespace="image-search-dataset",
84
+ include_metadata=True
85
+ )
86
+
87
+ # Display results
88
+ cols = st.columns(2)
89
+ for i, result in enumerate(response.matches):
90
+ with cols[i % 2]:
91
+ st.image(result.metadata["url"], caption=f"Score: {result.score:.4f}")