Spaces:
Sleeping
Sleeping
Upload app.py
Browse files- 2.CNN/app.py +31 -13
2.CNN/app.py
CHANGED
|
@@ -158,9 +158,11 @@ def erode_binary_like(arr, radius=1):
|
|
| 158 |
def generate_inference_variants(arr, *, fast: bool = False):
|
| 159 |
variants = []
|
| 160 |
if fast:
|
| 161 |
-
# Space-optimized:
|
| 162 |
for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
|
| 163 |
variants.append(shift_with_zero_pad(arr, dy, dx))
|
|
|
|
|
|
|
| 164 |
return variants
|
| 165 |
# Full set: 8 shifts + morphology
|
| 166 |
for dy in (-1, 0, 1):
|
|
@@ -308,17 +310,34 @@ def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool
|
|
| 308 |
clamp=(0.7, 1.4),
|
| 309 |
)
|
| 310 |
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8)
|
| 319 |
|
| 320 |
diagnostics = compute_diagnostics(
|
| 321 |
-
|
| 322 |
bbox,
|
| 323 |
original_canvas_shape,
|
| 324 |
mean_image,
|
|
@@ -331,7 +350,7 @@ def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool
|
|
| 331 |
"mass_fraction_after": float(balanced_mass_fraction),
|
| 332 |
}
|
| 333 |
|
| 334 |
-
return augmented_standardized,
|
| 335 |
|
| 336 |
|
| 337 |
def compute_diagnostics(arr_float, bbox, original_shape, mean_image, standardized, std_safe):
|
|
@@ -528,9 +547,8 @@ def predict_number(left_canvas, right_canvas, stroke_scale, auto_balance):
|
|
| 528 |
|
| 529 |
variants_matrix = np.concatenate(standardized_variants, axis=1).astype(np.float32, copy=False)
|
| 530 |
cache, probs_matrix = forward_prop(variants_matrix, params, training=False)
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
probs = softmax(avg_logits)
|
| 534 |
|
| 535 |
pred = int(get_predictions(probs)[0])
|
| 536 |
|
|
|
|
| 158 |
def generate_inference_variants(arr, *, fast: bool = False):
|
| 159 |
variants = []
|
| 160 |
if fast:
|
| 161 |
+
# Space-optimized: cardinal shifts plus light morphology (6 variants)
|
| 162 |
for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
|
| 163 |
variants.append(shift_with_zero_pad(arr, dy, dx))
|
| 164 |
+
variants.append(dilate_binary_like(arr, radius=1))
|
| 165 |
+
variants.append(erode_binary_like(arr, radius=1))
|
| 166 |
return variants
|
| 167 |
# Full set: 8 shifts + morphology
|
| 168 |
for dy in (-1, 0, 1):
|
|
|
|
| 310 |
clamp=(0.7, 1.4),
|
| 311 |
)
|
| 312 |
|
| 313 |
+
# Light recentering by center-of-mass to reduce sensitivity to placement
|
| 314 |
+
mass = arr_resized
|
| 315 |
+
total_intensity = float(mass.sum())
|
| 316 |
+
arr_centered = arr_resized
|
| 317 |
+
if total_intensity > 1e-6:
|
| 318 |
+
gy, gx = np.indices(mass.shape)
|
| 319 |
+
cy = float((gy * mass).sum() / total_intensity)
|
| 320 |
+
cx = float((gx * mass).sum() / total_intensity)
|
| 321 |
+
ideal_cy = (TARGET_HEIGHT - 1) / 2.0
|
| 322 |
+
ideal_cx = (TARGET_WIDTH - 1) / 2.0
|
| 323 |
+
dy = int(np.clip(round(ideal_cy - cy), -2, 2))
|
| 324 |
+
dx = int(np.clip(round(ideal_cx - cx), -2, 2))
|
| 325 |
+
if dy != 0 or dx != 0:
|
| 326 |
+
arr_centered = shift_with_zero_pad(arr_resized, dy, dx)
|
| 327 |
+
|
| 328 |
+
augmented_arrays = [arr_centered, *generate_inference_variants(arr_centered, fast=IS_SPACE)]
|
| 329 |
+
# Standardize each variant and clip to tame outliers for stable inference
|
| 330 |
+
augmented_standardized = []
|
| 331 |
+
for arr in augmented_arrays:
|
| 332 |
+
z = (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
|
| 333 |
+
z = np.clip(z, -8.0, 8.0)
|
| 334 |
+
augmented_standardized.append(z.astype(np.float32, copy=False))
|
| 335 |
+
|
| 336 |
+
mean_diff = np.abs(arr_centered - mean_image)
|
| 337 |
mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8)
|
| 338 |
|
| 339 |
diagnostics = compute_diagnostics(
|
| 340 |
+
arr_centered,
|
| 341 |
bbox,
|
| 342 |
original_canvas_shape,
|
| 343 |
mean_image,
|
|
|
|
| 350 |
"mass_fraction_after": float(balanced_mass_fraction),
|
| 351 |
}
|
| 352 |
|
| 353 |
+
return augmented_standardized, arr_centered, mean_diff_uint8, diagnostics
|
| 354 |
|
| 355 |
|
| 356 |
def compute_diagnostics(arr_float, bbox, original_shape, mean_image, standardized, std_safe):
|
|
|
|
| 547 |
|
| 548 |
variants_matrix = np.concatenate(standardized_variants, axis=1).astype(np.float32, copy=False)
|
| 549 |
cache, probs_matrix = forward_prop(variants_matrix, params, training=False)
|
| 550 |
+
# Average probabilities across variants to reduce domination by any single variant
|
| 551 |
+
probs = np.mean(probs_matrix, axis=1, keepdims=True)
|
|
|
|
| 552 |
|
| 553 |
pred = int(get_predictions(probs)[0])
|
| 554 |
|