Spaces:
Runtime error
Runtime error
| import torch | |
| import clip | |
| from datasets import load_dataset | |
| from PIL import Image | |
| import gradio as gr | |
| from torchvision import transforms | |
| import requests | |
| from io import BytesIO | |
| import numpy as np | |
| import faiss | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load CLIP model | |
| model_clip, preprocess = clip.load("ViT-B/32", device=device) | |
| # Load your fine-tuned model weights | |
| fine_tuned_state_dict = torch.load("best_model.pt", map_location=device) | |
| model_clip.load_state_dict(fine_tuned_state_dict) | |
| model_clip.eval() | |
| # Load 50 samples from Flickr30k test split | |
| dataset = load_dataset("nlphuji/flickr30k", split="test[:50]") | |
| # Precompute embeddings | |
| image_embeddings = [] | |
| images = [] | |
| captions = [] | |
| valid_indices = [] | |
| print("Extracting embeddings...") | |
| for i, example in enumerate(dataset): | |
| try: | |
| img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB") | |
| img_tensor = preprocess(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| img_feat = model_clip.encode_image(img_tensor) | |
| img_feat /= img_feat.norm(dim=-1, keepdim=True) | |
| image_embeddings.append(img_feat.cpu()) | |
| images.append(img) | |
| captions.append(example["sentence"]) | |
| valid_indices.append(i) | |
| except Exception as e: | |
| print(f"Skipping sample {i} due to error: {e}") | |
| continue | |
| # Stack image features | |
| image_embeddings = torch.cat(image_embeddings, dim=0) | |
| # Build FAISS index | |
| image_index = faiss.IndexFlatIP(image_embeddings.shape[1]) | |
| image_index.add(image_embeddings.numpy()) | |
| # Search function | |
| def search_by_text(query): | |
| with torch.no_grad(): | |
| tokens = clip.tokenize([query]).to(device) | |
| text_feat = model_clip.encode_text(tokens) | |
| text_feat /= text_feat.norm(dim=-1, keepdim=True) | |
| text_feat_np = text_feat.cpu().numpy() | |
| D, I = image_index.search(text_feat_np, 5) | |
| results = [] | |
| for idx in I[0]: | |
| img = images[idx] | |
| caption = captions[idx] | |
| results.append((img, caption)) | |
| return results | |
| # Gradio interface | |
| def display_results(text_query): | |
| results = search_by_text(text_query) | |
| output = "" | |
| for i, (img, caption) in enumerate(results): | |
| output += f"### Result {i+1}\n" | |
| output += f"**Caption:** {caption}\n\n" | |
| output += f"})\n\n" | |
| return output | |
| # Convert PIL image to base64 | |
| import base64 | |
| from io import BytesIO | |
| def image_to_base64(image): | |
| buffer = BytesIO() | |
| image.save(buffer, format="PNG") | |
| return base64.b64encode(buffer.getvalue()).decode() | |
| iface = gr.Interface(fn=display_results, | |
| inputs=gr.Textbox(lines=2, placeholder="Enter text to search..."), | |
| outputs="markdown", | |
| title="Text-to-Image Retrieval with CLIP", | |
| description="Enter a sentence to retrieve similar images using a fine-tuned CLIP model.") | |
| iface.launch() |