Spaces:
Running
Running
File size: 5,120 Bytes
0a3df57 f79c9e0 0a3df57 f79c9e0 0a3df57 f79c9e0 0a3df57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
from transformers import AutoModel, AutoProcessor
import gradio as gr
from PIL import Image
import requests
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
colors.orange_red = colors.Color(
name="orange_red",
c50="#FFF0E5",
c100="#FFE0CC",
c200="#FFC299",
c300="#FFA366",
c400="#FF8533",
c500="#FF4500",
c600="#E63E00",
c700="#CC3700",
c800="#B33000",
c900="#992900",
c950="#802200",
)
class OrangeRedTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_secondary_text_color="black",
button_secondary_text_color_hover="white",
button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
orange_red_theme = OrangeRedTheme()
model = AutoModel.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16", torch_dtype=torch.bfloat16, attn_implementation="sdpa")
processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16")
def postprocess_metaclip(probs, labels):
output = {labels[i]: probs[0][i].item() for i in range(len(labels))}
return output
def metaclip_detector(image, texts):
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
return probs
def infer(image, candidate_labels):
candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
probs = metaclip_detector(image, candidate_labels)
return postprocess_metaclip(probs, labels=candidate_labels)
with gr.Blocks(theme=orange_red_theme) as demo:
gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**")
gr.Markdown(
"Test the performance of MetaCLIP 2 on zero-shot classification in this Space"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil")
text_input = gr.Textbox(label="Input a list of labels (comma seperated)")
run_button = gr.Button("Run", visible=True)
with gr.Column():
metaclip_output = gr.Label(label="MetaCLIP 2 Output", num_top_classes=3)
examples = [
["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
["./cat.jpg", "a cat, two cats, three cats"],
["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
]
gr.Examples(
examples=examples,
inputs=[image_input, text_input],
outputs=[metaclip_output],
fn=infer,
)
run_button.click(fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output])
demo.launch() |