import torch from transformers import AutoModel, AutoProcessor import gradio as gr from PIL import Image import requests from typing import Iterable from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes colors.orange_red = colors.Color( name="orange_red", c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366", c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700", c800="#B33000", c900="#992900", c950="#802200", ) class OrangeRedTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.orange_red, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", background_fill_primary_dark="*primary_900", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", button_primary_text_color="white", button_primary_text_color_hover="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", block_title_text_weight="600", block_shadow="*shadow_drop_lg", ) orange_red_theme = OrangeRedTheme() model = AutoModel.from_pretrained( "facebook/metaclip-2-mt5-worldwide-s16", torch_dtype=torch.bfloat16, attn_implementation="sdpa" ) processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16") def postprocess_metaclip(probs, labels): return {labels[i]: probs[0][i].item() for i in range(len(labels))} def metaclip_detector(image, texts): inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits_per_image.softmax(dim=1) return probs def infer(image, candidate_labels): candidate_labels = [l.strip() for l in candidate_labels.split(",")] probs = metaclip_detector(image, candidate_labels) return postprocess_metaclip(probs, labels=candidate_labels) css = """ #root, body, html { margin: 0; padding: 0; height: 100%; } .center-container { max-width: 1000px; margin: 0 auto !important; display: flex; flex-direction: column; align-items: center; } #main-title h1 { text-align: center !important; width: 100%; } """ with gr.Blocks(css=css, theme=orange_red_theme) as demo: with gr.Column(elem_classes="center-container"): gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**", elem_id="main-title") gr.Markdown("This is the demo of MetaCLIP 2 for zero-shot classification.") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image", height=310) text_input = gr.Textbox(label="Input labels (comma separated)") run_button = gr.Button("Run", variant="primary") with gr.Column(): metaclip_output = gr.Label( label="MetaCLIP 2 Output", num_top_classes=3 ) gr.Examples( examples=[ ["./baklava.jpg", "dessert on a plate, baklava"], ["./cat.jpg", "a cat, two cats, three cats"], ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"], ], inputs=[image_input, text_input], outputs=[metaclip_output], fn=infer, ) run_button.click( fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output] ) demo.launch()