Eli181927 commited on
Commit
9617586
·
verified ·
1 Parent(s): c9f2cb0

Upload app.py

Browse files
Files changed (1) hide show
  1. 2.CNN/app.py +45 -33
2.CNN/app.py CHANGED
@@ -35,6 +35,9 @@ forward_prop = training_mod.forward_prop
35
  get_predictions = training_mod.get_predictions
36
  softmax = training_mod.softmax
37
 
 
 
 
38
 
39
  def _metric_status(name, value):
40
  target = METRIC_TARGETS.get(name)
@@ -128,36 +131,43 @@ def shift_with_zero_pad(arr, shift_y=0, shift_x=0):
128
 
129
 
130
  def dilate_binary_like(arr, radius=1):
131
- pad = radius
132
- padded = np.pad(arr, pad, mode="constant", constant_values=0.0)
133
- out = np.zeros_like(arr)
134
- for i in range(arr.shape[0]):
135
- for j in range(arr.shape[1]):
136
- window = padded[i : i + 2 * pad + 1, j : j + 2 * pad + 1]
137
- out[i, j] = window.max()
138
- return out
 
 
139
 
140
 
141
  def erode_binary_like(arr, radius=1):
142
- pad = radius
143
- padded = np.pad(arr, pad, mode="constant", constant_values=1.0)
144
- out = np.zeros_like(arr)
145
- for i in range(arr.shape[0]):
146
- for j in range(arr.shape[1]):
147
- window = padded[i : i + 2 * pad + 1, j : j + 2 * pad + 1]
148
- out[i, j] = window.min()
149
- return out
 
150
 
151
 
152
- def generate_inference_variants(arr):
153
  variants = []
154
- # slight shifts
 
 
 
 
 
155
  for dy in (-1, 0, 1):
156
  for dx in (-1, 0, 1):
157
  if dy == 0 and dx == 0:
158
  continue
159
  variants.append(shift_with_zero_pad(arr, dy, dx))
160
- # dilation/erosion to handle thin or thick strokes
161
  variants.append(dilate_binary_like(arr, radius=1))
162
  variants.append(erode_binary_like(arr, radius=1))
163
  return variants
@@ -298,7 +308,7 @@ def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool
298
  clamp=(0.7, 1.4),
299
  )
300
 
301
- augmented_arrays = [arr_resized, *generate_inference_variants(arr_resized)]
302
  augmented_standardized = [
303
  (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
304
  for arr in augmented_arrays
@@ -534,7 +544,7 @@ def predict_number(left_canvas, right_canvas, stroke_scale, auto_balance):
534
  return pred, prob_rows, (preview * 255).astype(np.uint8), mean_diff, json.dumps(diagnostics, indent=2)
535
 
536
 
537
- with gr.Blocks() as demo:
538
  gr.Markdown(
539
  """
540
  # Elliot's MNIST-100 Classifier
@@ -591,24 +601,26 @@ with gr.Blocks() as demo:
591
  inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
592
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
593
  )
594
- left_canvas.change(
595
- fn=predict_number,
596
- inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
597
- outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
598
- )
599
- right_canvas.change(
600
- fn=predict_number,
601
- inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
602
- outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
603
- )
 
 
604
 
605
 
606
  if __name__ == "__main__":
607
  space_env = os.getenv("SPACE_ID")
608
  if space_env:
609
- demo.launch(show_api=False)
610
  else:
611
- demo.launch(server_name="0.0.0.0", share=True, show_api=False)
612
  def _disable_gradio_api_schema(*_args, **_kwargs):
613
  """Work around Gradio schema bug on Python 3.13 by returning empty metadata."""
614
  return {}
 
35
  get_predictions = training_mod.get_predictions
36
  softmax = training_mod.softmax
37
 
38
+ # Detect if running on Hugging Face Spaces
39
+ IS_SPACE = bool(os.getenv("SPACE_ID"))
40
+
41
 
42
  def _metric_status(name, value):
43
  target = METRIC_TARGETS.get(name)
 
131
 
132
 
133
  def dilate_binary_like(arr, radius=1):
134
+ # Vectorized dilation via max over shifted windows (3x3 when radius=1)
135
+ if radius != 1:
136
+ # Fallback to radius=1 behavior for simplicity/perf
137
+ radius = 1
138
+ shifts = []
139
+ for dy in (-1, 0, 1):
140
+ for dx in (-1, 0, 1):
141
+ shifts.append(shift_with_zero_pad(arr, dy, dx))
142
+ stacked = np.stack(shifts, axis=0)
143
+ return np.max(stacked, axis=0)
144
 
145
 
146
  def erode_binary_like(arr, radius=1):
147
+ # Vectorized erosion via min over shifted windows (3x3 when radius=1)
148
+ if radius != 1:
149
+ radius = 1
150
+ shifts = []
151
+ for dy in (-1, 0, 1):
152
+ for dx in (-1, 0, 1):
153
+ shifts.append(shift_with_zero_pad(arr, dy, dx))
154
+ stacked = np.stack(shifts, axis=0)
155
+ return np.min(stacked, axis=0)
156
 
157
 
158
+ def generate_inference_variants(arr, *, fast: bool = False):
159
  variants = []
160
+ if fast:
161
+ # Space-optimized: only cardinal shifts (4 variants)
162
+ for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
163
+ variants.append(shift_with_zero_pad(arr, dy, dx))
164
+ return variants
165
+ # Full set: 8 shifts + morphology
166
  for dy in (-1, 0, 1):
167
  for dx in (-1, 0, 1):
168
  if dy == 0 and dx == 0:
169
  continue
170
  variants.append(shift_with_zero_pad(arr, dy, dx))
 
171
  variants.append(dilate_binary_like(arr, radius=1))
172
  variants.append(erode_binary_like(arr, radius=1))
173
  return variants
 
308
  clamp=(0.7, 1.4),
309
  )
310
 
311
+ augmented_arrays = [arr_resized, *generate_inference_variants(arr_resized, fast=IS_SPACE)]
312
  augmented_standardized = [
313
  (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
314
  for arr in augmented_arrays
 
544
  return pred, prob_rows, (preview * 255).astype(np.uint8), mean_diff, json.dumps(diagnostics, indent=2)
545
 
546
 
547
+ with gr.Blocks(queue=True) as demo:
548
  gr.Markdown(
549
  """
550
  # Elliot's MNIST-100 Classifier
 
601
  inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
602
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
603
  )
604
+ # On Spaces, avoid per-stroke inference to prevent event floods
605
+ if not IS_SPACE:
606
+ left_canvas.change(
607
+ fn=predict_number,
608
+ inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
609
+ outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
610
+ )
611
+ right_canvas.change(
612
+ fn=predict_number,
613
+ inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
614
+ outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
615
+ )
616
 
617
 
618
  if __name__ == "__main__":
619
  space_env = os.getenv("SPACE_ID")
620
  if space_env:
621
+ demo.queue(concurrency_count=1).launch(show_api=False)
622
  else:
623
+ demo.queue(concurrency_count=1).launch(server_name="0.0.0.0", share=True, show_api=False)
624
  def _disable_gradio_api_schema(*_args, **_kwargs):
625
  """Work around Gradio schema bug on Python 3.13 by returning empty metadata."""
626
  return {}