molehh's picture
modified file
fe1a29a
import os
import sys
src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
sys.path.append(src_directory)
from transformers import AutoProcessor, CLIPModel
import streamlit as st
from utils import logger
from database import pinecone_index
from PIL import Image
logger = logger.get_logger()
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
PINECONE_INDEX = pinecone_index.create_index()
def search_by_text(query_text, index):
inputs = processor(text=query_text, return_tensors="pt")
text_features = model.get_text_features(**inputs)
query_vector = text_features.detach().cpu().numpy().flatten().tolist()
results = index.query(vector=query_vector, top_k=10, include_metadata=True, namespace="image-search-dataset")
return results
def search_by_image(image, index):
inputs = processor(images=image, return_tensors="pt")
image_features = model.get_image_features(**inputs)
query_vector = image_features.detach().cpu().numpy().flatten().tolist()
results = index.query(vector=query_vector, top_k=5, include_metadata=True, namespace="image-search-dataset")
return results
def main():
st.set_page_config(page_title="Clip Search", layout="wide")
st.title("📸Image Search with Pinecone and CLIP")
option = st.selectbox("Choose Input Type", ["Text", "Image Upload"])
if option == "Text":
user_text = st.text_input("Enter your search text", placeholder = "for eg: dogs or cat etc..")
if st.button("Search"):
results = search_by_text(user_text, PINECONE_INDEX)
columns = st.columns(2)
for idx, match in enumerate(results['matches']):
with columns[idx % 2]:
st.image(
match['metadata']['url'],
caption=f"Match: {match['metadata']['photo_id']}",
width=500
)
elif option == "Image Upload":
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image")
if st.button("Search by Image"):
results = search_by_image(image, PINECONE_INDEX)
columns = st.columns(2)
for idx, match in enumerate(results['matches']):
with columns[idx % 2]:
st.image(
match['metadata']['url'],
caption=f"Match: {match['metadata']['photo_id']}",
width=500
)
if __name__ == "__main__":
main()