oscarw-t commited on
Commit
f9eae93
·
1 Parent(s): e3a2969
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import gradio as gr
7
+
8
+ # --- Define the MLP_one CNN architecture ---
9
+ class MLP_one(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.conv1 = nn.Conv2d(3, 6, 5)
13
+ self.pool = nn.MaxPool2d(2, 2)
14
+ self.conv2 = nn.Conv2d(6, 16, 5)
15
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
16
+ self.fc2 = nn.Linear(120, 84)
17
+ self.fc3 = nn.Linear(84, 10)
18
+
19
+ def forward(self, x):
20
+ x = self.pool(F.relu(self.conv1(x)))
21
+ x = self.pool(F.relu(self.conv2(x)))
22
+ x = torch.flatten(x, 1)
23
+ x = F.relu(self.fc1(x))
24
+ x = F.relu(self.fc2(x))
25
+ x = self.fc3(x)
26
+ return x
27
+
28
+
29
+ # --- Load trained model weights ---
30
+ model = MLP_one()
31
+ model.load_state_dict(torch.load("model.pth", map_location="cpu"))
32
+ model.eval()
33
+
34
+ # --- CIFAR-10 class names ---
35
+ classes = [
36
+ "airplane", "automobile", "bird", "cat", "deer",
37
+ "dog", "frog", "horse", "ship", "truck"
38
+ ]
39
+
40
+ # --- Transform pipeline ---
41
+ transform = transforms.Compose([
42
+ transforms.Resize((32, 32)), # resize any image to 32x32
43
+ transforms.ToTensor(),
44
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
45
+ ])
46
+
47
+ # --- Prediction function ---
48
+ def predict(image):
49
+ """
50
+ Takes any image (JPG, PNG, etc.), converts to RGB, resizes to 32x32,
51
+ runs through the CNN, and returns class probabilities.
52
+ """
53
+ # Convert to RGB (in case of grayscale or RGBA input)
54
+ image = image.convert("RGB")
55
+ image = transform(image).unsqueeze(0) # shape: [1, 3, 32, 32]
56
+
57
+ with torch.no_grad():
58
+ outputs=gr.Label(num_top_classes=3)
59
+ probs = torch.softmax(outputs, dim=1)[0]
60
+
61
+ # Convert to dictionary: {class: probability}
62
+ return {classes[i]: float(probs[i]) for i in range(10)}
63
+
64
+ # --- Gradio Interface ---
65
+ demo = gr.Interface(
66
+ fn=predict,
67
+ inputs=gr.Image(type="pil", label="Upload any image"),
68
+ outputs=gr.Label(num_top_classes=3),
69
+ title="CIFAR-10 Image Classifier (MLP_one)",
70
+ description=(
71
+ "Upload any image (JPG, PNG, etc.) and this model will resize it to 32×32 "
72
+ "and predict the closest CIFAR-10 class."
73
+ ),
74
+ examples=[
75
+ ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cifar10-dog.png"],
76
+ ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cifar10-truck.png"],
77
+ ]
78
+ )
79
+
80
+ demo.launch()