Jagjeet2003's picture
Update app.py
1cab240 verified
raw
history blame contribute delete
995 Bytes
import torch
import gradio as gr
from torchvision import transforms
from PIL import Image
from model import ALexNet # Make sure this file and class exist
print("App is starting...")
try:
model = ALexNet(3, 64, 10)
model.load_state_dict(torch.load("Modified_ALexnet_for_CIFAR.pth", map_location=torch.device("cpu")))
model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model: {e}")
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
def predict(img):
img = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
predicted_class = torch.argmax(outputs, dim=1).item()
class_names = ["airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"]
return f"Predicted class: {class_names[predicted_class]}"
gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()