import gradio as gr from transformers import CLIPProcessor, CLIPModel from PIL import Image import torch # Load the CLIP model and processor model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Define a function to perform zero-shot classification def classify_image(image, candidate_labels): if isinstance(candidate_labels, str): candidate_labels = [label.strip() for label in candidate_labels.split(",")] # Debug: Check candidate labels print("Candidate Labels:", candidate_labels) # Tokenize the inputs inputs = processor(text=candidate_labels, images=image, return_tensors="pt", padding=True) # Debug: Check input shapes print("Inputs for model:", inputs) # Perform inference outputs = model(**inputs) # Compute logits and probabilities logits_per_image = outputs.logits_per_image # Shape: [1, len(candidate_labels)] # Debug: Check logits shape print("Logits shape:", logits_per_image.shape) # Ensure logits_per_image has the correct shape if logits_per_image.size(1) != len(candidate_labels): raise ValueError("Mismatch between logits and candidate labels.") # Normalize to probabilities probs = logits_per_image.softmax(dim=1).squeeze(0).tolist() # Convert tensor to list # Return a dictionary mapping labels to probabilities return {label: prob for label, prob in zip(candidate_labels, probs)} # Define the Gradio interface interface = gr.Interface( fn=classify_image, inputs=[ gr.Image(type="pil"), # Accept an image gr.Textbox(label="Candidate Labels (comma-separated)"), # Accept text input ], outputs=gr.Label(num_top_classes=5), # Output probabilities title="Zero-Shot Image Classification with CLIP" ) # Launch the app interface.launch()