Eli181927 commited on
Commit
5c1d99f
·
verified ·
1 Parent(s): 8dcf078

Upload app.py

Browse files
Files changed (1) hide show
  1. 2.CNN/app.py +31 -13
2.CNN/app.py CHANGED
@@ -158,9 +158,11 @@ def erode_binary_like(arr, radius=1):
158
  def generate_inference_variants(arr, *, fast: bool = False):
159
  variants = []
160
  if fast:
161
- # Space-optimized: only cardinal shifts (4 variants)
162
  for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
163
  variants.append(shift_with_zero_pad(arr, dy, dx))
 
 
164
  return variants
165
  # Full set: 8 shifts + morphology
166
  for dy in (-1, 0, 1):
@@ -308,17 +310,34 @@ def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool
308
  clamp=(0.7, 1.4),
309
  )
310
 
311
- augmented_arrays = [arr_resized, *generate_inference_variants(arr_resized, fast=IS_SPACE)]
312
- augmented_standardized = [
313
- (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
314
- for arr in augmented_arrays
315
- ]
316
-
317
- mean_diff = np.abs(arr_resized - mean_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8)
319
 
320
  diagnostics = compute_diagnostics(
321
- arr_resized,
322
  bbox,
323
  original_canvas_shape,
324
  mean_image,
@@ -331,7 +350,7 @@ def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool
331
  "mass_fraction_after": float(balanced_mass_fraction),
332
  }
333
 
334
- return augmented_standardized, arr_resized, mean_diff_uint8, diagnostics
335
 
336
 
337
  def compute_diagnostics(arr_float, bbox, original_shape, mean_image, standardized, std_safe):
@@ -528,9 +547,8 @@ def predict_number(left_canvas, right_canvas, stroke_scale, auto_balance):
528
 
529
  variants_matrix = np.concatenate(standardized_variants, axis=1).astype(np.float32, copy=False)
530
  cache, probs_matrix = forward_prop(variants_matrix, params, training=False)
531
- logits_matrix = cache["Z_fc2"]
532
- avg_logits = np.mean(logits_matrix, axis=1, keepdims=True)
533
- probs = softmax(avg_logits)
534
 
535
  pred = int(get_predictions(probs)[0])
536
 
 
158
  def generate_inference_variants(arr, *, fast: bool = False):
159
  variants = []
160
  if fast:
161
+ # Space-optimized: cardinal shifts plus light morphology (6 variants)
162
  for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
163
  variants.append(shift_with_zero_pad(arr, dy, dx))
164
+ variants.append(dilate_binary_like(arr, radius=1))
165
+ variants.append(erode_binary_like(arr, radius=1))
166
  return variants
167
  # Full set: 8 shifts + morphology
168
  for dy in (-1, 0, 1):
 
310
  clamp=(0.7, 1.4),
311
  )
312
 
313
+ # Light recentering by center-of-mass to reduce sensitivity to placement
314
+ mass = arr_resized
315
+ total_intensity = float(mass.sum())
316
+ arr_centered = arr_resized
317
+ if total_intensity > 1e-6:
318
+ gy, gx = np.indices(mass.shape)
319
+ cy = float((gy * mass).sum() / total_intensity)
320
+ cx = float((gx * mass).sum() / total_intensity)
321
+ ideal_cy = (TARGET_HEIGHT - 1) / 2.0
322
+ ideal_cx = (TARGET_WIDTH - 1) / 2.0
323
+ dy = int(np.clip(round(ideal_cy - cy), -2, 2))
324
+ dx = int(np.clip(round(ideal_cx - cx), -2, 2))
325
+ if dy != 0 or dx != 0:
326
+ arr_centered = shift_with_zero_pad(arr_resized, dy, dx)
327
+
328
+ augmented_arrays = [arr_centered, *generate_inference_variants(arr_centered, fast=IS_SPACE)]
329
+ # Standardize each variant and clip to tame outliers for stable inference
330
+ augmented_standardized = []
331
+ for arr in augmented_arrays:
332
+ z = (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
333
+ z = np.clip(z, -8.0, 8.0)
334
+ augmented_standardized.append(z.astype(np.float32, copy=False))
335
+
336
+ mean_diff = np.abs(arr_centered - mean_image)
337
  mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8)
338
 
339
  diagnostics = compute_diagnostics(
340
+ arr_centered,
341
  bbox,
342
  original_canvas_shape,
343
  mean_image,
 
350
  "mass_fraction_after": float(balanced_mass_fraction),
351
  }
352
 
353
+ return augmented_standardized, arr_centered, mean_diff_uint8, diagnostics
354
 
355
 
356
  def compute_diagnostics(arr_float, bbox, original_shape, mean_image, standardized, std_safe):
 
547
 
548
  variants_matrix = np.concatenate(standardized_variants, axis=1).astype(np.float32, copy=False)
549
  cache, probs_matrix = forward_prop(variants_matrix, params, training=False)
550
+ # Average probabilities across variants to reduce domination by any single variant
551
+ probs = np.mean(probs_matrix, axis=1, keepdims=True)
 
552
 
553
  pred = int(get_predictions(probs)[0])
554