eaglelandsonce commited on
Commit
75aca80
·
verified ·
1 Parent(s): f839b6f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +327 -0
app.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import threading
5
+
6
+ import numpy as np
7
+ from PIL import Image, ImageOps
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader, Subset
13
+ from torchvision import datasets, transforms
14
+
15
+ import gradio as gr
16
+
17
+
18
+ # -----------------------------
19
+ # Custom PyTorch model (nn.Module)
20
+ # -----------------------------
21
+ class MnistCNN(nn.Module):
22
+ def __init__(self, num_classes: int = 10, dropout: float = 0.25):
23
+ super().__init__()
24
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 28x28 -> 28x28
25
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 28x28 -> 28x28
26
+ self.pool = nn.MaxPool2d(2, 2) # 28x28 -> 14x14
27
+ self.dropout = nn.Dropout(dropout)
28
+ self.fc1 = nn.Linear(64 * 14 * 14, 128)
29
+ self.fc2 = nn.Linear(128, num_classes)
30
+
31
+ def forward(self, x):
32
+ x = F.relu(self.conv1(x))
33
+ x = self.pool(F.relu(self.conv2(x)))
34
+ x = torch.flatten(x, 1)
35
+ x = self.dropout(F.relu(self.fc1(x)))
36
+ return self.fc2(x) # logits
37
+
38
+
39
+ # -----------------------------
40
+ # Global state
41
+ # -----------------------------
42
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ MODEL_LOCK = threading.Lock()
44
+ MODEL = MnistCNN().to(DEVICE)
45
+
46
+ WEIGHTS_PATH = "mnist_cnn.pth"
47
+ CONFIG_PATH = "mnist_config.json"
48
+
49
+ DEFAULT_CONFIG = {
50
+ "num_classes": 10,
51
+ "dropout": 0.25,
52
+ "normalize_mean": 0.1307,
53
+ "normalize_std": 0.3081,
54
+ "image_size": 28
55
+ }
56
+
57
+ # Use deterministic-ish behavior for demos (not perfect determinism on all systems)
58
+ torch.manual_seed(42)
59
+ np.random.seed(42)
60
+
61
+
62
+ def save_config():
63
+ with open(CONFIG_PATH, "w") as f:
64
+ json.dump(DEFAULT_CONFIG, f, indent=2)
65
+
66
+
67
+ def load_config():
68
+ if os.path.exists(CONFIG_PATH):
69
+ with open(CONFIG_PATH, "r") as f:
70
+ return json.load(f)
71
+ save_config()
72
+ return DEFAULT_CONFIG
73
+
74
+
75
+ CFG = load_config()
76
+
77
+
78
+ # -----------------------------
79
+ # Utilities
80
+ # -----------------------------
81
+ def maybe_load_weights():
82
+ global MODEL
83
+ if os.path.exists(WEIGHTS_PATH):
84
+ state = torch.load(WEIGHTS_PATH, map_location=DEVICE)
85
+ with MODEL_LOCK:
86
+ MODEL.load_state_dict(state)
87
+ MODEL.eval()
88
+ return True
89
+ return False
90
+
91
+
92
+ def preprocess_pil(img: Image.Image) -> torch.Tensor:
93
+ """
94
+ Converts a PIL image to MNIST-like tensor: (1,1,28,28), normalized.
95
+ Also attempts to handle "black ink on white background" by auto-inverting.
96
+ """
97
+ if img is None:
98
+ raise ValueError("No image provided.")
99
+
100
+ # Convert to grayscale
101
+ img = img.convert("L")
102
+
103
+ # Resize to 28x28
104
+ img = img.resize((CFG["image_size"], CFG["image_size"]))
105
+
106
+ # Convert to numpy [0..1]
107
+ arr = np.array(img).astype(np.float32) / 255.0
108
+
109
+ # Auto-invert if background looks white-ish (common with sketch tools)
110
+ # MNIST digits are typically bright strokes on darker background.
111
+ if arr.mean() > 0.5:
112
+ arr = 1.0 - arr
113
+
114
+ # Normalize like training
115
+ arr = (arr - CFG["normalize_mean"]) / CFG["normalize_std"]
116
+
117
+ # Shape to (1,1,28,28)
118
+ x = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
119
+ return x.to(DEVICE)
120
+
121
+
122
+ def predict_digit(img: Image.Image):
123
+ global MODEL
124
+ if img is None:
125
+ return "No image", {}
126
+
127
+ x = preprocess_pil(img)
128
+
129
+ with MODEL_LOCK:
130
+ MODEL.eval()
131
+ with torch.no_grad():
132
+ logits = MODEL(x)
133
+ probs = torch.softmax(logits, dim=1).cpu().numpy().squeeze(0)
134
+
135
+ pred = int(np.argmax(probs))
136
+ prob_dict = {str(i): float(probs[i]) for i in range(10)}
137
+ return pred, prob_dict
138
+
139
+
140
+ # -----------------------------
141
+ # Training
142
+ # -----------------------------
143
+ def get_dataloaders(batch_size: int, max_train_samples: int, max_test_samples: int):
144
+ transform = transforms.Compose([
145
+ transforms.ToTensor(),
146
+ transforms.Normalize((CFG["normalize_mean"],), (CFG["normalize_std"],))
147
+ ])
148
+
149
+ train_ds = datasets.MNIST(root="data", train=True, download=True, transform=transform)
150
+ test_ds = datasets.MNIST(root="data", train=False, download=True, transform=transform)
151
+
152
+ # Subset for faster training on Spaces (optional)
153
+ if max_train_samples and max_train_samples < len(train_ds):
154
+ train_ds = Subset(train_ds, range(max_train_samples))
155
+ if max_test_samples and max_test_samples < len(test_ds):
156
+ test_ds = Subset(test_ds, range(max_test_samples))
157
+
158
+ # num_workers=0 is safest in Spaces
159
+ train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
160
+ test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
161
+ return train_dl, test_dl
162
+
163
+
164
+ def evaluate(model: nn.Module, test_dl: DataLoader):
165
+ model.eval()
166
+ correct = 0
167
+ total = 0
168
+ loss_sum = 0.0
169
+ criterion = nn.CrossEntropyLoss()
170
+
171
+ with torch.no_grad():
172
+ for x, y in test_dl:
173
+ x, y = x.to(DEVICE), y.to(DEVICE)
174
+ logits = model(x)
175
+ loss = criterion(logits, y)
176
+ loss_sum += loss.item()
177
+
178
+ preds = logits.argmax(dim=1)
179
+ correct += (preds == y).sum().item()
180
+ total += y.numel()
181
+
182
+ avg_loss = loss_sum / max(1, len(test_dl))
183
+ acc = correct / max(1, total)
184
+ return avg_loss, acc
185
+
186
+
187
+ def train_mnist(epochs: int, lr: float, batch_size: int, max_train_samples: int, max_test_samples: int, progress=gr.Progress()):
188
+ global MODEL
189
+
190
+ train_dl, test_dl = get_dataloaders(batch_size, max_train_samples, max_test_samples)
191
+
192
+ # Re-init model each time you train (simple + predictable)
193
+ model = MnistCNN(num_classes=CFG["num_classes"], dropout=CFG["dropout"]).to(DEVICE)
194
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
195
+ criterion = nn.CrossEntropyLoss()
196
+
197
+ logs = []
198
+ start = time.time()
199
+
200
+ for epoch in range(1, epochs + 1):
201
+ model.train()
202
+ running_loss = 0.0
203
+ correct = 0
204
+ total = 0
205
+
206
+ for step, (x, y) in enumerate(progress.tqdm(train_dl, desc=f"Epoch {epoch}/{epochs}")):
207
+ x, y = x.to(DEVICE), y.to(DEVICE)
208
+
209
+ optimizer.zero_grad()
210
+ logits = model(x)
211
+ loss = criterion(logits, y)
212
+ loss.backward()
213
+ optimizer.step()
214
+
215
+ running_loss += loss.item()
216
+ preds = logits.argmax(dim=1)
217
+ correct += (preds == y).sum().item()
218
+ total += y.numel()
219
+
220
+ train_loss = running_loss / max(1, len(train_dl))
221
+ train_acc = correct / max(1, total)
222
+
223
+ test_loss, test_acc = evaluate(model, test_dl)
224
+
225
+ logs.append(
226
+ f"Epoch {epoch}/{epochs} | "
227
+ f"train loss {train_loss:.4f} acc {train_acc:.4f} | "
228
+ f"test loss {test_loss:.4f} acc {test_acc:.4f}"
229
+ )
230
+
231
+ # Save weights locally
232
+ torch.save(model.state_dict(), WEIGHTS_PATH)
233
+ save_config()
234
+
235
+ # Swap global model
236
+ with MODEL_LOCK:
237
+ MODEL.load_state_dict(model.state_dict())
238
+ MODEL.eval()
239
+
240
+ elapsed = time.time() - start
241
+ header = f"Done. Saved weights to `{WEIGHTS_PATH}`. Device: {DEVICE}. Time: {elapsed:.1f}s\n"
242
+ return header + "\n".join(logs)
243
+
244
+
245
+ def load_saved_weights_ui():
246
+ ok = maybe_load_weights()
247
+ if ok:
248
+ return f"Loaded saved weights from `{WEIGHTS_PATH}`."
249
+ return f"No saved weights found at `{WEIGHTS_PATH}`. Train first."
250
+
251
+
252
+ # Try to load weights at startup (if present)
253
+ _ = maybe_load_weights()
254
+
255
+
256
+ # -----------------------------
257
+ # Gradio UI
258
+ # -----------------------------
259
+ with gr.Blocks() as demo:
260
+ gr.Markdown("# MNIST (Custom `nn.Module`) — Train + Predict (PyTorch + Gradio)")
261
+ gr.Markdown(
262
+ "Use **Train** to fit a small CNN on MNIST. Then **draw** or **upload** a digit to predict.\n\n"
263
+ f"- Running on: `{DEVICE}`\n"
264
+ f"- Weights file: `{WEIGHTS_PATH}`"
265
+ )
266
+
267
+ with gr.Row():
268
+ with gr.Column():
269
+ gr.Markdown("## 1) Train (optional)")
270
+ epochs = gr.Slider(1, 5, value=1, step=1, label="Epochs (start with 1)")
271
+ lr = gr.Number(value=1e-3, label="Learning rate", precision=6)
272
+ batch = gr.Slider(32, 256, value=128, step=32, label="Batch size")
273
+
274
+ gr.Markdown("### Speed controls (use smaller values for faster training)")
275
+ max_train = gr.Slider(1000, 60000, value=10000, step=1000, label="Max train samples")
276
+ max_test = gr.Slider(500, 10000, value=2000, step=500, label="Max test samples")
277
+
278
+ train_btn = gr.Button("Train model")
279
+ load_btn = gr.Button("Load saved weights")
280
+
281
+ train_log = gr.Textbox(label="Training log", lines=10)
282
+ status = gr.Textbox(label="Status", lines=2)
283
+
284
+ with gr.Column():
285
+ gr.Markdown("## 2) Predict")
286
+ with gr.Tab("Draw"):
287
+ draw_img = gr.Image(source="canvas", tool="sketch", type="pil", label="Draw a digit (0-9)")
288
+ draw_btn = gr.Button("Predict from drawing")
289
+ with gr.Tab("Upload"):
290
+ up_img = gr.Image(source="upload", type="pil", label="Upload an image of a digit")
291
+ up_btn = gr.Button("Predict from upload")
292
+
293
+ pred_out = gr.Number(label="Prediction")
294
+ prob_out = gr.Label(num_top_classes=3, label="Probabilities (top 3)")
295
+
296
+ # Wiring
297
+ train_btn.click(
298
+ fn=train_mnist,
299
+ inputs=[epochs, lr, batch, max_train, max_test],
300
+ outputs=[train_log],
301
+ ).then(
302
+ fn=lambda: "Training complete. You can now predict.",
303
+ inputs=[],
304
+ outputs=[status],
305
+ )
306
+
307
+ load_btn.click(
308
+ fn=load_saved_weights_ui,
309
+ inputs=[],
310
+ outputs=[status],
311
+ )
312
+
313
+ draw_btn.click(
314
+ fn=predict_digit,
315
+ inputs=[draw_img],
316
+ outputs=[pred_out, prob_out],
317
+ )
318
+
319
+ up_btn.click(
320
+ fn=predict_digit,
321
+ inputs=[up_img],
322
+ outputs=[pred_out, prob_out],
323
+ )
324
+
325
+
326
+ if __name__ == "__main__":
327
+ demo.launch(server_name="0.0.0.0", server_port=7860)