ymlin105 commited on
Commit
d9b5881
Β·
1 Parent(s): c173f49

feat: Refactor experiments and update report

Browse files
Files changed (28) hide show
  1. README.md +3 -3
  2. app.py +57 -12
  3. docs/REPORT.md +69 -65
  4. docs/research_results/fig_01_spectrum.png +2 -2
  5. docs/research_results/{fig_02_svd_confusion.png β†’ fig_01_svd_confusion.png} +2 -2
  6. docs/research_results/{fig_03_eigen_digits.png β†’ fig_02_eigen_digits.png} +2 -2
  7. docs/research_results/{fig_05_interpolation.png β†’ fig_03_interpolation.png} +2 -2
  8. docs/research_results/fig_04_cnn_confusion.png +2 -2
  9. docs/research_results/{fig_06_explainability.png β†’ fig_04_explainability.png} +0 -0
  10. docs/research_results/{fig_08_manifold_collapse.png β†’ fig_05_manifold_collapse.png} +2 -2
  11. docs/research_results/{fig_14_learning_curves.png β†’ fig_06_robustness_mnist_gaussian.png} +2 -2
  12. docs/research_results/fig_07_robustness_mnist_svd_aligned.png +3 -0
  13. docs/research_results/fig_08_robustness_fashion.png +3 -0
  14. docs/research_results/fig_09_learning_curves.png +3 -0
  15. docs/research_results/fig_10_per_class_metrics_comparison.png +3 -0
  16. docs/research_results/fig_19_per_class_metrics_comparison.png +0 -3
  17. docs/research_results/fig_robustness_fashion.png +0 -3
  18. docs/research_results/fig_robustness_mnist.png +0 -3
  19. experiments/01_phenomenon_diagnosis.py +2 -2
  20. experiments/{02_mechanistic_analysis.py β†’ 02_mechanistic_proof.py} +23 -17
  21. experiments/{run_robustness_test.py β†’ 03_operational_boundaries.py} +25 -9
  22. experiments/{appendix_learning_curves.py β†’ 04_appendix_learning_curves.py} +1 -1
  23. experiments/{appendix_per_class_metrics.py β†’ 05_appendix_per_class_metrics.py} +5 -2
  24. run_all_experiments.sh +4 -4
  25. run_migration.sh +68 -0
  26. src/exp_utils.py +35 -1
  27. src/hybrid_model.py +5 -0
  28. src/viz.py +37 -17
README.md CHANGED
@@ -10,14 +10,14 @@ app_file: app.py
10
  pinned: false
11
  ---
12
 
13
- # SVD vs CNN: Mechanistic Analysis of Manifold Alignment on MNIST
14
 
15
  [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ymlin105/Coconut-MNIST) [![Full Report](https://img.shields.io/badge/πŸ“–_Read-Full_Report-blue)](./docs/REPORT.md)
16
 
17
  While it is a known theoretical property that linear dimensionality reduction (SVD) acts as a low-pass filter, this project provides a **concrete, visual, and quantitative mechanistic explanation** of how this property manifests in neural network classificationβ€”specifically, why linear subspaces consistently force a "3" to collapse into an "8".
18
 
19
  <p align="center">
20
- <img src="./docs/research_results/fig_06_explainability.png" width="600" alt="Mechanistic Analysis of SVD Inductive Bias">
21
  </p>
22
 
23
  By mapping the exact decision boundaries where linear global variance models fail and non-linear topological models (CNNs) succeed, I empirically validate the **inherent trade-offs** of linear denoising in high-stakes domains like medical imaging or satellite dataβ€”where a linear filter might suppress critical diagnostic features to minimize noise variance.
@@ -93,7 +93,7 @@ streamlit run app.py
93
  ### Project Structure
94
  ```
95
  β”œβ”€β”€ src/ Core modules (CNN, SVD layer) + Experimental Utils
96
- β”œβ”€β”€ experiments/ Theme-based scripts (01 Diagnosis, 02 Analysis, 03 Robustness)
97
  β”œβ”€β”€ docs/ Full report (REPORT.md) + figures
98
  β”œβ”€β”€ models/ Pretrained checkpoints
99
  β”œβ”€β”€ run_all_experiments.sh One-click reproduction script
 
10
  pinned: false
11
  ---
12
 
13
+ # Why Does SVD Turn a "3" into an "8"? Linear vs. Non-linear Manifolds on MNIST
14
 
15
  [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ymlin105/Coconut-MNIST) [![Full Report](https://img.shields.io/badge/πŸ“–_Read-Full_Report-blue)](./docs/REPORT.md)
16
 
17
  While it is a known theoretical property that linear dimensionality reduction (SVD) acts as a low-pass filter, this project provides a **concrete, visual, and quantitative mechanistic explanation** of how this property manifests in neural network classificationβ€”specifically, why linear subspaces consistently force a "3" to collapse into an "8".
18
 
19
  <p align="center">
20
+ <img src="./docs/research_results/fig_04_explainability.png" width="600" alt="Mechanistic Analysis: SVD Blind Spot">
21
  </p>
22
 
23
  By mapping the exact decision boundaries where linear global variance models fail and non-linear topological models (CNNs) succeed, I empirically validate the **inherent trade-offs** of linear denoising in high-stakes domains like medical imaging or satellite dataβ€”where a linear filter might suppress critical diagnostic features to minimize noise variance.
 
93
  ### Project Structure
94
  ```
95
  β”œβ”€β”€ src/ Core modules (CNN, SVD layer) + Experimental Utils
96
+ β”œβ”€β”€ experiments/ Sequential scripts (01 Diagnosis, 02 Proof, 03 Boundaries)
97
  β”œβ”€β”€ docs/ Full report (REPORT.md) + figures
98
  β”œβ”€β”€ models/ Pretrained checkpoints
99
  β”œβ”€β”€ run_all_experiments.sh One-click reproduction script
app.py CHANGED
@@ -144,7 +144,13 @@ def get_reconstruction(svd_model, img_flat):
144
  recons = svd_model.inverse_transform(svd_model.transform(flat))
145
  if mean is not None:
146
  recons = recons + mean
147
- return torch.clamp(torch.tensor(recons).float(), 0, 1).view(28, 28)
 
 
 
 
 
 
148
 
149
 
150
  # --- UI Sidebar ---
@@ -161,6 +167,14 @@ with st.sidebar:
161
  )
162
  if noise_mode:
163
  st.success("SVD Denoising Active")
 
 
 
 
 
 
 
 
164
 
165
 
166
  # --- Initialization ---
@@ -204,11 +218,11 @@ with tab1:
204
  img_s = X_flat[np.where(y_orig == start_digit)[0][0]]
205
  img_e = X_flat[np.where(y_orig == end_digit)[0][0]]
206
  img_interp = (1 - alpha) * img_s + alpha * img_e
207
- img_svd = get_reconstruction(svd_model, img_interp)
208
 
209
  with torch.no_grad():
210
  logits = cnn_model(img_interp.view(1, 1, 28, 28))
211
- probs = torch.softmax(logits, dim=1)
212
  conf, pred = torch.max(probs, 1)
213
 
214
  # Visual Display
@@ -219,11 +233,35 @@ with tab1:
219
  with v2:
220
  st.markdown("**SVD Reconstruction**")
221
  st.image(img_svd.numpy(), width=150)
 
222
  with v3:
223
  st.markdown(f"**CNN Prediction: {pred.item()}**")
224
  st.progress(conf.item(), text=f"Confidence: {conf.item():.1%}")
225
  st.caption("Note: CNN 'snaps' at the topological midpoint, not a smooth transition.")
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  # --- Tab 2: Robustness (The SVD Advantage) ---
229
  with tab2:
@@ -237,12 +275,12 @@ with tab2:
237
 
238
  img_clean = X_flat[np.where(y_orig == noise_digit)[0][0]]
239
  img_noisy = torch.clamp(img_clean + torch.randn_like(img_clean) * sigma, 0, 1)
240
- img_denoised = get_reconstruction(svd_model, img_noisy)
241
 
242
  with col2:
243
  res1, res2 = st.columns(2)
244
  res1.image(img_noisy.view(28, 28).numpy(), caption="Noisy Input", width=150)
245
- res2.image(img_denoised.numpy(), caption="SVD Subspace Projection", width=150)
246
 
247
  st.markdown("---")
248
  st.markdown("#### Accuracy Breakdown at This Noise Level")
@@ -257,8 +295,10 @@ with tab2:
257
  results = data["results"]
258
 
259
  def interp(model_name: str, s: float) -> float:
 
 
260
  vals = np.array(results[model_name], dtype=float)
261
- return float(np.interp(s, levels, vals))
262
 
263
  acc_svd = interp("SVD", sigma)
264
  acc_cnn = interp("CNN", sigma)
@@ -334,7 +374,8 @@ with tab3:
334
  )
335
  fig_svd.update_layout(
336
  margin=dict(l=20, r=20, b=40, t=20),
337
- showlegend=False,
 
338
  xaxis_title="Component 1",
339
  yaxis_title="Component 2"
340
  )
@@ -353,7 +394,8 @@ with tab3:
353
  )
354
  fig_umap.update_layout(
355
  margin=dict(l=20, r=20, b=40, t=20),
356
- showlegend=False,
 
357
  xaxis_title="UMAP 1",
358
  yaxis_title="UMAP 2"
359
  )
@@ -421,7 +463,7 @@ with tab4:
421
  if add_noise:
422
  img_sample = torch.clamp(img_sample + torch.randn_like(img_sample) * noise_sigma, 0, 1)
423
 
424
- img_svd_sample = get_reconstruction(svd_model, img_sample)
425
 
426
  if noise_mode:
427
  cnn_input = img_svd_sample.view(1, 1, 28, 28)
@@ -429,7 +471,7 @@ with tab4:
429
  cnn_input = img_sample.view(1, 1, 28, 28)
430
 
431
  with torch.no_grad():
432
- probs = torch.softmax(cnn_model(cnn_input), dim=1)
433
  conf, pred = torch.max(probs, 1)
434
 
435
  r1, r2, r3 = st.columns(3)
@@ -441,6 +483,7 @@ with tab4:
441
  with r2:
442
  st.markdown("**SVD Projection**")
443
  st.image(img_svd_sample.numpy(), width=120)
 
444
 
445
  with r3:
446
  st.markdown(f"**Prediction: {pred.item()}**")
@@ -454,6 +497,7 @@ with tab4:
454
 
455
  else: # Draw Digit
456
  st.info("Draw a digit in the box below. SVD and CNN will analyze it in real-time.")
 
457
 
458
  col_canvas, col_preview = st.columns([1, 1])
459
 
@@ -481,7 +525,7 @@ with tab4:
481
  img_tensor = torch.tensor(img_resized, dtype=torch.float32) / 255.0
482
 
483
  img_flat_up = img_tensor.view(1, 784)
484
- img_svd_up = get_reconstruction(svd_model, img_flat_up)
485
 
486
  if noise_mode:
487
  cnn_input = img_svd_up.view(1, 1, 28, 28)
@@ -489,7 +533,7 @@ with tab4:
489
  cnn_input = img_tensor.view(1, 1, 28, 28)
490
 
491
  with torch.no_grad():
492
- probs = torch.softmax(cnn_model(cnn_input), dim=1)
493
  conf, pred = torch.max(probs, 1)
494
 
495
  with col_preview:
@@ -503,6 +547,7 @@ with tab4:
503
  with r1:
504
  st.markdown("**SVD View**")
505
  st.image(img_svd_up.numpy(), width=100)
 
506
  with r2:
507
  if noise_mode:
508
  st.caption("Using SVD Denoised Input")
 
144
  recons = svd_model.inverse_transform(svd_model.transform(flat))
145
  if mean is not None:
146
  recons = recons + mean
147
+
148
+ # Monitor truncation ratio
149
+ out_of_range = (recons < 0) | (recons > 1)
150
+ clamp_ratio = np.mean(out_of_range)
151
+
152
+ recons_tensor = torch.clamp(torch.tensor(recons).float(), 0, 1).view(28, 28)
153
+ return recons_tensor, clamp_ratio
154
 
155
 
156
  # --- UI Sidebar ---
 
167
  )
168
  if noise_mode:
169
  st.success("SVD Denoising Active")
170
+
171
+ st.markdown("---")
172
+ st.subheader("Model Calibration")
173
+ temp_scaling = st.slider(
174
+ "Softmax Temperature (T)",
175
+ 0.1, 5.0, 1.0, 0.1,
176
+ help="Higher T = smoother transitions (less over-confident), Lower T = sharper 'snaps'."
177
+ )
178
 
179
 
180
  # --- Initialization ---
 
218
  img_s = X_flat[np.where(y_orig == start_digit)[0][0]]
219
  img_e = X_flat[np.where(y_orig == end_digit)[0][0]]
220
  img_interp = (1 - alpha) * img_s + alpha * img_e
221
+ img_svd, clamp_ratio = get_reconstruction(svd_model, img_interp)
222
 
223
  with torch.no_grad():
224
  logits = cnn_model(img_interp.view(1, 1, 28, 28))
225
+ probs = torch.softmax(logits / temp_scaling, dim=1)
226
  conf, pred = torch.max(probs, 1)
227
 
228
  # Visual Display
 
233
  with v2:
234
  st.markdown("**SVD Reconstruction**")
235
  st.image(img_svd.numpy(), width=150)
236
+ st.caption(f"Truncation: {clamp_ratio:.1%}")
237
  with v3:
238
  st.markdown(f"**CNN Prediction: {pred.item()}**")
239
  st.progress(conf.item(), text=f"Confidence: {conf.item():.1%}")
240
  st.caption("Note: CNN 'snaps' at the topological midpoint, not a smooth transition.")
241
 
242
+ # --- Confidence Curve Visualization ---
243
+ alphas_curve = np.linspace(0, 1, 21)
244
+ curve_probs = []
245
+
246
+ with torch.no_grad():
247
+ for a in alphas_curve:
248
+ img_a = (1 - a) * img_s + a * img_e
249
+ logits_a = cnn_model(img_a.view(1, 1, 28, 28))
250
+ probs_a = torch.softmax(logits_a / temp_scaling, dim=1)
251
+ # Probability of being the "end_digit"
252
+ curve_probs.append(probs_a[0, end_digit].item())
253
+
254
+ st.markdown("---")
255
+ st.markdown(f"#### Confidence Snap: Probability of '{end_digit}'")
256
+
257
+ df_curve = pd.DataFrame({
258
+ "alpha": alphas_curve,
259
+ "Probability": curve_probs
260
+ }).set_index("alpha")
261
+
262
+ st.line_chart(df_curve, height=200)
263
+ st.caption(f"The vertical 'snap' in this curve highlights the non-linear decision boundary. Even as the pixels fade linearly, the CNN's internal representation jumps once a topological threshold is crossed.")
264
+
265
 
266
  # --- Tab 2: Robustness (The SVD Advantage) ---
267
  with tab2:
 
275
 
276
  img_clean = X_flat[np.where(y_orig == noise_digit)[0][0]]
277
  img_noisy = torch.clamp(img_clean + torch.randn_like(img_clean) * sigma, 0, 1)
278
+ img_denoised, clamp_ratio_robust = get_reconstruction(svd_model, img_noisy)
279
 
280
  with col2:
281
  res1, res2 = st.columns(2)
282
  res1.image(img_noisy.view(28, 28).numpy(), caption="Noisy Input", width=150)
283
+ res2.image(img_denoised.numpy(), caption=f"SVD Projection (Trunc: {clamp_ratio_robust:.1%})", width=150)
284
 
285
  st.markdown("---")
286
  st.markdown("#### Accuracy Breakdown at This Noise Level")
 
295
  results = data["results"]
296
 
297
  def interp(model_name: str, s: float) -> float:
298
+ # Robust boundary treatment for floating point sigma
299
+ s_clipped = np.clip(s, levels.min(), levels.max())
300
  vals = np.array(results[model_name], dtype=float)
301
+ return float(np.interp(s_clipped, levels, vals))
302
 
303
  acc_svd = interp("SVD", sigma)
304
  acc_cnn = interp("CNN", sigma)
 
374
  )
375
  fig_svd.update_layout(
376
  margin=dict(l=20, r=20, b=40, t=20),
377
+ showlegend=True,
378
+ legend=dict(orientation="h", yanchor="top", y=-0.15, xanchor="center", x=0.5),
379
  xaxis_title="Component 1",
380
  yaxis_title="Component 2"
381
  )
 
394
  )
395
  fig_umap.update_layout(
396
  margin=dict(l=20, r=20, b=40, t=20),
397
+ showlegend=True,
398
+ legend=dict(orientation="h", yanchor="top", y=-0.15, xanchor="center", x=0.5),
399
  xaxis_title="UMAP 1",
400
  yaxis_title="UMAP 2"
401
  )
 
463
  if add_noise:
464
  img_sample = torch.clamp(img_sample + torch.randn_like(img_sample) * noise_sigma, 0, 1)
465
 
466
+ img_svd_sample, clamp_ratio_lab = get_reconstruction(svd_model, img_sample)
467
 
468
  if noise_mode:
469
  cnn_input = img_svd_sample.view(1, 1, 28, 28)
 
471
  cnn_input = img_sample.view(1, 1, 28, 28)
472
 
473
  with torch.no_grad():
474
+ probs = torch.softmax(cnn_model(cnn_input) / temp_scaling, dim=1)
475
  conf, pred = torch.max(probs, 1)
476
 
477
  r1, r2, r3 = st.columns(3)
 
483
  with r2:
484
  st.markdown("**SVD Projection**")
485
  st.image(img_svd_sample.numpy(), width=120)
486
+ st.caption(f"Truncation: {clamp_ratio_lab:.1%}")
487
 
488
  with r3:
489
  st.markdown(f"**Prediction: {pred.item()}**")
 
497
 
498
  else: # Draw Digit
499
  st.info("Draw a digit in the box below. SVD and CNN will analyze it in real-time.")
500
+ st.caption("*Tip: Draw the digit large and centered for best results.*")
501
 
502
  col_canvas, col_preview = st.columns([1, 1])
503
 
 
525
  img_tensor = torch.tensor(img_resized, dtype=torch.float32) / 255.0
526
 
527
  img_flat_up = img_tensor.view(1, 784)
528
+ img_svd_up, clamp_ratio_draw = get_reconstruction(svd_model, img_flat_up)
529
 
530
  if noise_mode:
531
  cnn_input = img_svd_up.view(1, 1, 28, 28)
 
533
  cnn_input = img_tensor.view(1, 1, 28, 28)
534
 
535
  with torch.no_grad():
536
+ probs = torch.softmax(cnn_model(cnn_input) / temp_scaling, dim=1)
537
  conf, pred = torch.max(probs, 1)
538
 
539
  with col_preview:
 
547
  with r1:
548
  st.markdown("**SVD View**")
549
  st.image(img_svd_up.numpy(), width=100)
550
+ st.caption(f"Truncation: {clamp_ratio_draw:.1%}")
551
  with r2:
552
  if noise_mode:
553
  st.caption("Using SVD Denoised Input")
docs/REPORT.md CHANGED
@@ -1,125 +1,129 @@
1
- # Mechanistic Analysis of Linear vs. Non-linear Manifolds on MNIST: A Validation of Inductive Biases
2
 
3
- While linear dimensionality reduction (SVD) is a standard low-pass filter, its failure modes in classification are often described as "accuracy drops" without mechanistic explanation. This report provides a concrete, visual, and quantitative analysis of how linear subspaces consistently force a "3" to collapse into an "8". By mapping the exact decision boundaries where global variance models fail and non-linear topological models (CNNs) succeed, we empirically validate the inherent trade-offs of linear denoising in classification tasks.
 
 
 
 
 
 
 
 
 
 
4
 
5
  ---
6
 
7
- ## I. Executive Summary
8
 
9
- - **The Variance Trap**: SVD-like methods behave as global low-pass filters that prioritize high-energy shared structures over local discriminative cues. In digits like '3', the critical topological gap has low pixel variance and is thus suppressed as "noise."
10
- - **Quantitative Manifold Collapse**: We demonstrate that linear projection **more than doubles the classification error rate (+130% relative increase)** on ambiguous pairs, providing "iron-clad" proof that SVD intrinsically destroys discriminative manifold information.
11
- - **Operational Boundary**: A Hybrid SVD→CNN pipeline provides significant robustness gains in high-noise environments ($\sigma=0.7$) but fails on texture-rich data (Fashion-MNIST), defining its scope as a specialized shape-driven filter.
 
 
 
12
 
13
  ---
14
 
15
- ## II. The Phenomenon: Linear Subspace Failure on Digit Manifolds
16
 
17
- Linear dimensionality reduction (SVD) on full MNIST yields an accuracy of 88.12%. However, this aggregate metric masks a systematic failure mode: digits with high pixel-wise overlap (3/8, 5/3, 4/9) exhibit catastrophic confusion.
18
 
19
- ### Spectral Diagnosis: Variance vs. Discrimination
20
- Applying SVD to the isolated '3 vs 8' subset reveals that the first 10 principal components (eigen-digits) capture 49.2% of the total variance.
21
 
22
  <p align="center">
23
- <img src="research_results/fig_02_svd_confusion.png" alt="Fig 1: Confusion matrix for Global SVD" width="350" />
24
- <img src="research_results/fig_03_eigen_digits.png" alt="Fig 3: Eigen-digits" width="350" />
25
  <br>
26
- <em>Figure 1 & 2: SVD confusion matrix and the resulting eigen-digits. The principal components emphasize the shared circular energy common to both digits (the "8-like" outline), while attenuating the discriminative gap of the '3' as low-variance residual.</em>
27
  </p>
28
 
29
- This identifies the **Variance-Discrimination Paradox**: SVD optimizes for reconstruction (global energy) rather than separation (local topology). Since the "8-like" silhouette contains more pixel variance than the tiny gap in a '3', the linear model "hallucinates" a closed loop to minimize reconstruction error.
30
-
31
  ---
32
 
33
- ## III. Mechanistic Proof: Global Variance vs. Local Topology
34
 
35
- To confirm that the failure is intrinsic to the linear projection method, we contrast SVD with a small CNN and a non-linear manifold mapping (UMAP).
36
 
37
- ### 1. Dynamic Snap vs. The Variance Trap
38
- We observed class responses while smoothly interpolating a '3' into an '8'.
39
 
40
  <p align="center">
41
- <img src="research_results/fig_05_interpolation.png" alt="Fig 5: Decision Boundary Interpolation" width="700" />
42
  <br>
43
- <em>Figure 3: CNN class probability (light blue) vs SVD reconstruction error (deep blue). The CNN's sharp "snap" indicates a learned topological boundary, while the SVD's U-shaped error dip at the midpoint proves it treats blurred, overlapping superpositions as higher-fidelity matches than the original digits.</em>
44
  </p>
45
 
46
- ### 2. Static Attention and Hallucination
47
- Grad-CAM heatmaps confirm that the CNN focuses exclusively on the **topological gap**. In contrast, SVD reconstruction forcibly closes this gap to satisfy global energy constraints, effectively "reconstructing" a phantom 8.
48
 
49
  <p align="center">
50
- <img src="research_results/fig_06_explainability.png" alt="Fig 6: Grad-CAM vs SVD Inductive Bias" width="700" />
51
  <br>
52
- <em>Figure 4: Grad-CAM attention (center) vs. SVD reconstruction (right). CNN attention on the gap confirms it classifies by shape discontinuity; SVD's hallucination proves it classifies by global pixel coincidence.</em>
53
  </p>
54
 
55
- ### 3. Quantifying Manifold Collapse
56
- We use the internal accuracy of a $k$-Nearest Neighbors ($k$-NN, $k=5$) classifier as a strict benchmark for local neighborhood integrity.
57
-
58
- * **Raw 784D Pixel Space: 98.7% Accuracy (1.3% Error Rate).**
59
- * **SVD 10D Subspace: 97.0% Accuracy (3.0% Error Rate).**
60
-
61
- The error rate **more than doubles (+130% relative increase)** after linear projection. This provides iron-clad proof that SVD structurally forces distinct local neighborhoods of 3s and 8s to overlap, destroying information that is physically present in the pixels.
62
 
63
  <p align="center">
64
- <img src="research_results/fig_08_manifold_collapse.png" alt="Figure 5: Manifold Comparison" width="600"/>
65
  <br>
66
- <em>Figure 5: Side-by-side manifold contrast. SVD (left) collapses boundaries to maximize variance, while UMAP (right) preserves the manifold separation required for high-accuracy discrimination.</em>
67
  </p>
68
 
69
  ---
70
 
71
- ## IV. Operational Boundaries: High-Noise Defense and Texture Limits
72
 
73
- While SVD fails as a standalone discriminator, its low-pass filtering property provides a powerful inductive bias in high-noise environments.
74
 
75
- ### 1. SVD as a Data-Adapted Denoising Filter
76
- In environments with heavy Gaussian noise ($\sigma=0.7$), a standalone CNN collapses to 30.4% accuracy. However, a **Hybrid SVD→CNN** pipeline maintains **65.3% accuracy**, outperforming both the pure CNN and naive Gaussian blurring.
 
77
 
78
  <p align="center">
79
- <img src="research_results/fig_robustness_mnist.png" alt="Fig 10: Hybrid Robustness" width="500" />
 
80
  <br>
81
- <em>Figure 6: Robustness gain on MNIST. SVD reconstruction acts as a data-adapted filter, discarding destructive high-frequency noise before CNN feature extraction.</em>
82
  </p>
83
 
84
- ### 2. The Texture Breakdown (Fashion-MNIST)
85
- The defense fails on **Fashion-MNIST**, where accuracy collapses from 91% (CNN) to 67% (Hybrid). Unlike digits, fashion items (Shirts vs Pullovers) rely on high-frequency textures (buttons, collars). SVD's global silhouette objective suppresses these critical textures, defining the **physical limit** of linear denoising.
 
 
 
 
 
86
 
87
  <p align="center">
88
- <img src="research_results/fig_robustness_fashion.png" alt="Fig 11: Fashion-MNIST Robustness" width="500" />
89
  <br>
90
- <em>Figure 7: Texture collapse on Fashion-MNIST. Unlike digits, fashion items rely on high-frequency details that SVD's low-pass bias catastrophically suppresses, leading to poor robustness compared to direct CNN classification.</em>
91
  </p>
92
 
93
  ---
94
 
95
  ## Conclusion
96
- This investigation proves that SVD's reliance on global variance conflates discriminative local features with noise. While this serves as an effective low-pass filter for shape-dominated data (MNIST) under extreme corruption, it systematically degrades the manifold boundaries required for precise non-linear classification.
97
 
98
- ---
99
 
100
- ## Appendix: Implementation & Technical Details
101
 
102
- ### A. Training & Reproducibility
103
- - **Trainer**: Adam optimizer ($lr=10^{-3}$, batch size 64) with stratified 80/20 train/val splits.
104
- - **Denoising**: All SVD components use explicit mean-centering and a rank of $k=20$.
105
- - **Early Stopping**: Early stopping with a patience of 3 monitored on validation accuracy prevented overfitting, typically converging in 5-8 epochs.
106
 
 
 
107
  <p align="center">
108
- <img src="research_results/fig_14_learning_curves.png" alt="Figure A1: Learning Curves" width="450" />
109
- <br>
110
- <em>Figure A1: Standardized learning curves showing convergence and early-stopping preservation.</em>
111
  </p>
112
 
113
- ### B. Per-Class Metrics & SVD Failure Clusters
114
- SVD failure is concentrated in specific clusters with high energy overlap:
115
- - **Digit 5 (81.3% F1)**: Systematic confusion with digit 3 due to similar stroke energy.
116
- - **Digit 9 (83.6% F1)**: Confusion with 4 due to loop vs. open-top similarity.
117
-
118
  <p align="center">
119
- <img src="research_results/fig_19_per_class_metrics_comparison.png" alt="Figure A2: F1-Score Comparison" width="800" />
120
- <br>
121
- <em>Figure A2: Side-by-side per-class comparison highlighting SVD's failure regions.</em>
122
  </p>
123
-
124
- ### C. Grad-CAM Implementation Note
125
- To ensure accurate saliency mapping, we utilized `register_full_backward_hook` to capture complete gradient tensors from the intermediate convolutional layers, avoiding the gradient-drop issues found in legacy PyTorch hooks.
 
1
+ # Why Does SVD Turn a "3" into an "8"? A Mechanistic Comparison of Linear vs. Non-linear Manifolds
2
 
3
+ Why do linear models fail at tasks that seem trivial to a human or a simple neural network? While it is a known property that linear dimensionality reduction (SVD) acts as a low-pass filter, this report provides a **concrete, visual, and mechanistic explanation** of how this manifests in classificationβ€”specifically, why linear subspaces force a "3" to collapse into an "8".
4
+
5
+ By mapping the decision boundaries where linear global variance models fail and non-linear topological models (CNNs) succeed, we validate the **inherent trade-offs** of linear denoising in high-stakes domains like medical imaging or satellite dataβ€”where a linear filter might suppress critical diagnostic features to minimize noise variance.
6
+
7
+ ---
8
+
9
+ ## TL;DR: The 15-Second Summary
10
+
11
+ - **The Problem (The Variance Trap):** SVD prioritizes global pixel energy. In a '3', the tiny gap that distinguishes it from an '8' has very low variance, so SVD deletes it as "noise."
12
+ - **The Mechanism:** Linear models see **Global Energy** (the overall silhouette), while CNNs see **Local Topology** (the gap). SVD literally "welds" the ends of a '3' together to minimize reconstruction error.
13
+ - **The Solution & Boundary: We built a Hybrid SVD→CNN pipeline.** While SVD fails as a standalone classifier, it works as a powerful **low-pass filter** and defensive shield against high noise ($\sigma=0.7$), provided the data isn't too texture-rich (like Fashion-MNIST).
14
 
15
  ---
16
 
17
+ ## The Investigative Approach
18
 
19
+ ```text
20
+ Diagnosis Mechanism Solution & Boundary
21
+ ───────────────────── ───────────────────── ─────────────────────
22
+ SVD fails on 3 vs 8 → Why? Grad-CAM + UMAP → Hybrid SVD→CNN pipeline
23
+ (The Variance Trap) (Global vs. Local) + Texture stress test
24
+ ```
25
 
26
  ---
27
 
28
+ ## 1. Diagnosis: The "3 vs 8" Failure Mode
29
 
30
+ Aggregate accuracy metrics often hide the real story. While SVD achieves 88.1% accuracy on MNIST, it systematically fails on digits with high pixel overlap.
31
 
32
+ ### The Variance Trap
33
+ Linear dimensionality reduction (SVD) treats classification like a reconstruction problem. It looks for the directions of maximum variance (total pixel brightness). In the cluster of 3s and 8s, the shared "8-like" outline contains the most energy. The small gap that makes a '3' unique is mathematically ignored.
34
 
35
  <p align="center">
36
+ <img src="research_results/fig_01_svd_confusion.png" alt="SVD Confusion Matrix" width="350" />
37
+ <img src="research_results/fig_02_eigen_digits.png" alt="SVD Eigen-digits" width="350" />
38
  <br>
39
+ <em><strong>Figure 1 & 2:</strong> SVD concentrates its errors on ambiguous pairs. The first few "eigen-digits" capture the shared circular structure, smoothing over the critical discriminative gap.</em>
40
  </p>
41
 
 
 
42
  ---
43
 
44
+ ## 2. Mechanism: Global Energy vs. Local Topology
45
 
46
+ To understand *why* this happens, we compared how a CNN (non-linear) and SVD (linear) "see" the same image.
47
 
48
+ ### Linear Hallucination
49
+ When we interpolate a '3' into an '8', the CNN shows a sharp "snap" in confidenceβ€”it recognizes a topological boundary. In contrast, SVD's reconstruction error peaks at the midpoint because it's trying to fit a "linear bridge" between two distinct manifolds.
50
 
51
  <p align="center">
52
+ <img src="research_results/fig_03_interpolation.png" alt="Decision Boundary Interpolation" width="700" />
53
  <br>
54
+ <em><strong>Figure 3:</strong> CNN probability vs. Manifold Distance. The CNN's sharp boundary persists, while SVD creates "ghost" images at the midpoint that don't belong to either digit.</em>
55
  </p>
56
 
57
+ ### Local Topology: The Gap is the Signal
58
+ Grad-CAM heatmaps confirm that a CNN focuses exclusively on the **topological gap**. SVD, however, reconstructs a phantom loop. The linear model is "blind" to the gap because it optimizes for global pixel coincidence rather than shape continuity.
59
 
60
  <p align="center">
61
+ <img src="research_results/fig_04_explainability.png" alt="Grad-CAM vs SVD" width="700" />
62
  <br>
63
+ <em><strong>Figure 4:</strong> CNN attention (center) vs. SVD reconstruction (right). CNNs classify by shape discontinuity; SVD "hallucinates" a loop to satisfy energy constraints.</em>
64
  </p>
65
 
66
+ ### Quantifying Manifold Collapse
67
+ Using $k$-NN as a benchmark, we measured the damage:
68
+ - **Raw Pixel Space:** 98.7% Accuracy
69
+ - **SVD Subspace:** 97.0% Accuracy
70
+ This **130% relative increase in error** proves that SVD physically crushes the separation between 3s and 8s.
 
 
71
 
72
  <p align="center">
73
+ <img src="research_results/fig_05_manifold_collapse.png" alt="Manifold Comparison" width="600"/>
74
  <br>
75
+ <em><strong>Figure 5:</strong> SVD (left) collapses boundaries to maximize variance, whereas UMAP (right) preserves the topological separation required for high accuracy.</em>
76
  </p>
77
 
78
  ---
79
 
80
+ ## 3. Solution: Success on Low-Rank Manifolds (MNIST)
81
 
82
+ If SVD is so bad at discriminating, why use it? Because its "Variance Trap" is a perfect **Low-Pass Filter**.
83
 
84
+ In high-noise environments ($\sigma=0.7$), a raw CNN's accuracy drops to **30.1%**. A **Hybrid SVD→CNN** pipeline, however, maintains **65.5%** accuracy. By discarding low-variance dimensions, SVD acts as a "data-adapted shield," stripping away random Gaussian noise before it reaches the classifier.
85
+
86
+ However, this shield has a **blind spot**: if the noise is carefully aligned with the data's principal components (**Fig 7**), SVD preserves the noise rather than filtering it, making the model even more vulnerable than a raw CNN.
87
 
88
  <p align="center">
89
+ <img src="research_results/fig_06_robustness_mnist_gaussian.png" alt="Hybrid Robustness" width="450" />
90
+ <img src="research_results/fig_07_robustness_mnist_svd_aligned.png" alt="Subspace Risk" width="450" />
91
  <br>
92
+ <em><strong>Figure 6 & 7:</strong> SVD dominates under random noise (left) but becomes a liability if the noise is "aligned" with the data subspace (right), proving the defense is narrow-band.</em>
93
  </p>
94
 
95
+ ---
96
+
97
+ ## 4. Boundary: Failure on Texture-Rich Manifolds (Fashion-MNIST)
98
+
99
+ The SVD defense only works when the objects are simple silhouettes. On **Fashion-MNIST**, the strategy collapses.
100
+
101
+ Items like Shirts and Pullovers aren't distinguished by global outlines, but by **high-frequency textures** (buttons, zippers, collar stitching). SVD's low-pass bias treats these textures as noise and deletes them. Accuracy drops from 91% (CNN) to 67% (Hybrid), defining the physical limit of linear denoising.
102
 
103
  <p align="center">
104
+ <img src="research_results/fig_08_robustness_fashion.png" alt="Fashion texture collapse" width="500" />
105
  <br>
106
+ <em><strong>Figure 8:</strong> On texture-rich data, SVD's "low-pass filter" becomes a "detail-destroyer," failing to preserve the features needed for non-linear models to succeed.</em>
107
  </p>
108
 
109
  ---
110
 
111
  ## Conclusion
 
112
 
113
+ This study proves that SVD's fundamental bias toward **Global Variance** makes it a poor standalone classifier but a specialized defensive tool. It excels at denoising simple manifolds but fails catastrophically when locally discriminative details (like a gap in a '3' or a button on a shirt) are suppressed in favor of global energy.
114
 
115
+ ---
116
 
117
+ ## Appendix: Technical Details
 
 
 
118
 
119
+ ### A. Learning Curves
120
+ Convergence was typically reached within 5-8 epochs using the Adam optimizer.
121
  <p align="center">
122
+ <img src="research_results/fig_09_learning_curves.png" alt="Learning Curves" width="450" />
 
 
123
  </p>
124
 
125
+ ### B. Per-Class F1 Comparisons
126
+ SVD failures are consistently clustered in "Ambiguity Zones" (3 vs 8, 5 vs 3, 4 vs 9), where pixel-wise overlap is highest.
 
 
 
127
  <p align="center">
128
+ <img src="research_results/fig_10_per_class_metrics_comparison.png" alt="F1 Comparison" width="800" />
 
 
129
  </p>
 
 
 
docs/research_results/fig_01_spectrum.png CHANGED

Git LFS Details

  • SHA256: 20a9e5a6efd12cdf2dd316142ad11f2e29fdcf2d9d874753399f790edd277ebd
  • Pointer size: 131 Bytes
  • Size of remote file: 201 kB

Git LFS Details

  • SHA256: ae232b53b26d5529b70a7645124e1187bcf0bf3e4268518bca3ae31d87e7ebce
  • Pointer size: 131 Bytes
  • Size of remote file: 201 kB
docs/research_results/{fig_02_svd_confusion.png β†’ fig_01_svd_confusion.png} RENAMED
File without changes
docs/research_results/{fig_03_eigen_digits.png β†’ fig_02_eigen_digits.png} RENAMED
File without changes
docs/research_results/{fig_05_interpolation.png β†’ fig_03_interpolation.png} RENAMED
File without changes
docs/research_results/fig_04_cnn_confusion.png CHANGED

Git LFS Details

  • SHA256: b8ba8c837a449145503cc7e69534c6ee0606d6e1416c9200714619ba481819cd
  • Pointer size: 131 Bytes
  • Size of remote file: 242 kB

Git LFS Details

  • SHA256: 0939e42b5a67c027fd044e6bedfe1f8c70abd823b6c3a1e8ec5b51423ecd0ef6
  • Pointer size: 131 Bytes
  • Size of remote file: 241 kB
docs/research_results/{fig_06_explainability.png β†’ fig_04_explainability.png} RENAMED
File without changes
docs/research_results/{fig_08_manifold_collapse.png β†’ fig_05_manifold_collapse.png} RENAMED
File without changes
docs/research_results/{fig_14_learning_curves.png β†’ fig_06_robustness_mnist_gaussian.png} RENAMED
File without changes
docs/research_results/fig_07_robustness_mnist_svd_aligned.png ADDED

Git LFS Details

  • SHA256: 130b63b6cc155fd68c10b10e24b3eea29ab9b99ed28c9a88ba7a6bbdd994d155
  • Pointer size: 131 Bytes
  • Size of remote file: 195 kB
docs/research_results/fig_08_robustness_fashion.png ADDED

Git LFS Details

  • SHA256: 4e25ad7173308f2a6540e3708cd45b408a6ff04dd651afd91bf9852ba9e8dfa3
  • Pointer size: 131 Bytes
  • Size of remote file: 237 kB
docs/research_results/fig_09_learning_curves.png ADDED

Git LFS Details

  • SHA256: f78aebb3653e617493de7b8b1d3b0d7ef31584686ccd808fbffe3100800afd1c
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
docs/research_results/fig_10_per_class_metrics_comparison.png ADDED

Git LFS Details

  • SHA256: e967876dc8c9da34ee9ffd993bb965d20a1fafb3e1eaf4474889cb8e856da632
  • Pointer size: 130 Bytes
  • Size of remote file: 87.4 kB
docs/research_results/fig_19_per_class_metrics_comparison.png DELETED

Git LFS Details

  • SHA256: 70069fa6358733e6f4fa7a2e772e19e87e3c2e8deab7c2b62b11c11915debdd7
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB
docs/research_results/fig_robustness_fashion.png DELETED

Git LFS Details

  • SHA256: 4bde830bd6a1e1acf9c29d106422882298a47e4ab9b38f5077be6a6c4aa55f3f
  • Pointer size: 131 Bytes
  • Size of remote file: 231 kB
docs/research_results/fig_robustness_mnist.png DELETED

Git LFS Details

  • SHA256: f31ea6cb5a254f637446b2350a3cd07ba510dba6a7ed6a5d92cb8eac8d001be6
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
experiments/01_phenomenon_diagnosis.py CHANGED
@@ -43,7 +43,7 @@ def run_svd_analysis(X_train, y_train, X_test, y_test):
43
  # 3. Visualization: Confusion Matrix & Eigen-digits
44
  viz.plot_confusion_matrix(
45
  y_test, y_pred, list(range(10)),
46
- 'fig_02_svd_confusion.png',
47
  f'SVD Confusion Matrix (Acc={acc:.2f})',
48
  viz.COLOR_SVD
49
  )
@@ -52,7 +52,7 @@ def run_svd_analysis(X_train, y_train, X_test, y_test):
52
  viz.plot_multi_image_grid(
53
  [c.reshape(28, 28) for c in svd_20.components_[:10]],
54
  component_titles, 2, 5,
55
- 'fig_03_eigen_digits.png',
56
  'Global SVD Eigen-digits'
57
  )
58
 
 
43
  # 3. Visualization: Confusion Matrix & Eigen-digits
44
  viz.plot_confusion_matrix(
45
  y_test, y_pred, list(range(10)),
46
+ 'fig_01_svd_confusion.png',
47
  f'SVD Confusion Matrix (Acc={acc:.2f})',
48
  viz.COLOR_SVD
49
  )
 
52
  viz.plot_multi_image_grid(
53
  [c.reshape(28, 28) for c in svd_20.components_[:10]],
54
  component_titles, 2, 5,
55
+ 'fig_02_eigen_digits.png',
56
  'Global SVD Eigen-digits'
57
  )
58
 
experiments/{02_mechanistic_analysis.py β†’ 02_mechanistic_proof.py} RENAMED
@@ -8,7 +8,7 @@ import torch
8
  import torch.nn as nn
9
  import numpy as np
10
  from sklearn.decomposition import TruncatedSVD
11
- from sklearn.neighbors import KNeighborsClassifier
12
  from sklearn.metrics import accuracy_score
13
 
14
  from src import config, utils, viz, exp_utils
@@ -31,30 +31,36 @@ def run_interpolation_analysis(device):
31
  img_3, img_8 = X_test[idx_3], X_test[idx_8]
32
 
33
  alphas = np.linspace(0, 1, 11)
34
- probs_8, rec_errors = [], []
 
 
 
 
 
35
 
36
  for alpha in alphas:
37
  img_interp = (1 - alpha) * img_3 + alpha * img_8
38
- # CNN Probability of class 1 (Digit 8)
39
- with torch.no_grad():
40
- logits = cnn(img_interp.unsqueeze(0).to(device))
41
- # Note: We use index 8 from full model or index 1 if it was binary
42
- # Here we assume full model but we load 3v8 subset.
43
- # If model is 10-class, we need to pick actual digit indices.
44
- # Let's check model output size.
45
- out_dim = logits.shape[1]
46
- if out_dim == 10:
47
- p = torch.softmax(logits, dim=1)[0, 8].item()
48
- else:
49
- p = torch.softmax(logits, dim=1)[0, 1].item()
50
- probs_8.append(p)
51
 
52
  # SVD Reconstruction Error
53
  flat = img_interp.view(1, -1).numpy()
54
  rec = svd.inverse_transform(svd.transform(flat - mean)) + mean
55
  rec_errors.append(np.linalg.norm(flat - rec))
56
 
57
- viz.plot_interpolation_dynamics(alphas, probs_8, rec_errors, 'fig_05_interpolation.png')
 
 
 
 
58
 
59
  def run_quantifying_manifold_collapse():
60
  print("\n--- Running Experiment 7: Quantifying Manifold Collapse ---")
@@ -83,7 +89,7 @@ def run_quantifying_manifold_collapse():
83
  import umap
84
  reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
85
  X_umap = reducer.fit_transform(X_test_np)
86
- viz.plot_manifold_comparison(X_test_svd, X_umap, y_test_np, acc_svd, acc_raw, 'fig_08_manifold_collapse.png')
87
  except Exception as e:
88
  print(f"Warning: Manifold visualization failed: {e}")
89
 
 
8
  import torch.nn as nn
9
  import numpy as np
10
  from sklearn.decomposition import TruncatedSVD
11
+ from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
12
  from sklearn.metrics import accuracy_score
13
 
14
  from src import config, utils, viz, exp_utils
 
31
  img_3, img_8 = X_test[idx_3], X_test[idx_8]
32
 
33
  alphas = np.linspace(0, 1, 11)
34
+ probs_dict = {f'T={t:.1f}': [] for t in [1.0, 2.0, 5.0]}
35
+ rec_errors, manifold_dists = [], []
36
+
37
+ # Fit Nearest Neighbors on Training set to measure distance to manifold
38
+ X_train, y_train = utils.load_data_split(dataset_name="mnist", train=True, digits=[3, 8], flatten=True)
39
+ nn_manifold = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(X_train.numpy())
40
 
41
  for alpha in alphas:
42
  img_interp = (1 - alpha) * img_3 + alpha * img_8
43
+ for t_str, p_list in probs_dict.items():
44
+ temp = float(t_str.split('=')[1])
45
+ with torch.no_grad():
46
+ logits = cnn(img_interp.unsqueeze(0).to(device))
47
+ out_dim = logits.shape[1]
48
+ if out_dim == 10:
49
+ p = torch.softmax(logits / temp, dim=1)[0, 8].item()
50
+ else:
51
+ p = torch.softmax(logits / temp, dim=1)[0, 1].item()
52
+ p_list.append(p)
 
 
 
53
 
54
  # SVD Reconstruction Error
55
  flat = img_interp.view(1, -1).numpy()
56
  rec = svd.inverse_transform(svd.transform(flat - mean)) + mean
57
  rec_errors.append(np.linalg.norm(flat - rec))
58
 
59
+ # Distance to Real Manifold (784D)
60
+ dist, _ = nn_manifold.kneighbors(flat)
61
+ manifold_dists.append(dist[0][0])
62
+
63
+ viz.plot_interpolation_dynamics(alphas, probs_dict, rec_errors, 'fig_03_interpolation.png', manifold_distances=manifold_dists)
64
 
65
  def run_quantifying_manifold_collapse():
66
  print("\n--- Running Experiment 7: Quantifying Manifold Collapse ---")
 
89
  import umap
90
  reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
91
  X_umap = reducer.fit_transform(X_test_np)
92
+ viz.plot_manifold_comparison(X_test_svd, X_umap, y_test_np, acc_svd, acc_raw, 'fig_05_manifold_collapse.png')
93
  except Exception as e:
94
  print(f"Warning: Manifold visualization failed: {e}")
95
 
experiments/{run_robustness_test.py β†’ 03_operational_boundaries.py} RENAMED
@@ -15,6 +15,7 @@ def run_experiment(args):
15
  print(f"\n--- Running Robustness Test: {args.dataset.upper()} ---")
16
 
17
  # 1. Load Data and Models
 
18
  X_test, y_test = utils.load_data_split(dataset_name=args.dataset, train=False)
19
  _, cnn = utils.load_models(dataset_name=args.dataset)
20
 
@@ -22,9 +23,9 @@ def run_experiment(args):
22
  return
23
 
24
  # 2. Fit SVD Baseline and Build Hybrid Model
25
- print("Fitting SVD Baseline...")
26
- X_test_flat = X_test.view(X_test.size(0), -1).numpy()
27
- svd_pipe = exp_utils.fit_svd_baseline(X_test_flat, y_test.numpy(), n_components=20)
28
 
29
  svd = svd_pipe.named_steps['svd']
30
  scaler = svd_pipe.named_steps['scaler']
@@ -37,8 +38,14 @@ def run_experiment(args):
37
  results = {'CNN': [], 'SVD': [], 'Hybrid': []}
38
 
39
  # 4. Evaluation Loop
 
 
 
40
  for sigma in sigmas:
41
- X_noisy = exp_utils.add_gaussian_noise(X_test, sigma)
 
 
 
42
 
43
  results['CNN'].append(exp_utils.evaluate_classifier(cnn, X_noisy, y_test, device))
44
  results['SVD'].append(exp_utils.evaluate_classifier(svd_pipe, X_noisy, y_test, is_pytorch=False))
@@ -47,17 +54,26 @@ def run_experiment(args):
47
  print(f"Οƒ={sigma:.1f} | CNN: {results['CNN'][-1]:.4f} | SVD: {results['SVD'][-1]:.4f} | Hybrid: {results['Hybrid'][-1]:.4f}")
48
 
49
  # 5. Visualization
 
 
 
 
 
 
 
 
 
 
50
  viz.plot_robustness_curves(
51
- x_values=sigmas,
52
- results_dict=results,
53
- x_label='Gaussian Noise Level (Οƒ)',
54
- title=f'Robustness Analysis: {args.dataset.upper()}',
55
- filename=f'fig_robustness_{args.dataset}.png'
56
  )
57
 
58
  def main():
59
  parser = argparse.ArgumentParser(description="Unified Robustness Evaluation")
60
  parser.add_argument("--dataset", choices=["mnist", "fashion"], default="mnist", help="Dataset to evaluate.")
 
61
  args = parser.parse_args()
62
  run_experiment(args)
63
 
 
15
  print(f"\n--- Running Robustness Test: {args.dataset.upper()} ---")
16
 
17
  # 1. Load Data and Models
18
+ X_train, y_train = utils.load_data_split(dataset_name=args.dataset, train=True)
19
  X_test, y_test = utils.load_data_split(dataset_name=args.dataset, train=False)
20
  _, cnn = utils.load_models(dataset_name=args.dataset)
21
 
 
23
  return
24
 
25
  # 2. Fit SVD Baseline and Build Hybrid Model
26
+ print("Fitting SVD Baseline on Training Data...")
27
+ X_train_flat = X_train.view(X_train.size(0), -1).numpy()
28
+ svd_pipe = exp_utils.fit_svd_baseline(X_train_flat, y_train.numpy(), n_components=20)
29
 
30
  svd = svd_pipe.named_steps['svd']
31
  scaler = svd_pipe.named_steps['scaler']
 
38
  results = {'CNN': [], 'SVD': [], 'Hybrid': []}
39
 
40
  # 4. Evaluation Loop
41
+ noise_label = 'Gaussian' if args.noise_type == 'gaussian' else 'SVD-Aligned'
42
+ print(f"Noise Type: {noise_label}")
43
+
44
  for sigma in sigmas:
45
+ if args.noise_type == "svd_aligned":
46
+ X_noisy = exp_utils.add_svd_aligned_noise(X_test, sigma, svd.components_)
47
+ else:
48
+ X_noisy = exp_utils.add_gaussian_noise(X_test, sigma)
49
 
50
  results['CNN'].append(exp_utils.evaluate_classifier(cnn, X_noisy, y_test, device))
51
  results['SVD'].append(exp_utils.evaluate_classifier(svd_pipe, X_noisy, y_test, is_pytorch=False))
 
54
  print(f"Οƒ={sigma:.1f} | CNN: {results['CNN'][-1]:.4f} | SVD: {results['SVD'][-1]:.4f} | Hybrid: {results['Hybrid'][-1]:.4f}")
55
 
56
  # 5. Visualization
57
+ # Map to new sequential names
58
+ if args.dataset == "mnist" and args.noise_type == "gaussian":
59
+ filename = "fig_06_robustness_mnist_gaussian.png"
60
+ elif args.dataset == "mnist" and args.noise_type == "svd_aligned":
61
+ filename = "fig_07_robustness_mnist_svd_aligned.png"
62
+ elif args.dataset == "fashion":
63
+ filename = "fig_08_robustness_fashion.png"
64
+ else:
65
+ filename = f'fig_robustness_{args.dataset}_{args.noise_type}.png'
66
+
67
  viz.plot_robustness_curves(
68
+ sigmas, results, f'{noise_label} Noise Level (Οƒ)',
69
+ f'Robustness Analysis ({noise_label}): {args.dataset.upper()}',
70
+ filename
 
 
71
  )
72
 
73
  def main():
74
  parser = argparse.ArgumentParser(description="Unified Robustness Evaluation")
75
  parser.add_argument("--dataset", choices=["mnist", "fashion"], default="mnist", help="Dataset to evaluate.")
76
+ parser.add_argument("--noise_type", choices=["gaussian", "svd_aligned"], default="gaussian", help="Type of noise to apply.")
77
  args = parser.parse_args()
78
  run_experiment(args)
79
 
experiments/{appendix_learning_curves.py β†’ 04_appendix_learning_curves.py} RENAMED
@@ -9,7 +9,7 @@ from src import config, viz
9
 
10
  def main():
11
  experiments = [
12
- ('cnn_10class_history.pkl', 'MNIST 10-class CNN Training', 'fig_14_learning_curves.png'),
13
  ('cnn_fashion_history.pkl', 'Fashion-MNIST CNN Training', 'fig_15_learning_curves_fashion.png')
14
  ]
15
 
 
9
 
10
  def main():
11
  experiments = [
12
+ ('cnn_10class_history.pkl', 'MNIST 10-class CNN Training', 'fig_09_learning_curves.png'),
13
  ('cnn_fashion_history.pkl', 'Fashion-MNIST CNN Training', 'fig_15_learning_curves_fashion.png')
14
  ]
15
 
experiments/{appendix_per_class_metrics.py β†’ 05_appendix_per_class_metrics.py} RENAMED
@@ -30,7 +30,10 @@ def main():
30
  y_preds_dict['CNN'] = cnn(X_test.to(device)).argmax(dim=1).cpu().numpy()
31
 
32
  # SVD+LR Predictions
33
- y_preds_dict['SVD+LR'] = svd_pipe.predict(X_test_flat)
 
 
 
34
 
35
  # 2. Print Metrics Report
36
  from sklearn.metrics import recall_score, precision_score, f1_score
@@ -45,7 +48,7 @@ def main():
45
  viz.plot_per_class_comparison(
46
  y_test_np,
47
  y_preds_dict,
48
- 'fig_19_per_class_metrics_comparison.png'
49
  )
50
  print("Appendix B Completed.")
51
 
 
30
  y_preds_dict['CNN'] = cnn(X_test.to(device)).argmax(dim=1).cpu().numpy()
31
 
32
  # SVD+LR Predictions
33
+ print("Fitting SVD Baseline (10-class)...")
34
+ X_train_full, y_train_full = utils.load_data_split(dataset_name="mnist", train=True, flatten=True)
35
+ svd_pipe_fitted = exp_utils.fit_svd_baseline(X_train_full.numpy(), y_train_full.numpy(), n_components=20)
36
+ y_preds_dict['SVD+LR'] = svd_pipe_fitted.predict(X_test_flat)
37
 
38
  # 2. Print Metrics Report
39
  from sklearn.metrics import recall_score, precision_score, f1_score
 
48
  viz.plot_per_class_comparison(
49
  y_test_np,
50
  y_preds_dict,
51
+ 'fig_10_per_class_metrics_comparison.png'
52
  )
53
  print("Appendix B Completed.")
54
 
run_all_experiments.sh CHANGED
@@ -4,16 +4,16 @@
4
  set -e
5
 
6
  echo "=== 1. Phenomenon Diagnosis (Global SVD & CNN Baseline) ==="
7
- python experiments/01_phenomenon_diagnosis.py
8
 
9
  echo "=== 2. Mechanistic Analysis (Interpolation, Explainability, etc.) ==="
10
- python experiments/02_mechanistic_analysis.py
11
 
12
  echo "=== 3. Robustness and Boundary Tests (MNIST) ==="
13
- python experiments/run_robustness_test.py --dataset mnist
14
 
15
  echo "=== 4. Robustness and Boundary Tests (Fashion-MNIST) ==="
16
- python experiments/run_robustness_test.py --dataset fashion
17
 
18
  echo "=========================================================="
19
  echo "All experiments completed successfully."
 
4
  set -e
5
 
6
  echo "=== 1. Phenomenon Diagnosis (Global SVD & CNN Baseline) ==="
7
+ python -m experiments.01_phenomenon_diagnosis
8
 
9
  echo "=== 2. Mechanistic Analysis (Interpolation, Explainability, etc.) ==="
10
+ python -m experiments.02_mechanistic_proof
11
 
12
  echo "=== 3. Robustness and Boundary Tests (MNIST) ==="
13
+ python -m experiments.03_operational_boundaries --dataset mnist
14
 
15
  echo "=== 4. Robustness and Boundary Tests (Fashion-MNIST) ==="
16
+ python -m experiments.03_operational_boundaries --dataset fashion
17
 
18
  echo "=========================================================="
19
  echo "All experiments completed successfully."
run_migration.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script performs the renaming of scripts and figures, and updates references in the code and report.
4
+ # Run this from the project root: /Users/ymlin/Downloads/003-Study/137-Projects/01-mnist-linear-vs-nonlinear
5
+
6
+ echo "Starting migration..."
7
+
8
+ # 1. Rename Scripts
9
+ echo "Renaming scripts..."
10
+ mv experiments/01_exp_diagnosis.py experiments/01_phenomenon_diagnosis.py
11
+ mv experiments/02_mechanistic_analysis.py experiments/02_mechanistic_proof.py
12
+ mv experiments/run_robustness_test.py experiments/03_operational_boundaries.py
13
+ mv experiments/appendix_learning_curves.py experiments/04_appendix_learning_curves.py
14
+ mv experiments/appendix_per_class_metrics.py experiments/05_appendix_per_class_metrics.py
15
+
16
+ # 2. Rename Figures
17
+ echo "Renaming figures..."
18
+ cd docs/research_results || exit
19
+ mv fig_02_svd_confusion.png fig_01_svd_confusion.png
20
+ mv fig_03_eigen_digits.png fig_02_eigen_digits.png
21
+ mv fig_05_interpolation.png fig_03_interpolation.png
22
+ mv fig_06_explainability.png fig_04_explainability.png
23
+ mv fig_08_manifold_collapse.png fig_05_manifold_collapse.png
24
+ mv fig_robustness_mnist_gaussian.png fig_06_robustness_mnist_gaussian.png
25
+ mv fig_robustness_mnist_svd_aligned.png fig_07_robustness_mnist_svd_aligned.png
26
+ mv fig_robustness_fashion.png fig_08_robustness_fashion.png
27
+ mv fig_14_learning_curves.png fig_09_learning_curves.png
28
+ mv fig_19_per_class_metrics_comparison.png fig_10_per_class_metrics_comparison.png
29
+ cd ../..
30
+
31
+ # 3. Update Python Scripts (Using sed for macOS)
32
+ echo "Updating Python scripts..."
33
+
34
+ # 01_phenomenon_diagnosis.py
35
+ sed -i '' 's/fig_02_svd_confusion.png/fig_01_svd_confusion.png/g' experiments/01_phenomenon_diagnosis.py
36
+ sed -i '' 's/fig_03_eigen_digits.png/fig_02_eigen_digits.png/g' experiments/01_phenomenon_diagnosis.py
37
+ sed -i '' 's/fig_04_cnn_confusion.png/fig_01b_cnn_confusion.png/g' experiments/01_phenomenon_diagnosis.py
38
+
39
+ # 02_mechanistic_proof.py
40
+ sed -i '' 's/fig_05_interpolation.png/fig_03_interpolation.png/g' experiments/02_mechanistic_proof.py
41
+ sed -i '' 's/fig_06_explainability.png/fig_04_explainability.png/g' experiments/02_mechanistic_proof.py
42
+ sed -i '' 's/fig_08_manifold_collapse.png/fig_05_manifold_collapse.png/g' experiments/02_mechanistic_proof.py
43
+
44
+ # 03_operational_boundaries.py
45
+ sed -i '' 's/fig_robustness_mnist_gaussian.png/fig_06_robustness_mnist_gaussian.png/g' experiments/03_operational_boundaries.py
46
+ sed -i '' 's/fig_robustness_mnist_svd_aligned.png/fig_07_robustness_mnist_svd_aligned.png/g' experiments/03_operational_boundaries.py
47
+ sed -i '' 's/fig_robustness_fashion.png/fig_08_robustness_fashion.png/g' experiments/03_operational_boundaries.py
48
+
49
+ # 04_appendix_learning_curves.py
50
+ sed -i '' 's/fig_14_learning_curves.png/fig_09_learning_curves.png/g' experiments/04_appendix_learning_curves.py
51
+
52
+ # 05_appendix_per_class_metrics.py
53
+ sed -i '' 's/fig_19_per_class_metrics_comparison.png/fig_10_per_class_metrics_comparison.png/g' experiments/05_appendix_per_class_metrics.py
54
+
55
+ # 4. Update Report (Using sed for macOS)
56
+ echo "Updating REPORT.md..."
57
+ sed -i '' 's/fig_02_svd_confusion.png/fig_01_svd_confusion.png/g' docs/REPORT.md
58
+ sed -i '' 's/fig_03_eigen_digits.png/fig_02_eigen_digits.png/g' docs/REPORT.md
59
+ sed -i '' 's/fig_05_interpolation.png/fig_03_interpolation.png/g' docs/REPORT.md
60
+ sed -i '' 's/fig_06_explainability.png/fig_04_explainability.png/g' docs/REPORT.md
61
+ sed -i '' 's/fig_08_manifold_collapse.png/fig_05_manifold_collapse.png/g' docs/REPORT.md
62
+ sed -i '' 's/fig_robustness_mnist_gaussian.png/fig_06_robustness_mnist_gaussian.png/g' docs/REPORT.md
63
+ sed -i '' 's/fig_robustness_mnist_svd_aligned.png/fig_07_robustness_mnist_svd_aligned.png/g' docs/REPORT.md
64
+ sed -i '' 's/fig_robustness_fashion.png/fig_08_robustness_fashion.png/g' docs/REPORT.md
65
+ sed -i '' 's/fig_14_learning_curves.png/fig_09_learning_curves.png/g' docs/REPORT.md
66
+ sed -i '' 's/fig_19_per_class_metrics_comparison.png/fig_10_per_class_metrics_comparison.png/g' docs/REPORT.md
67
+
68
+ echo "Migration completed successfully!"
src/exp_utils.py CHANGED
@@ -10,7 +10,7 @@ from sklearn.preprocessing import StandardScaler
10
  def fit_svd_baseline(X_train, y_train, n_components=20):
11
  """Fits a linear baseline (SVD + Logistic Regression) on the fly."""
12
  pipeline = Pipeline([
13
- ('scaler', StandardScaler()),
14
  ('svd', TruncatedSVD(n_components=n_components, random_state=42)),
15
  ('logistic', LogisticRegression(max_iter=1000))
16
  ])
@@ -30,6 +30,40 @@ def add_gaussian_noise(X, sigma):
30
  noise = np.random.randn(*X.shape) * sigma
31
  return np.clip(X + noise, 0, 1)
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def add_blur(X, kernel_size):
34
  """Unified blur for torch Tensors (4D: B, C, H, W)."""
35
  if kernel_size <= 1:
 
10
  def fit_svd_baseline(X_train, y_train, n_components=20):
11
  """Fits a linear baseline (SVD + Logistic Regression) on the fly."""
12
  pipeline = Pipeline([
13
+ ('scaler', StandardScaler(with_std=False)),
14
  ('svd', TruncatedSVD(n_components=n_components, random_state=42)),
15
  ('logistic', LogisticRegression(max_iter=1000))
16
  ])
 
30
  noise = np.random.randn(*X.shape) * sigma
31
  return np.clip(X + noise, 0, 1)
32
 
33
+ def add_svd_aligned_noise(X, sigma, components):
34
+ """
35
+ Adds noise that is projected onto the SVD components, living entirely
36
+ within the 'signal' subspace.
37
+ """
38
+ if sigma <= 0: return X
39
+ is_tensor = torch.is_tensor(X)
40
+
41
+ # Flatten if needed
42
+ orig_shape = list(X.shape)
43
+ if is_tensor:
44
+ X_flat = X.cpu().numpy().reshape(orig_shape[0], -1)
45
+ components_np = components.cpu().numpy() if torch.is_tensor(components) else components
46
+ else:
47
+ X_flat = X.reshape(orig_shape[0], -1)
48
+ components_np = components
49
+
50
+ # 1. Generate random Gaussian noise in full dimensionality
51
+ noise = np.random.randn(*X_flat.shape) * sigma
52
+
53
+ # 2. Project noise onto components (V_k)
54
+ # V_k (components_np) is assumed to be (k, 784)
55
+ # Projection P = V_k^T @ V_k
56
+ projected_noise = (noise @ components_np.T) @ components_np
57
+
58
+ # 3. Add back and clip
59
+ X_noisy = X_flat + projected_noise
60
+ X_noisy = np.clip(X_noisy, 0, 1)
61
+
62
+ if is_tensor:
63
+ return torch.from_numpy(X_noisy).float().view(orig_shape)
64
+ else:
65
+ return X_noisy.reshape(orig_shape)
66
+
67
  def add_blur(X, kernel_size):
68
  """Unified blur for torch Tensors (4D: B, C, H, W)."""
69
  if kernel_size <= 1:
src/hybrid_model.py CHANGED
@@ -29,6 +29,11 @@ class SVDProjectionLayer(nn.Module):
29
  def forward(self, x):
30
  b = x.size(0)
31
  x_rec = (x.view(b, -1) - self.mean) @ self.V_k.T @ self.V_k + self.mean
 
 
 
 
 
32
  return torch.clamp(x_rec, 0, 1).view(b, 1, 28, 28)
33
 
34
  class HybridSVDCNN(nn.Module):
 
29
  def forward(self, x):
30
  b = x.size(0)
31
  x_rec = (x.view(b, -1) - self.mean) @ self.V_k.T @ self.V_k + self.mean
32
+
33
+ # Monitor truncation ratio (percentage of pixels outside [0, 1])
34
+ out_of_range = (x_rec < 0) | (x_rec > 1)
35
+ self.last_clamp_ratio = out_of_range.float().mean().item()
36
+
37
  return torch.clamp(x_rec, 0, 1).view(b, 1, 28, 28)
38
 
39
  class HybridSVDCNN(nn.Module):
src/viz.py CHANGED
@@ -81,20 +81,36 @@ def plot_singular_spectrum(singular_values, cumulative_variance, filename):
81
  fig.legend(loc="upper right", bbox_to_anchor=(1,1), bbox_transform=ax1.transAxes)
82
  save_fig(filename)
83
 
84
- def plot_interpolation_dynamics(alphas, probs_8, rec_errors, filename):
85
- """Visualizes the CNN response vs SVD reconstruction error during interpolation."""
86
  setup_style()
87
- plt.figure(figsize=(10, 6))
88
-
89
- plt.plot(alphas, probs_8, color=COLOR_CNN, label='CNN Prob(8) [Topology]', marker='o', linewidth=2)
90
- plt.plot(alphas, rec_errors, color=COLOR_SVD, label='SVD Rec Error [Global Variance]', marker='s', linewidth=2)
91
 
92
- plt.axvline(x=0.5, color='#4C566A', linestyle='--', alpha=0.5, label='Ambiguity Mid-point')
93
- plt.title('Mechanistic Dynamics: Interpolation vs. SVD Error', fontsize=14, fontweight='bold', pad=15)
94
- plt.xlabel('Alpha (0=Digit 3, 1=Digit 8)', fontsize=12)
95
- plt.ylabel('Metric Value', fontsize=12)
96
- plt.legend()
97
- plt.grid(True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  save_fig(filename)
99
 
100
  def plot_manifold_comparison(X_svd, X_umap, y, acc_svd, acc_raw, filename):
@@ -126,7 +142,7 @@ def plot_manifold_comparison(X_svd, X_umap, y, acc_svd, acc_raw, filename):
126
  def plot_learning_curves(history, title, filename):
127
  """Standardized plotter for training history (loss and accuracy)."""
128
  setup_style()
129
- epochs = range(1, len(history['train_loss']) + 1)
130
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
131
 
132
  # Nord palette for curves
@@ -134,8 +150,10 @@ def plot_learning_curves(history, title, filename):
134
  COLOR_VAL = "#D08770" # Nord 12 (Orange)
135
 
136
  # Loss Plot
137
- ax1.plot(epochs, history['train_loss'], label='Train', color=COLOR_TRAIN, marker='o', markersize=4, linewidth=1.5)
138
- ax1.plot(epochs, history['val_loss'], label='Val', color=COLOR_VAL, marker='s', markersize=4, linewidth=1.5)
 
 
139
  ax1.set_title('Loss Dynamics', fontsize=12, fontweight='bold')
140
  ax1.set_xlabel('Epoch')
141
  ax1.set_ylabel('Loss')
@@ -143,8 +161,10 @@ def plot_learning_curves(history, title, filename):
143
  ax1.grid(True)
144
 
145
  # Accuracy Plot
146
- ax2.plot(epochs, history['train_acc'], label='Train', color=COLOR_TRAIN, marker='o', markersize=4, linewidth=1.5)
147
- ax2.plot(epochs, history['val_acc'], label='Val', color=COLOR_VAL, marker='s', markersize=4, linewidth=1.5)
 
 
148
  ax2.set_title('Accuracy Dynamics', fontsize=12, fontweight='bold')
149
  ax2.set_xlabel('Epoch')
150
  ax2.set_ylabel('Accuracy')
 
81
  fig.legend(loc="upper right", bbox_to_anchor=(1,1), bbox_transform=ax1.transAxes)
82
  save_fig(filename)
83
 
84
+ def plot_interpolation_dynamics(alphas, probs_dict, rec_errors, filename, manifold_distances=None):
85
+ """Visualizes the CNN response vs SVD reconstruction error and manifold distance."""
86
  setup_style()
87
+ fig, ax1 = plt.subplots(figsize=(10, 6))
 
 
 
88
 
89
+ # Support both single list and dict of labels->probs
90
+ if isinstance(probs_dict, list):
91
+ probs_dict = {'CNN Prob(8)': probs_dict}
92
+
93
+ styles = ['-', '--', ':', '-.']
94
+ for i, (label, probs) in enumerate(probs_dict.items()):
95
+ ax1.plot(alphas, probs, label=label, marker='o' if i==0 else None,
96
+ linestyle=styles[i % len(styles)], linewidth=2)
97
+
98
+ ax1.plot(alphas, rec_errors, color=COLOR_SVD, label='SVD Rec Error', marker='s', linewidth=2, alpha=0.6)
99
+ ax1.set_xlabel('Alpha (0=Digit 3, 1=Digit 8)', fontsize=12)
100
+ ax1.set_ylabel('CNN Prob / Rec Error', fontsize=12)
101
+
102
+ if manifold_distances is not None:
103
+ ax2 = ax1.twinx()
104
+ ax2.plot(alphas, manifold_distances, color="#D08770", label='Manifold Distance', marker='^', linestyle='--', linewidth=2)
105
+ ax2.set_ylabel('Dist to Nearest Neighbor', color="#D08770", fontsize=12)
106
+ ax2.tick_params(axis='y', labelcolor="#D08770")
107
+ fig.legend(loc="upper right", bbox_to_anchor=(0.9, 0.9), bbox_transform=ax1.transAxes)
108
+ else:
109
+ ax1.legend()
110
+
111
+ plt.title('Mechanistic Dynamics: Snap Analysis with Temperature Scaling', fontsize=14, fontweight='bold', pad=15)
112
+ ax1.axvline(x=0.5, color='#4C566A', linestyle='--', alpha=0.5, label='Ambiguity Mid-point')
113
+ ax1.grid(True)
114
  save_fig(filename)
115
 
116
  def plot_manifold_comparison(X_svd, X_umap, y, acc_svd, acc_raw, filename):
 
142
  def plot_learning_curves(history, title, filename):
143
  """Standardized plotter for training history (loss and accuracy)."""
144
  setup_style()
145
+ # Calculate epochs separately for loss and accuracy
146
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
147
 
148
  # Nord palette for curves
 
150
  COLOR_VAL = "#D08770" # Nord 12 (Orange)
151
 
152
  # Loss Plot
153
+ if len(history.get('train_loss', [])) > 0:
154
+ epochs_loss = range(1, len(history['train_loss']) + 1)
155
+ ax1.plot(epochs_loss, history['train_loss'], label='Train', color=COLOR_TRAIN, marker='o', markersize=4, linewidth=1.5)
156
+ ax1.plot(epochs_loss, history['val_loss'], label='Val', color=COLOR_VAL, marker='s', markersize=4, linewidth=1.5)
157
  ax1.set_title('Loss Dynamics', fontsize=12, fontweight='bold')
158
  ax1.set_xlabel('Epoch')
159
  ax1.set_ylabel('Loss')
 
161
  ax1.grid(True)
162
 
163
  # Accuracy Plot
164
+ if len(history.get('train_acc', [])) > 0:
165
+ epochs_acc = range(1, len(history['train_acc']) + 1)
166
+ ax2.plot(epochs_acc, history['train_acc'], label='Train', color=COLOR_TRAIN, marker='o', markersize=4, linewidth=1.5)
167
+ ax2.plot(epochs_acc, history['val_acc'], label='Val', color=COLOR_VAL, marker='s', markersize=4, linewidth=1.5)
168
  ax2.set_title('Accuracy Dynamics', fontsize=12, fontweight='bold')
169
  ax2.set_xlabel('Epoch')
170
  ax2.set_ylabel('Accuracy')