File size: 6,459 Bytes
08adf1a
 
 
 
 
 
a263f63
af5379c
 
 
08adf1a
ffd2453
 
 
 
a263f63
 
 
 
90ce7ca
 
 
 
 
 
 
 
 
af5379c
90ce7ca
 
af5379c
 
b05b11a
af5379c
 
 
90ce7ca
 
af5379c
a263f63
 
 
 
 
af5379c
a263f63
af5379c
 
 
 
a263f63
 
 
 
af5379c
 
 
a263f63
ffd2453
08adf1a
 
b05b11a
 
08adf1a
 
 
 
 
 
 
 
 
 
73a10b7
08adf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af5379c
 
 
08adf1a
ffd2453
af5379c
ffd2453
 
 
 
 
 
 
 
af5379c
 
 
 
 
08adf1a
ebe4d36
af5379c
 
 
08adf1a
 
 
 
 
ebe4d36
08adf1a
 
 
af5379c
 
08adf1a
 
 
c902692
 
08adf1a
c902692
08adf1a
 
 
b05b11a
 
de7af60
 
c902692
af5379c
c902692
08adf1a
 
 
455daca
c902692
455daca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
import torch
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import zipfile
import os
from huggingface_hub import hf_hub_download


from token_classifier import load_token_classifier, predict
from model import Model
from dataset import RetrievalDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 512

def unzip_file(zip_path, extract_path):
    # Create the target directory if it doesn't exist
    os.makedirs(extract_path, exist_ok=True)
    
    # Open the zip file
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        # Extract all contents to the specified directory
        zip_ref.extractall(extract_path)

# Setup files
zip_path = "sample_evaluation.zip"
extract_path = "sample_evaluation"
if os.path.exists(zip_path): # Check exists to prevent errors if already unzipped
    unzip_file(zip_path, extract_path)

# Download weights if not present
if not os.path.exists("weights.pth"):
    hf_hub_download(repo_id="safinal/compositional-image-retrieval", filename="weights.pth", local_dir='.')


def encode_database(model, df: pd.DataFrame) -> np.ndarray:
    """
    Process database images and generate embeddings.
    """
    model.eval()
    all_embeddings = []
    # Ensure batching handles empty or small datasets gracefully
    for i in tqdm(range(0, len(df), batch_size)):
        batch_df = df['target_image'][i:i+batch_size]
        if len(batch_df) == 0: continue
        
        target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in batch_df]).to(device)
        with torch.no_grad():
            target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
        target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
        all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
    
    if not all_embeddings:
        return np.array([])
    return np.concatenate(all_embeddings)


def load_model():
    model = Model(model_name="ViTamin-L-384", pretrained=None)
    model.load("weights.pth")
    model.eval()
    return model


def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
    # Process query image
    query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
    
    # Get token classifier
    token_classifier, token_classifier_tokenizer = load_token_classifier(
        "safinal/compositional-image-retrieval-token-classifier",
        device
    )
    
    with torch.no_grad():
        query_img_embd = model.feature_extractor.encode_image(query_img)
        
        # Process text query
        predictions = predict(
            tokens=query_text,
            model=token_classifier,
            tokenizer=token_classifier_tokenizer,
            device=device,
            max_length=128
        )
        
        # Process positive and negative objects
        pos = []
        neg = []
        last_tag = ''
        for token, label in predictions:
            if label == '<positive_object>':
                if last_tag != '<positive_object>':
                    pos.append(f"a photo of a {token}.")
                else:
                    pos[-1] = pos[-1][:-1] + f" {token}."
            elif label == '<negative_object>':
                if last_tag != '<negative_object>':
                    neg.append(f"a photo of a {token}.")
                else:
                    neg[-1] = neg[-1][:-1] + f" {token}."
            last_tag = label
            
        # Combine embeddings
        for obj in pos:
            query_img_embd += model.feature_extractor.encode_text(
                model.tokenizer(obj).to(device)
            )[0]
        for obj in neg:
            query_img_embd -= model.feature_extractor.encode_text(
                model.tokenizer(obj).to(device)
            )[0]
            
        query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
        
    # Calculate similarities
    query_embedding = query_img_embd.cpu().numpy()
    similarities = cosine_similarity(query_embedding, database_embeddings)[0]
    
    # Get most similar image
    most_similar_idx = np.argmax(similarities)
    most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
    
    return most_similar_image_path

# --- Initialization ---

print("Loading model...")
model = load_model()

print("Loading dataset...")
test_dataset = RetrievalDataset(
    img_dir_path="sample_evaluation/images",
    annotations_file_path="sample_evaluation/data.csv",
    split='test',
    transform=model.processor,
    tokenizer=model.tokenizer
)

# Load database once globally to avoid reloading it on every user request
print("Encoding database...")
database_df = test_dataset.load_database() 
database_embeddings = encode_database(model, database_df)


def interface_fn(selected_image: str, query_text: str) -> Image.Image:
    if selected_image is None:
        return None
        
    result_image_path = process_single_query(
        model, 
        selected_image, 
        query_text, 
        database_embeddings, 
        database_df
    )
    return Image.open(result_image_path)

# --- Gradio Interface ---

demo = gr.Interface(
    fn=interface_fn,
    inputs=[
        gr.Image(type="filepath", label="Select Query Image", image_mode="RGB"),
        gr.Textbox(label="Enter Query Text", lines=2)
    ],
    outputs=gr.Image(label="Retrieved Image", type="pil"),
    title="Compositional Image Retrieval",
    description="Select an image and enter a text query to find the most similar image.",
    examples=[
        ["sample_evaluation/images/261684.png", "Bring cow into the picture, and then follow up with removing bench."],
        ["sample_evaluation/images/283700.png", "add bowl and bench and remove shoe and elephant"],
        ["sample_evaluation/images/455007.png", "Discard chair in the beginning, then proceed to bring car into play."],
        ["sample_evaluation/images/612311.png", "Get rid of train initially, and then follow up by including snowboard."]
    ],
    flagging_mode="never",
    cache_examples=False
)

if __name__ == "__main__":
    try:
        demo.queue().launch(server_name="0.0.0.0", server_port=7860)
    except Exception as e:
        print(f"Error launching app: {str(e)}")
        raise