openai-clip-vit-base-patch32 / zeroShot_app.py
yilchenko's picture
rm: unnecesary comments (#1)
d441df8 verified
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Define a function to perform zero-shot classification
def classify_image(image, candidate_labels):
if isinstance(candidate_labels, str):
candidate_labels = [label.strip() for label in candidate_labels.split(",")]
# Debug: Check candidate labels
print("Candidate Labels:", candidate_labels)
# Tokenize the inputs
inputs = processor(text=candidate_labels, images=image, return_tensors="pt", padding=True)
# Debug: Check input shapes
print("Inputs for model:", inputs)
# Perform inference
outputs = model(**inputs)
# Compute logits and probabilities
logits_per_image = outputs.logits_per_image # Shape: [1, len(candidate_labels)]
# Debug: Check logits shape
print("Logits shape:", logits_per_image.shape)
# Ensure logits_per_image has the correct shape
if logits_per_image.size(1) != len(candidate_labels):
raise ValueError("Mismatch between logits and candidate labels.")
# Normalize to probabilities
probs = logits_per_image.softmax(dim=1).squeeze(0).tolist() # Convert tensor to list
# Return a dictionary mapping labels to probabilities
return {label: prob for label, prob in zip(candidate_labels, probs)}
# Define the Gradio interface
interface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil"), # Accept an image
gr.Textbox(label="Candidate Labels (comma-separated)"), # Accept text input
],
outputs=gr.Label(num_top_classes=5), # Output probabilities
title="Zero-Shot Image Classification with CLIP"
)
# Launch the app
interface.launch()