File size: 827 Bytes
e7f1ba9
c2570eb
e7f1ba9
 
 
 
c2570eb
e7f1ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37a50a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import gradio as gr
from transformers import ViTForImageClassification, AutoImageProcessor
from PIL import Image
import torch

model = ViTForImageClassification.from_pretrained("rakib730/vit-base-oxford-iiit-pets")
processor = AutoImageProcessor.from_pretrained("rakib730/vit-base-oxford-iiit-pets")

def classify_image(image):
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    label = model.config.id2label[predicted_class_idx]
    return label

iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs="label",
    title="Oxford-IIIT Pets Classifier",
    description="Upload a pet image to classify its breed using ViT."
)

iface.launch()