File size: 1,929 Bytes
7dd345f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
import torch
from transformers import pipeline

# 1. Initialize the zero-shot image classification pipeline using CLIP
print("Loading OpenAI CLIP model...")
classifier = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32")

def classify_image(image, labels_text):
    # Safe checks for missing inputs
    if image is None or not labels_text.strip():
        return {"Please upload an image and provide labels.": 1.0}
    
    # Clean up the comma-separated labels from the textbox input
    candidate_labels = [label.strip() for label in labels_text.split(",") if label.strip()]
    
    if not candidate_labels:
        return {"Please enter at least one valid label.": 1.0}
    
    # 2. Run inference through CLIP
    # The pipeline automatically coordinates text tokens and image tensors
    predictions = classifier(image, candidate_labels=candidate_labels)
    
    # 3. Format the response dictionary so Gradio's gr.Label can display it
    # Format looks like: {"label_name": score_float}
    return {pred["label"]: float(pred["score"]) for pred in predictions}

# 4. Define the User Interface
demo = gr.Interface(
    fn=classify_image,
    inputs=[
        gr.Image(type="pil", label="1. Upload your Image"),
        gr.Textbox(
            label="2. Candidate Labels (Separate with commas)", 
            placeholder="e.g., a sunny beach, a cozy rainy day, a cute animal, corporate office",
            value="a playful dog, a quiet cat, an outdoor landscape, indoor architecture"
        )
    ],
    outputs=gr.Label(num_top_classes=5, label="Matching Confidence"),
    title="CLIP Zero-Shot Image Matcher",
    description="Type any descriptive phrases or labels you can think of, separate them with commas, and see how well OpenAI's CLIP aligns them to your uploaded photo.",
    flagging_mode="never"
)

# Launch the app
if __name__ == "__main__":
    demo.launch()