File size: 4,201 Bytes
5aa6736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from transformers import AutoModel, AutoProcessor
import gradio as gr
from PIL import Image
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes

import warnings
warnings.filterwarnings(action="ignore")

colors.orange_red = colors.Color(
    name="orange_red",
    c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366",
    c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700",
    c800="#B33000", c900="#992900", c950="#802200",
)

class OrangeRedTheme(Soft):
    def __init__(self):
        super().__init__(
            primary_hue=colors.orange_red,
            secondary_hue=colors.orange_red,
            neutral_hue=colors.slate,
            text_size=sizes.text_lg,
            font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
            font_mono=(fonts.GoogleFont("IBM Plex Mono"), "monospace"),
        )
        super().set(
            body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
            button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
            button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
            button_primary_text_color="white",
            block_border_width="3px",
            block_shadow="*shadow_drop_lg",
        )

orange_red_theme = OrangeRedTheme()

MODEL_ID = "openai/clip-vit-base-patch32"
model = AutoModel.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(MODEL_ID)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

def postprocess_metaclip(probs, labels):
    return {labels[i]: probs[0][i].item() for i in range(len(labels))}

def metaclip_detector(image, texts):
    inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        probs = outputs.logits_per_image.softmax(dim=1)
    return probs

def infer(image, candidate_labels):
    candidate_labels = [l.strip() for l in candidate_labels.split(",")]
    probs = metaclip_detector(image, candidate_labels)
    return postprocess_metaclip(probs, labels=candidate_labels)

css_style = """
#container {
    max-width: 1280px;   /* wider layout */
    margin: auto;
}

@media (min-width: 1600px) {
    #container {
        max-width: 1440px;
    }
}

#title h1 {
    font-size: 2.4em !important;
}
"""

with gr.Blocks(title="AI Document Summarizer") as demo:
    with gr.Column(elem_id="container"):
        
        gr.Markdown("# **Open AI Zero-Shot Classification**", elem_id="title")
        gr.Markdown("This is the demo of model 'openai/clip-vit-base-patch32' for zero-shot classification.")

        with gr.Row(equal_height=True):
            with gr.Column():
                image_input = gr.Image(type="pil", label="Upload Image", height=310)
                text_input = gr.Textbox(label="Input labels (comma separated)")
                run_button = gr.Button("Run", variant="primary")
            with gr.Column():
                metaclip_output = gr.Label(
                    label="Open AI Zero-Shot Classification Output",
                    num_top_classes=5
                )

        with gr.Row(equal_height=True):
            gr.Examples(
                examples=[
                    ["./zebra.jpg", "a photo of a zebra, a photo of a horse, a photo of a donkey"],
                    ["./cat.jpg", "a photo of a cat, a photo of two cats, a photo of three cats"],
                    ["./fridge.jpg", "a photo of a fridge, a photo of a cupboard, a photo of a wardrobe"]
                ],
                inputs=[image_input, text_input],
                outputs=[metaclip_output],
                fn=infer,
            )

        run_button.click(
            fn=infer, 
            inputs=[image_input, text_input], 
            outputs=[metaclip_output]
        )

if __name__ == "__main__":
    demo.queue().launch(
        theme=orange_red_theme,
        css=css_style,
        show_error=True,
        server_name="0.0.0.0",
        server_port=7860,
        debug=True
    )