safinal's picture
Update app.py
ebe4d36 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import zipfile
import os
from huggingface_hub import hf_hub_download
from token_classifier import load_token_classifier, predict
from model import Model
from dataset import RetrievalDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 512
def unzip_file(zip_path, extract_path):
# Create the target directory if it doesn't exist
os.makedirs(extract_path, exist_ok=True)
# Open the zip file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# Extract all contents to the specified directory
zip_ref.extractall(extract_path)
# Setup files
zip_path = "sample_evaluation.zip"
extract_path = "sample_evaluation"
if os.path.exists(zip_path): # Check exists to prevent errors if already unzipped
unzip_file(zip_path, extract_path)
# Download weights if not present
if not os.path.exists("weights.pth"):
hf_hub_download(repo_id="safinal/compositional-image-retrieval", filename="weights.pth", local_dir='.')
def encode_database(model, df: pd.DataFrame) -> np.ndarray:
"""
Process database images and generate embeddings.
"""
model.eval()
all_embeddings = []
# Ensure batching handles empty or small datasets gracefully
for i in tqdm(range(0, len(df), batch_size)):
batch_df = df['target_image'][i:i+batch_size]
if len(batch_df) == 0: continue
target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in batch_df]).to(device)
with torch.no_grad():
target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
if not all_embeddings:
return np.array([])
return np.concatenate(all_embeddings)
def load_model():
model = Model(model_name="ViTamin-L-384", pretrained=None)
model.load("weights.pth")
model.eval()
return model
def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
# Process query image
query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
# Get token classifier
token_classifier, token_classifier_tokenizer = load_token_classifier(
"safinal/compositional-image-retrieval-token-classifier",
device
)
with torch.no_grad():
query_img_embd = model.feature_extractor.encode_image(query_img)
# Process text query
predictions = predict(
tokens=query_text,
model=token_classifier,
tokenizer=token_classifier_tokenizer,
device=device,
max_length=128
)
# Process positive and negative objects
pos = []
neg = []
last_tag = ''
for token, label in predictions:
if label == '<positive_object>':
if last_tag != '<positive_object>':
pos.append(f"a photo of a {token}.")
else:
pos[-1] = pos[-1][:-1] + f" {token}."
elif label == '<negative_object>':
if last_tag != '<negative_object>':
neg.append(f"a photo of a {token}.")
else:
neg[-1] = neg[-1][:-1] + f" {token}."
last_tag = label
# Combine embeddings
for obj in pos:
query_img_embd += model.feature_extractor.encode_text(
model.tokenizer(obj).to(device)
)[0]
for obj in neg:
query_img_embd -= model.feature_extractor.encode_text(
model.tokenizer(obj).to(device)
)[0]
query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
# Calculate similarities
query_embedding = query_img_embd.cpu().numpy()
similarities = cosine_similarity(query_embedding, database_embeddings)[0]
# Get most similar image
most_similar_idx = np.argmax(similarities)
most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
return most_similar_image_path
# --- Initialization ---
print("Loading model...")
model = load_model()
print("Loading dataset...")
test_dataset = RetrievalDataset(
img_dir_path="sample_evaluation/images",
annotations_file_path="sample_evaluation/data.csv",
split='test',
transform=model.processor,
tokenizer=model.tokenizer
)
# Load database once globally to avoid reloading it on every user request
print("Encoding database...")
database_df = test_dataset.load_database()
database_embeddings = encode_database(model, database_df)
def interface_fn(selected_image: str, query_text: str) -> Image.Image:
if selected_image is None:
return None
result_image_path = process_single_query(
model,
selected_image,
query_text,
database_embeddings,
database_df
)
return Image.open(result_image_path)
# --- Gradio Interface ---
demo = gr.Interface(
fn=interface_fn,
inputs=[
gr.Image(type="filepath", label="Select Query Image", image_mode="RGB"),
gr.Textbox(label="Enter Query Text", lines=2)
],
outputs=gr.Image(label="Retrieved Image", type="pil"),
title="Compositional Image Retrieval",
description="Select an image and enter a text query to find the most similar image.",
examples=[
["sample_evaluation/images/261684.png", "Bring cow into the picture, and then follow up with removing bench."],
["sample_evaluation/images/283700.png", "add bowl and bench and remove shoe and elephant"],
["sample_evaluation/images/455007.png", "Discard chair in the beginning, then proceed to bring car into play."],
["sample_evaluation/images/612311.png", "Get rid of train initially, and then follow up by including snowboard."]
],
flagging_mode="never",
cache_examples=False
)
if __name__ == "__main__":
try:
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
print(f"Error launching app: {str(e)}")
raise