Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| import pandas as pd | |
| import warnings | |
| # Suppress future warnings | |
| warnings.simplefilter(action='ignore', category=FutureWarning) | |
| # Specify the single model to use | |
| model_name = "openai/clip-vit-large-patch14-336" | |
| embedding_file = 'openai_clip_vit_large_patch14_336_embeddings_16000.pt' | |
| # Load model and preprocessor | |
| print(f"Loading {model_name} embeddings from {embedding_file}...") | |
| embeddings = {} | |
| image_paths = {} | |
| try: | |
| data = torch.load(embedding_file) | |
| embeddings = data['embeddings'] | |
| image_paths = data['image_paths'] | |
| except Exception as e: | |
| print(f"Failed to load {model_name} embeddings: {e}") | |
| # Load the model and processor from Hugging Face Transformers | |
| processor = CLIPProcessor.from_pretrained(model_name) | |
| model = CLIPModel.from_pretrained(model_name) | |
| # Load the mapping CSV file for the URLs | |
| mapping_file = 'image_mapping.csv' | |
| mapping_df = pd.read_csv(mapping_file) # Columns: file_name, public_url | |
| url_mapping = dict(zip(mapping_df['file_name'], mapping_df['public_url'])) | |
| def get_top_images(text_embedding, image_embeddings, image_paths, url_mapping, top_n=5): | |
| try: | |
| similarities = torch.nn.functional.cosine_similarity(text_embedding, image_embeddings) | |
| top_k_scores, top_k_indices = similarities.topk(top_n) | |
| return [url_mapping[image_paths[i]] for i in top_k_indices if image_paths[i] in url_mapping] | |
| except Exception as e: | |
| print(f"Error during similarity calculation: {e}") | |
| return [] | |
| def find_top_images(text_query): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| all_results = [] | |
| print(f"Generating text embedding for {model_name}...") | |
| try: | |
| model.to(device) | |
| embeddings_tensor = embeddings.to(device) | |
| # Tokenize and encode text | |
| inputs = processor(text=[text_query], return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| text_embedding = model.get_text_features(**inputs) | |
| # Ensure dimension match between text and image embeddings | |
| if text_embedding.size(1) != embeddings_tensor.size(1): | |
| raise ValueError(f"Dimension mismatch: text embeddings have {text_embedding.size(1)} dimensions but image embeddings have {embeddings_tensor.size(1)} dimensions.") | |
| # Retrieve top-k similar images | |
| top_images = get_top_images(text_embedding, embeddings_tensor, image_paths, url_mapping) | |
| all_results.extend(top_images[:5]) | |
| except Exception as e: | |
| print(f"An error occurred while processing {model_name}: {e}") | |
| all_results.extend([None] * 5) | |
| # Ensure exactly 5 outputs | |
| while len(all_results) < 5: | |
| all_results.append(None) | |
| return all_results | |
| # Gradio interface setup | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Query Results using OpenAI CLIP (openai/clip-vit-large-patch14-336)") | |
| text_input = gr.Textbox(label="Enter your text query") | |
| # Output row for the model | |
| image_elements = [] | |
| with gr.Row(): | |
| for i in range(5): | |
| image = gr.Image(label=f"Image {i+1}", width=200, height=200) | |
| image_elements.append(image) | |
| text_input.submit( | |
| fn=find_top_images, | |
| inputs=text_input, | |
| outputs=image_elements | |
| ) | |
| demo.launch() |