Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import pickle | |
| from transformers import AutoProcessor | |
| from src.model import MMEBModel | |
| from src.arguments import ModelArguments | |
| QUERY_DIR = "imgs/queries" | |
| IMAGE_DIR = "imgs/candidates" | |
| image_paths = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.endswith((".jpg", ".png"))] | |
| global IMAGE_TOKEN, TOP_N | |
| IMAGE_TOKEN = "<|image_1|>" | |
| TOP_N = 5 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"device: {device}") | |
| def load_model(): | |
| global IMAGE_TOKEN | |
| model_args = ModelArguments( | |
| # model_name="/fs-computility/ai-shen/kilab-shared/liubangwei/ckpt/my_hf/IDMR-2B", | |
| model_name="lbw18601752667/IDMR-2B", | |
| model_backbone="internvl_2_5", | |
| ) | |
| if model_args.model_backbone == "phi35v": | |
| processor = AutoProcessor.from_pretrained( | |
| model_args.model_name, | |
| trust_remote_code=True, | |
| num_crops=model_args.num_crops, | |
| ) | |
| processor.tokenizer.padding_side = "right" | |
| elif model_args.model_backbone == "internvl_2_5": | |
| from src.vlm_backbone.intern_vl import InternVLProcessor | |
| from transformers import AutoTokenizer, AutoImageProcessor | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_args.model_name, | |
| trust_remote_code=True | |
| ) | |
| image_processor = AutoImageProcessor.from_pretrained( | |
| model_args.model_name, | |
| trust_remote_code=True, | |
| use_fast=False | |
| ) | |
| processor = InternVLProcessor( | |
| image_processor=image_processor, | |
| tokenizer=tokenizer | |
| ) | |
| IMAGE_TOKEN = "<image>" | |
| model = MMEBModel.load(model_args) | |
| model = model.to(device, dtype=torch.bfloat16) | |
| model.eval() | |
| return model, processor | |
| model, processor = load_model() | |
| def get_inputs(processor, text, image_path=None, image=None): | |
| if image_path: | |
| image = Image.open(image_path) | |
| if image is None: | |
| text = text.replace(IMAGE_TOKEN, "") | |
| inputs = processor( | |
| text=text, | |
| images=[image] if image else None, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True | |
| ) | |
| inputs = {key: value.to(device) for key, value in inputs.items()} | |
| inputs["image_flags"] = torch.tensor([1 if image else 0], dtype=torch.long).to(device) | |
| if image is None: | |
| del inputs['pixel_values'] | |
| return inputs | |
| def encode_image_library(image_paths): | |
| embeddings_dict = {} | |
| for img_path in image_paths: | |
| text = f"{IMAGE_TOKEN}\n Represent the given image." | |
| print(f"text: {text}") | |
| inputs = get_inputs(processor, text, image_path=img_path) | |
| with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16): | |
| output = model(tgt=inputs) | |
| img_name = os.path.basename(img_path) | |
| embeddings_dict[img_name] = output["tgt_reps"].float().cpu().numpy() | |
| return embeddings_dict | |
| def save_embeddings(embeddings, file_path="image_embeddings.pkl"): | |
| with open(file_path, "wb") as f: | |
| pickle.dump(embeddings, f) | |
| def load_embeddings(file_path="image_embeddings.pkl"): | |
| with open(file_path, "rb") as f: | |
| return pickle.load(f) | |
| def cosine_similarity(query_embedding, embeddings): | |
| similarity = np.sum(query_embedding * embeddings, axis=-1) | |
| return similarity | |
| def retrieve_images(query_text, query_image, top_n=TOP_N): | |
| if query_text: | |
| query_text = f"{IMAGE_TOKEN}\n {query_text}" | |
| else: | |
| query_text = f"{IMAGE_TOKEN}\n Represent the given image." | |
| if query_image is not None: | |
| image = Image.fromarray(query_image) | |
| else: | |
| image = None | |
| inputs = get_inputs(processor, query_text, image=image) | |
| print(f"inputs: {inputs}") | |
| with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16): | |
| query_embedding = model(qry=inputs)["qry_reps"].float().cpu().numpy() | |
| embeddings_dict = load_embeddings() | |
| img_names = [] | |
| embeddings = [] | |
| for img_name in os.listdir(IMAGE_DIR): | |
| if img_name in embeddings_dict: | |
| img_names.append(img_name) | |
| embeddings.append(embeddings_dict[img_name]) | |
| embeddings = np.stack(embeddings) | |
| similarity = cosine_similarity(query_embedding, embeddings) | |
| similarity = similarity.T | |
| print(f"cosine_similarity: {similarity}") | |
| top_indices = np.argsort(-similarity).squeeze(0)[:top_n] | |
| print(f"top_indices: {top_indices}") | |
| return [os.path.join(IMAGE_DIR, img_names[i]) for i in top_indices] | |
| def demo(query_text, query_image): | |
| # print(f"query_text: {query_text}, query_image: {query_image}, type(query_image): {type(query_image)}, image shape: {query_image.shape if query_image is not None else 'None'}") | |
| retrieved_images = retrieve_images(query_text, query_image) | |
| return [Image.open(img) for img in retrieved_images] | |
| def load_examples(): | |
| examples = [] | |
| image_files = [f for f in os.listdir(QUERY_DIR) if f.endswith((".jpg", ".png"))] | |
| for img_file in image_files: | |
| img_path = os.path.join(QUERY_DIR, img_file) | |
| txt_file = os.path.splitext(img_file)[0] + ".txt" | |
| txt_path = os.path.join(QUERY_DIR, txt_file) | |
| if os.path.exists(txt_path): | |
| with open(txt_path, 'r', encoding='utf-8') as f: | |
| query_text = f.read().strip().replace("<|image_1|>\n", "") | |
| examples.append([query_text, img_path]) | |
| return examples | |
| iface = gr.Interface( | |
| fn=demo, | |
| inputs=[ | |
| gr.Textbox(placeholder="Enter your query text here...", label="Query Text"), | |
| gr.Image(label="Query Image", type="numpy") | |
| ], | |
| outputs=gr.Gallery(label=f"Retrieved Images (Top {TOP_N})", columns=3), | |
| examples=load_examples(), | |
| title="Instance-Driven Multi-modal Retrieval (IDMR) Demo", | |
| description="Enter a query text or upload an image to retrieve relevant images from the library. You can click on the examples below to try them out." | |
| ) | |
| if not os.path.exists("image_embeddings.pkl"): | |
| embeddings = encode_image_library(image_paths) | |
| save_embeddings(embeddings) | |
| iface.launch() |