Spaces:
Sleeping
Sleeping
| import chromadb | |
| from chromadb.config import Settings | |
| import torchvision.models as models | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import logging | |
| import streamlit as st | |
| import requests | |
| import json | |
| import uuid | |
| import os | |
| try: | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def load_mobilenet_model(): | |
| device = 'cpu' | |
| model = models.mobilenet_v3_small(pretrained=False) | |
| model.classifier[3] = torch.nn.Linear(1024, 768) | |
| model.load_state_dict(torch.load( | |
| 'mobilenet_v3_small_distilled_new_state_dict.pth', map_location=device)) | |
| model.eval().to(device) | |
| return model | |
| def load_chromadb(): | |
| chroma_client = chromadb.PersistentClient( | |
| path='data', settings=Settings(anonymized_telemetry=False)) | |
| collection = chroma_client.get_collection(name='images') | |
| return collection | |
| model = load_mobilenet_model() | |
| logger.info("MobileNet loaded") | |
| collection = load_chromadb() | |
| logger.info("ChromaDB loaded") | |
| logger.info( | |
| f"Connected to ChromaDB collection images with {collection.count()} items") | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ | |
| 0.229, 0.224, 0.225]) | |
| ]) | |
| def get_image_embedding(image): | |
| if isinstance(image, str): | |
| img = Image.open(image).convert('RGB') | |
| else: | |
| img = Image.open(image).convert('RGB') | |
| input_tensor = preprocess(img).unsqueeze(0).to('cpu') | |
| with torch.no_grad(): | |
| student_embedding = model(input_tensor) | |
| return torch.nn.functional.normalize(student_embedding, p=2, dim=1).squeeze(0).tolist() | |
| def save_image(image_file): | |
| unique_filename = f"{image_file.name}" | |
| save_path = os.path.join('images', unique_filename) | |
| with open(save_path, "wb") as f: | |
| f.write(image_file.getbuffer()) | |
| return save_path | |
| def resize_image(image_path, size=(224, 224)): | |
| if isinstance(image_path, str): | |
| img = Image.open(image_path).convert("RGB") | |
| else: | |
| # Handle uploaded file | |
| img = Image.open(image_path).convert("RGB") | |
| img_resized = img.resize(size, Image.LANCZOS) # High-quality resizing | |
| return img_resized | |
| st.sidebar.header("Upload Images") | |
| image_files = st.sidebar.file_uploader( | |
| "Upload images (Please do not upload personal data.)", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
| num_images = st.sidebar.slider( | |
| "Number of results to return", min_value=1, max_value=10, value=3) | |
| if image_files: | |
| st.sidebar.subheader( | |
| "Add Images to collection") | |
| if st.sidebar.button("Add uploaded images"): | |
| for idx, image_file in enumerate(image_files): | |
| image_embedding = get_image_embedding(image_file) | |
| saved_path = save_image(image_file) | |
| unique_id = str(uuid.uuid4()) | |
| metadata = { | |
| 'path': f'images/{image_file.name}', "type": "photo" | |
| } | |
| collection.add( | |
| embeddings=[image_embedding], | |
| ids=[unique_id], | |
| metadatas=[metadata] | |
| ) | |
| st.sidebar.success( | |
| f"Image {image_file.name} added to the collection") | |
| st.title('Image Search Using Text') | |
| st.write( | |
| "The images stored in this database are sourced from the [COCO 2017 Validation Dataset](https://cocodataset.org/#download).") | |
| st.write('Enter the text to search for images with matching description') | |
| text_input = st.text_input("Description", "Road") | |
| if st.button("Search"): | |
| if text_input.strip(): | |
| params = {'text': text_input} | |
| response = requests.get( | |
| 'https://ashish-001-clip-api.hf.space/embedding', params=params) | |
| if response.status_code == 200: | |
| logger.info("Embedding returned by API successfully") | |
| data = json.loads(response.content) | |
| embedding = data['embedding'] | |
| results = collection.query( | |
| query_embeddings=[embedding], | |
| n_results=num_images | |
| ) | |
| images = [results['metadatas'][0][i]['path'] | |
| for i in range(len(results['metadatas'][0]))] | |
| distances = [results['distances'][0][i] | |
| for i in range(len(results['metadatas'][0]))] | |
| if images: | |
| cols_per_row = 3 | |
| rows = (len(images)+cols_per_row-1)//cols_per_row | |
| for row in range(rows): | |
| cols = st.columns(cols_per_row) | |
| for col_idx, col in enumerate(cols): | |
| img_idx = row*cols_per_row+col_idx | |
| if img_idx < len(images): | |
| resized_img = resize_image( | |
| images[img_idx], size=(224, 224)) | |
| col.image(resized_img, | |
| caption=f"Image {img_idx+1}\ndistance {distances[img_idx]}", use_container_width=True) | |
| else: | |
| st.write("No image found") | |
| else: | |
| st.write("Please try again later") | |
| logger.info(f"status code {response.status_code} returned") | |
| else: | |
| st.write("Please enter text in the text area") | |
| except Exception as e: | |
| logger.info(f"Exception occured: {e}") | |