Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import torch | |
| # Load the CLIP model and processor | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Define a function to perform zero-shot classification | |
| def classify_image(image, candidate_labels): | |
| if isinstance(candidate_labels, str): | |
| candidate_labels = [label.strip() for label in candidate_labels.split(",")] | |
| # Debug: Check candidate labels | |
| print("Candidate Labels:", candidate_labels) | |
| # Tokenize the inputs | |
| inputs = processor(text=candidate_labels, images=image, return_tensors="pt", padding=True) | |
| # Debug: Check input shapes | |
| print("Inputs for model:", inputs) | |
| # Perform inference | |
| outputs = model(**inputs) | |
| # Compute logits and probabilities | |
| logits_per_image = outputs.logits_per_image # Shape: [1, len(candidate_labels)] | |
| # Debug: Check logits shape | |
| print("Logits shape:", logits_per_image.shape) | |
| # Ensure logits_per_image has the correct shape | |
| if logits_per_image.size(1) != len(candidate_labels): | |
| raise ValueError("Mismatch between logits and candidate labels.") | |
| # Normalize to probabilities | |
| probs = logits_per_image.softmax(dim=1).squeeze(0).tolist() # Convert tensor to list | |
| # Return a dictionary mapping labels to probabilities | |
| return {label: prob for label, prob in zip(candidate_labels, probs)} | |
| # Define the Gradio interface | |
| interface = gr.Interface( | |
| fn=classify_image, | |
| inputs=[ | |
| gr.Image(type="pil"), # Accept an image | |
| gr.Textbox(label="Candidate Labels (comma-separated)"), # Accept text input | |
| ], | |
| outputs=gr.Label(num_top_classes=5), # Output probabilities | |
| title="Zero-Shot Image Classification with CLIP" | |
| ) | |
| # Launch the app | |
| interface.launch() | |