aupfe08's picture
Update app.py
526351d verified
raw
history blame
567 Bytes
import gradio as gr
from transformers import pipeline
MODEL_ID = "microsoft/resnet-50"
clf = pipeline("image-classification", model=MODEL_ID)
def predict(img):
out = clf(img)
# show top-3 with scores
out = sorted(out, key=lambda r: r["score"], reverse=True)[:3]
return {r["label"]: float(r["score"]) for r in out}
gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload image"),
outputs=gr.Label(num_top_classes=3),
title="Image Classifier (pre-tuned)",
examples=["banana-1.jpg", "cat1.jpg", "zebra.jpg"]
).launch()