import torch import gradio as gr from torchvision import transforms from PIL import Image from model import ALexNet # Make sure this file and class exist print("App is starting...") try: model = ALexNet(3, 64, 10) model.load_state_dict(torch.load("Modified_ALexnet_for_CIFAR.pth", map_location=torch.device("cpu"))) model.eval() print("Model loaded successfully.") except Exception as e: print(f"Failed to load model: {e}") transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor() ]) def predict(img): img = transform(img).unsqueeze(0) with torch.no_grad(): outputs = model(img) predicted_class = torch.argmax(outputs, dim=1).item() class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"] return f"Predicted class: {class_names[predicted_class]}" gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()