CLIP / app.py
SebastianRuff's picture
Rename app2.py to app.py
be32000 verified
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
import pandas as pd
import warnings
# Suppress future warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
# Specify the single model to use
model_name = "openai/clip-vit-large-patch14-336"
embedding_file = 'openai_clip_vit_large_patch14_336_embeddings_16000.pt'
# Load model and preprocessor
print(f"Loading {model_name} embeddings from {embedding_file}...")
embeddings = {}
image_paths = {}
try:
data = torch.load(embedding_file)
embeddings = data['embeddings']
image_paths = data['image_paths']
except Exception as e:
print(f"Failed to load {model_name} embeddings: {e}")
# Load the model and processor from Hugging Face Transformers
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)
# Load the mapping CSV file for the URLs
mapping_file = 'image_mapping.csv'
mapping_df = pd.read_csv(mapping_file) # Columns: file_name, public_url
url_mapping = dict(zip(mapping_df['file_name'], mapping_df['public_url']))
def get_top_images(text_embedding, image_embeddings, image_paths, url_mapping, top_n=5):
try:
similarities = torch.nn.functional.cosine_similarity(text_embedding, image_embeddings)
top_k_scores, top_k_indices = similarities.topk(top_n)
return [url_mapping[image_paths[i]] for i in top_k_indices if image_paths[i] in url_mapping]
except Exception as e:
print(f"Error during similarity calculation: {e}")
return []
def find_top_images(text_query):
device = "cuda" if torch.cuda.is_available() else "cpu"
all_results = []
print(f"Generating text embedding for {model_name}...")
try:
model.to(device)
embeddings_tensor = embeddings.to(device)
# Tokenize and encode text
inputs = processor(text=[text_query], return_tensors="pt").to(device)
with torch.no_grad():
text_embedding = model.get_text_features(**inputs)
# Ensure dimension match between text and image embeddings
if text_embedding.size(1) != embeddings_tensor.size(1):
raise ValueError(f"Dimension mismatch: text embeddings have {text_embedding.size(1)} dimensions but image embeddings have {embeddings_tensor.size(1)} dimensions.")
# Retrieve top-k similar images
top_images = get_top_images(text_embedding, embeddings_tensor, image_paths, url_mapping)
all_results.extend(top_images[:5])
except Exception as e:
print(f"An error occurred while processing {model_name}: {e}")
all_results.extend([None] * 5)
# Ensure exactly 5 outputs
while len(all_results) < 5:
all_results.append(None)
return all_results
# Gradio interface setup
with gr.Blocks() as demo:
gr.Markdown("Query Results using OpenAI CLIP (openai/clip-vit-large-patch14-336)")
text_input = gr.Textbox(label="Enter your text query")
# Output row for the model
image_elements = []
with gr.Row():
for i in range(5):
image = gr.Image(label=f"Image {i+1}", width=200, height=200)
image_elements.append(image)
text_input.submit(
fn=find_top_images,
inputs=text_input,
outputs=image_elements
)
demo.launch()