eaglelandsonce commited on
Commit
fc26943
·
verified ·
1 Parent(s): b80739d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -91
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import os
2
- import time
3
  import json
 
4
  import threading
5
 
6
  import numpy as np
7
- from PIL import Image, ImageOps
8
 
9
  import torch
10
  import torch.nn as nn
@@ -21,9 +21,9 @@ import gradio as gr
21
  class MnistCNN(nn.Module):
22
  def __init__(self, num_classes: int = 10, dropout: float = 0.25):
23
  super().__init__()
24
- self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 28x28 -> 28x28
25
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 28x28 -> 28x28
26
- self.pool = nn.MaxPool2d(2, 2) # 28x28 -> 14x14
27
  self.dropout = nn.Dropout(dropout)
28
  self.fc1 = nn.Linear(64 * 14 * 14, 128)
29
  self.fc2 = nn.Linear(128, num_classes)
@@ -46,7 +46,7 @@ MODEL = MnistCNN().to(DEVICE)
46
  WEIGHTS_PATH = "mnist_cnn.pth"
47
  CONFIG_PATH = "mnist_config.json"
48
 
49
- DEFAULT_CONFIG = {
50
  "num_classes": 10,
51
  "dropout": 0.25,
52
  "normalize_mean": 0.1307,
@@ -54,30 +54,28 @@ DEFAULT_CONFIG = {
54
  "image_size": 28
55
  }
56
 
57
- # Use deterministic-ish behavior for demos (not perfect determinism on all systems)
58
  torch.manual_seed(42)
59
  np.random.seed(42)
60
 
61
 
62
- def save_config():
63
- with open(CONFIG_PATH, "w") as f:
64
- json.dump(DEFAULT_CONFIG, f, indent=2)
65
-
66
-
67
- def load_config():
68
  if os.path.exists(CONFIG_PATH):
69
  with open(CONFIG_PATH, "r") as f:
70
  return json.load(f)
71
- save_config()
72
- return DEFAULT_CONFIG
 
73
 
74
 
75
- CFG = load_config()
 
 
 
 
 
 
76
 
77
 
78
- # -----------------------------
79
- # Utilities
80
- # -----------------------------
81
  def maybe_load_weights():
82
  global MODEL
83
  if os.path.exists(WEIGHTS_PATH):
@@ -91,36 +89,25 @@ def maybe_load_weights():
91
 
92
  def preprocess_pil(img: Image.Image) -> torch.Tensor:
93
  """
94
- Converts a PIL image to MNIST-like tensor: (1,1,28,28), normalized.
95
- Also attempts to handle "black ink on white background" by auto-inverting.
96
  """
97
  if img is None:
98
  raise ValueError("No image provided.")
99
 
100
- # Convert to grayscale
101
- img = img.convert("L")
102
-
103
- # Resize to 28x28
104
- img = img.resize((CFG["image_size"], CFG["image_size"]))
105
-
106
- # Convert to numpy [0..1]
107
  arr = np.array(img).astype(np.float32) / 255.0
108
 
109
- # Auto-invert if background looks white-ish (common with sketch tools)
110
- # MNIST digits are typically bright strokes on darker background.
111
  if arr.mean() > 0.5:
112
  arr = 1.0 - arr
113
 
114
- # Normalize like training
115
  arr = (arr - CFG["normalize_mean"]) / CFG["normalize_std"]
116
-
117
- # Shape to (1,1,28,28)
118
- x = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
119
  return x.to(DEVICE)
120
 
121
 
122
- def predict_digit(img: Image.Image):
123
- global MODEL
124
  if img is None:
125
  return "No image", {}
126
 
@@ -137,6 +124,13 @@ def predict_digit(img: Image.Image):
137
  return pred, prob_dict
138
 
139
 
 
 
 
 
 
 
 
140
  # -----------------------------
141
  # Training
142
  # -----------------------------
@@ -149,13 +143,11 @@ def get_dataloaders(batch_size: int, max_train_samples: int, max_test_samples: i
149
  train_ds = datasets.MNIST(root="data", train=True, download=True, transform=transform)
150
  test_ds = datasets.MNIST(root="data", train=False, download=True, transform=transform)
151
 
152
- # Subset for faster training on Spaces (optional)
153
  if max_train_samples and max_train_samples < len(train_ds):
154
  train_ds = Subset(train_ds, range(max_train_samples))
155
  if max_test_samples and max_test_samples < len(test_ds):
156
  test_ds = Subset(test_ds, range(max_test_samples))
157
 
158
- # num_workers=0 is safest in Spaces
159
  train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
160
  test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
161
  return train_dl, test_dl
@@ -163,49 +155,38 @@ def get_dataloaders(batch_size: int, max_train_samples: int, max_test_samples: i
163
 
164
  def evaluate(model: nn.Module, test_dl: DataLoader):
165
  model.eval()
166
- correct = 0
167
- total = 0
168
- loss_sum = 0.0
169
  criterion = nn.CrossEntropyLoss()
 
170
 
171
  with torch.no_grad():
172
  for x, y in test_dl:
173
  x, y = x.to(DEVICE), y.to(DEVICE)
174
  logits = model(x)
175
- loss = criterion(logits, y)
176
- loss_sum += loss.item()
177
-
178
  preds = logits.argmax(dim=1)
179
  correct += (preds == y).sum().item()
180
  total += y.numel()
181
 
182
- avg_loss = loss_sum / max(1, len(test_dl))
183
- acc = correct / max(1, total)
184
- return avg_loss, acc
185
 
186
 
187
- def train_mnist(epochs: int, lr: float, batch_size: int, max_train_samples: int, max_test_samples: int, progress=gr.Progress()):
188
  global MODEL
189
 
 
190
  train_dl, test_dl = get_dataloaders(batch_size, max_train_samples, max_test_samples)
191
 
192
- # Re-init model each time you train (simple + predictable)
193
  model = MnistCNN(num_classes=CFG["num_classes"], dropout=CFG["dropout"]).to(DEVICE)
194
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
195
  criterion = nn.CrossEntropyLoss()
196
 
197
  logs = []
198
- start = time.time()
199
-
200
  for epoch in range(1, epochs + 1):
201
  model.train()
202
- running_loss = 0.0
203
- correct = 0
204
- total = 0
205
 
206
- for step, (x, y) in enumerate(progress.tqdm(train_dl, desc=f"Epoch {epoch}/{epochs}")):
207
  x, y = x.to(DEVICE), y.to(DEVICE)
208
-
209
  optimizer.zero_grad()
210
  logits = model(x)
211
  loss = criterion(logits, y)
@@ -219,7 +200,6 @@ def train_mnist(epochs: int, lr: float, batch_size: int, max_train_samples: int,
219
 
220
  train_loss = running_loss / max(1, len(train_dl))
221
  train_acc = correct / max(1, total)
222
-
223
  test_loss, test_acc = evaluate(model, test_dl)
224
 
225
  logs.append(
@@ -228,100 +208,95 @@ def train_mnist(epochs: int, lr: float, batch_size: int, max_train_samples: int,
228
  f"test loss {test_loss:.4f} acc {test_acc:.4f}"
229
  )
230
 
231
- # Save weights locally
232
  torch.save(model.state_dict(), WEIGHTS_PATH)
233
- save_config()
234
 
235
- # Swap global model
236
  with MODEL_LOCK:
237
  MODEL.load_state_dict(model.state_dict())
238
  MODEL.eval()
239
 
240
  elapsed = time.time() - start
241
- header = f"Done. Saved weights to `{WEIGHTS_PATH}`. Device: {DEVICE}. Time: {elapsed:.1f}s\n"
242
- return header + "\n".join(logs)
243
 
244
 
245
- def load_saved_weights_ui():
246
  ok = maybe_load_weights()
247
- if ok:
248
- return f"Loaded saved weights from `{WEIGHTS_PATH}`."
249
- return f"No saved weights found at `{WEIGHTS_PATH}`. Train first."
250
 
251
 
252
- # Try to load weights at startup (if present)
253
- _ = maybe_load_weights()
254
 
255
 
256
  # -----------------------------
257
- # Gradio UI
258
  # -----------------------------
259
  with gr.Blocks() as demo:
260
- gr.Markdown("# MNIST (Custom `nn.Module`) — Train + Predict (PyTorch + Gradio)")
261
- gr.Markdown(
262
- "Use **Train** to fit a small CNN on MNIST. Then **draw** or **upload** a digit to predict.\n\n"
263
- f"- Running on: `{DEVICE}`\n"
264
- f"- Weights file: `{WEIGHTS_PATH}`"
265
- )
266
 
267
  with gr.Row():
268
  with gr.Column():
269
  gr.Markdown("## 1) Train (optional)")
270
- epochs = gr.Slider(1, 5, value=1, step=1, label="Epochs (start with 1)")
271
  lr = gr.Number(value=1e-3, label="Learning rate", precision=6)
272
  batch = gr.Slider(32, 256, value=128, step=32, label="Batch size")
273
 
274
- gr.Markdown("### Speed controls (use smaller values for faster training)")
275
  max_train = gr.Slider(1000, 60000, value=10000, step=1000, label="Max train samples")
276
  max_test = gr.Slider(500, 10000, value=2000, step=500, label="Max test samples")
277
 
278
  train_btn = gr.Button("Train model")
279
  load_btn = gr.Button("Load saved weights")
280
 
281
- train_log = gr.Textbox(label="Training log", lines=10)
282
  status = gr.Textbox(label="Status", lines=2)
 
283
 
284
  with gr.Column():
285
  gr.Markdown("## 2) Predict")
 
286
  with gr.Tab("Draw"):
287
- draw_img = gr.Image(source="canvas", tool="sketch", type="pil", label="Draw a digit (0-9)")
 
 
 
 
 
 
 
288
  draw_btn = gr.Button("Predict from drawing")
 
289
  with gr.Tab("Upload"):
290
- up_img = gr.Image(source="upload", type="pil", label="Upload an image of a digit")
291
  up_btn = gr.Button("Predict from upload")
292
 
293
  pred_out = gr.Number(label="Prediction")
294
  prob_out = gr.Label(num_top_classes=3, label="Probabilities (top 3)")
295
 
296
- # Wiring
297
  train_btn.click(
298
  fn=train_mnist,
299
  inputs=[epochs, lr, batch, max_train, max_test],
300
- outputs=[train_log],
301
- ).then(
302
- fn=lambda: "Training complete. You can now predict.",
303
- inputs=[],
304
- outputs=[status],
305
  )
306
 
307
  load_btn.click(
308
- fn=load_saved_weights_ui,
309
  inputs=[],
310
  outputs=[status],
311
  )
312
 
313
  draw_btn.click(
314
- fn=predict_digit,
315
- inputs=[draw_img],
316
  outputs=[pred_out, prob_out],
317
  )
318
 
319
  up_btn.click(
320
- fn=predict_digit,
321
  inputs=[up_img],
322
  outputs=[pred_out, prob_out],
323
  )
324
 
325
 
326
  if __name__ == "__main__":
327
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
 
2
  import json
3
+ import time
4
  import threading
5
 
6
  import numpy as np
7
+ from PIL import Image
8
 
9
  import torch
10
  import torch.nn as nn
 
21
  class MnistCNN(nn.Module):
22
  def __init__(self, num_classes: int = 10, dropout: float = 0.25):
23
  super().__init__()
24
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
25
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
26
+ self.pool = nn.MaxPool2d(2, 2) # 28x28 -> 14x14
27
  self.dropout = nn.Dropout(dropout)
28
  self.fc1 = nn.Linear(64 * 14 * 14, 128)
29
  self.fc2 = nn.Linear(128, num_classes)
 
46
  WEIGHTS_PATH = "mnist_cnn.pth"
47
  CONFIG_PATH = "mnist_config.json"
48
 
49
+ CFG_DEFAULT = {
50
  "num_classes": 10,
51
  "dropout": 0.25,
52
  "normalize_mean": 0.1307,
 
54
  "image_size": 28
55
  }
56
 
 
57
  torch.manual_seed(42)
58
  np.random.seed(42)
59
 
60
 
61
+ def load_or_init_config():
 
 
 
 
 
62
  if os.path.exists(CONFIG_PATH):
63
  with open(CONFIG_PATH, "r") as f:
64
  return json.load(f)
65
+ with open(CONFIG_PATH, "w") as f:
66
+ json.dump(CFG_DEFAULT, f, indent=2)
67
+ return CFG_DEFAULT
68
 
69
 
70
+ CFG = load_or_init_config()
71
+
72
+
73
+ def blank_editor_value(size=280):
74
+ """Initial blank canvas for ImageEditor."""
75
+ img = Image.new("RGBA", (size, size), (255, 255, 255, 255))
76
+ return {"background": img, "layers": [], "composite": img}
77
 
78
 
 
 
 
79
  def maybe_load_weights():
80
  global MODEL
81
  if os.path.exists(WEIGHTS_PATH):
 
89
 
90
  def preprocess_pil(img: Image.Image) -> torch.Tensor:
91
  """
92
+ Convert PIL image to MNIST tensor (1,1,28,28), normalized like training.
93
+ Auto-invert if the background is bright.
94
  """
95
  if img is None:
96
  raise ValueError("No image provided.")
97
 
98
+ img = img.convert("L").resize((CFG["image_size"], CFG["image_size"]))
 
 
 
 
 
 
99
  arr = np.array(img).astype(np.float32) / 255.0
100
 
101
+ # If background is mostly white, invert so digit becomes bright on dark
 
102
  if arr.mean() > 0.5:
103
  arr = 1.0 - arr
104
 
 
105
  arr = (arr - CFG["normalize_mean"]) / CFG["normalize_std"]
106
+ x = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) # (1,1,28,28)
 
 
107
  return x.to(DEVICE)
108
 
109
 
110
+ def predict_from_pil(img: Image.Image):
 
111
  if img is None:
112
  return "No image", {}
113
 
 
124
  return pred, prob_dict
125
 
126
 
127
+ def predict_from_editor(editor_value):
128
+ # ImageEditor returns a dict with keys: background, layers, composite
129
+ if editor_value is None or "composite" not in editor_value:
130
+ return "No drawing", {}
131
+ return predict_from_pil(editor_value["composite"])
132
+
133
+
134
  # -----------------------------
135
  # Training
136
  # -----------------------------
 
143
  train_ds = datasets.MNIST(root="data", train=True, download=True, transform=transform)
144
  test_ds = datasets.MNIST(root="data", train=False, download=True, transform=transform)
145
 
 
146
  if max_train_samples and max_train_samples < len(train_ds):
147
  train_ds = Subset(train_ds, range(max_train_samples))
148
  if max_test_samples and max_test_samples < len(test_ds):
149
  test_ds = Subset(test_ds, range(max_test_samples))
150
 
 
151
  train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
152
  test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
153
  return train_dl, test_dl
 
155
 
156
  def evaluate(model: nn.Module, test_dl: DataLoader):
157
  model.eval()
 
 
 
158
  criterion = nn.CrossEntropyLoss()
159
+ loss_sum, correct, total = 0.0, 0, 0
160
 
161
  with torch.no_grad():
162
  for x, y in test_dl:
163
  x, y = x.to(DEVICE), y.to(DEVICE)
164
  logits = model(x)
165
+ loss_sum += criterion(logits, y).item()
 
 
166
  preds = logits.argmax(dim=1)
167
  correct += (preds == y).sum().item()
168
  total += y.numel()
169
 
170
+ return loss_sum / max(1, len(test_dl)), correct / max(1, total)
 
 
171
 
172
 
173
+ def train_mnist(epochs: int, lr: float, batch_size: int, max_train_samples: int, max_test_samples: int):
174
  global MODEL
175
 
176
+ start = time.time()
177
  train_dl, test_dl = get_dataloaders(batch_size, max_train_samples, max_test_samples)
178
 
 
179
  model = MnistCNN(num_classes=CFG["num_classes"], dropout=CFG["dropout"]).to(DEVICE)
180
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
181
  criterion = nn.CrossEntropyLoss()
182
 
183
  logs = []
 
 
184
  for epoch in range(1, epochs + 1):
185
  model.train()
186
+ running_loss, correct, total = 0.0, 0, 0
 
 
187
 
188
+ for x, y in train_dl:
189
  x, y = x.to(DEVICE), y.to(DEVICE)
 
190
  optimizer.zero_grad()
191
  logits = model(x)
192
  loss = criterion(logits, y)
 
200
 
201
  train_loss = running_loss / max(1, len(train_dl))
202
  train_acc = correct / max(1, total)
 
203
  test_loss, test_acc = evaluate(model, test_dl)
204
 
205
  logs.append(
 
208
  f"test loss {test_loss:.4f} acc {test_acc:.4f}"
209
  )
210
 
 
211
  torch.save(model.state_dict(), WEIGHTS_PATH)
 
212
 
 
213
  with MODEL_LOCK:
214
  MODEL.load_state_dict(model.state_dict())
215
  MODEL.eval()
216
 
217
  elapsed = time.time() - start
218
+ status = f"Done. Saved `{WEIGHTS_PATH}`. Device: {DEVICE}. Time: {elapsed:.1f}s"
219
+ return status, "\n".join(logs)
220
 
221
 
222
+ def load_weights_ui():
223
  ok = maybe_load_weights()
224
+ return f"✅ Loaded `{WEIGHTS_PATH}`." if ok else f"⚠️ No `{WEIGHTS_PATH}` found yet. Train first."
 
 
225
 
226
 
227
+ # Try load at startup
228
+ maybe_load_weights()
229
 
230
 
231
  # -----------------------------
232
+ # Gradio UI (Gradio 6+)
233
  # -----------------------------
234
  with gr.Blocks() as demo:
235
+ gr.Markdown("# MNIST — Train + Predict (PyTorch custom `nn.Module`)")
236
+ gr.Markdown(f"- Running on: `{DEVICE}` \n- Weights file: `{WEIGHTS_PATH}`")
 
 
 
 
237
 
238
  with gr.Row():
239
  with gr.Column():
240
  gr.Markdown("## 1) Train (optional)")
241
+ epochs = gr.Slider(1, 5, value=1, step=1, label="Epochs")
242
  lr = gr.Number(value=1e-3, label="Learning rate", precision=6)
243
  batch = gr.Slider(32, 256, value=128, step=32, label="Batch size")
244
 
245
+ gr.Markdown("### Speed controls (smaller = faster)")
246
  max_train = gr.Slider(1000, 60000, value=10000, step=1000, label="Max train samples")
247
  max_test = gr.Slider(500, 10000, value=2000, step=500, label="Max test samples")
248
 
249
  train_btn = gr.Button("Train model")
250
  load_btn = gr.Button("Load saved weights")
251
 
 
252
  status = gr.Textbox(label="Status", lines=2)
253
+ train_log = gr.Textbox(label="Training log", lines=10)
254
 
255
  with gr.Column():
256
  gr.Markdown("## 2) Predict")
257
+
258
  with gr.Tab("Draw"):
259
+ # ImageEditor is the Gradio 6 way to draw/paint
260
+ draw_editor = gr.ImageEditor(
261
+ value=blank_editor_value,
262
+ type="pil",
263
+ canvas_size=(280, 280),
264
+ fixed_canvas=True,
265
+ label="Draw a digit (0–9)"
266
+ )
267
  draw_btn = gr.Button("Predict from drawing")
268
+
269
  with gr.Tab("Upload"):
270
+ up_img = gr.Image(type="pil", label="Upload a digit image")
271
  up_btn = gr.Button("Predict from upload")
272
 
273
  pred_out = gr.Number(label="Prediction")
274
  prob_out = gr.Label(num_top_classes=3, label="Probabilities (top 3)")
275
 
 
276
  train_btn.click(
277
  fn=train_mnist,
278
  inputs=[epochs, lr, batch, max_train, max_test],
279
+ outputs=[status, train_log],
 
 
 
 
280
  )
281
 
282
  load_btn.click(
283
+ fn=load_weights_ui,
284
  inputs=[],
285
  outputs=[status],
286
  )
287
 
288
  draw_btn.click(
289
+ fn=predict_from_editor,
290
+ inputs=[draw_editor],
291
  outputs=[pred_out, prob_out],
292
  )
293
 
294
  up_btn.click(
295
+ fn=predict_from_pil,
296
  inputs=[up_img],
297
  outputs=[pred_out, prob_out],
298
  )
299
 
300
 
301
  if __name__ == "__main__":
302
+ demo.launch()