Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import pandas as pd | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from tqdm import tqdm | |
| import zipfile | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from token_classifier import load_token_classifier, predict | |
| from model import Model | |
| from dataset import RetrievalDataset | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| batch_size = 512 | |
| def unzip_file(zip_path, extract_path): | |
| # Create the target directory if it doesn't exist | |
| os.makedirs(extract_path, exist_ok=True) | |
| # Open the zip file | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| # Extract all contents to the specified directory | |
| zip_ref.extractall(extract_path) | |
| # Setup files | |
| zip_path = "sample_evaluation.zip" | |
| extract_path = "sample_evaluation" | |
| if os.path.exists(zip_path): # Check exists to prevent errors if already unzipped | |
| unzip_file(zip_path, extract_path) | |
| # Download weights if not present | |
| if not os.path.exists("weights.pth"): | |
| hf_hub_download(repo_id="safinal/compositional-image-retrieval", filename="weights.pth", local_dir='.') | |
| def encode_database(model, df: pd.DataFrame) -> np.ndarray: | |
| """ | |
| Process database images and generate embeddings. | |
| """ | |
| model.eval() | |
| all_embeddings = [] | |
| # Ensure batching handles empty or small datasets gracefully | |
| for i in tqdm(range(0, len(df), batch_size)): | |
| batch_df = df['target_image'][i:i+batch_size] | |
| if len(batch_df) == 0: continue | |
| target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in batch_df]).to(device) | |
| with torch.no_grad(): | |
| target_imgs_embedding = model.feature_extractor.encode_image(target_imgs) | |
| target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2) | |
| all_embeddings.append(target_imgs_embedding.detach().cpu().numpy()) | |
| if not all_embeddings: | |
| return np.array([]) | |
| return np.concatenate(all_embeddings) | |
| def load_model(): | |
| model = Model(model_name="ViTamin-L-384", pretrained=None) | |
| model.load("weights.pth") | |
| model.eval() | |
| return model | |
| def process_single_query(model, query_image_path, query_text, database_embeddings, database_df): | |
| # Process query image | |
| query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device) | |
| # Get token classifier | |
| token_classifier, token_classifier_tokenizer = load_token_classifier( | |
| "safinal/compositional-image-retrieval-token-classifier", | |
| device | |
| ) | |
| with torch.no_grad(): | |
| query_img_embd = model.feature_extractor.encode_image(query_img) | |
| # Process text query | |
| predictions = predict( | |
| tokens=query_text, | |
| model=token_classifier, | |
| tokenizer=token_classifier_tokenizer, | |
| device=device, | |
| max_length=128 | |
| ) | |
| # Process positive and negative objects | |
| pos = [] | |
| neg = [] | |
| last_tag = '' | |
| for token, label in predictions: | |
| if label == '<positive_object>': | |
| if last_tag != '<positive_object>': | |
| pos.append(f"a photo of a {token}.") | |
| else: | |
| pos[-1] = pos[-1][:-1] + f" {token}." | |
| elif label == '<negative_object>': | |
| if last_tag != '<negative_object>': | |
| neg.append(f"a photo of a {token}.") | |
| else: | |
| neg[-1] = neg[-1][:-1] + f" {token}." | |
| last_tag = label | |
| # Combine embeddings | |
| for obj in pos: | |
| query_img_embd += model.feature_extractor.encode_text( | |
| model.tokenizer(obj).to(device) | |
| )[0] | |
| for obj in neg: | |
| query_img_embd -= model.feature_extractor.encode_text( | |
| model.tokenizer(obj).to(device) | |
| )[0] | |
| query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2) | |
| # Calculate similarities | |
| query_embedding = query_img_embd.cpu().numpy() | |
| similarities = cosine_similarity(query_embedding, database_embeddings)[0] | |
| # Get most similar image | |
| most_similar_idx = np.argmax(similarities) | |
| most_similar_image_path = database_df.iloc[most_similar_idx]['target_image'] | |
| return most_similar_image_path | |
| # --- Initialization --- | |
| print("Loading model...") | |
| model = load_model() | |
| print("Loading dataset...") | |
| test_dataset = RetrievalDataset( | |
| img_dir_path="sample_evaluation/images", | |
| annotations_file_path="sample_evaluation/data.csv", | |
| split='test', | |
| transform=model.processor, | |
| tokenizer=model.tokenizer | |
| ) | |
| # Load database once globally to avoid reloading it on every user request | |
| print("Encoding database...") | |
| database_df = test_dataset.load_database() | |
| database_embeddings = encode_database(model, database_df) | |
| def interface_fn(selected_image: str, query_text: str) -> Image.Image: | |
| if selected_image is None: | |
| return None | |
| result_image_path = process_single_query( | |
| model, | |
| selected_image, | |
| query_text, | |
| database_embeddings, | |
| database_df | |
| ) | |
| return Image.open(result_image_path) | |
| # --- Gradio Interface --- | |
| demo = gr.Interface( | |
| fn=interface_fn, | |
| inputs=[ | |
| gr.Image(type="filepath", label="Select Query Image", image_mode="RGB"), | |
| gr.Textbox(label="Enter Query Text", lines=2) | |
| ], | |
| outputs=gr.Image(label="Retrieved Image", type="pil"), | |
| title="Compositional Image Retrieval", | |
| description="Select an image and enter a text query to find the most similar image.", | |
| examples=[ | |
| ["sample_evaluation/images/261684.png", "Bring cow into the picture, and then follow up with removing bench."], | |
| ["sample_evaluation/images/283700.png", "add bowl and bench and remove shoe and elephant"], | |
| ["sample_evaluation/images/455007.png", "Discard chair in the beginning, then proceed to bring car into play."], | |
| ["sample_evaluation/images/612311.png", "Get rid of train initially, and then follow up by including snowboard."] | |
| ], | |
| flagging_mode="never", | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |
| except Exception as e: | |
| print(f"Error launching app: {str(e)}") | |
| raise |