Spaces:
Running
Running
| import torch | |
| from transformers import AutoModel, AutoProcessor | |
| import gradio as gr | |
| from PIL import Image | |
| import requests | |
| model = AutoModel.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16", torch_dtype=torch.bfloat16, attn_implementation="sdpa") | |
| processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16") | |
| def postprocess_metaclip(probs, labels): | |
| output = {labels[i]: probs[0][i].item() for i in range(len(labels))} | |
| return output | |
| def metaclip_detector(image, texts): | |
| inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1) | |
| return probs | |
| def infer(image, candidate_labels): | |
| candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")] | |
| probs = metaclip_detector(image, candidate_labels) | |
| return postprocess_metaclip(probs, labels=candidate_labels) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# MetaCLIP 2 Zero-Shot Classification") | |
| gr.Markdown( | |
| "Test the performance of MetaCLIP 2 on zero-shot classification in this Space :point_down:" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil") | |
| text_input = gr.Textbox(label="Input a list of labels (comma seperated)") | |
| run_button = gr.Button("Run", visible=True) | |
| with gr.Column(): | |
| metaclip_output = gr.Label(label="MetaCLIP 2 Output", num_top_classes=3) | |
| # It's recommended to have local images for the examples | |
| # For demonstration purposes, we will download them if they don't exist. | |
| def download_image(url, filename): | |
| import os | |
| if not os.path.exists(filename): | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| with open(filename, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| download_image("https://gradio-builds.s3.amazonaws.com/demo-files/baklava.jpg", "baklava.jpg") | |
| download_image("https://gradio-builds.s3.amazonaws.com/demo-files/cat.jpg", "cat.jpg") | |
| examples = [ | |
| ["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"], | |
| ["./cat.jpg", "a cat, two cats, three cats"], | |
| ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[image_input, text_input], | |
| outputs=[metaclip_output], | |
| fn=infer, | |
| ) | |
| run_button.click(fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output]) | |
| demo.launch() |