File size: 5,120 Bytes
0a3df57
 
 
 
 
 
f79c9e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a3df57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f79c9e0
 
0a3df57
f79c9e0
0a3df57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from transformers import AutoModel, AutoProcessor
import gradio as gr
from PIL import Image
import requests

from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes

colors.orange_red = colors.Color(
    name="orange_red",
    c50="#FFF0E5",
    c100="#FFE0CC",
    c200="#FFC299",
    c300="#FFA366",
    c400="#FF8533",
    c500="#FF4500",
    c600="#E63E00",
    c700="#CC3700",
    c800="#B33000",
    c900="#992900",
    c950="#802200",
)

class OrangeRedTheme(Soft):
    def __init__(
        self,
        *,
        primary_hue: colors.Color | str = colors.gray,
        secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
        neutral_hue: colors.Color | str = colors.slate,
        text_size: sizes.Size | str = sizes.text_lg,
        font: fonts.Font | str | Iterable[fonts.Font | str] = (
            fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
        ),
        font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
            fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
        ),
    ):
        super().__init__(
            primary_hue=primary_hue,
            secondary_hue=secondary_hue,
            neutral_hue=neutral_hue,
            text_size=text_size,
            font=font,
            font_mono=font_mono,
        )
        super().set(
            background_fill_primary="*primary_50",
            background_fill_primary_dark="*primary_900",
            body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
            body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
            button_primary_text_color="white",
            button_primary_text_color_hover="white",
            button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
            button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
            button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
            button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
            button_secondary_text_color="black",
            button_secondary_text_color_hover="white",
            button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
            button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
            button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
            button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
            slider_color="*secondary_500",
            slider_color_dark="*secondary_600",
            block_title_text_weight="600",
            block_border_width="3px",
            block_shadow="*shadow_drop_lg",
            button_primary_shadow="*shadow_drop_lg",
            button_large_padding="11px",
            color_accent_soft="*primary_100",
            block_label_background_fill="*primary_200",
        )

orange_red_theme = OrangeRedTheme()

model = AutoModel.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16", torch_dtype=torch.bfloat16, attn_implementation="sdpa")
processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16")

def postprocess_metaclip(probs, labels):
    output = {labels[i]: probs[0][i].item() for i in range(len(labels))}
    return output


def metaclip_detector(image, texts):
    inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)
    return probs


def infer(image, candidate_labels):
    candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
    probs = metaclip_detector(image, candidate_labels)
    return postprocess_metaclip(probs, labels=candidate_labels)

with gr.Blocks(theme=orange_red_theme) as demo:
    gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**")
    gr.Markdown(
        "Test the performance of MetaCLIP 2 on zero-shot classification in this Space"
    )
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil")
            text_input = gr.Textbox(label="Input a list of labels (comma seperated)")
            run_button = gr.Button("Run", visible=True)
        with gr.Column():
            metaclip_output = gr.Label(label="MetaCLIP 2 Output", num_top_classes=3)

    examples = [
        ["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
        ["./cat.jpg", "a cat, two cats, three cats"],
        ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
    ]
    gr.Examples(
        examples=examples,
        inputs=[image_input, text_input],
        outputs=[metaclip_output],
        fn=infer,
    )
    run_button.click(fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output])
demo.launch()