import torch from transformers import AutoModel, AutoProcessor import gradio as gr from PIL import Image from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes import warnings warnings.filterwarnings(action="ignore") colors.orange_red = colors.Color( name="orange_red", c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366", c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700", c800="#B33000", c900="#992900", c950="#802200", ) class OrangeRedTheme(Soft): def __init__(self): super().__init__( primary_hue=colors.orange_red, secondary_hue=colors.orange_red, neutral_hue=colors.slate, text_size=sizes.text_lg, font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"), font_mono=(fonts.GoogleFont("IBM Plex Mono"), "monospace"), ) super().set( body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_text_color="white", block_border_width="3px", block_shadow="*shadow_drop_lg", ) orange_red_theme = OrangeRedTheme() MODEL_ID = "openai/clip-vit-base-patch32" model = AutoModel.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, attn_implementation="sdpa" ) processor = AutoProcessor.from_pretrained(MODEL_ID) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) def postprocess_metaclip(probs, labels): return {labels[i]: probs[0][i].item() for i in range(len(labels))} def metaclip_detector(image, texts): inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits_per_image.softmax(dim=1) return probs def infer(image, candidate_labels): candidate_labels = [l.strip() for l in candidate_labels.split(",")] probs = metaclip_detector(image, candidate_labels) return postprocess_metaclip(probs, labels=candidate_labels) css_style = """ #container { max-width: 1280px; /* wider layout */ margin: auto; } @media (min-width: 1600px) { #container { max-width: 1440px; } } #title h1 { font-size: 2.4em !important; } """ with gr.Blocks(title="AI Document Summarizer") as demo: with gr.Column(elem_id="container"): gr.Markdown("# **Open AI Zero-Shot Classification**", elem_id="title") gr.Markdown("This is the demo of model 'openai/clip-vit-base-patch32' for zero-shot classification.") with gr.Row(equal_height=True): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image", height=310) text_input = gr.Textbox(label="Input labels (comma separated)") run_button = gr.Button("Run", variant="primary") with gr.Column(): metaclip_output = gr.Label( label="Open AI Zero-Shot Classification Output", num_top_classes=5 ) with gr.Row(equal_height=True): gr.Examples( examples=[ ["./zebra.jpg", "a photo of a zebra, a photo of a horse, a photo of a donkey"], ["./cat.jpg", "a photo of a cat, a photo of two cats, a photo of three cats"], ["./fridge.jpg", "a photo of a fridge, a photo of a cupboard, a photo of a wardrobe"] ], inputs=[image_input, text_input], outputs=[metaclip_output], fn=infer, ) run_button.click( fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output] ) if __name__ == "__main__": demo.queue().launch( theme=orange_red_theme, css=css_style, show_error=True, server_name="0.0.0.0", server_port=7860, debug=True )