Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Optional, Tuple | |
| import faiss | |
| import angle_emb | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from datasets import Dataset | |
| class FlickrAngleSearch: | |
| def __init__(self, model_name: str = "WhereIsAI/UAE-Large-V1", device: str = "cuda:0"): | |
| """Initialize the search engine with model and empty index""" | |
| self.model = angle_emb.AnglE(model_name, pooling_strategy='cls', device=device) | |
| self._index: Optional[faiss.IndexFlatIP] = None | |
| self.captions: Optional[List[str]] = None | |
| self.caption2image: Optional[Dict[str, int]] = None | |
| self.ds: Optional[Dataset] = None | |
| def index(self, dataset: Dataset) -> "FlickrAngleSearch": | |
| """Build the search index from a dataset""" | |
| self.ds = dataset | |
| # Extract unique captions and build caption->image mapping | |
| captions: List[str] = [] | |
| caption2image: Dict[str, int] = {} | |
| for i, example in enumerate(tqdm(dataset)): | |
| for caption in example['caption']: | |
| if caption not in caption2image: | |
| captions.append(caption) | |
| caption2image[caption] = i | |
| self.captions = captions | |
| self.caption2image = caption2image | |
| # Encode captions | |
| print(f"Encoding {len(captions)} unique captions...") | |
| caption_embeddings = self.encode(captions) | |
| # Build FAISS index | |
| dimension = caption_embeddings.shape[1] | |
| self._index = faiss.IndexFlatIP(dimension) | |
| self._index.add(caption_embeddings) | |
| return self | |
| def from_preindexed(cls, index_path: str, captions_path: str, caption2image_path: str, dataset: Dataset, device: str = "cpu") -> "FlickrAngleSearch": | |
| """Load a pre-built index and mappings""" | |
| instance = cls(device=device) | |
| instance._index = faiss.read_index(index_path) | |
| instance.captions = torch.load(captions_path) | |
| instance.caption2image = torch.load(caption2image_path) | |
| instance.ds = dataset | |
| return instance | |
| def save_index(self, index_path: str, captions_path: str, caption2image_path: str) -> None: | |
| """Save the index and mappings to disk""" | |
| faiss.write_index(self._index, index_path) | |
| torch.save(self.captions, captions_path) | |
| torch.save(self.caption2image, caption2image_path) | |
| def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Encode a list of texts to embeddings""" | |
| embeddings: List[np.ndarray] = [] | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Encoding texts"): | |
| batch = texts[i:i + batch_size] | |
| with torch.no_grad(): | |
| embs = self.model.encode(batch, to_numpy=True, device=self.model.device) | |
| embeddings.extend(embs) | |
| return np.stack(embeddings) | |
| def search(self, query: str, k: int = 5) -> List[Tuple[float, str, int]]: | |
| """ | |
| Search for the top-k most relevant captions and their corresponding images | |
| Args: | |
| query: Text query to search for | |
| k: Number of results to return | |
| Returns: | |
| List of (score, caption, image_index) tuples | |
| """ | |
| # Encode the query text | |
| query_embedding = self.encode([query]) | |
| # Search the index | |
| scores, indices = self._index.search(query_embedding, k) | |
| # Get the results | |
| results: List[Tuple[float, str, int]] = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| caption = self.captions[idx] | |
| image_idx = self.caption2image[caption] | |
| results.append((float(score), caption, image_idx)) | |
| return results | |
| if __name__ == "__main__": | |
| import os | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from huggingface_hub import snapshot_download | |
| local_dir = snapshot_download(repo_id='WhereIsAI/angle-flickr-index-v2') | |
| ds = load_dataset("WhereIsAI/flickr30k-v2", split='train') | |
| search = FlickrAngleSearch.from_preindexed( | |
| os.path.join(local_dir, 'index.faiss'), | |
| os.path.join(local_dir, 'captions.pt'), | |
| os.path.join(local_dir, 'caption2image.pt'), | |
| ds, | |
| device='cpu' | |
| ) | |
| def search_and_display(query, num_results=5): | |
| results = search.search(query, k=num_results) | |
| images = [] | |
| captions = [] | |
| similarities = [] | |
| visited_images = set() | |
| for similarity, caption, image_idx in results: | |
| if image_idx not in visited_images: | |
| visited_images.add(image_idx) | |
| images.append(ds[image_idx]['image']) | |
| captions.append(caption) | |
| similarities.append(f"{similarity:.4f}") | |
| return images, captions, similarities | |
| demo = gr.Interface( | |
| fn=search_and_display, | |
| inputs=[ | |
| gr.Textbox(label="Search Query"), | |
| gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Results") | |
| ], | |
| outputs=[ | |
| gr.Gallery(label="Top Results"), | |
| gr.Dataframe(headers=["Caption"], label="Captions"), | |
| gr.Dataframe(headers=["Similarity Score"], label="Similarity Scores") | |
| ], | |
| title="Flickr Image Search", | |
| description="Search through Flickr images using natural language queries" | |
| ) | |
| demo.launch(share=True) | |