Spaces:
Sleeping
Sleeping
Upload app.py
Browse files- 2.CNN/app.py +45 -33
2.CNN/app.py
CHANGED
|
@@ -35,6 +35,9 @@ forward_prop = training_mod.forward_prop
|
|
| 35 |
get_predictions = training_mod.get_predictions
|
| 36 |
softmax = training_mod.softmax
|
| 37 |
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def _metric_status(name, value):
|
| 40 |
target = METRIC_TARGETS.get(name)
|
|
@@ -128,36 +131,43 @@ def shift_with_zero_pad(arr, shift_y=0, shift_x=0):
|
|
| 128 |
|
| 129 |
|
| 130 |
def dilate_binary_like(arr, radius=1):
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
| 139 |
|
| 140 |
|
| 141 |
def erode_binary_like(arr, radius=1):
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
-
def generate_inference_variants(arr):
|
| 153 |
variants = []
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
for dy in (-1, 0, 1):
|
| 156 |
for dx in (-1, 0, 1):
|
| 157 |
if dy == 0 and dx == 0:
|
| 158 |
continue
|
| 159 |
variants.append(shift_with_zero_pad(arr, dy, dx))
|
| 160 |
-
# dilation/erosion to handle thin or thick strokes
|
| 161 |
variants.append(dilate_binary_like(arr, radius=1))
|
| 162 |
variants.append(erode_binary_like(arr, radius=1))
|
| 163 |
return variants
|
|
@@ -298,7 +308,7 @@ def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool
|
|
| 298 |
clamp=(0.7, 1.4),
|
| 299 |
)
|
| 300 |
|
| 301 |
-
augmented_arrays = [arr_resized, *generate_inference_variants(arr_resized)]
|
| 302 |
augmented_standardized = [
|
| 303 |
(arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
|
| 304 |
for arr in augmented_arrays
|
|
@@ -534,7 +544,7 @@ def predict_number(left_canvas, right_canvas, stroke_scale, auto_balance):
|
|
| 534 |
return pred, prob_rows, (preview * 255).astype(np.uint8), mean_diff, json.dumps(diagnostics, indent=2)
|
| 535 |
|
| 536 |
|
| 537 |
-
with gr.Blocks() as demo:
|
| 538 |
gr.Markdown(
|
| 539 |
"""
|
| 540 |
# Elliot's MNIST-100 Classifier
|
|
@@ -591,24 +601,26 @@ with gr.Blocks() as demo:
|
|
| 591 |
inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
|
| 592 |
outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
|
| 593 |
)
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
|
|
|
|
|
|
| 604 |
|
| 605 |
|
| 606 |
if __name__ == "__main__":
|
| 607 |
space_env = os.getenv("SPACE_ID")
|
| 608 |
if space_env:
|
| 609 |
-
demo.launch(show_api=False)
|
| 610 |
else:
|
| 611 |
-
demo.launch(server_name="0.0.0.0", share=True, show_api=False)
|
| 612 |
def _disable_gradio_api_schema(*_args, **_kwargs):
|
| 613 |
"""Work around Gradio schema bug on Python 3.13 by returning empty metadata."""
|
| 614 |
return {}
|
|
|
|
| 35 |
get_predictions = training_mod.get_predictions
|
| 36 |
softmax = training_mod.softmax
|
| 37 |
|
| 38 |
+
# Detect if running on Hugging Face Spaces
|
| 39 |
+
IS_SPACE = bool(os.getenv("SPACE_ID"))
|
| 40 |
+
|
| 41 |
|
| 42 |
def _metric_status(name, value):
|
| 43 |
target = METRIC_TARGETS.get(name)
|
|
|
|
| 131 |
|
| 132 |
|
| 133 |
def dilate_binary_like(arr, radius=1):
|
| 134 |
+
# Vectorized dilation via max over shifted windows (3x3 when radius=1)
|
| 135 |
+
if radius != 1:
|
| 136 |
+
# Fallback to radius=1 behavior for simplicity/perf
|
| 137 |
+
radius = 1
|
| 138 |
+
shifts = []
|
| 139 |
+
for dy in (-1, 0, 1):
|
| 140 |
+
for dx in (-1, 0, 1):
|
| 141 |
+
shifts.append(shift_with_zero_pad(arr, dy, dx))
|
| 142 |
+
stacked = np.stack(shifts, axis=0)
|
| 143 |
+
return np.max(stacked, axis=0)
|
| 144 |
|
| 145 |
|
| 146 |
def erode_binary_like(arr, radius=1):
|
| 147 |
+
# Vectorized erosion via min over shifted windows (3x3 when radius=1)
|
| 148 |
+
if radius != 1:
|
| 149 |
+
radius = 1
|
| 150 |
+
shifts = []
|
| 151 |
+
for dy in (-1, 0, 1):
|
| 152 |
+
for dx in (-1, 0, 1):
|
| 153 |
+
shifts.append(shift_with_zero_pad(arr, dy, dx))
|
| 154 |
+
stacked = np.stack(shifts, axis=0)
|
| 155 |
+
return np.min(stacked, axis=0)
|
| 156 |
|
| 157 |
|
| 158 |
+
def generate_inference_variants(arr, *, fast: bool = False):
|
| 159 |
variants = []
|
| 160 |
+
if fast:
|
| 161 |
+
# Space-optimized: only cardinal shifts (4 variants)
|
| 162 |
+
for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
|
| 163 |
+
variants.append(shift_with_zero_pad(arr, dy, dx))
|
| 164 |
+
return variants
|
| 165 |
+
# Full set: 8 shifts + morphology
|
| 166 |
for dy in (-1, 0, 1):
|
| 167 |
for dx in (-1, 0, 1):
|
| 168 |
if dy == 0 and dx == 0:
|
| 169 |
continue
|
| 170 |
variants.append(shift_with_zero_pad(arr, dy, dx))
|
|
|
|
| 171 |
variants.append(dilate_binary_like(arr, radius=1))
|
| 172 |
variants.append(erode_binary_like(arr, radius=1))
|
| 173 |
return variants
|
|
|
|
| 308 |
clamp=(0.7, 1.4),
|
| 309 |
)
|
| 310 |
|
| 311 |
+
augmented_arrays = [arr_resized, *generate_inference_variants(arr_resized, fast=IS_SPACE)]
|
| 312 |
augmented_standardized = [
|
| 313 |
(arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
|
| 314 |
for arr in augmented_arrays
|
|
|
|
| 544 |
return pred, prob_rows, (preview * 255).astype(np.uint8), mean_diff, json.dumps(diagnostics, indent=2)
|
| 545 |
|
| 546 |
|
| 547 |
+
with gr.Blocks(queue=True) as demo:
|
| 548 |
gr.Markdown(
|
| 549 |
"""
|
| 550 |
# Elliot's MNIST-100 Classifier
|
|
|
|
| 601 |
inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
|
| 602 |
outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
|
| 603 |
)
|
| 604 |
+
# On Spaces, avoid per-stroke inference to prevent event floods
|
| 605 |
+
if not IS_SPACE:
|
| 606 |
+
left_canvas.change(
|
| 607 |
+
fn=predict_number,
|
| 608 |
+
inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
|
| 609 |
+
outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
|
| 610 |
+
)
|
| 611 |
+
right_canvas.change(
|
| 612 |
+
fn=predict_number,
|
| 613 |
+
inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
|
| 614 |
+
outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
|
| 615 |
+
)
|
| 616 |
|
| 617 |
|
| 618 |
if __name__ == "__main__":
|
| 619 |
space_env = os.getenv("SPACE_ID")
|
| 620 |
if space_env:
|
| 621 |
+
demo.queue(concurrency_count=1).launch(show_api=False)
|
| 622 |
else:
|
| 623 |
+
demo.queue(concurrency_count=1).launch(server_name="0.0.0.0", share=True, show_api=False)
|
| 624 |
def _disable_gradio_api_schema(*_args, **_kwargs):
|
| 625 |
"""Work around Gradio schema bug on Python 3.13 by returning empty metadata."""
|
| 626 |
return {}
|