jugalgajjar commited on
Commit
92bcc69
·
verified ·
1 Parent(s): 824f6ea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from torchvision import transforms
8
+
9
+ class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
10
+
11
+ device = torch.device("cpu")
12
+ inference_transform = T.Compose([
13
+ T.ToTensor(),
14
+ T.Normalize(mean=(0.4914, 0.4822, 0.4465),
15
+ std=(0.2023, 0.1994, 0.2010)),
16
+ ])
17
+
18
+ class SmallCifarCNN(nn.Module):
19
+ def __init__(self, num_classes: int = 10):
20
+ super().__init__()
21
+ self.features = nn.Sequential(
22
+ # Block 1
23
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
24
+ nn.BatchNorm2d(32),
25
+ nn.ReLU(inplace=True),
26
+ nn.MaxPool2d(2), # 32 -> 16
27
+ # Block 2
28
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
29
+ nn.BatchNorm2d(64),
30
+ nn.ReLU(inplace=True),
31
+ nn.MaxPool2d(2), # 16 -> 8
32
+ # Block 3
33
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
34
+ nn.BatchNorm2d(128),
35
+ nn.ReLU(inplace=True),
36
+ nn.MaxPool2d(2), # 8 -> 4
37
+ )
38
+ self.classifier = nn.Sequential(
39
+ nn.Flatten(),
40
+ nn.Linear(128 * 4 * 4, 256),
41
+ nn.ReLU(inplace=True),
42
+ nn.Dropout(p=0.5),
43
+ nn.Linear(256, num_classes),
44
+ )
45
+
46
+ def forward(self, x):
47
+ x = self.features(x)
48
+ x = self.classifier(x)
49
+ return x
50
+
51
+ deployed_model = SmallCifarCNN(num_classes=len(class_names)).to(device)
52
+
53
+ model_path = 'cifar_cnn_best.pt'
54
+ deployed_model.load_state_dict(
55
+ torch.load(model_path, map_location=device)
56
+ )
57
+ deployed_model.to(device)
58
+ deployed_model.eval()
59
+
60
+ def predict_cifar_image(img: Image.Image):
61
+ """
62
+ Gradio callback:
63
+ - Takes a PIL Image
64
+ - Resizes to 32x32 (CIFAR size)
65
+ - Normalizes and runs through the CNN
66
+ - Returns top-3 class probabilities
67
+ """
68
+ img = img.convert("RGB")
69
+ img = img.resize((32, 32), Image.BILINEAR)
70
+
71
+ x = inference_transform(img).unsqueeze(0).to(device)
72
+ with torch.no_grad():
73
+ logits = deployed_model(x)
74
+ probs = F.softmax(logits, dim=1).cpu().numpy().ravel()
75
+
76
+ topk = 3
77
+ idxs = np.argsort(-probs)[:topk]
78
+ return {class_names[i]: float(probs[i]) for i in idxs}
79
+
80
+ demo = gr.Interface(
81
+ fn=predict_cifar_image,
82
+ inputs=gr.Image(type="pil", label="Upload an RGB image (will be resized to 32×32)"),
83
+ outputs=gr.Label(num_top_classes=3, label="Top-3 CIFAR-10 predictions"),
84
+ title="CIFAR-10 CNN Classifier",
85
+ description="Small CNN trained on CIFAR-10. Upload an image and see top-3 class probabilities.",
86
+ )
87
+
88
+
89
+ if __name__ == "__main__":
90
+ demo.launch()