import torch from transformers import AutoModel, AutoProcessor import gradio as gr from PIL import Image import requests from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes colors.orange_red = colors.Color( name="orange_red", c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366", c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700", c800="#B33000", c900="#992900", c950="#802200", ) class OrangeRedTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.orange_red, # Use the new color neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", background_fill_primary_dark="*primary_900", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", button_primary_text_color="white", button_primary_text_color_hover="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", button_secondary_text_color="black", button_secondary_text_color_hover="white", button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", slider_color="*secondary_500", slider_color_dark="*secondary_600", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", button_large_padding="11px", color_accent_soft="*primary_100", block_label_background_fill="*primary_200", ) orange_red_theme = OrangeRedTheme() model = AutoModel.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16", torch_dtype=torch.bfloat16, attn_implementation="sdpa") processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16") def postprocess_metaclip(probs, labels): output = {labels[i]: probs[0][i].item() for i in range(len(labels))} return output def metaclip_detector(image, texts): inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) return probs def infer(image, candidate_labels): candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")] probs = metaclip_detector(image, candidate_labels) return postprocess_metaclip(probs, labels=candidate_labels) with gr.Blocks(theme=orange_red_theme) as demo: gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**") gr.Markdown( "Test the performance of MetaCLIP 2 on zero-shot classification in this Space" ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil") text_input = gr.Textbox(label="Input a list of labels (comma seperated)") run_button = gr.Button("Run", visible=True) with gr.Column(): metaclip_output = gr.Label(label="MetaCLIP 2 Output", num_top_classes=3) examples = [ ["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"], ["./cat.jpg", "a cat, two cats, three cats"], ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"], ] gr.Examples( examples=examples, inputs=[image_input, text_input], outputs=[metaclip_output], fn=infer, ) run_button.click(fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output]) demo.launch()