Eli181927 commited on
Commit
4ba948e
·
verified ·
1 Parent(s): 5c1d99f

Upload app.py

Browse files
Files changed (1) hide show
  1. 2.CNN/app.py +16 -21
2.CNN/app.py CHANGED
@@ -225,7 +225,7 @@ def compose_dual_canvas(left_input, right_input):
225
  return combined
226
 
227
 
228
- def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool = True):
229
  ensure_model_loaded()
230
  img = extract_canvas_array(img_input)
231
  if img is None:
@@ -301,14 +301,14 @@ def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool
301
  arr_resized = np.clip(arr_resized * stroke_scale, 0.0, 1.0)
302
 
303
  auto_balance_scale = 1.0
304
- balanced_mass_fraction = float(arr_resized.sum() / (TARGET_HEIGHT * TARGET_WIDTH))
305
- if auto_balance:
306
- target_mass = sum(METRIC_TARGETS["mass_fraction"]) / 2.0
307
- arr_resized, auto_balance_scale, balanced_mass_fraction = _auto_balance_stroke(
308
- arr_resized,
309
- target_mass_fraction=target_mass,
310
- clamp=(0.7, 1.4),
311
- )
312
 
313
  # Light recentering by center-of-mass to reduce sensitivity to placement
314
  mass = arr_resized
@@ -345,9 +345,11 @@ def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool
345
  std_safe,
346
  )
347
  diagnostics["applied_auto_balance"] = {
348
- "enabled": bool(auto_balance),
349
  "scale": float(auto_balance_scale),
350
  "mass_fraction_after": float(balanced_mass_fraction),
 
 
351
  }
352
 
353
  return augmented_standardized, arr_centered, mean_diff_uint8, diagnostics
@@ -521,7 +523,7 @@ def enrich_diagnostics(stats, probs):
521
  return stats
522
 
523
 
524
- def predict_number(left_canvas, right_canvas, stroke_scale, auto_balance):
525
  ensure_model_loaded()
526
  combined_canvas = compose_dual_canvas(left_canvas, right_canvas)
527
  if combined_canvas is None:
@@ -534,7 +536,6 @@ def predict_number(left_canvas, right_canvas, stroke_scale, auto_balance):
534
  result = preprocess_image(
535
  combined_canvas,
536
  stroke_scale=stroke_scale,
537
- auto_balance=bool(auto_balance),
538
  )
539
  if result is None:
540
  blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)}
@@ -582,11 +583,6 @@ with gr.Blocks() as demo:
582
  step=0.05,
583
  label="Stroke Intensity (scale)",
584
  )
585
- auto_balance = gr.Checkbox(
586
- value=True,
587
- label="Auto Balance Stroke Thickness",
588
- info="Automatically rescales the digit to match training mass and brightness.",
589
- )
590
 
591
  with gr.Column(scale=1):
592
  pred_box = gr.Number(label="Predicted Number", precision=0, value=None)
@@ -605,7 +601,6 @@ with gr.Blocks() as demo:
605
  left_canvas,
606
  right_canvas,
607
  stroke_slider,
608
- auto_balance,
609
  pred_box,
610
  prob_table,
611
  preview,
@@ -616,19 +611,19 @@ with gr.Blocks() as demo:
616
 
617
  predict_btn.click(
618
  fn=predict_number,
619
- inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
620
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
621
  )
622
  # On Spaces, avoid per-stroke inference to prevent event floods
623
  if not IS_SPACE:
624
  left_canvas.change(
625
  fn=predict_number,
626
- inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
627
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
628
  )
629
  right_canvas.change(
630
  fn=predict_number,
631
- inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
632
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
633
  )
634
 
 
225
  return combined
226
 
227
 
228
+ def preprocess_image(img_input, stroke_scale: float = 1.0):
229
  ensure_model_loaded()
230
  img = extract_canvas_array(img_input)
231
  if img is None:
 
301
  arr_resized = np.clip(arr_resized * stroke_scale, 0.0, 1.0)
302
 
303
  auto_balance_scale = 1.0
304
+ # Match the dataset's global mean intensity (more faithful than a fixed midpoint)
305
+ pre_balance_mass_fraction = float(arr_resized.mean())
306
+ target_mass = float(mean.mean())
307
+ arr_resized, auto_balance_scale, balanced_mass_fraction = _auto_balance_stroke(
308
+ arr_resized,
309
+ target_mass_fraction=target_mass,
310
+ clamp=(0.6, 1.6),
311
+ )
312
 
313
  # Light recentering by center-of-mass to reduce sensitivity to placement
314
  mass = arr_resized
 
345
  std_safe,
346
  )
347
  diagnostics["applied_auto_balance"] = {
348
+ "enabled": True,
349
  "scale": float(auto_balance_scale),
350
  "mass_fraction_after": float(balanced_mass_fraction),
351
+ "mass_fraction_before": float(pre_balance_mass_fraction),
352
+ "target_mass_fraction": float(target_mass),
353
  }
354
 
355
  return augmented_standardized, arr_centered, mean_diff_uint8, diagnostics
 
523
  return stats
524
 
525
 
526
+ def predict_number(left_canvas, right_canvas, stroke_scale):
527
  ensure_model_loaded()
528
  combined_canvas = compose_dual_canvas(left_canvas, right_canvas)
529
  if combined_canvas is None:
 
536
  result = preprocess_image(
537
  combined_canvas,
538
  stroke_scale=stroke_scale,
 
539
  )
540
  if result is None:
541
  blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)}
 
583
  step=0.05,
584
  label="Stroke Intensity (scale)",
585
  )
 
 
 
 
 
586
 
587
  with gr.Column(scale=1):
588
  pred_box = gr.Number(label="Predicted Number", precision=0, value=None)
 
601
  left_canvas,
602
  right_canvas,
603
  stroke_slider,
 
604
  pred_box,
605
  prob_table,
606
  preview,
 
611
 
612
  predict_btn.click(
613
  fn=predict_number,
614
+ inputs=[left_canvas, right_canvas, stroke_slider],
615
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
616
  )
617
  # On Spaces, avoid per-stroke inference to prevent event floods
618
  if not IS_SPACE:
619
  left_canvas.change(
620
  fn=predict_number,
621
+ inputs=[left_canvas, right_canvas, stroke_slider],
622
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
623
  )
624
  right_canvas.change(
625
  fn=predict_number,
626
+ inputs=[left_canvas, right_canvas, stroke_slider],
627
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
628
  )
629