File size: 3,313 Bytes
fd166c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132d2e0
fd166c6
fc14255
132d2e0
 
fd166c6
 
 
 
fc14255
fd166c6
 
3b77722
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
import pandas as pd
import warnings

# Suppress future warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Specify the single model to use
model_name = "openai/clip-vit-large-patch14-336"
embedding_file = 'openai_clip_vit_large_patch14_336_embeddings_16000.pt'

# Load model and preprocessor
print(f"Loading {model_name} embeddings from {embedding_file}...")

embeddings = {}
image_paths = {}

try:
    data = torch.load(embedding_file)
    embeddings = data['embeddings']
    image_paths = data['image_paths']
except Exception as e:
    print(f"Failed to load {model_name} embeddings: {e}")

# Load the model and processor from Hugging Face Transformers
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)

# Load the mapping CSV file for the URLs
mapping_file = 'image_mapping.csv'
mapping_df = pd.read_csv(mapping_file)  # Columns: file_name, public_url
url_mapping = dict(zip(mapping_df['file_name'], mapping_df['public_url']))

def get_top_images(text_embedding, image_embeddings, image_paths, url_mapping, top_n=5):
    try:
        similarities = torch.nn.functional.cosine_similarity(text_embedding, image_embeddings)
        top_k_scores, top_k_indices = similarities.topk(top_n)
        return [url_mapping[image_paths[i]] for i in top_k_indices if image_paths[i] in url_mapping]
    except Exception as e:
        print(f"Error during similarity calculation: {e}")
        return []

def find_top_images(text_query):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    all_results = []

    print(f"Generating text embedding for {model_name}...")

    try:
        model.to(device)
        embeddings_tensor = embeddings.to(device)

        # Tokenize and encode text
        inputs = processor(text=[text_query], return_tensors="pt").to(device)
        with torch.no_grad():
            text_embedding = model.get_text_features(**inputs)

        # Ensure dimension match between text and image embeddings
        if text_embedding.size(1) != embeddings_tensor.size(1):
            raise ValueError(f"Dimension mismatch: text embeddings have {text_embedding.size(1)} dimensions but image embeddings have {embeddings_tensor.size(1)} dimensions.")

        # Retrieve top-k similar images
        top_images = get_top_images(text_embedding, embeddings_tensor, image_paths, url_mapping)
        all_results.extend(top_images[:5])
    
    except Exception as e:
        print(f"An error occurred while processing {model_name}: {e}")
        all_results.extend([None] * 5)

    # Ensure exactly 5 outputs
    while len(all_results) < 5:
        all_results.append(None)

    return all_results

# Gradio interface setup
with gr.Blocks() as demo:
    gr.Markdown("Query Results using OpenAI CLIP (openai/clip-vit-large-patch14-336)")
    text_input = gr.Textbox(label="Enter your text query")

    # Output row for the model
    image_elements = []
    with gr.Row():
        for i in range(5):
            image = gr.Image(label=f"Image {i+1}", width=200, height=200)
            image_elements.append(image)

    text_input.submit(
        fn=find_top_images,
        inputs=text_input,
        outputs=image_elements
    )

demo.launch()