Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from torchvision import transforms | |
| from PIL import Image | |
| from model import ALexNet # Make sure this file and class exist | |
| print("App is starting...") | |
| try: | |
| model = ALexNet(3, 64, 10) | |
| model.load_state_dict(torch.load("Modified_ALexnet_for_CIFAR.pth", map_location=torch.device("cpu"))) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor() | |
| ]) | |
| def predict(img): | |
| img = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| predicted_class = torch.argmax(outputs, dim=1).item() | |
| class_names = ["airplane", "automobile", "bird", "cat", "deer", | |
| "dog", "frog", "horse", "ship", "truck"] | |
| return f"Predicted class: {class_names[predicted_class]}" | |
| gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch() |