import gradio as gr import torch from transformers import CLIPProcessor, CLIPModel import pandas as pd import warnings # Suppress future warnings warnings.simplefilter(action='ignore', category=FutureWarning) # Specify the single model to use model_name = "openai/clip-vit-large-patch14-336" embedding_file = 'openai_clip_vit_large_patch14_336_embeddings_16000.pt' # Load model and preprocessor print(f"Loading {model_name} embeddings from {embedding_file}...") embeddings = {} image_paths = {} try: data = torch.load(embedding_file) embeddings = data['embeddings'] image_paths = data['image_paths'] except Exception as e: print(f"Failed to load {model_name} embeddings: {e}") # Load the model and processor from Hugging Face Transformers processor = CLIPProcessor.from_pretrained(model_name) model = CLIPModel.from_pretrained(model_name) # Load the mapping CSV file for the URLs mapping_file = 'image_mapping.csv' mapping_df = pd.read_csv(mapping_file) # Columns: file_name, public_url url_mapping = dict(zip(mapping_df['file_name'], mapping_df['public_url'])) def get_top_images(text_embedding, image_embeddings, image_paths, url_mapping, top_n=5): try: similarities = torch.nn.functional.cosine_similarity(text_embedding, image_embeddings) top_k_scores, top_k_indices = similarities.topk(top_n) return [url_mapping[image_paths[i]] for i in top_k_indices if image_paths[i] in url_mapping] except Exception as e: print(f"Error during similarity calculation: {e}") return [] def find_top_images(text_query): device = "cuda" if torch.cuda.is_available() else "cpu" all_results = [] print(f"Generating text embedding for {model_name}...") try: model.to(device) embeddings_tensor = embeddings.to(device) # Tokenize and encode text inputs = processor(text=[text_query], return_tensors="pt").to(device) with torch.no_grad(): text_embedding = model.get_text_features(**inputs) # Ensure dimension match between text and image embeddings if text_embedding.size(1) != embeddings_tensor.size(1): raise ValueError(f"Dimension mismatch: text embeddings have {text_embedding.size(1)} dimensions but image embeddings have {embeddings_tensor.size(1)} dimensions.") # Retrieve top-k similar images top_images = get_top_images(text_embedding, embeddings_tensor, image_paths, url_mapping) all_results.extend(top_images[:5]) except Exception as e: print(f"An error occurred while processing {model_name}: {e}") all_results.extend([None] * 5) # Ensure exactly 5 outputs while len(all_results) < 5: all_results.append(None) return all_results # Gradio interface setup with gr.Blocks() as demo: gr.Markdown("Query Results using OpenAI CLIP (openai/clip-vit-large-patch14-336)") text_input = gr.Textbox(label="Enter your text query") # Output row for the model image_elements = [] with gr.Row(): for i in range(5): image = gr.Image(label=f"Image {i+1}", width=200, height=200) image_elements.append(image) text_input.submit( fn=find_top_images, inputs=text_input, outputs=image_elements ) demo.launch()