File size: 995 Bytes
55b8a37
 
 
 
 
1cab240
 
 
 
 
 
 
 
 
 
 
55b8a37
 
1cab240
55b8a37
 
 
 
 
 
 
 
 
1cab240
55b8a37
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import gradio as gr
from torchvision import transforms
from PIL import Image

from model import ALexNet  # Make sure this file and class exist

print("App is starting...")

try:
    model = ALexNet(3, 64, 10)
    model.load_state_dict(torch.load("Modified_ALexnet_for_CIFAR.pth", map_location=torch.device("cpu")))
    model.eval()
    print("Model loaded successfully.")
except Exception as e:
    print(f"Failed to load model: {e}")

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

def predict(img):
    img = transform(img).unsqueeze(0)
    with torch.no_grad():
        outputs = model(img)
        predicted_class = torch.argmax(outputs, dim=1).item()
        class_names = ["airplane", "automobile", "bird", "cat", "deer",
                       "dog", "frog", "horse", "ship", "truck"]
    return f"Predicted class: {class_names[predicted_class]}"

gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()