Spaces:
Sleeping
Sleeping
feat: Refactor experiments and update report
Browse files- README.md +3 -3
- app.py +57 -12
- docs/REPORT.md +69 -65
- docs/research_results/fig_01_spectrum.png +2 -2
- docs/research_results/{fig_02_svd_confusion.png β fig_01_svd_confusion.png} +2 -2
- docs/research_results/{fig_03_eigen_digits.png β fig_02_eigen_digits.png} +2 -2
- docs/research_results/{fig_05_interpolation.png β fig_03_interpolation.png} +2 -2
- docs/research_results/fig_04_cnn_confusion.png +2 -2
- docs/research_results/{fig_06_explainability.png β fig_04_explainability.png} +0 -0
- docs/research_results/{fig_08_manifold_collapse.png β fig_05_manifold_collapse.png} +2 -2
- docs/research_results/{fig_14_learning_curves.png β fig_06_robustness_mnist_gaussian.png} +2 -2
- docs/research_results/fig_07_robustness_mnist_svd_aligned.png +3 -0
- docs/research_results/fig_08_robustness_fashion.png +3 -0
- docs/research_results/fig_09_learning_curves.png +3 -0
- docs/research_results/fig_10_per_class_metrics_comparison.png +3 -0
- docs/research_results/fig_19_per_class_metrics_comparison.png +0 -3
- docs/research_results/fig_robustness_fashion.png +0 -3
- docs/research_results/fig_robustness_mnist.png +0 -3
- experiments/01_phenomenon_diagnosis.py +2 -2
- experiments/{02_mechanistic_analysis.py β 02_mechanistic_proof.py} +23 -17
- experiments/{run_robustness_test.py β 03_operational_boundaries.py} +25 -9
- experiments/{appendix_learning_curves.py β 04_appendix_learning_curves.py} +1 -1
- experiments/{appendix_per_class_metrics.py β 05_appendix_per_class_metrics.py} +5 -2
- run_all_experiments.sh +4 -4
- run_migration.sh +68 -0
- src/exp_utils.py +35 -1
- src/hybrid_model.py +5 -0
- src/viz.py +37 -17
README.md
CHANGED
|
@@ -10,14 +10,14 @@ app_file: app.py
|
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
| 13 |
-
# SVD
|
| 14 |
|
| 15 |
[](https://huggingface.co/spaces/ymlin105/Coconut-MNIST) [](./docs/REPORT.md)
|
| 16 |
|
| 17 |
While it is a known theoretical property that linear dimensionality reduction (SVD) acts as a low-pass filter, this project provides a **concrete, visual, and quantitative mechanistic explanation** of how this property manifests in neural network classificationβspecifically, why linear subspaces consistently force a "3" to collapse into an "8".
|
| 18 |
|
| 19 |
<p align="center">
|
| 20 |
-
<img src="./docs/research_results/
|
| 21 |
</p>
|
| 22 |
|
| 23 |
By mapping the exact decision boundaries where linear global variance models fail and non-linear topological models (CNNs) succeed, I empirically validate the **inherent trade-offs** of linear denoising in high-stakes domains like medical imaging or satellite dataβwhere a linear filter might suppress critical diagnostic features to minimize noise variance.
|
|
@@ -93,7 +93,7 @@ streamlit run app.py
|
|
| 93 |
### Project Structure
|
| 94 |
```
|
| 95 |
βββ src/ Core modules (CNN, SVD layer) + Experimental Utils
|
| 96 |
-
βββ experiments/
|
| 97 |
βββ docs/ Full report (REPORT.md) + figures
|
| 98 |
βββ models/ Pretrained checkpoints
|
| 99 |
βββ run_all_experiments.sh One-click reproduction script
|
|
|
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Why Does SVD Turn a "3" into an "8"? Linear vs. Non-linear Manifolds on MNIST
|
| 14 |
|
| 15 |
[](https://huggingface.co/spaces/ymlin105/Coconut-MNIST) [](./docs/REPORT.md)
|
| 16 |
|
| 17 |
While it is a known theoretical property that linear dimensionality reduction (SVD) acts as a low-pass filter, this project provides a **concrete, visual, and quantitative mechanistic explanation** of how this property manifests in neural network classificationβspecifically, why linear subspaces consistently force a "3" to collapse into an "8".
|
| 18 |
|
| 19 |
<p align="center">
|
| 20 |
+
<img src="./docs/research_results/fig_04_explainability.png" width="600" alt="Mechanistic Analysis: SVD Blind Spot">
|
| 21 |
</p>
|
| 22 |
|
| 23 |
By mapping the exact decision boundaries where linear global variance models fail and non-linear topological models (CNNs) succeed, I empirically validate the **inherent trade-offs** of linear denoising in high-stakes domains like medical imaging or satellite dataβwhere a linear filter might suppress critical diagnostic features to minimize noise variance.
|
|
|
|
| 93 |
### Project Structure
|
| 94 |
```
|
| 95 |
βββ src/ Core modules (CNN, SVD layer) + Experimental Utils
|
| 96 |
+
βββ experiments/ Sequential scripts (01 Diagnosis, 02 Proof, 03 Boundaries)
|
| 97 |
βββ docs/ Full report (REPORT.md) + figures
|
| 98 |
βββ models/ Pretrained checkpoints
|
| 99 |
βββ run_all_experiments.sh One-click reproduction script
|
app.py
CHANGED
|
@@ -144,7 +144,13 @@ def get_reconstruction(svd_model, img_flat):
|
|
| 144 |
recons = svd_model.inverse_transform(svd_model.transform(flat))
|
| 145 |
if mean is not None:
|
| 146 |
recons = recons + mean
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
# --- UI Sidebar ---
|
|
@@ -161,6 +167,14 @@ with st.sidebar:
|
|
| 161 |
)
|
| 162 |
if noise_mode:
|
| 163 |
st.success("SVD Denoising Active")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
|
| 166 |
# --- Initialization ---
|
|
@@ -204,11 +218,11 @@ with tab1:
|
|
| 204 |
img_s = X_flat[np.where(y_orig == start_digit)[0][0]]
|
| 205 |
img_e = X_flat[np.where(y_orig == end_digit)[0][0]]
|
| 206 |
img_interp = (1 - alpha) * img_s + alpha * img_e
|
| 207 |
-
img_svd = get_reconstruction(svd_model, img_interp)
|
| 208 |
|
| 209 |
with torch.no_grad():
|
| 210 |
logits = cnn_model(img_interp.view(1, 1, 28, 28))
|
| 211 |
-
probs = torch.softmax(logits, dim=1)
|
| 212 |
conf, pred = torch.max(probs, 1)
|
| 213 |
|
| 214 |
# Visual Display
|
|
@@ -219,11 +233,35 @@ with tab1:
|
|
| 219 |
with v2:
|
| 220 |
st.markdown("**SVD Reconstruction**")
|
| 221 |
st.image(img_svd.numpy(), width=150)
|
|
|
|
| 222 |
with v3:
|
| 223 |
st.markdown(f"**CNN Prediction: {pred.item()}**")
|
| 224 |
st.progress(conf.item(), text=f"Confidence: {conf.item():.1%}")
|
| 225 |
st.caption("Note: CNN 'snaps' at the topological midpoint, not a smooth transition.")
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
# --- Tab 2: Robustness (The SVD Advantage) ---
|
| 229 |
with tab2:
|
|
@@ -237,12 +275,12 @@ with tab2:
|
|
| 237 |
|
| 238 |
img_clean = X_flat[np.where(y_orig == noise_digit)[0][0]]
|
| 239 |
img_noisy = torch.clamp(img_clean + torch.randn_like(img_clean) * sigma, 0, 1)
|
| 240 |
-
img_denoised = get_reconstruction(svd_model, img_noisy)
|
| 241 |
|
| 242 |
with col2:
|
| 243 |
res1, res2 = st.columns(2)
|
| 244 |
res1.image(img_noisy.view(28, 28).numpy(), caption="Noisy Input", width=150)
|
| 245 |
-
res2.image(img_denoised.numpy(), caption="SVD
|
| 246 |
|
| 247 |
st.markdown("---")
|
| 248 |
st.markdown("#### Accuracy Breakdown at This Noise Level")
|
|
@@ -257,8 +295,10 @@ with tab2:
|
|
| 257 |
results = data["results"]
|
| 258 |
|
| 259 |
def interp(model_name: str, s: float) -> float:
|
|
|
|
|
|
|
| 260 |
vals = np.array(results[model_name], dtype=float)
|
| 261 |
-
return float(np.interp(
|
| 262 |
|
| 263 |
acc_svd = interp("SVD", sigma)
|
| 264 |
acc_cnn = interp("CNN", sigma)
|
|
@@ -334,7 +374,8 @@ with tab3:
|
|
| 334 |
)
|
| 335 |
fig_svd.update_layout(
|
| 336 |
margin=dict(l=20, r=20, b=40, t=20),
|
| 337 |
-
showlegend=
|
|
|
|
| 338 |
xaxis_title="Component 1",
|
| 339 |
yaxis_title="Component 2"
|
| 340 |
)
|
|
@@ -353,7 +394,8 @@ with tab3:
|
|
| 353 |
)
|
| 354 |
fig_umap.update_layout(
|
| 355 |
margin=dict(l=20, r=20, b=40, t=20),
|
| 356 |
-
showlegend=
|
|
|
|
| 357 |
xaxis_title="UMAP 1",
|
| 358 |
yaxis_title="UMAP 2"
|
| 359 |
)
|
|
@@ -421,7 +463,7 @@ with tab4:
|
|
| 421 |
if add_noise:
|
| 422 |
img_sample = torch.clamp(img_sample + torch.randn_like(img_sample) * noise_sigma, 0, 1)
|
| 423 |
|
| 424 |
-
img_svd_sample = get_reconstruction(svd_model, img_sample)
|
| 425 |
|
| 426 |
if noise_mode:
|
| 427 |
cnn_input = img_svd_sample.view(1, 1, 28, 28)
|
|
@@ -429,7 +471,7 @@ with tab4:
|
|
| 429 |
cnn_input = img_sample.view(1, 1, 28, 28)
|
| 430 |
|
| 431 |
with torch.no_grad():
|
| 432 |
-
probs = torch.softmax(cnn_model(cnn_input), dim=1)
|
| 433 |
conf, pred = torch.max(probs, 1)
|
| 434 |
|
| 435 |
r1, r2, r3 = st.columns(3)
|
|
@@ -441,6 +483,7 @@ with tab4:
|
|
| 441 |
with r2:
|
| 442 |
st.markdown("**SVD Projection**")
|
| 443 |
st.image(img_svd_sample.numpy(), width=120)
|
|
|
|
| 444 |
|
| 445 |
with r3:
|
| 446 |
st.markdown(f"**Prediction: {pred.item()}**")
|
|
@@ -454,6 +497,7 @@ with tab4:
|
|
| 454 |
|
| 455 |
else: # Draw Digit
|
| 456 |
st.info("Draw a digit in the box below. SVD and CNN will analyze it in real-time.")
|
|
|
|
| 457 |
|
| 458 |
col_canvas, col_preview = st.columns([1, 1])
|
| 459 |
|
|
@@ -481,7 +525,7 @@ with tab4:
|
|
| 481 |
img_tensor = torch.tensor(img_resized, dtype=torch.float32) / 255.0
|
| 482 |
|
| 483 |
img_flat_up = img_tensor.view(1, 784)
|
| 484 |
-
img_svd_up = get_reconstruction(svd_model, img_flat_up)
|
| 485 |
|
| 486 |
if noise_mode:
|
| 487 |
cnn_input = img_svd_up.view(1, 1, 28, 28)
|
|
@@ -489,7 +533,7 @@ with tab4:
|
|
| 489 |
cnn_input = img_tensor.view(1, 1, 28, 28)
|
| 490 |
|
| 491 |
with torch.no_grad():
|
| 492 |
-
probs = torch.softmax(cnn_model(cnn_input), dim=1)
|
| 493 |
conf, pred = torch.max(probs, 1)
|
| 494 |
|
| 495 |
with col_preview:
|
|
@@ -503,6 +547,7 @@ with tab4:
|
|
| 503 |
with r1:
|
| 504 |
st.markdown("**SVD View**")
|
| 505 |
st.image(img_svd_up.numpy(), width=100)
|
|
|
|
| 506 |
with r2:
|
| 507 |
if noise_mode:
|
| 508 |
st.caption("Using SVD Denoised Input")
|
|
|
|
| 144 |
recons = svd_model.inverse_transform(svd_model.transform(flat))
|
| 145 |
if mean is not None:
|
| 146 |
recons = recons + mean
|
| 147 |
+
|
| 148 |
+
# Monitor truncation ratio
|
| 149 |
+
out_of_range = (recons < 0) | (recons > 1)
|
| 150 |
+
clamp_ratio = np.mean(out_of_range)
|
| 151 |
+
|
| 152 |
+
recons_tensor = torch.clamp(torch.tensor(recons).float(), 0, 1).view(28, 28)
|
| 153 |
+
return recons_tensor, clamp_ratio
|
| 154 |
|
| 155 |
|
| 156 |
# --- UI Sidebar ---
|
|
|
|
| 167 |
)
|
| 168 |
if noise_mode:
|
| 169 |
st.success("SVD Denoising Active")
|
| 170 |
+
|
| 171 |
+
st.markdown("---")
|
| 172 |
+
st.subheader("Model Calibration")
|
| 173 |
+
temp_scaling = st.slider(
|
| 174 |
+
"Softmax Temperature (T)",
|
| 175 |
+
0.1, 5.0, 1.0, 0.1,
|
| 176 |
+
help="Higher T = smoother transitions (less over-confident), Lower T = sharper 'snaps'."
|
| 177 |
+
)
|
| 178 |
|
| 179 |
|
| 180 |
# --- Initialization ---
|
|
|
|
| 218 |
img_s = X_flat[np.where(y_orig == start_digit)[0][0]]
|
| 219 |
img_e = X_flat[np.where(y_orig == end_digit)[0][0]]
|
| 220 |
img_interp = (1 - alpha) * img_s + alpha * img_e
|
| 221 |
+
img_svd, clamp_ratio = get_reconstruction(svd_model, img_interp)
|
| 222 |
|
| 223 |
with torch.no_grad():
|
| 224 |
logits = cnn_model(img_interp.view(1, 1, 28, 28))
|
| 225 |
+
probs = torch.softmax(logits / temp_scaling, dim=1)
|
| 226 |
conf, pred = torch.max(probs, 1)
|
| 227 |
|
| 228 |
# Visual Display
|
|
|
|
| 233 |
with v2:
|
| 234 |
st.markdown("**SVD Reconstruction**")
|
| 235 |
st.image(img_svd.numpy(), width=150)
|
| 236 |
+
st.caption(f"Truncation: {clamp_ratio:.1%}")
|
| 237 |
with v3:
|
| 238 |
st.markdown(f"**CNN Prediction: {pred.item()}**")
|
| 239 |
st.progress(conf.item(), text=f"Confidence: {conf.item():.1%}")
|
| 240 |
st.caption("Note: CNN 'snaps' at the topological midpoint, not a smooth transition.")
|
| 241 |
|
| 242 |
+
# --- Confidence Curve Visualization ---
|
| 243 |
+
alphas_curve = np.linspace(0, 1, 21)
|
| 244 |
+
curve_probs = []
|
| 245 |
+
|
| 246 |
+
with torch.no_grad():
|
| 247 |
+
for a in alphas_curve:
|
| 248 |
+
img_a = (1 - a) * img_s + a * img_e
|
| 249 |
+
logits_a = cnn_model(img_a.view(1, 1, 28, 28))
|
| 250 |
+
probs_a = torch.softmax(logits_a / temp_scaling, dim=1)
|
| 251 |
+
# Probability of being the "end_digit"
|
| 252 |
+
curve_probs.append(probs_a[0, end_digit].item())
|
| 253 |
+
|
| 254 |
+
st.markdown("---")
|
| 255 |
+
st.markdown(f"#### Confidence Snap: Probability of '{end_digit}'")
|
| 256 |
+
|
| 257 |
+
df_curve = pd.DataFrame({
|
| 258 |
+
"alpha": alphas_curve,
|
| 259 |
+
"Probability": curve_probs
|
| 260 |
+
}).set_index("alpha")
|
| 261 |
+
|
| 262 |
+
st.line_chart(df_curve, height=200)
|
| 263 |
+
st.caption(f"The vertical 'snap' in this curve highlights the non-linear decision boundary. Even as the pixels fade linearly, the CNN's internal representation jumps once a topological threshold is crossed.")
|
| 264 |
+
|
| 265 |
|
| 266 |
# --- Tab 2: Robustness (The SVD Advantage) ---
|
| 267 |
with tab2:
|
|
|
|
| 275 |
|
| 276 |
img_clean = X_flat[np.where(y_orig == noise_digit)[0][0]]
|
| 277 |
img_noisy = torch.clamp(img_clean + torch.randn_like(img_clean) * sigma, 0, 1)
|
| 278 |
+
img_denoised, clamp_ratio_robust = get_reconstruction(svd_model, img_noisy)
|
| 279 |
|
| 280 |
with col2:
|
| 281 |
res1, res2 = st.columns(2)
|
| 282 |
res1.image(img_noisy.view(28, 28).numpy(), caption="Noisy Input", width=150)
|
| 283 |
+
res2.image(img_denoised.numpy(), caption=f"SVD Projection (Trunc: {clamp_ratio_robust:.1%})", width=150)
|
| 284 |
|
| 285 |
st.markdown("---")
|
| 286 |
st.markdown("#### Accuracy Breakdown at This Noise Level")
|
|
|
|
| 295 |
results = data["results"]
|
| 296 |
|
| 297 |
def interp(model_name: str, s: float) -> float:
|
| 298 |
+
# Robust boundary treatment for floating point sigma
|
| 299 |
+
s_clipped = np.clip(s, levels.min(), levels.max())
|
| 300 |
vals = np.array(results[model_name], dtype=float)
|
| 301 |
+
return float(np.interp(s_clipped, levels, vals))
|
| 302 |
|
| 303 |
acc_svd = interp("SVD", sigma)
|
| 304 |
acc_cnn = interp("CNN", sigma)
|
|
|
|
| 374 |
)
|
| 375 |
fig_svd.update_layout(
|
| 376 |
margin=dict(l=20, r=20, b=40, t=20),
|
| 377 |
+
showlegend=True,
|
| 378 |
+
legend=dict(orientation="h", yanchor="top", y=-0.15, xanchor="center", x=0.5),
|
| 379 |
xaxis_title="Component 1",
|
| 380 |
yaxis_title="Component 2"
|
| 381 |
)
|
|
|
|
| 394 |
)
|
| 395 |
fig_umap.update_layout(
|
| 396 |
margin=dict(l=20, r=20, b=40, t=20),
|
| 397 |
+
showlegend=True,
|
| 398 |
+
legend=dict(orientation="h", yanchor="top", y=-0.15, xanchor="center", x=0.5),
|
| 399 |
xaxis_title="UMAP 1",
|
| 400 |
yaxis_title="UMAP 2"
|
| 401 |
)
|
|
|
|
| 463 |
if add_noise:
|
| 464 |
img_sample = torch.clamp(img_sample + torch.randn_like(img_sample) * noise_sigma, 0, 1)
|
| 465 |
|
| 466 |
+
img_svd_sample, clamp_ratio_lab = get_reconstruction(svd_model, img_sample)
|
| 467 |
|
| 468 |
if noise_mode:
|
| 469 |
cnn_input = img_svd_sample.view(1, 1, 28, 28)
|
|
|
|
| 471 |
cnn_input = img_sample.view(1, 1, 28, 28)
|
| 472 |
|
| 473 |
with torch.no_grad():
|
| 474 |
+
probs = torch.softmax(cnn_model(cnn_input) / temp_scaling, dim=1)
|
| 475 |
conf, pred = torch.max(probs, 1)
|
| 476 |
|
| 477 |
r1, r2, r3 = st.columns(3)
|
|
|
|
| 483 |
with r2:
|
| 484 |
st.markdown("**SVD Projection**")
|
| 485 |
st.image(img_svd_sample.numpy(), width=120)
|
| 486 |
+
st.caption(f"Truncation: {clamp_ratio_lab:.1%}")
|
| 487 |
|
| 488 |
with r3:
|
| 489 |
st.markdown(f"**Prediction: {pred.item()}**")
|
|
|
|
| 497 |
|
| 498 |
else: # Draw Digit
|
| 499 |
st.info("Draw a digit in the box below. SVD and CNN will analyze it in real-time.")
|
| 500 |
+
st.caption("*Tip: Draw the digit large and centered for best results.*")
|
| 501 |
|
| 502 |
col_canvas, col_preview = st.columns([1, 1])
|
| 503 |
|
|
|
|
| 525 |
img_tensor = torch.tensor(img_resized, dtype=torch.float32) / 255.0
|
| 526 |
|
| 527 |
img_flat_up = img_tensor.view(1, 784)
|
| 528 |
+
img_svd_up, clamp_ratio_draw = get_reconstruction(svd_model, img_flat_up)
|
| 529 |
|
| 530 |
if noise_mode:
|
| 531 |
cnn_input = img_svd_up.view(1, 1, 28, 28)
|
|
|
|
| 533 |
cnn_input = img_tensor.view(1, 1, 28, 28)
|
| 534 |
|
| 535 |
with torch.no_grad():
|
| 536 |
+
probs = torch.softmax(cnn_model(cnn_input) / temp_scaling, dim=1)
|
| 537 |
conf, pred = torch.max(probs, 1)
|
| 538 |
|
| 539 |
with col_preview:
|
|
|
|
| 547 |
with r1:
|
| 548 |
st.markdown("**SVD View**")
|
| 549 |
st.image(img_svd_up.numpy(), width=100)
|
| 550 |
+
st.caption(f"Truncation: {clamp_ratio_draw:.1%}")
|
| 551 |
with r2:
|
| 552 |
if noise_mode:
|
| 553 |
st.caption("Using SVD Denoised Input")
|
docs/REPORT.md
CHANGED
|
@@ -1,125 +1,129 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
---
|
| 6 |
|
| 7 |
-
##
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
---
|
| 14 |
|
| 15 |
-
##
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
###
|
| 20 |
-
|
| 21 |
|
| 22 |
<p align="center">
|
| 23 |
-
<img src="research_results/
|
| 24 |
-
<img src="research_results/
|
| 25 |
<br>
|
| 26 |
-
<em>Figure 1 & 2: SVD
|
| 27 |
</p>
|
| 28 |
|
| 29 |
-
This identifies the **Variance-Discrimination Paradox**: SVD optimizes for reconstruction (global energy) rather than separation (local topology). Since the "8-like" silhouette contains more pixel variance than the tiny gap in a '3', the linear model "hallucinates" a closed loop to minimize reconstruction error.
|
| 30 |
-
|
| 31 |
---
|
| 32 |
|
| 33 |
-
##
|
| 34 |
|
| 35 |
-
To
|
| 36 |
|
| 37 |
-
###
|
| 38 |
-
|
| 39 |
|
| 40 |
<p align="center">
|
| 41 |
-
<img src="research_results/
|
| 42 |
<br>
|
| 43 |
-
<em>Figure 3: CNN
|
| 44 |
</p>
|
| 45 |
|
| 46 |
-
###
|
| 47 |
-
Grad-CAM heatmaps confirm that
|
| 48 |
|
| 49 |
<p align="center">
|
| 50 |
-
<img src="research_results/
|
| 51 |
<br>
|
| 52 |
-
<em>Figure 4:
|
| 53 |
</p>
|
| 54 |
|
| 55 |
-
###
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
* **SVD
|
| 60 |
-
|
| 61 |
-
The error rate **more than doubles (+130% relative increase)** after linear projection. This provides iron-clad proof that SVD structurally forces distinct local neighborhoods of 3s and 8s to overlap, destroying information that is physically present in the pixels.
|
| 62 |
|
| 63 |
<p align="center">
|
| 64 |
-
<img src="research_results/
|
| 65 |
<br>
|
| 66 |
-
<em>Figure 5:
|
| 67 |
</p>
|
| 68 |
|
| 69 |
---
|
| 70 |
|
| 71 |
-
##
|
| 72 |
|
| 73 |
-
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
<p align="center">
|
| 79 |
-
<img src="research_results/
|
|
|
|
| 80 |
<br>
|
| 81 |
-
<em>Figure 6:
|
| 82 |
</p>
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
<p align="center">
|
| 88 |
-
<img src="research_results/
|
| 89 |
<br>
|
| 90 |
-
<em>Figure
|
| 91 |
</p>
|
| 92 |
|
| 93 |
---
|
| 94 |
|
| 95 |
## Conclusion
|
| 96 |
-
This investigation proves that SVD's reliance on global variance conflates discriminative local features with noise. While this serves as an effective low-pass filter for shape-dominated data (MNIST) under extreme corruption, it systematically degrades the manifold boundaries required for precise non-linear classification.
|
| 97 |
|
| 98 |
-
|
| 99 |
|
| 100 |
-
|
| 101 |
|
| 102 |
-
##
|
| 103 |
-
- **Trainer**: Adam optimizer ($lr=10^{-3}$, batch size 64) with stratified 80/20 train/val splits.
|
| 104 |
-
- **Denoising**: All SVD components use explicit mean-centering and a rank of $k=20$.
|
| 105 |
-
- **Early Stopping**: Early stopping with a patience of 3 monitored on validation accuracy prevented overfitting, typically converging in 5-8 epochs.
|
| 106 |
|
|
|
|
|
|
|
| 107 |
<p align="center">
|
| 108 |
-
<img src="research_results/
|
| 109 |
-
<br>
|
| 110 |
-
<em>Figure A1: Standardized learning curves showing convergence and early-stopping preservation.</em>
|
| 111 |
</p>
|
| 112 |
|
| 113 |
-
### B. Per-Class
|
| 114 |
-
SVD
|
| 115 |
-
- **Digit 5 (81.3% F1)**: Systematic confusion with digit 3 due to similar stroke energy.
|
| 116 |
-
- **Digit 9 (83.6% F1)**: Confusion with 4 due to loop vs. open-top similarity.
|
| 117 |
-
|
| 118 |
<p align="center">
|
| 119 |
-
<img src="research_results/
|
| 120 |
-
<br>
|
| 121 |
-
<em>Figure A2: Side-by-side per-class comparison highlighting SVD's failure regions.</em>
|
| 122 |
</p>
|
| 123 |
-
|
| 124 |
-
### C. Grad-CAM Implementation Note
|
| 125 |
-
To ensure accurate saliency mapping, we utilized `register_full_backward_hook` to capture complete gradient tensors from the intermediate convolutional layers, avoiding the gradient-drop issues found in legacy PyTorch hooks.
|
|
|
|
| 1 |
+
# Why Does SVD Turn a "3" into an "8"? A Mechanistic Comparison of Linear vs. Non-linear Manifolds
|
| 2 |
|
| 3 |
+
Why do linear models fail at tasks that seem trivial to a human or a simple neural network? While it is a known property that linear dimensionality reduction (SVD) acts as a low-pass filter, this report provides a **concrete, visual, and mechanistic explanation** of how this manifests in classificationβspecifically, why linear subspaces force a "3" to collapse into an "8".
|
| 4 |
+
|
| 5 |
+
By mapping the decision boundaries where linear global variance models fail and non-linear topological models (CNNs) succeed, we validate the **inherent trade-offs** of linear denoising in high-stakes domains like medical imaging or satellite dataβwhere a linear filter might suppress critical diagnostic features to minimize noise variance.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## TL;DR: The 15-Second Summary
|
| 10 |
+
|
| 11 |
+
- **The Problem (The Variance Trap):** SVD prioritizes global pixel energy. In a '3', the tiny gap that distinguishes it from an '8' has very low variance, so SVD deletes it as "noise."
|
| 12 |
+
- **The Mechanism:** Linear models see **Global Energy** (the overall silhouette), while CNNs see **Local Topology** (the gap). SVD literally "welds" the ends of a '3' together to minimize reconstruction error.
|
| 13 |
+
- **The Solution & Boundary: We built a Hybrid SVDβCNN pipeline.** While SVD fails as a standalone classifier, it works as a powerful **low-pass filter** and defensive shield against high noise ($\sigma=0.7$), provided the data isn't too texture-rich (like Fashion-MNIST).
|
| 14 |
|
| 15 |
---
|
| 16 |
|
| 17 |
+
## The Investigative Approach
|
| 18 |
|
| 19 |
+
```text
|
| 20 |
+
Diagnosis Mechanism Solution & Boundary
|
| 21 |
+
βββββββββββββββββββββ βββββββββββββββββββββ βββββββββββββββββββββ
|
| 22 |
+
SVD fails on 3 vs 8 β Why? Grad-CAM + UMAP β Hybrid SVDβCNN pipeline
|
| 23 |
+
(The Variance Trap) (Global vs. Local) + Texture stress test
|
| 24 |
+
```
|
| 25 |
|
| 26 |
---
|
| 27 |
|
| 28 |
+
## 1. Diagnosis: The "3 vs 8" Failure Mode
|
| 29 |
|
| 30 |
+
Aggregate accuracy metrics often hide the real story. While SVD achieves 88.1% accuracy on MNIST, it systematically fails on digits with high pixel overlap.
|
| 31 |
|
| 32 |
+
### The Variance Trap
|
| 33 |
+
Linear dimensionality reduction (SVD) treats classification like a reconstruction problem. It looks for the directions of maximum variance (total pixel brightness). In the cluster of 3s and 8s, the shared "8-like" outline contains the most energy. The small gap that makes a '3' unique is mathematically ignored.
|
| 34 |
|
| 35 |
<p align="center">
|
| 36 |
+
<img src="research_results/fig_01_svd_confusion.png" alt="SVD Confusion Matrix" width="350" />
|
| 37 |
+
<img src="research_results/fig_02_eigen_digits.png" alt="SVD Eigen-digits" width="350" />
|
| 38 |
<br>
|
| 39 |
+
<em><strong>Figure 1 & 2:</strong> SVD concentrates its errors on ambiguous pairs. The first few "eigen-digits" capture the shared circular structure, smoothing over the critical discriminative gap.</em>
|
| 40 |
</p>
|
| 41 |
|
|
|
|
|
|
|
| 42 |
---
|
| 43 |
|
| 44 |
+
## 2. Mechanism: Global Energy vs. Local Topology
|
| 45 |
|
| 46 |
+
To understand *why* this happens, we compared how a CNN (non-linear) and SVD (linear) "see" the same image.
|
| 47 |
|
| 48 |
+
### Linear Hallucination
|
| 49 |
+
When we interpolate a '3' into an '8', the CNN shows a sharp "snap" in confidenceβit recognizes a topological boundary. In contrast, SVD's reconstruction error peaks at the midpoint because it's trying to fit a "linear bridge" between two distinct manifolds.
|
| 50 |
|
| 51 |
<p align="center">
|
| 52 |
+
<img src="research_results/fig_03_interpolation.png" alt="Decision Boundary Interpolation" width="700" />
|
| 53 |
<br>
|
| 54 |
+
<em><strong>Figure 3:</strong> CNN probability vs. Manifold Distance. The CNN's sharp boundary persists, while SVD creates "ghost" images at the midpoint that don't belong to either digit.</em>
|
| 55 |
</p>
|
| 56 |
|
| 57 |
+
### Local Topology: The Gap is the Signal
|
| 58 |
+
Grad-CAM heatmaps confirm that a CNN focuses exclusively on the **topological gap**. SVD, however, reconstructs a phantom loop. The linear model is "blind" to the gap because it optimizes for global pixel coincidence rather than shape continuity.
|
| 59 |
|
| 60 |
<p align="center">
|
| 61 |
+
<img src="research_results/fig_04_explainability.png" alt="Grad-CAM vs SVD" width="700" />
|
| 62 |
<br>
|
| 63 |
+
<em><strong>Figure 4:</strong> CNN attention (center) vs. SVD reconstruction (right). CNNs classify by shape discontinuity; SVD "hallucinates" a loop to satisfy energy constraints.</em>
|
| 64 |
</p>
|
| 65 |
|
| 66 |
+
### Quantifying Manifold Collapse
|
| 67 |
+
Using $k$-NN as a benchmark, we measured the damage:
|
| 68 |
+
- **Raw Pixel Space:** 98.7% Accuracy
|
| 69 |
+
- **SVD Subspace:** 97.0% Accuracy
|
| 70 |
+
This **130% relative increase in error** proves that SVD physically crushes the separation between 3s and 8s.
|
|
|
|
|
|
|
| 71 |
|
| 72 |
<p align="center">
|
| 73 |
+
<img src="research_results/fig_05_manifold_collapse.png" alt="Manifold Comparison" width="600"/>
|
| 74 |
<br>
|
| 75 |
+
<em><strong>Figure 5:</strong> SVD (left) collapses boundaries to maximize variance, whereas UMAP (right) preserves the topological separation required for high accuracy.</em>
|
| 76 |
</p>
|
| 77 |
|
| 78 |
---
|
| 79 |
|
| 80 |
+
## 3. Solution: Success on Low-Rank Manifolds (MNIST)
|
| 81 |
|
| 82 |
+
If SVD is so bad at discriminating, why use it? Because its "Variance Trap" is a perfect **Low-Pass Filter**.
|
| 83 |
|
| 84 |
+
In high-noise environments ($\sigma=0.7$), a raw CNN's accuracy drops to **30.1%**. A **Hybrid SVDβCNN** pipeline, however, maintains **65.5%** accuracy. By discarding low-variance dimensions, SVD acts as a "data-adapted shield," stripping away random Gaussian noise before it reaches the classifier.
|
| 85 |
+
|
| 86 |
+
However, this shield has a **blind spot**: if the noise is carefully aligned with the data's principal components (**Fig 7**), SVD preserves the noise rather than filtering it, making the model even more vulnerable than a raw CNN.
|
| 87 |
|
| 88 |
<p align="center">
|
| 89 |
+
<img src="research_results/fig_06_robustness_mnist_gaussian.png" alt="Hybrid Robustness" width="450" />
|
| 90 |
+
<img src="research_results/fig_07_robustness_mnist_svd_aligned.png" alt="Subspace Risk" width="450" />
|
| 91 |
<br>
|
| 92 |
+
<em><strong>Figure 6 & 7:</strong> SVD dominates under random noise (left) but becomes a liability if the noise is "aligned" with the data subspace (right), proving the defense is narrow-band.</em>
|
| 93 |
</p>
|
| 94 |
|
| 95 |
+
---
|
| 96 |
+
|
| 97 |
+
## 4. Boundary: Failure on Texture-Rich Manifolds (Fashion-MNIST)
|
| 98 |
+
|
| 99 |
+
The SVD defense only works when the objects are simple silhouettes. On **Fashion-MNIST**, the strategy collapses.
|
| 100 |
+
|
| 101 |
+
Items like Shirts and Pullovers aren't distinguished by global outlines, but by **high-frequency textures** (buttons, zippers, collar stitching). SVD's low-pass bias treats these textures as noise and deletes them. Accuracy drops from 91% (CNN) to 67% (Hybrid), defining the physical limit of linear denoising.
|
| 102 |
|
| 103 |
<p align="center">
|
| 104 |
+
<img src="research_results/fig_08_robustness_fashion.png" alt="Fashion texture collapse" width="500" />
|
| 105 |
<br>
|
| 106 |
+
<em><strong>Figure 8:</strong> On texture-rich data, SVD's "low-pass filter" becomes a "detail-destroyer," failing to preserve the features needed for non-linear models to succeed.</em>
|
| 107 |
</p>
|
| 108 |
|
| 109 |
---
|
| 110 |
|
| 111 |
## Conclusion
|
|
|
|
| 112 |
|
| 113 |
+
This study proves that SVD's fundamental bias toward **Global Variance** makes it a poor standalone classifier but a specialized defensive tool. It excels at denoising simple manifolds but fails catastrophically when locally discriminative details (like a gap in a '3' or a button on a shirt) are suppressed in favor of global energy.
|
| 114 |
|
| 115 |
+
---
|
| 116 |
|
| 117 |
+
## Appendix: Technical Details
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
+
### A. Learning Curves
|
| 120 |
+
Convergence was typically reached within 5-8 epochs using the Adam optimizer.
|
| 121 |
<p align="center">
|
| 122 |
+
<img src="research_results/fig_09_learning_curves.png" alt="Learning Curves" width="450" />
|
|
|
|
|
|
|
| 123 |
</p>
|
| 124 |
|
| 125 |
+
### B. Per-Class F1 Comparisons
|
| 126 |
+
SVD failures are consistently clustered in "Ambiguity Zones" (3 vs 8, 5 vs 3, 4 vs 9), where pixel-wise overlap is highest.
|
|
|
|
|
|
|
|
|
|
| 127 |
<p align="center">
|
| 128 |
+
<img src="research_results/fig_10_per_class_metrics_comparison.png" alt="F1 Comparison" width="800" />
|
|
|
|
|
|
|
| 129 |
</p>
|
|
|
|
|
|
|
|
|
docs/research_results/fig_01_spectrum.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
docs/research_results/{fig_02_svd_confusion.png β fig_01_svd_confusion.png}
RENAMED
|
File without changes
|
docs/research_results/{fig_03_eigen_digits.png β fig_02_eigen_digits.png}
RENAMED
|
File without changes
|
docs/research_results/{fig_05_interpolation.png β fig_03_interpolation.png}
RENAMED
|
File without changes
|
docs/research_results/fig_04_cnn_confusion.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
docs/research_results/{fig_06_explainability.png β fig_04_explainability.png}
RENAMED
|
File without changes
|
docs/research_results/{fig_08_manifold_collapse.png β fig_05_manifold_collapse.png}
RENAMED
|
File without changes
|
docs/research_results/{fig_14_learning_curves.png β fig_06_robustness_mnist_gaussian.png}
RENAMED
|
File without changes
|
docs/research_results/fig_07_robustness_mnist_svd_aligned.png
ADDED
|
Git LFS Details
|
docs/research_results/fig_08_robustness_fashion.png
ADDED
|
Git LFS Details
|
docs/research_results/fig_09_learning_curves.png
ADDED
|
Git LFS Details
|
docs/research_results/fig_10_per_class_metrics_comparison.png
ADDED
|
Git LFS Details
|
docs/research_results/fig_19_per_class_metrics_comparison.png
DELETED
Git LFS Details
|
docs/research_results/fig_robustness_fashion.png
DELETED
Git LFS Details
|
docs/research_results/fig_robustness_mnist.png
DELETED
Git LFS Details
|
experiments/01_phenomenon_diagnosis.py
CHANGED
|
@@ -43,7 +43,7 @@ def run_svd_analysis(X_train, y_train, X_test, y_test):
|
|
| 43 |
# 3. Visualization: Confusion Matrix & Eigen-digits
|
| 44 |
viz.plot_confusion_matrix(
|
| 45 |
y_test, y_pred, list(range(10)),
|
| 46 |
-
'
|
| 47 |
f'SVD Confusion Matrix (Acc={acc:.2f})',
|
| 48 |
viz.COLOR_SVD
|
| 49 |
)
|
|
@@ -52,7 +52,7 @@ def run_svd_analysis(X_train, y_train, X_test, y_test):
|
|
| 52 |
viz.plot_multi_image_grid(
|
| 53 |
[c.reshape(28, 28) for c in svd_20.components_[:10]],
|
| 54 |
component_titles, 2, 5,
|
| 55 |
-
'
|
| 56 |
'Global SVD Eigen-digits'
|
| 57 |
)
|
| 58 |
|
|
|
|
| 43 |
# 3. Visualization: Confusion Matrix & Eigen-digits
|
| 44 |
viz.plot_confusion_matrix(
|
| 45 |
y_test, y_pred, list(range(10)),
|
| 46 |
+
'fig_01_svd_confusion.png',
|
| 47 |
f'SVD Confusion Matrix (Acc={acc:.2f})',
|
| 48 |
viz.COLOR_SVD
|
| 49 |
)
|
|
|
|
| 52 |
viz.plot_multi_image_grid(
|
| 53 |
[c.reshape(28, 28) for c in svd_20.components_[:10]],
|
| 54 |
component_titles, 2, 5,
|
| 55 |
+
'fig_02_eigen_digits.png',
|
| 56 |
'Global SVD Eigen-digits'
|
| 57 |
)
|
| 58 |
|
experiments/{02_mechanistic_analysis.py β 02_mechanistic_proof.py}
RENAMED
|
@@ -8,7 +8,7 @@ import torch
|
|
| 8 |
import torch.nn as nn
|
| 9 |
import numpy as np
|
| 10 |
from sklearn.decomposition import TruncatedSVD
|
| 11 |
-
from sklearn.neighbors import KNeighborsClassifier
|
| 12 |
from sklearn.metrics import accuracy_score
|
| 13 |
|
| 14 |
from src import config, utils, viz, exp_utils
|
|
@@ -31,30 +31,36 @@ def run_interpolation_analysis(device):
|
|
| 31 |
img_3, img_8 = X_test[idx_3], X_test[idx_8]
|
| 32 |
|
| 33 |
alphas = np.linspace(0, 1, 11)
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
for alpha in alphas:
|
| 37 |
img_interp = (1 - alpha) * img_3 + alpha * img_8
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
else:
|
| 49 |
-
p = torch.softmax(logits, dim=1)[0, 1].item()
|
| 50 |
-
probs_8.append(p)
|
| 51 |
|
| 52 |
# SVD Reconstruction Error
|
| 53 |
flat = img_interp.view(1, -1).numpy()
|
| 54 |
rec = svd.inverse_transform(svd.transform(flat - mean)) + mean
|
| 55 |
rec_errors.append(np.linalg.norm(flat - rec))
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
def run_quantifying_manifold_collapse():
|
| 60 |
print("\n--- Running Experiment 7: Quantifying Manifold Collapse ---")
|
|
@@ -83,7 +89,7 @@ def run_quantifying_manifold_collapse():
|
|
| 83 |
import umap
|
| 84 |
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
|
| 85 |
X_umap = reducer.fit_transform(X_test_np)
|
| 86 |
-
viz.plot_manifold_comparison(X_test_svd, X_umap, y_test_np, acc_svd, acc_raw, '
|
| 87 |
except Exception as e:
|
| 88 |
print(f"Warning: Manifold visualization failed: {e}")
|
| 89 |
|
|
|
|
| 8 |
import torch.nn as nn
|
| 9 |
import numpy as np
|
| 10 |
from sklearn.decomposition import TruncatedSVD
|
| 11 |
+
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
|
| 12 |
from sklearn.metrics import accuracy_score
|
| 13 |
|
| 14 |
from src import config, utils, viz, exp_utils
|
|
|
|
| 31 |
img_3, img_8 = X_test[idx_3], X_test[idx_8]
|
| 32 |
|
| 33 |
alphas = np.linspace(0, 1, 11)
|
| 34 |
+
probs_dict = {f'T={t:.1f}': [] for t in [1.0, 2.0, 5.0]}
|
| 35 |
+
rec_errors, manifold_dists = [], []
|
| 36 |
+
|
| 37 |
+
# Fit Nearest Neighbors on Training set to measure distance to manifold
|
| 38 |
+
X_train, y_train = utils.load_data_split(dataset_name="mnist", train=True, digits=[3, 8], flatten=True)
|
| 39 |
+
nn_manifold = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(X_train.numpy())
|
| 40 |
|
| 41 |
for alpha in alphas:
|
| 42 |
img_interp = (1 - alpha) * img_3 + alpha * img_8
|
| 43 |
+
for t_str, p_list in probs_dict.items():
|
| 44 |
+
temp = float(t_str.split('=')[1])
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
logits = cnn(img_interp.unsqueeze(0).to(device))
|
| 47 |
+
out_dim = logits.shape[1]
|
| 48 |
+
if out_dim == 10:
|
| 49 |
+
p = torch.softmax(logits / temp, dim=1)[0, 8].item()
|
| 50 |
+
else:
|
| 51 |
+
p = torch.softmax(logits / temp, dim=1)[0, 1].item()
|
| 52 |
+
p_list.append(p)
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# SVD Reconstruction Error
|
| 55 |
flat = img_interp.view(1, -1).numpy()
|
| 56 |
rec = svd.inverse_transform(svd.transform(flat - mean)) + mean
|
| 57 |
rec_errors.append(np.linalg.norm(flat - rec))
|
| 58 |
|
| 59 |
+
# Distance to Real Manifold (784D)
|
| 60 |
+
dist, _ = nn_manifold.kneighbors(flat)
|
| 61 |
+
manifold_dists.append(dist[0][0])
|
| 62 |
+
|
| 63 |
+
viz.plot_interpolation_dynamics(alphas, probs_dict, rec_errors, 'fig_03_interpolation.png', manifold_distances=manifold_dists)
|
| 64 |
|
| 65 |
def run_quantifying_manifold_collapse():
|
| 66 |
print("\n--- Running Experiment 7: Quantifying Manifold Collapse ---")
|
|
|
|
| 89 |
import umap
|
| 90 |
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
|
| 91 |
X_umap = reducer.fit_transform(X_test_np)
|
| 92 |
+
viz.plot_manifold_comparison(X_test_svd, X_umap, y_test_np, acc_svd, acc_raw, 'fig_05_manifold_collapse.png')
|
| 93 |
except Exception as e:
|
| 94 |
print(f"Warning: Manifold visualization failed: {e}")
|
| 95 |
|
experiments/{run_robustness_test.py β 03_operational_boundaries.py}
RENAMED
|
@@ -15,6 +15,7 @@ def run_experiment(args):
|
|
| 15 |
print(f"\n--- Running Robustness Test: {args.dataset.upper()} ---")
|
| 16 |
|
| 17 |
# 1. Load Data and Models
|
|
|
|
| 18 |
X_test, y_test = utils.load_data_split(dataset_name=args.dataset, train=False)
|
| 19 |
_, cnn = utils.load_models(dataset_name=args.dataset)
|
| 20 |
|
|
@@ -22,9 +23,9 @@ def run_experiment(args):
|
|
| 22 |
return
|
| 23 |
|
| 24 |
# 2. Fit SVD Baseline and Build Hybrid Model
|
| 25 |
-
print("Fitting SVD Baseline...")
|
| 26 |
-
|
| 27 |
-
svd_pipe = exp_utils.fit_svd_baseline(
|
| 28 |
|
| 29 |
svd = svd_pipe.named_steps['svd']
|
| 30 |
scaler = svd_pipe.named_steps['scaler']
|
|
@@ -37,8 +38,14 @@ def run_experiment(args):
|
|
| 37 |
results = {'CNN': [], 'SVD': [], 'Hybrid': []}
|
| 38 |
|
| 39 |
# 4. Evaluation Loop
|
|
|
|
|
|
|
|
|
|
| 40 |
for sigma in sigmas:
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
results['CNN'].append(exp_utils.evaluate_classifier(cnn, X_noisy, y_test, device))
|
| 44 |
results['SVD'].append(exp_utils.evaluate_classifier(svd_pipe, X_noisy, y_test, is_pytorch=False))
|
|
@@ -47,17 +54,26 @@ def run_experiment(args):
|
|
| 47 |
print(f"Ο={sigma:.1f} | CNN: {results['CNN'][-1]:.4f} | SVD: {results['SVD'][-1]:.4f} | Hybrid: {results['Hybrid'][-1]:.4f}")
|
| 48 |
|
| 49 |
# 5. Visualization
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
viz.plot_robustness_curves(
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
title=f'Robustness Analysis: {args.dataset.upper()}',
|
| 55 |
-
filename=f'fig_robustness_{args.dataset}.png'
|
| 56 |
)
|
| 57 |
|
| 58 |
def main():
|
| 59 |
parser = argparse.ArgumentParser(description="Unified Robustness Evaluation")
|
| 60 |
parser.add_argument("--dataset", choices=["mnist", "fashion"], default="mnist", help="Dataset to evaluate.")
|
|
|
|
| 61 |
args = parser.parse_args()
|
| 62 |
run_experiment(args)
|
| 63 |
|
|
|
|
| 15 |
print(f"\n--- Running Robustness Test: {args.dataset.upper()} ---")
|
| 16 |
|
| 17 |
# 1. Load Data and Models
|
| 18 |
+
X_train, y_train = utils.load_data_split(dataset_name=args.dataset, train=True)
|
| 19 |
X_test, y_test = utils.load_data_split(dataset_name=args.dataset, train=False)
|
| 20 |
_, cnn = utils.load_models(dataset_name=args.dataset)
|
| 21 |
|
|
|
|
| 23 |
return
|
| 24 |
|
| 25 |
# 2. Fit SVD Baseline and Build Hybrid Model
|
| 26 |
+
print("Fitting SVD Baseline on Training Data...")
|
| 27 |
+
X_train_flat = X_train.view(X_train.size(0), -1).numpy()
|
| 28 |
+
svd_pipe = exp_utils.fit_svd_baseline(X_train_flat, y_train.numpy(), n_components=20)
|
| 29 |
|
| 30 |
svd = svd_pipe.named_steps['svd']
|
| 31 |
scaler = svd_pipe.named_steps['scaler']
|
|
|
|
| 38 |
results = {'CNN': [], 'SVD': [], 'Hybrid': []}
|
| 39 |
|
| 40 |
# 4. Evaluation Loop
|
| 41 |
+
noise_label = 'Gaussian' if args.noise_type == 'gaussian' else 'SVD-Aligned'
|
| 42 |
+
print(f"Noise Type: {noise_label}")
|
| 43 |
+
|
| 44 |
for sigma in sigmas:
|
| 45 |
+
if args.noise_type == "svd_aligned":
|
| 46 |
+
X_noisy = exp_utils.add_svd_aligned_noise(X_test, sigma, svd.components_)
|
| 47 |
+
else:
|
| 48 |
+
X_noisy = exp_utils.add_gaussian_noise(X_test, sigma)
|
| 49 |
|
| 50 |
results['CNN'].append(exp_utils.evaluate_classifier(cnn, X_noisy, y_test, device))
|
| 51 |
results['SVD'].append(exp_utils.evaluate_classifier(svd_pipe, X_noisy, y_test, is_pytorch=False))
|
|
|
|
| 54 |
print(f"Ο={sigma:.1f} | CNN: {results['CNN'][-1]:.4f} | SVD: {results['SVD'][-1]:.4f} | Hybrid: {results['Hybrid'][-1]:.4f}")
|
| 55 |
|
| 56 |
# 5. Visualization
|
| 57 |
+
# Map to new sequential names
|
| 58 |
+
if args.dataset == "mnist" and args.noise_type == "gaussian":
|
| 59 |
+
filename = "fig_06_robustness_mnist_gaussian.png"
|
| 60 |
+
elif args.dataset == "mnist" and args.noise_type == "svd_aligned":
|
| 61 |
+
filename = "fig_07_robustness_mnist_svd_aligned.png"
|
| 62 |
+
elif args.dataset == "fashion":
|
| 63 |
+
filename = "fig_08_robustness_fashion.png"
|
| 64 |
+
else:
|
| 65 |
+
filename = f'fig_robustness_{args.dataset}_{args.noise_type}.png'
|
| 66 |
+
|
| 67 |
viz.plot_robustness_curves(
|
| 68 |
+
sigmas, results, f'{noise_label} Noise Level (Ο)',
|
| 69 |
+
f'Robustness Analysis ({noise_label}): {args.dataset.upper()}',
|
| 70 |
+
filename
|
|
|
|
|
|
|
| 71 |
)
|
| 72 |
|
| 73 |
def main():
|
| 74 |
parser = argparse.ArgumentParser(description="Unified Robustness Evaluation")
|
| 75 |
parser.add_argument("--dataset", choices=["mnist", "fashion"], default="mnist", help="Dataset to evaluate.")
|
| 76 |
+
parser.add_argument("--noise_type", choices=["gaussian", "svd_aligned"], default="gaussian", help="Type of noise to apply.")
|
| 77 |
args = parser.parse_args()
|
| 78 |
run_experiment(args)
|
| 79 |
|
experiments/{appendix_learning_curves.py β 04_appendix_learning_curves.py}
RENAMED
|
@@ -9,7 +9,7 @@ from src import config, viz
|
|
| 9 |
|
| 10 |
def main():
|
| 11 |
experiments = [
|
| 12 |
-
('cnn_10class_history.pkl', 'MNIST 10-class CNN Training', '
|
| 13 |
('cnn_fashion_history.pkl', 'Fashion-MNIST CNN Training', 'fig_15_learning_curves_fashion.png')
|
| 14 |
]
|
| 15 |
|
|
|
|
| 9 |
|
| 10 |
def main():
|
| 11 |
experiments = [
|
| 12 |
+
('cnn_10class_history.pkl', 'MNIST 10-class CNN Training', 'fig_09_learning_curves.png'),
|
| 13 |
('cnn_fashion_history.pkl', 'Fashion-MNIST CNN Training', 'fig_15_learning_curves_fashion.png')
|
| 14 |
]
|
| 15 |
|
experiments/{appendix_per_class_metrics.py β 05_appendix_per_class_metrics.py}
RENAMED
|
@@ -30,7 +30,10 @@ def main():
|
|
| 30 |
y_preds_dict['CNN'] = cnn(X_test.to(device)).argmax(dim=1).cpu().numpy()
|
| 31 |
|
| 32 |
# SVD+LR Predictions
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# 2. Print Metrics Report
|
| 36 |
from sklearn.metrics import recall_score, precision_score, f1_score
|
|
@@ -45,7 +48,7 @@ def main():
|
|
| 45 |
viz.plot_per_class_comparison(
|
| 46 |
y_test_np,
|
| 47 |
y_preds_dict,
|
| 48 |
-
'
|
| 49 |
)
|
| 50 |
print("Appendix B Completed.")
|
| 51 |
|
|
|
|
| 30 |
y_preds_dict['CNN'] = cnn(X_test.to(device)).argmax(dim=1).cpu().numpy()
|
| 31 |
|
| 32 |
# SVD+LR Predictions
|
| 33 |
+
print("Fitting SVD Baseline (10-class)...")
|
| 34 |
+
X_train_full, y_train_full = utils.load_data_split(dataset_name="mnist", train=True, flatten=True)
|
| 35 |
+
svd_pipe_fitted = exp_utils.fit_svd_baseline(X_train_full.numpy(), y_train_full.numpy(), n_components=20)
|
| 36 |
+
y_preds_dict['SVD+LR'] = svd_pipe_fitted.predict(X_test_flat)
|
| 37 |
|
| 38 |
# 2. Print Metrics Report
|
| 39 |
from sklearn.metrics import recall_score, precision_score, f1_score
|
|
|
|
| 48 |
viz.plot_per_class_comparison(
|
| 49 |
y_test_np,
|
| 50 |
y_preds_dict,
|
| 51 |
+
'fig_10_per_class_metrics_comparison.png'
|
| 52 |
)
|
| 53 |
print("Appendix B Completed.")
|
| 54 |
|
run_all_experiments.sh
CHANGED
|
@@ -4,16 +4,16 @@
|
|
| 4 |
set -e
|
| 5 |
|
| 6 |
echo "=== 1. Phenomenon Diagnosis (Global SVD & CNN Baseline) ==="
|
| 7 |
-
python experiments
|
| 8 |
|
| 9 |
echo "=== 2. Mechanistic Analysis (Interpolation, Explainability, etc.) ==="
|
| 10 |
-
python experiments
|
| 11 |
|
| 12 |
echo "=== 3. Robustness and Boundary Tests (MNIST) ==="
|
| 13 |
-
python experiments
|
| 14 |
|
| 15 |
echo "=== 4. Robustness and Boundary Tests (Fashion-MNIST) ==="
|
| 16 |
-
python experiments
|
| 17 |
|
| 18 |
echo "=========================================================="
|
| 19 |
echo "All experiments completed successfully."
|
|
|
|
| 4 |
set -e
|
| 5 |
|
| 6 |
echo "=== 1. Phenomenon Diagnosis (Global SVD & CNN Baseline) ==="
|
| 7 |
+
python -m experiments.01_phenomenon_diagnosis
|
| 8 |
|
| 9 |
echo "=== 2. Mechanistic Analysis (Interpolation, Explainability, etc.) ==="
|
| 10 |
+
python -m experiments.02_mechanistic_proof
|
| 11 |
|
| 12 |
echo "=== 3. Robustness and Boundary Tests (MNIST) ==="
|
| 13 |
+
python -m experiments.03_operational_boundaries --dataset mnist
|
| 14 |
|
| 15 |
echo "=== 4. Robustness and Boundary Tests (Fashion-MNIST) ==="
|
| 16 |
+
python -m experiments.03_operational_boundaries --dataset fashion
|
| 17 |
|
| 18 |
echo "=========================================================="
|
| 19 |
echo "All experiments completed successfully."
|
run_migration.sh
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This script performs the renaming of scripts and figures, and updates references in the code and report.
|
| 4 |
+
# Run this from the project root: /Users/ymlin/Downloads/003-Study/137-Projects/01-mnist-linear-vs-nonlinear
|
| 5 |
+
|
| 6 |
+
echo "Starting migration..."
|
| 7 |
+
|
| 8 |
+
# 1. Rename Scripts
|
| 9 |
+
echo "Renaming scripts..."
|
| 10 |
+
mv experiments/01_exp_diagnosis.py experiments/01_phenomenon_diagnosis.py
|
| 11 |
+
mv experiments/02_mechanistic_analysis.py experiments/02_mechanistic_proof.py
|
| 12 |
+
mv experiments/run_robustness_test.py experiments/03_operational_boundaries.py
|
| 13 |
+
mv experiments/appendix_learning_curves.py experiments/04_appendix_learning_curves.py
|
| 14 |
+
mv experiments/appendix_per_class_metrics.py experiments/05_appendix_per_class_metrics.py
|
| 15 |
+
|
| 16 |
+
# 2. Rename Figures
|
| 17 |
+
echo "Renaming figures..."
|
| 18 |
+
cd docs/research_results || exit
|
| 19 |
+
mv fig_02_svd_confusion.png fig_01_svd_confusion.png
|
| 20 |
+
mv fig_03_eigen_digits.png fig_02_eigen_digits.png
|
| 21 |
+
mv fig_05_interpolation.png fig_03_interpolation.png
|
| 22 |
+
mv fig_06_explainability.png fig_04_explainability.png
|
| 23 |
+
mv fig_08_manifold_collapse.png fig_05_manifold_collapse.png
|
| 24 |
+
mv fig_robustness_mnist_gaussian.png fig_06_robustness_mnist_gaussian.png
|
| 25 |
+
mv fig_robustness_mnist_svd_aligned.png fig_07_robustness_mnist_svd_aligned.png
|
| 26 |
+
mv fig_robustness_fashion.png fig_08_robustness_fashion.png
|
| 27 |
+
mv fig_14_learning_curves.png fig_09_learning_curves.png
|
| 28 |
+
mv fig_19_per_class_metrics_comparison.png fig_10_per_class_metrics_comparison.png
|
| 29 |
+
cd ../..
|
| 30 |
+
|
| 31 |
+
# 3. Update Python Scripts (Using sed for macOS)
|
| 32 |
+
echo "Updating Python scripts..."
|
| 33 |
+
|
| 34 |
+
# 01_phenomenon_diagnosis.py
|
| 35 |
+
sed -i '' 's/fig_02_svd_confusion.png/fig_01_svd_confusion.png/g' experiments/01_phenomenon_diagnosis.py
|
| 36 |
+
sed -i '' 's/fig_03_eigen_digits.png/fig_02_eigen_digits.png/g' experiments/01_phenomenon_diagnosis.py
|
| 37 |
+
sed -i '' 's/fig_04_cnn_confusion.png/fig_01b_cnn_confusion.png/g' experiments/01_phenomenon_diagnosis.py
|
| 38 |
+
|
| 39 |
+
# 02_mechanistic_proof.py
|
| 40 |
+
sed -i '' 's/fig_05_interpolation.png/fig_03_interpolation.png/g' experiments/02_mechanistic_proof.py
|
| 41 |
+
sed -i '' 's/fig_06_explainability.png/fig_04_explainability.png/g' experiments/02_mechanistic_proof.py
|
| 42 |
+
sed -i '' 's/fig_08_manifold_collapse.png/fig_05_manifold_collapse.png/g' experiments/02_mechanistic_proof.py
|
| 43 |
+
|
| 44 |
+
# 03_operational_boundaries.py
|
| 45 |
+
sed -i '' 's/fig_robustness_mnist_gaussian.png/fig_06_robustness_mnist_gaussian.png/g' experiments/03_operational_boundaries.py
|
| 46 |
+
sed -i '' 's/fig_robustness_mnist_svd_aligned.png/fig_07_robustness_mnist_svd_aligned.png/g' experiments/03_operational_boundaries.py
|
| 47 |
+
sed -i '' 's/fig_robustness_fashion.png/fig_08_robustness_fashion.png/g' experiments/03_operational_boundaries.py
|
| 48 |
+
|
| 49 |
+
# 04_appendix_learning_curves.py
|
| 50 |
+
sed -i '' 's/fig_14_learning_curves.png/fig_09_learning_curves.png/g' experiments/04_appendix_learning_curves.py
|
| 51 |
+
|
| 52 |
+
# 05_appendix_per_class_metrics.py
|
| 53 |
+
sed -i '' 's/fig_19_per_class_metrics_comparison.png/fig_10_per_class_metrics_comparison.png/g' experiments/05_appendix_per_class_metrics.py
|
| 54 |
+
|
| 55 |
+
# 4. Update Report (Using sed for macOS)
|
| 56 |
+
echo "Updating REPORT.md..."
|
| 57 |
+
sed -i '' 's/fig_02_svd_confusion.png/fig_01_svd_confusion.png/g' docs/REPORT.md
|
| 58 |
+
sed -i '' 's/fig_03_eigen_digits.png/fig_02_eigen_digits.png/g' docs/REPORT.md
|
| 59 |
+
sed -i '' 's/fig_05_interpolation.png/fig_03_interpolation.png/g' docs/REPORT.md
|
| 60 |
+
sed -i '' 's/fig_06_explainability.png/fig_04_explainability.png/g' docs/REPORT.md
|
| 61 |
+
sed -i '' 's/fig_08_manifold_collapse.png/fig_05_manifold_collapse.png/g' docs/REPORT.md
|
| 62 |
+
sed -i '' 's/fig_robustness_mnist_gaussian.png/fig_06_robustness_mnist_gaussian.png/g' docs/REPORT.md
|
| 63 |
+
sed -i '' 's/fig_robustness_mnist_svd_aligned.png/fig_07_robustness_mnist_svd_aligned.png/g' docs/REPORT.md
|
| 64 |
+
sed -i '' 's/fig_robustness_fashion.png/fig_08_robustness_fashion.png/g' docs/REPORT.md
|
| 65 |
+
sed -i '' 's/fig_14_learning_curves.png/fig_09_learning_curves.png/g' docs/REPORT.md
|
| 66 |
+
sed -i '' 's/fig_19_per_class_metrics_comparison.png/fig_10_per_class_metrics_comparison.png/g' docs/REPORT.md
|
| 67 |
+
|
| 68 |
+
echo "Migration completed successfully!"
|
src/exp_utils.py
CHANGED
|
@@ -10,7 +10,7 @@ from sklearn.preprocessing import StandardScaler
|
|
| 10 |
def fit_svd_baseline(X_train, y_train, n_components=20):
|
| 11 |
"""Fits a linear baseline (SVD + Logistic Regression) on the fly."""
|
| 12 |
pipeline = Pipeline([
|
| 13 |
-
('scaler', StandardScaler()),
|
| 14 |
('svd', TruncatedSVD(n_components=n_components, random_state=42)),
|
| 15 |
('logistic', LogisticRegression(max_iter=1000))
|
| 16 |
])
|
|
@@ -30,6 +30,40 @@ def add_gaussian_noise(X, sigma):
|
|
| 30 |
noise = np.random.randn(*X.shape) * sigma
|
| 31 |
return np.clip(X + noise, 0, 1)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def add_blur(X, kernel_size):
|
| 34 |
"""Unified blur for torch Tensors (4D: B, C, H, W)."""
|
| 35 |
if kernel_size <= 1:
|
|
|
|
| 10 |
def fit_svd_baseline(X_train, y_train, n_components=20):
|
| 11 |
"""Fits a linear baseline (SVD + Logistic Regression) on the fly."""
|
| 12 |
pipeline = Pipeline([
|
| 13 |
+
('scaler', StandardScaler(with_std=False)),
|
| 14 |
('svd', TruncatedSVD(n_components=n_components, random_state=42)),
|
| 15 |
('logistic', LogisticRegression(max_iter=1000))
|
| 16 |
])
|
|
|
|
| 30 |
noise = np.random.randn(*X.shape) * sigma
|
| 31 |
return np.clip(X + noise, 0, 1)
|
| 32 |
|
| 33 |
+
def add_svd_aligned_noise(X, sigma, components):
|
| 34 |
+
"""
|
| 35 |
+
Adds noise that is projected onto the SVD components, living entirely
|
| 36 |
+
within the 'signal' subspace.
|
| 37 |
+
"""
|
| 38 |
+
if sigma <= 0: return X
|
| 39 |
+
is_tensor = torch.is_tensor(X)
|
| 40 |
+
|
| 41 |
+
# Flatten if needed
|
| 42 |
+
orig_shape = list(X.shape)
|
| 43 |
+
if is_tensor:
|
| 44 |
+
X_flat = X.cpu().numpy().reshape(orig_shape[0], -1)
|
| 45 |
+
components_np = components.cpu().numpy() if torch.is_tensor(components) else components
|
| 46 |
+
else:
|
| 47 |
+
X_flat = X.reshape(orig_shape[0], -1)
|
| 48 |
+
components_np = components
|
| 49 |
+
|
| 50 |
+
# 1. Generate random Gaussian noise in full dimensionality
|
| 51 |
+
noise = np.random.randn(*X_flat.shape) * sigma
|
| 52 |
+
|
| 53 |
+
# 2. Project noise onto components (V_k)
|
| 54 |
+
# V_k (components_np) is assumed to be (k, 784)
|
| 55 |
+
# Projection P = V_k^T @ V_k
|
| 56 |
+
projected_noise = (noise @ components_np.T) @ components_np
|
| 57 |
+
|
| 58 |
+
# 3. Add back and clip
|
| 59 |
+
X_noisy = X_flat + projected_noise
|
| 60 |
+
X_noisy = np.clip(X_noisy, 0, 1)
|
| 61 |
+
|
| 62 |
+
if is_tensor:
|
| 63 |
+
return torch.from_numpy(X_noisy).float().view(orig_shape)
|
| 64 |
+
else:
|
| 65 |
+
return X_noisy.reshape(orig_shape)
|
| 66 |
+
|
| 67 |
def add_blur(X, kernel_size):
|
| 68 |
"""Unified blur for torch Tensors (4D: B, C, H, W)."""
|
| 69 |
if kernel_size <= 1:
|
src/hybrid_model.py
CHANGED
|
@@ -29,6 +29,11 @@ class SVDProjectionLayer(nn.Module):
|
|
| 29 |
def forward(self, x):
|
| 30 |
b = x.size(0)
|
| 31 |
x_rec = (x.view(b, -1) - self.mean) @ self.V_k.T @ self.V_k + self.mean
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
return torch.clamp(x_rec, 0, 1).view(b, 1, 28, 28)
|
| 33 |
|
| 34 |
class HybridSVDCNN(nn.Module):
|
|
|
|
| 29 |
def forward(self, x):
|
| 30 |
b = x.size(0)
|
| 31 |
x_rec = (x.view(b, -1) - self.mean) @ self.V_k.T @ self.V_k + self.mean
|
| 32 |
+
|
| 33 |
+
# Monitor truncation ratio (percentage of pixels outside [0, 1])
|
| 34 |
+
out_of_range = (x_rec < 0) | (x_rec > 1)
|
| 35 |
+
self.last_clamp_ratio = out_of_range.float().mean().item()
|
| 36 |
+
|
| 37 |
return torch.clamp(x_rec, 0, 1).view(b, 1, 28, 28)
|
| 38 |
|
| 39 |
class HybridSVDCNN(nn.Module):
|
src/viz.py
CHANGED
|
@@ -81,20 +81,36 @@ def plot_singular_spectrum(singular_values, cumulative_variance, filename):
|
|
| 81 |
fig.legend(loc="upper right", bbox_to_anchor=(1,1), bbox_transform=ax1.transAxes)
|
| 82 |
save_fig(filename)
|
| 83 |
|
| 84 |
-
def plot_interpolation_dynamics(alphas,
|
| 85 |
-
"""Visualizes the CNN response vs SVD reconstruction error
|
| 86 |
setup_style()
|
| 87 |
-
plt.
|
| 88 |
-
|
| 89 |
-
plt.plot(alphas, probs_8, color=COLOR_CNN, label='CNN Prob(8) [Topology]', marker='o', linewidth=2)
|
| 90 |
-
plt.plot(alphas, rec_errors, color=COLOR_SVD, label='SVD Rec Error [Global Variance]', marker='s', linewidth=2)
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
save_fig(filename)
|
| 99 |
|
| 100 |
def plot_manifold_comparison(X_svd, X_umap, y, acc_svd, acc_raw, filename):
|
|
@@ -126,7 +142,7 @@ def plot_manifold_comparison(X_svd, X_umap, y, acc_svd, acc_raw, filename):
|
|
| 126 |
def plot_learning_curves(history, title, filename):
|
| 127 |
"""Standardized plotter for training history (loss and accuracy)."""
|
| 128 |
setup_style()
|
| 129 |
-
epochs
|
| 130 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
| 131 |
|
| 132 |
# Nord palette for curves
|
|
@@ -134,8 +150,10 @@ def plot_learning_curves(history, title, filename):
|
|
| 134 |
COLOR_VAL = "#D08770" # Nord 12 (Orange)
|
| 135 |
|
| 136 |
# Loss Plot
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
| 139 |
ax1.set_title('Loss Dynamics', fontsize=12, fontweight='bold')
|
| 140 |
ax1.set_xlabel('Epoch')
|
| 141 |
ax1.set_ylabel('Loss')
|
|
@@ -143,8 +161,10 @@ def plot_learning_curves(history, title, filename):
|
|
| 143 |
ax1.grid(True)
|
| 144 |
|
| 145 |
# Accuracy Plot
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
ax2.set_title('Accuracy Dynamics', fontsize=12, fontweight='bold')
|
| 149 |
ax2.set_xlabel('Epoch')
|
| 150 |
ax2.set_ylabel('Accuracy')
|
|
|
|
| 81 |
fig.legend(loc="upper right", bbox_to_anchor=(1,1), bbox_transform=ax1.transAxes)
|
| 82 |
save_fig(filename)
|
| 83 |
|
| 84 |
+
def plot_interpolation_dynamics(alphas, probs_dict, rec_errors, filename, manifold_distances=None):
|
| 85 |
+
"""Visualizes the CNN response vs SVD reconstruction error and manifold distance."""
|
| 86 |
setup_style()
|
| 87 |
+
fig, ax1 = plt.subplots(figsize=(10, 6))
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
# Support both single list and dict of labels->probs
|
| 90 |
+
if isinstance(probs_dict, list):
|
| 91 |
+
probs_dict = {'CNN Prob(8)': probs_dict}
|
| 92 |
+
|
| 93 |
+
styles = ['-', '--', ':', '-.']
|
| 94 |
+
for i, (label, probs) in enumerate(probs_dict.items()):
|
| 95 |
+
ax1.plot(alphas, probs, label=label, marker='o' if i==0 else None,
|
| 96 |
+
linestyle=styles[i % len(styles)], linewidth=2)
|
| 97 |
+
|
| 98 |
+
ax1.plot(alphas, rec_errors, color=COLOR_SVD, label='SVD Rec Error', marker='s', linewidth=2, alpha=0.6)
|
| 99 |
+
ax1.set_xlabel('Alpha (0=Digit 3, 1=Digit 8)', fontsize=12)
|
| 100 |
+
ax1.set_ylabel('CNN Prob / Rec Error', fontsize=12)
|
| 101 |
+
|
| 102 |
+
if manifold_distances is not None:
|
| 103 |
+
ax2 = ax1.twinx()
|
| 104 |
+
ax2.plot(alphas, manifold_distances, color="#D08770", label='Manifold Distance', marker='^', linestyle='--', linewidth=2)
|
| 105 |
+
ax2.set_ylabel('Dist to Nearest Neighbor', color="#D08770", fontsize=12)
|
| 106 |
+
ax2.tick_params(axis='y', labelcolor="#D08770")
|
| 107 |
+
fig.legend(loc="upper right", bbox_to_anchor=(0.9, 0.9), bbox_transform=ax1.transAxes)
|
| 108 |
+
else:
|
| 109 |
+
ax1.legend()
|
| 110 |
+
|
| 111 |
+
plt.title('Mechanistic Dynamics: Snap Analysis with Temperature Scaling', fontsize=14, fontweight='bold', pad=15)
|
| 112 |
+
ax1.axvline(x=0.5, color='#4C566A', linestyle='--', alpha=0.5, label='Ambiguity Mid-point')
|
| 113 |
+
ax1.grid(True)
|
| 114 |
save_fig(filename)
|
| 115 |
|
| 116 |
def plot_manifold_comparison(X_svd, X_umap, y, acc_svd, acc_raw, filename):
|
|
|
|
| 142 |
def plot_learning_curves(history, title, filename):
|
| 143 |
"""Standardized plotter for training history (loss and accuracy)."""
|
| 144 |
setup_style()
|
| 145 |
+
# Calculate epochs separately for loss and accuracy
|
| 146 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
| 147 |
|
| 148 |
# Nord palette for curves
|
|
|
|
| 150 |
COLOR_VAL = "#D08770" # Nord 12 (Orange)
|
| 151 |
|
| 152 |
# Loss Plot
|
| 153 |
+
if len(history.get('train_loss', [])) > 0:
|
| 154 |
+
epochs_loss = range(1, len(history['train_loss']) + 1)
|
| 155 |
+
ax1.plot(epochs_loss, history['train_loss'], label='Train', color=COLOR_TRAIN, marker='o', markersize=4, linewidth=1.5)
|
| 156 |
+
ax1.plot(epochs_loss, history['val_loss'], label='Val', color=COLOR_VAL, marker='s', markersize=4, linewidth=1.5)
|
| 157 |
ax1.set_title('Loss Dynamics', fontsize=12, fontweight='bold')
|
| 158 |
ax1.set_xlabel('Epoch')
|
| 159 |
ax1.set_ylabel('Loss')
|
|
|
|
| 161 |
ax1.grid(True)
|
| 162 |
|
| 163 |
# Accuracy Plot
|
| 164 |
+
if len(history.get('train_acc', [])) > 0:
|
| 165 |
+
epochs_acc = range(1, len(history['train_acc']) + 1)
|
| 166 |
+
ax2.plot(epochs_acc, history['train_acc'], label='Train', color=COLOR_TRAIN, marker='o', markersize=4, linewidth=1.5)
|
| 167 |
+
ax2.plot(epochs_acc, history['val_acc'], label='Val', color=COLOR_VAL, marker='s', markersize=4, linewidth=1.5)
|
| 168 |
ax2.set_title('Accuracy Dynamics', fontsize=12, fontweight='bold')
|
| 169 |
ax2.set_xlabel('Epoch')
|
| 170 |
ax2.set_ylabel('Accuracy')
|