| import torch | |
| from transformers import AutoModel, AutoProcessor | |
| import gradio as gr | |
| from PIL import Image | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| import warnings | |
| warnings.filterwarnings(action="ignore") | |
| colors.orange_red = colors.Color( | |
| name="orange_red", | |
| c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366", | |
| c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700", | |
| c800="#B33000", c900="#992900", c950="#802200", | |
| ) | |
| class OrangeRedTheme(Soft): | |
| def __init__(self): | |
| super().__init__( | |
| primary_hue=colors.orange_red, | |
| secondary_hue=colors.orange_red, | |
| neutral_hue=colors.slate, | |
| text_size=sizes.text_lg, | |
| font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"), | |
| font_mono=(fonts.GoogleFont("IBM Plex Mono"), "monospace"), | |
| ) | |
| super().set( | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_text_color="white", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| ) | |
| orange_red_theme = OrangeRedTheme() | |
| MODEL_ID = "openai/clip-vit-base-patch32" | |
| model = AutoModel.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa" | |
| ) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| def postprocess_metaclip(probs, labels): | |
| return {labels[i]: probs[0][i].item() for i in range(len(labels))} | |
| def metaclip_detector(image, texts): | |
| inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = outputs.logits_per_image.softmax(dim=1) | |
| return probs | |
| def infer(image, candidate_labels): | |
| candidate_labels = [l.strip() for l in candidate_labels.split(",")] | |
| probs = metaclip_detector(image, candidate_labels) | |
| return postprocess_metaclip(probs, labels=candidate_labels) | |
| css_style = """ | |
| #container { | |
| max-width: 1280px; /* wider layout */ | |
| margin: auto; | |
| } | |
| @media (min-width: 1600px) { | |
| #container { | |
| max-width: 1440px; | |
| } | |
| } | |
| #title h1 { | |
| font-size: 2.4em !important; | |
| } | |
| """ | |
| with gr.Blocks(title="AI Document Summarizer") as demo: | |
| with gr.Column(elem_id="container"): | |
| gr.Markdown("# **Open AI Zero-Shot Classification**", elem_id="title") | |
| gr.Markdown("This is the demo of model 'openai/clip-vit-base-patch32' for zero-shot classification.") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image", height=310) | |
| text_input = gr.Textbox(label="Input labels (comma separated)") | |
| run_button = gr.Button("Run", variant="primary") | |
| with gr.Column(): | |
| metaclip_output = gr.Label( | |
| label="Open AI Zero-Shot Classification Output", | |
| num_top_classes=5 | |
| ) | |
| with gr.Row(equal_height=True): | |
| gr.Examples( | |
| examples=[ | |
| ["./zebra.jpg", "a photo of a zebra, a photo of a horse, a photo of a donkey"], | |
| ["./cat.jpg", "a photo of a cat, a photo of two cats, a photo of three cats"], | |
| ["./fridge.jpg", "a photo of a fridge, a photo of a cupboard, a photo of a wardrobe"] | |
| ], | |
| inputs=[image_input, text_input], | |
| outputs=[metaclip_output], | |
| fn=infer, | |
| ) | |
| run_button.click( | |
| fn=infer, | |
| inputs=[image_input, text_input], | |
| outputs=[metaclip_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| theme=orange_red_theme, | |
| css=css_style, | |
| show_error=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=True | |
| ) |