Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| from tokenizers import Tokenizer | |
| from torch.utils.data import Dataset | |
| import albumentations as A | |
| from tqdm import tqdm | |
| from huggingface_hub import hf_hub_download | |
| from datasets import load_dataset | |
| from fourm.vq.vqvae import VQVAE | |
| from fourm.models.fm import FM | |
| from fourm.models.generate import ( | |
| GenerationSampler, | |
| build_chained_generation_schedules, | |
| init_empty_target_modality, | |
| custom_text, | |
| ) | |
| from fourm.utils.plotting_utils import decode_dict | |
| from fourm.data.modality_info import MODALITY_INFO | |
| from fourm.data.modality_transforms import RGBTransform | |
| from torchvision.transforms.functional import center_crop | |
| # Constants and configurations | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| IMG_SIZE = 224 | |
| TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json" | |
| FM_MODEL_PATH = "EPFL-VILAB/4M-21_L" | |
| VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224" | |
| IMAGE_DATASET_PATH = "./data" | |
| # Load models | |
| text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH) | |
| vqvae = VQVAE.from_pretrained(VQVAE_PATH) | |
| fm_model = FM.from_pretrained(FM_MODEL_PATH).eval().to(DEVICE) | |
| # Generation configurations | |
| cond_domains = ["caption", "metadata"] | |
| target_domains = ["tok_dinov2_global"] | |
| tokens_per_target = [16] | |
| generation_config = { | |
| "autoregression_schemes": ["roar"], | |
| "decoding_steps": [1], | |
| "token_decoding_schedules": ["linear"], | |
| "temps": [2.0], | |
| "temp_schedules": ["onex:0.5:0.5"], | |
| "cfg_scales": [1.0], | |
| "cfg_schedules": ["constant"], | |
| "cfg_grow_conditioning": True, | |
| } | |
| top_p, top_k = 0.8, 0.0 | |
| schedule = build_chained_generation_schedules( | |
| cond_domains=cond_domains, | |
| target_domains=target_domains, | |
| tokens_per_target=tokens_per_target, | |
| **generation_config, | |
| ) | |
| sampler = GenerationSampler(fm_model) | |
| class HuggingFaceImageDataset(Dataset): | |
| def __init__(self, dataset_name, split="train", img_sz=224): | |
| self.dataset = load_dataset(dataset_name, split=split) | |
| self.tfms = A.Compose([ | |
| A.SmallestMaxSize(img_sz) | |
| ]) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| img = self.dataset[idx]['image'] | |
| img = np.array(img) | |
| img = self.tfms(image=img)["image"] | |
| return Image.fromarray(img) | |
| # Usage | |
| dataset = HuggingFaceImageDataset("aroraaman/4m-21-demo") | |
| def load_image_embeddings(): | |
| # Download the file | |
| file_path = hf_hub_download(repo_id="aroraaman/img-tensor", filename="image_emb.pt") | |
| # Load the tensor | |
| image_embeddings = torch.load(file_path) | |
| return image_embeddings | |
| # Use the embeddings in your app | |
| image_embeddings = load_image_embeddings() | |
| image_embeddings = image_embeddings.to(DEVICE) | |
| image_embeddings.shape | |
| print(image_embeddings.shape) | |
| def get_similar_images(caption, brightness, num_items): | |
| batched_sample = {} | |
| for target_mod, ntoks in zip(target_domains, tokens_per_target): | |
| batched_sample = init_empty_target_modality( | |
| batched_sample, MODALITY_INFO, target_mod, 1, ntoks, DEVICE | |
| ) | |
| metadata = f"v1=6 v0={num_items} v1=10 v0={brightness}" | |
| print(metadata) | |
| batched_sample = custom_text( | |
| batched_sample, | |
| input_text=caption, | |
| eos_token="[EOS]", | |
| key="caption", | |
| device=DEVICE, | |
| text_tokenizer=text_tokenizer, | |
| ) | |
| batched_sample = custom_text( | |
| batched_sample, | |
| input_text=metadata, | |
| eos_token="[EOS]", | |
| key="metadata", | |
| device=DEVICE, | |
| text_tokenizer=text_tokenizer, | |
| ) | |
| out_dict = sampler.generate( | |
| batched_sample, | |
| schedule, | |
| text_tokenizer=text_tokenizer, | |
| verbose=True, | |
| seed=0, | |
| top_p=top_p, | |
| top_k=top_k, | |
| ) | |
| with torch.no_grad(): | |
| dec_dict = decode_dict( | |
| out_dict, | |
| {"tok_dinov2_global": vqvae.to(DEVICE)}, | |
| text_tokenizer, | |
| image_size=IMG_SIZE, | |
| patch_size=16, | |
| decoding_steps=1, | |
| ) | |
| combined_features = dec_dict["tok_dinov2_global"] | |
| similarities = torch.nn.functional.cosine_similarity( | |
| combined_features, image_embeddings | |
| ) | |
| top_indices = similarities.argsort(descending=True)[:1] | |
| print(top_indices, similarities[top_indices]) | |
| return [dataset[int(i)] for i in top_indices.cpu().numpy()] | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Image Retrieval using 4M-21: An Any-to-Any Vision Model") | |
| gr.Markdown(""" | |
| This app demonstrates image retrieval using the 4M-21 model, an any-to-any vision model. | |
| Enter a caption description, adjust the brightness, and specify the number of items to retrieve similar images. | |
| The retrieval dataset for this demo is available at: https://huggingface.co/datasets/aroraaman/4m-21-demo | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| caption = gr.Textbox( | |
| label="Caption Description", placeholder="Enter image description..." | |
| ) | |
| brightness = gr.Slider( | |
| minimum=0, maximum=255, value=5, step=1, | |
| label="Brightness", info="Adjust image brightness (0-255)" | |
| ) | |
| num_items = gr.Slider( | |
| minimum=0, maximum=50, value=5, step=1, | |
| label="Number of Items", info="Number of COCO instances in image (0-50)" | |
| ) | |
| with gr.Column(scale=1): | |
| output_images = gr.Gallery( | |
| label="Retrieved Images", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=2, | |
| rows=2, | |
| height=512, | |
| ) | |
| submit_btn = gr.Button("Retrieve Most Similar Image") | |
| submit_btn.click( | |
| fn=get_similar_images, | |
| inputs=[caption, brightness, num_items], | |
| outputs=output_images, | |
| ) | |
| # Add examples | |
| gr.Examples( | |
| examples=[ | |
| ["swimming pool", 27, 7], | |
| ["swimming pool", 255, 7], | |
| ["dining room", 22, 7], | |
| ["dining room", 5, 7], | |
| ["dining room", 5, 46] | |
| ], | |
| inputs=[caption, brightness, num_items] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |