File size: 904 Bytes
c71972d
d55f345
c71972d
 
ce49d3c
75b19fb
 
 
 
ce49d3c
 
 
 
 
fc0afa2
ce49d3c
 
 
 
 
 
 
 
fc0afa2
75b19fb
 
ce49d3c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import gradio as gr
from transformers import pipeline
from PIL import Image

# Load zero-shot image classification
checkpoint = "openai/clip-vit-large-patch14"
detector = pipeline(model=checkpoint, task="zero-shot-image-classification")

# Inference function
# def classify_image(image):
#     labels = ["cat", "dog"]
#     image = image.convert("RGB")
#     results = detector(image, candidate_labels=labels)
#     return {res["label"]: round(res["score"], 3) for res in results}

def classify_image(image):
    try:
        labels = ["cat", "dog"]
        image = image.convert("RGB")
        results = detector(image, candidate_labels=labels)
        return {res["label"]: round(res["score"], 3) for res in results}
    except Exception as e:
        return {"error": str(e)}


# Gradio interface
iface = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Label())
iface.launch()