Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| # --- Define the MLP_one CNN architecture --- | |
| class MLP_one(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(3, 6, 5) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.conv2 = nn.Conv2d(6, 16, 5) | |
| self.fc1 = nn.Linear(16 * 5 * 5, 120) | |
| self.fc2 = nn.Linear(120, 84) | |
| self.fc3 = nn.Linear(84, 10) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = torch.flatten(x, 1) | |
| x = F.relu(self.fc1(x)) | |
| x = F.relu(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| # --- Load trained model weights --- | |
| model = MLP_one() | |
| model.load_state_dict(torch.load("model.pth", map_location="cpu")) | |
| model.eval() | |
| # --- CIFAR-10 class names --- | |
| classes = [ | |
| "airplane", "automobile", "bird", "cat", "deer", | |
| "dog", "frog", "horse", "ship", "truck" | |
| ] | |
| # --- Transform pipeline --- | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), # resize any image to 32x32 | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| # --- Prediction function --- | |
| def predict(image): | |
| image = image.convert("RGB") | |
| x = transform(image).unsqueeze(0) # (1, 3, 32, 32) | |
| with torch.no_grad(): | |
| outputs = model(x) # tensor shape [1, 10] | |
| probs = torch.nn.functional.softmax(outputs, dim=1) # apply softmax | |
| probs = probs[0].cpu().numpy() # convert to numpy for Gradio | |
| return {classes[i]: float(probs[i]) for i in range(10)} | |
| # --- Gradio Interface --- | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload any image"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="CIFAR-10 Image Classifier (MLP_one)", | |
| description=( | |
| "Upload any image (JPG, PNG, etc.) and this model will resize it to 32×32 " | |
| "and predict the closest CIFAR-10 class." | |
| ) | |
| ) | |
| demo.launch() | |