File size: 556 Bytes
42db648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2d52bd
f1466dc
42db648
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import gradio as gr
from transformers import pipeline

MODEL_ID = "microsoft/resnet-50"
clf = pipeline("image-classification", model=MODEL_ID)

def predict(img):
    out = clf(img)
    # show top-3 with scores
    out = sorted(out, key=lambda r: r["score"], reverse=True)[:3]
    return {r["label"]: float(r["score"]) for r in out}

gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload image"),
    outputs=gr.Label(num_top_classes=3),
    title="Image Classifier",
    examples=["banana-1.jpg", "cat1.png", "zebra.jpg"],
).launch()