anyonehomep1mane
Code Changes
5aa6736
import torch
from transformers import AutoModel, AutoProcessor
import gradio as gr
from PIL import Image
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
import warnings
warnings.filterwarnings(action="ignore")
colors.orange_red = colors.Color(
name="orange_red",
c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366",
c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700",
c800="#B33000", c900="#992900", c950="#802200",
)
class OrangeRedTheme(Soft):
def __init__(self):
super().__init__(
primary_hue=colors.orange_red,
secondary_hue=colors.orange_red,
neutral_hue=colors.slate,
text_size=sizes.text_lg,
font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
font_mono=(fonts.GoogleFont("IBM Plex Mono"), "monospace"),
)
super().set(
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_text_color="white",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
)
orange_red_theme = OrangeRedTheme()
MODEL_ID = "openai/clip-vit-base-patch32"
model = AutoModel.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
def postprocess_metaclip(probs, labels):
return {labels[i]: probs[0][i].item() for i in range(len(labels))}
def metaclip_detector(image, texts):
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
return probs
def infer(image, candidate_labels):
candidate_labels = [l.strip() for l in candidate_labels.split(",")]
probs = metaclip_detector(image, candidate_labels)
return postprocess_metaclip(probs, labels=candidate_labels)
css_style = """
#container {
max-width: 1280px; /* wider layout */
margin: auto;
}
@media (min-width: 1600px) {
#container {
max-width: 1440px;
}
}
#title h1 {
font-size: 2.4em !important;
}
"""
with gr.Blocks(title="AI Document Summarizer") as demo:
with gr.Column(elem_id="container"):
gr.Markdown("# **Open AI Zero-Shot Classification**", elem_id="title")
gr.Markdown("This is the demo of model 'openai/clip-vit-base-patch32' for zero-shot classification.")
with gr.Row(equal_height=True):
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image", height=310)
text_input = gr.Textbox(label="Input labels (comma separated)")
run_button = gr.Button("Run", variant="primary")
with gr.Column():
metaclip_output = gr.Label(
label="Open AI Zero-Shot Classification Output",
num_top_classes=5
)
with gr.Row(equal_height=True):
gr.Examples(
examples=[
["./zebra.jpg", "a photo of a zebra, a photo of a horse, a photo of a donkey"],
["./cat.jpg", "a photo of a cat, a photo of two cats, a photo of three cats"],
["./fridge.jpg", "a photo of a fridge, a photo of a cupboard, a photo of a wardrobe"]
],
inputs=[image_input, text_input],
outputs=[metaclip_output],
fn=infer,
)
run_button.click(
fn=infer,
inputs=[image_input, text_input],
outputs=[metaclip_output]
)
if __name__ == "__main__":
demo.queue().launch(
theme=orange_red_theme,
css=css_style,
show_error=True,
server_name="0.0.0.0",
server_port=7860,
debug=True
)