File size: 535 Bytes
c00b118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35aee79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

import gradio as gr
from transformers import pipeline

clf = pipeline("image-classification",
               model="google/vit-base-patch16-224",
               device_map="auto")

def predict(img):
    out = clf(img)[:5]
    return {o['label']: float(o['score']) for o in out}

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=5),
    title="Image Classification (ViT)"
)

if __name__ == "__main__":
    demo.launch(share=True)  # share=True gives you a temporary Colab link