Spaces:
Running
Running
File size: 6,459 Bytes
08adf1a a263f63 af5379c 08adf1a ffd2453 a263f63 90ce7ca af5379c 90ce7ca af5379c b05b11a af5379c 90ce7ca af5379c a263f63 af5379c a263f63 af5379c a263f63 af5379c a263f63 ffd2453 08adf1a b05b11a 08adf1a 73a10b7 08adf1a af5379c 08adf1a ffd2453 af5379c ffd2453 af5379c 08adf1a ebe4d36 af5379c 08adf1a ebe4d36 08adf1a af5379c 08adf1a c902692 08adf1a c902692 08adf1a b05b11a de7af60 c902692 af5379c c902692 08adf1a 455daca c902692 455daca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import zipfile
import os
from huggingface_hub import hf_hub_download
from token_classifier import load_token_classifier, predict
from model import Model
from dataset import RetrievalDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 512
def unzip_file(zip_path, extract_path):
# Create the target directory if it doesn't exist
os.makedirs(extract_path, exist_ok=True)
# Open the zip file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# Extract all contents to the specified directory
zip_ref.extractall(extract_path)
# Setup files
zip_path = "sample_evaluation.zip"
extract_path = "sample_evaluation"
if os.path.exists(zip_path): # Check exists to prevent errors if already unzipped
unzip_file(zip_path, extract_path)
# Download weights if not present
if not os.path.exists("weights.pth"):
hf_hub_download(repo_id="safinal/compositional-image-retrieval", filename="weights.pth", local_dir='.')
def encode_database(model, df: pd.DataFrame) -> np.ndarray:
"""
Process database images and generate embeddings.
"""
model.eval()
all_embeddings = []
# Ensure batching handles empty or small datasets gracefully
for i in tqdm(range(0, len(df), batch_size)):
batch_df = df['target_image'][i:i+batch_size]
if len(batch_df) == 0: continue
target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in batch_df]).to(device)
with torch.no_grad():
target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
if not all_embeddings:
return np.array([])
return np.concatenate(all_embeddings)
def load_model():
model = Model(model_name="ViTamin-L-384", pretrained=None)
model.load("weights.pth")
model.eval()
return model
def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
# Process query image
query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
# Get token classifier
token_classifier, token_classifier_tokenizer = load_token_classifier(
"safinal/compositional-image-retrieval-token-classifier",
device
)
with torch.no_grad():
query_img_embd = model.feature_extractor.encode_image(query_img)
# Process text query
predictions = predict(
tokens=query_text,
model=token_classifier,
tokenizer=token_classifier_tokenizer,
device=device,
max_length=128
)
# Process positive and negative objects
pos = []
neg = []
last_tag = ''
for token, label in predictions:
if label == '<positive_object>':
if last_tag != '<positive_object>':
pos.append(f"a photo of a {token}.")
else:
pos[-1] = pos[-1][:-1] + f" {token}."
elif label == '<negative_object>':
if last_tag != '<negative_object>':
neg.append(f"a photo of a {token}.")
else:
neg[-1] = neg[-1][:-1] + f" {token}."
last_tag = label
# Combine embeddings
for obj in pos:
query_img_embd += model.feature_extractor.encode_text(
model.tokenizer(obj).to(device)
)[0]
for obj in neg:
query_img_embd -= model.feature_extractor.encode_text(
model.tokenizer(obj).to(device)
)[0]
query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
# Calculate similarities
query_embedding = query_img_embd.cpu().numpy()
similarities = cosine_similarity(query_embedding, database_embeddings)[0]
# Get most similar image
most_similar_idx = np.argmax(similarities)
most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
return most_similar_image_path
# --- Initialization ---
print("Loading model...")
model = load_model()
print("Loading dataset...")
test_dataset = RetrievalDataset(
img_dir_path="sample_evaluation/images",
annotations_file_path="sample_evaluation/data.csv",
split='test',
transform=model.processor,
tokenizer=model.tokenizer
)
# Load database once globally to avoid reloading it on every user request
print("Encoding database...")
database_df = test_dataset.load_database()
database_embeddings = encode_database(model, database_df)
def interface_fn(selected_image: str, query_text: str) -> Image.Image:
if selected_image is None:
return None
result_image_path = process_single_query(
model,
selected_image,
query_text,
database_embeddings,
database_df
)
return Image.open(result_image_path)
# --- Gradio Interface ---
demo = gr.Interface(
fn=interface_fn,
inputs=[
gr.Image(type="filepath", label="Select Query Image", image_mode="RGB"),
gr.Textbox(label="Enter Query Text", lines=2)
],
outputs=gr.Image(label="Retrieved Image", type="pil"),
title="Compositional Image Retrieval",
description="Select an image and enter a text query to find the most similar image.",
examples=[
["sample_evaluation/images/261684.png", "Bring cow into the picture, and then follow up with removing bench."],
["sample_evaluation/images/283700.png", "add bowl and bench and remove shoe and elephant"],
["sample_evaluation/images/455007.png", "Discard chair in the beginning, then proceed to bring car into play."],
["sample_evaluation/images/612311.png", "Get rid of train initially, and then follow up by including snowboard."]
],
flagging_mode="never",
cache_examples=False
)
if __name__ == "__main__":
try:
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
print(f"Error launching app: {str(e)}")
raise |