Spaces:
Running
Running
| import torch | |
| from transformers import AutoModel, AutoProcessor | |
| import gradio as gr | |
| from PIL import Image | |
| import requests | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| colors.orange_red = colors.Color( | |
| name="orange_red", | |
| c50="#FFF0E5", | |
| c100="#FFE0CC", | |
| c200="#FFC299", | |
| c300="#FFA366", | |
| c400="#FF8533", | |
| c500="#FF4500", | |
| c600="#E63E00", | |
| c700="#CC3700", | |
| c800="#B33000", | |
| c900="#992900", | |
| c950="#802200", | |
| ) | |
| class OrangeRedTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.orange_red, # Use the new color | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| background_fill_primary_dark="*primary_900", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| button_primary_text_color="white", | |
| button_primary_text_color_hover="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_secondary_text_color="black", | |
| button_secondary_text_color_hover="white", | |
| button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", | |
| button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", | |
| button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", | |
| button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", | |
| slider_color="*secondary_500", | |
| slider_color_dark="*secondary_600", | |
| block_title_text_weight="600", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_shadow="*shadow_drop_lg", | |
| button_large_padding="11px", | |
| color_accent_soft="*primary_100", | |
| block_label_background_fill="*primary_200", | |
| ) | |
| orange_red_theme = OrangeRedTheme() | |
| model = AutoModel.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16", torch_dtype=torch.bfloat16, attn_implementation="sdpa") | |
| processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16") | |
| def postprocess_metaclip(probs, labels): | |
| output = {labels[i]: probs[0][i].item() for i in range(len(labels))} | |
| return output | |
| def metaclip_detector(image, texts): | |
| inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1) | |
| return probs | |
| def infer(image, candidate_labels): | |
| candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")] | |
| probs = metaclip_detector(image, candidate_labels) | |
| return postprocess_metaclip(probs, labels=candidate_labels) | |
| with gr.Blocks(theme=orange_red_theme) as demo: | |
| gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**") | |
| gr.Markdown( | |
| "Test the performance of MetaCLIP 2 on zero-shot classification in this Space" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil") | |
| text_input = gr.Textbox(label="Input a list of labels (comma seperated)") | |
| run_button = gr.Button("Run", visible=True) | |
| with gr.Column(): | |
| metaclip_output = gr.Label(label="MetaCLIP 2 Output", num_top_classes=3) | |
| examples = [ | |
| ["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"], | |
| ["./cat.jpg", "a cat, two cats, three cats"], | |
| ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[image_input, text_input], | |
| outputs=[metaclip_output], | |
| fn=infer, | |
| ) | |
| run_button.click(fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output]) | |
| demo.launch() |