File size: 4,201 Bytes
5aa6736 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import torch
from transformers import AutoModel, AutoProcessor
import gradio as gr
from PIL import Image
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
import warnings
warnings.filterwarnings(action="ignore")
colors.orange_red = colors.Color(
name="orange_red",
c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366",
c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700",
c800="#B33000", c900="#992900", c950="#802200",
)
class OrangeRedTheme(Soft):
def __init__(self):
super().__init__(
primary_hue=colors.orange_red,
secondary_hue=colors.orange_red,
neutral_hue=colors.slate,
text_size=sizes.text_lg,
font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
font_mono=(fonts.GoogleFont("IBM Plex Mono"), "monospace"),
)
super().set(
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_text_color="white",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
)
orange_red_theme = OrangeRedTheme()
MODEL_ID = "openai/clip-vit-base-patch32"
model = AutoModel.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
def postprocess_metaclip(probs, labels):
return {labels[i]: probs[0][i].item() for i in range(len(labels))}
def metaclip_detector(image, texts):
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
return probs
def infer(image, candidate_labels):
candidate_labels = [l.strip() for l in candidate_labels.split(",")]
probs = metaclip_detector(image, candidate_labels)
return postprocess_metaclip(probs, labels=candidate_labels)
css_style = """
#container {
max-width: 1280px; /* wider layout */
margin: auto;
}
@media (min-width: 1600px) {
#container {
max-width: 1440px;
}
}
#title h1 {
font-size: 2.4em !important;
}
"""
with gr.Blocks(title="AI Document Summarizer") as demo:
with gr.Column(elem_id="container"):
gr.Markdown("# **Open AI Zero-Shot Classification**", elem_id="title")
gr.Markdown("This is the demo of model 'openai/clip-vit-base-patch32' for zero-shot classification.")
with gr.Row(equal_height=True):
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image", height=310)
text_input = gr.Textbox(label="Input labels (comma separated)")
run_button = gr.Button("Run", variant="primary")
with gr.Column():
metaclip_output = gr.Label(
label="Open AI Zero-Shot Classification Output",
num_top_classes=5
)
with gr.Row(equal_height=True):
gr.Examples(
examples=[
["./zebra.jpg", "a photo of a zebra, a photo of a horse, a photo of a donkey"],
["./cat.jpg", "a photo of a cat, a photo of two cats, a photo of three cats"],
["./fridge.jpg", "a photo of a fridge, a photo of a cupboard, a photo of a wardrobe"]
],
inputs=[image_input, text_input],
outputs=[metaclip_output],
fn=infer,
)
run_button.click(
fn=infer,
inputs=[image_input, text_input],
outputs=[metaclip_output]
)
if __name__ == "__main__":
demo.queue().launch(
theme=orange_red_theme,
css=css_style,
show_error=True,
server_name="0.0.0.0",
server_port=7860,
debug=True
) |