Spaces:
Running
Running
File size: 4,627 Bytes
0a3df57 373f9f0 0a3df57 f79c9e0 fe17f1e f79c9e0 fe17f1e 0a3df57 fe17f1e 0a3df57 fe17f1e 0a3df57 fe17f1e 0a3df57 fe17f1e 6f00b27 fe17f1e 0c1e6ae fe17f1e d38b7e1 fe17f1e 2fd565a fe17f1e 0a3df57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
from transformers import AutoModel, AutoProcessor
import gradio as gr
from PIL import Image
import requests
from typing import Iterable
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
colors.orange_red = colors.Color(
name="orange_red",
c50="#FFF0E5",
c100="#FFE0CC",
c200="#FFC299",
c300="#FFA366",
c400="#FF8533",
c500="#FF4500",
c600="#E63E00",
c700="#CC3700",
c800="#B33000",
c900="#992900",
c950="#802200",
)
class OrangeRedTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.orange_red,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
block_title_text_weight="600",
block_shadow="*shadow_drop_lg",
)
orange_red_theme = OrangeRedTheme()
model = AutoModel.from_pretrained(
"facebook/metaclip-2-mt5-worldwide-s16",
torch_dtype=torch.bfloat16,
attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16")
def postprocess_metaclip(probs, labels):
return {labels[i]: probs[0][i].item() for i in range(len(labels))}
def metaclip_detector(image, texts):
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
return probs
def infer(image, candidate_labels):
candidate_labels = [l.strip() for l in candidate_labels.split(",")]
probs = metaclip_detector(image, candidate_labels)
return postprocess_metaclip(probs, labels=candidate_labels)
css = """
#root, body, html {
margin: 0;
padding: 0;
height: 100%;
}
.center-container {
max-width: 1000px;
margin: 0 auto !important;
display: flex;
flex-direction: column;
align-items: center;
}
#main-title h1 {
text-align: center !important;
width: 100%;
}
"""
with gr.Blocks(css=css, theme=orange_red_theme) as demo:
with gr.Column(elem_classes="center-container"):
gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**", elem_id="main-title")
gr.Markdown("This is the demo of MetaCLIP 2 for zero-shot classification.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image", height=310)
text_input = gr.Textbox(label="Input labels (comma separated)")
run_button = gr.Button("Run", variant="primary")
with gr.Column():
metaclip_output = gr.Label(
label="MetaCLIP 2 Output",
num_top_classes=3
)
gr.Examples(
examples=[
["./baklava.jpg", "dessert on a plate, baklava"],
["./cat.jpg", "a cat, two cats, three cats"],
["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
],
inputs=[image_input, text_input],
outputs=[metaclip_output],
fn=infer,
)
run_button.click(
fn=infer,
inputs=[image_input, text_input],
outputs=[metaclip_output]
)
demo.launch() |