import torch import clip from datasets import load_dataset from PIL import Image import gradio as gr from torchvision import transforms import requests from io import BytesIO import numpy as np import faiss # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # Load CLIP model model_clip, preprocess = clip.load("ViT-B/32", device=device) # Load your fine-tuned model weights fine_tuned_state_dict = torch.load("best_model.pt", map_location=device) model_clip.load_state_dict(fine_tuned_state_dict) model_clip.eval() # Load 50 samples from Flickr30k test split dataset = load_dataset("nlphuji/flickr30k", split="test[:50]") # Precompute embeddings image_embeddings = [] images = [] captions = [] valid_indices = [] print("Extracting embeddings...") for i, example in enumerate(dataset): try: img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB") img_tensor = preprocess(img).unsqueeze(0).to(device) with torch.no_grad(): img_feat = model_clip.encode_image(img_tensor) img_feat /= img_feat.norm(dim=-1, keepdim=True) image_embeddings.append(img_feat.cpu()) images.append(img) captions.append(example["sentence"]) valid_indices.append(i) except Exception as e: print(f"Skipping sample {i} due to error: {e}") continue # Stack image features image_embeddings = torch.cat(image_embeddings, dim=0) # Build FAISS index image_index = faiss.IndexFlatIP(image_embeddings.shape[1]) image_index.add(image_embeddings.numpy()) # Search function def search_by_text(query): with torch.no_grad(): tokens = clip.tokenize([query]).to(device) text_feat = model_clip.encode_text(tokens) text_feat /= text_feat.norm(dim=-1, keepdim=True) text_feat_np = text_feat.cpu().numpy() D, I = image_index.search(text_feat_np, 5) results = [] for idx in I[0]: img = images[idx] caption = captions[idx] results.append((img, caption)) return results # Gradio interface def display_results(text_query): results = search_by_text(text_query) output = "" for i, (img, caption) in enumerate(results): output += f"### Result {i+1}\n" output += f"**Caption:** {caption}\n\n" output += f"![img](data:image/png;base64,{image_to_base64(img)})\n\n" return output # Convert PIL image to base64 import base64 from io import BytesIO def image_to_base64(image): buffer = BytesIO() image.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode() iface = gr.Interface(fn=display_results, inputs=gr.Textbox(lines=2, placeholder="Enter text to search..."), outputs="markdown", title="Text-to-Image Retrieval with CLIP", description="Enter a sentence to retrieve similar images using a fine-tuned CLIP model.") iface.launch()