metaclip-2-demo / app.py
prithivMLmods's picture
update app
6f00b27 verified
import torch
from transformers import AutoModel, AutoProcessor
import gradio as gr
from PIL import Image
import requests
from typing import Iterable
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
colors.orange_red = colors.Color(
name="orange_red",
c50="#FFF0E5",
c100="#FFE0CC",
c200="#FFC299",
c300="#FFA366",
c400="#FF8533",
c500="#FF4500",
c600="#E63E00",
c700="#CC3700",
c800="#B33000",
c900="#992900",
c950="#802200",
)
class OrangeRedTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.orange_red,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
block_title_text_weight="600",
block_shadow="*shadow_drop_lg",
)
orange_red_theme = OrangeRedTheme()
model = AutoModel.from_pretrained(
"facebook/metaclip-2-mt5-worldwide-s16",
torch_dtype=torch.bfloat16,
attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16")
def postprocess_metaclip(probs, labels):
return {labels[i]: probs[0][i].item() for i in range(len(labels))}
def metaclip_detector(image, texts):
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
return probs
def infer(image, candidate_labels):
candidate_labels = [l.strip() for l in candidate_labels.split(",")]
probs = metaclip_detector(image, candidate_labels)
return postprocess_metaclip(probs, labels=candidate_labels)
css = """
#root, body, html {
margin: 0;
padding: 0;
height: 100%;
}
.center-container {
max-width: 1000px;
margin: 0 auto !important;
display: flex;
flex-direction: column;
align-items: center;
}
#main-title h1 {
text-align: center !important;
width: 100%;
}
"""
with gr.Blocks(css=css, theme=orange_red_theme) as demo:
with gr.Column(elem_classes="center-container"):
gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**", elem_id="main-title")
gr.Markdown("This is the demo of MetaCLIP 2 for zero-shot classification.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image", height=310)
text_input = gr.Textbox(label="Input labels (comma separated)")
run_button = gr.Button("Run", variant="primary")
with gr.Column():
metaclip_output = gr.Label(
label="MetaCLIP 2 Output",
num_top_classes=3
)
gr.Examples(
examples=[
["./baklava.jpg", "dessert on a plate, baklava"],
["./cat.jpg", "a cat, two cats, three cats"],
["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
],
inputs=[image_input, text_input],
outputs=[metaclip_output],
fn=infer,
)
run_button.click(
fn=infer,
inputs=[image_input, text_input],
outputs=[metaclip_output]
)
demo.launch()