File size: 1,116 Bytes
0f79521
 
 
 
 
ae47ae6
696435b
024824d
 
0f79521
024824d
0f79521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor


# Load model and processor with custom code enabled 
model = AutoModelForImageClassification.from_pretrained("shravvvv/SAG-ViT", trust_remote_code=True)
processor = AutoImageProcessor.from_pretrained("shravvvv/SAG-ViT", trust_remote_code=True)

# Define CIFAR-10 class labels
class_labels = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# Define prediction function
def predict(image):
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    return class_labels[predicted_class_idx]

# Create Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.inputs.Image(type="pil"),
    outputs=gr.outputs.Label(),
    title="SAG-ViT Image Classifier",
    description="Upload an image to classify it using the SAG-ViT model."
)

if __name__ == "__main__":
    iface.launch()