metaclip-2-demo / app.py
prithivMLmods's picture
update app
0a3df57 verified
raw
history blame
2.74 kB
import torch
from transformers import AutoModel, AutoProcessor
import gradio as gr
from PIL import Image
import requests
model = AutoModel.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16", torch_dtype=torch.bfloat16, attn_implementation="sdpa")
processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16")
def postprocess_metaclip(probs, labels):
output = {labels[i]: probs[0][i].item() for i in range(len(labels))}
return output
def metaclip_detector(image, texts):
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
return probs
def infer(image, candidate_labels):
candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
probs = metaclip_detector(image, candidate_labels)
return postprocess_metaclip(probs, labels=candidate_labels)
with gr.Blocks() as demo:
gr.Markdown("# MetaCLIP 2 Zero-Shot Classification")
gr.Markdown(
"Test the performance of MetaCLIP 2 on zero-shot classification in this Space :point_down:"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil")
text_input = gr.Textbox(label="Input a list of labels (comma seperated)")
run_button = gr.Button("Run", visible=True)
with gr.Column():
metaclip_output = gr.Label(label="MetaCLIP 2 Output", num_top_classes=3)
# It's recommended to have local images for the examples
# For demonstration purposes, we will download them if they don't exist.
def download_image(url, filename):
import os
if not os.path.exists(filename):
response = requests.get(url, stream=True)
response.raise_for_status()
with open(filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
download_image("https://gradio-builds.s3.amazonaws.com/demo-files/baklava.jpg", "baklava.jpg")
download_image("https://gradio-builds.s3.amazonaws.com/demo-files/cat.jpg", "cat.jpg")
examples = [
["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
["./cat.jpg", "a cat, two cats, three cats"],
["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
]
gr.Examples(
examples=examples,
inputs=[image_input, text_input],
outputs=[metaclip_output],
fn=infer,
)
run_button.click(fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output])
demo.launch()