Spaces:
Sleeping
Sleeping
refactor: optimize project structure and documentation
Browse files- .gitattributes +0 -5
- README.md +9 -6
- app.py +157 -42
- docs/REPORT.md +142 -67
- docs/research_results/fig_01_spectrum.png +2 -2
- docs/research_results/fig_01_svd_confusion.png +2 -2
- docs/research_results/fig_02_eigen_digits.png +2 -2
- docs/research_results/fig_03_interpolation.png +2 -2
- docs/research_results/fig_04_cnn_confusion.png +0 -3
- docs/research_results/fig_05_manifold_collapse.png +2 -2
- docs/research_results/fig_06_robustness_mnist_gaussian.png +2 -2
- docs/research_results/fig_07_robustness_mnist_svd_aligned.png +0 -3
- docs/research_results/fig_08_robustness_fashion.png +2 -2
- docs/research_results/fig_09_learning_curves.png +0 -3
- docs/research_results/fig_10_per_class_metrics_comparison.png +0 -3
- experiments/03_operational_boundaries.py +2 -2
- experiments/04_appendix_learning_curves.py +0 -26
- experiments/05_appendix_per_class_metrics.py +0 -56
- run_migration.sh +0 -68
- src/viz.py +20 -11
.gitattributes
CHANGED
|
@@ -1,6 +1 @@
|
|
| 1 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.png filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
| 1 |
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -14,13 +14,13 @@ pinned: false
|
|
| 14 |
|
| 15 |
[](https://huggingface.co/spaces/ymlin105/Coconut-MNIST) [](./docs/REPORT.md)
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
<p align="center">
|
| 20 |
<img src="./docs/research_results/fig_04_explainability.png" width="600" alt="Mechanistic Analysis: SVD Blind Spot">
|
| 21 |
</p>
|
| 22 |
|
| 23 |
-
|
| 24 |
|
| 25 |
## The Solution: Hybrid SVD-CNN
|
| 26 |
|
|
@@ -64,10 +64,13 @@ flowchart TD
|
|
| 64 |
### Key Takeaways
|
| 65 |
For full analysis and detailed metrics, see the [Technical Report](./docs/REPORT.md).
|
| 66 |
|
| 67 |
-
1. **The Variance Trap**:
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
---
|
| 73 |
|
|
|
|
| 14 |
|
| 15 |
[](https://huggingface.co/spaces/ymlin105/Coconut-MNIST) [](./docs/REPORT.md)
|
| 16 |
|
| 17 |
+
This project investigates **why SVD systematically misclassifies digit 3 as 8**, revealing fundamental differences between linear (variance-based) and non-linear (topology-based) representations. Through mechanistic analysis and empirical validation, we show that SVD and CNN optimize different objectives, leading to **complementary strengths and failure modes**.
|
| 18 |
|
| 19 |
<p align="center">
|
| 20 |
<img src="./docs/research_results/fig_04_explainability.png" width="600" alt="Mechanistic Analysis: SVD Blind Spot">
|
| 21 |
</p>
|
| 22 |
|
| 23 |
+
**Key Finding**: SVD's low-pass filtering property provides complementary benefits under realistic noise conditions (σ ∈ [0, 0.3]), but becomes destructive on texture-rich data. Methods succeed in different regimes based on their optimization objectives, not universally.
|
| 24 |
|
| 25 |
## The Solution: Hybrid SVD-CNN
|
| 26 |
|
|
|
|
| 64 |
### Key Takeaways
|
| 65 |
For full analysis and detailed metrics, see the [Technical Report](./docs/REPORT.md).
|
| 66 |
|
| 67 |
+
1. **The Variance Trap**: SVD's optimization for global pixel variance treats the topological gap distinguishing 3 from 8 as low-variance noise, discarding it during dimensionality reduction. This causes systematic manifold collapse (98.74% k-NN raw pixels → 96.98% in SVD subspace).
|
| 68 |
+
|
| 69 |
+
2. **Mechanistic Proof**: Grad-CAM visualization shows CNN focuses on topological boundaries (the gap), while SVD reconstructs phantom features (a closed loop). UMAP analysis confirms manifold overlap in SVD subspace but separation in raw pixel space.
|
| 70 |
+
|
| 71 |
+
3. **Complementary Strength**: Under realistic Gaussian noise (σ ∈ [0, 0.3]), Hybrid SVD→CNN maintains 90.02% accuracy at σ=0.3 while CNN drops to 95.67%, validating SVD as an adaptive low-pass filter that enables CNN to learn from cleaner input.
|
| 72 |
+
|
| 73 |
+
4. **Data-Dependent Boundary**: On texture-rich Fashion-MNIST, the hybrid approach fails (CNN 89.79% → Hybrid 71.78%) because SVD destroys high-frequency features that distinguish clothing items, proving complementarity requires silhouette-based structure.
|
| 74 |
|
| 75 |
---
|
| 76 |
|
app.py
CHANGED
|
@@ -24,43 +24,104 @@ st.set_page_config(
|
|
| 24 |
initial_sidebar_state="expanded"
|
| 25 |
)
|
| 26 |
|
| 27 |
-
# --- Custom CSS for Clean
|
| 28 |
st.markdown("""
|
| 29 |
<style>
|
| 30 |
-
/*
|
| 31 |
.block-container {
|
| 32 |
-
max-width:
|
| 33 |
-
padding-top:
|
| 34 |
padding-bottom: 5rem;
|
| 35 |
margin: 0 auto;
|
| 36 |
}
|
| 37 |
-
|
| 38 |
/* Premium Typography */
|
| 39 |
-
h1, h2, h3 {
|
| 40 |
-
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
|
| 41 |
-
font-weight: 700;
|
|
|
|
| 42 |
}
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
.stTabs [data-baseweb="tab-list"] {
|
| 46 |
justify-content: center;
|
| 47 |
-
gap:
|
|
|
|
| 48 |
}
|
|
|
|
| 49 |
.stTabs [data-baseweb="tab"] {
|
| 50 |
height: 3rem;
|
| 51 |
white-space: pre-wrap;
|
| 52 |
background-color: transparent;
|
| 53 |
-
border-radius:
|
| 54 |
-
padding-top:
|
| 55 |
-
padding-bottom:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
}
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
.stMetric {
|
| 60 |
-
background-color: #
|
| 61 |
-
padding:
|
| 62 |
-
border-radius:
|
| 63 |
-
border:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
}
|
| 65 |
</style>
|
| 66 |
""", unsafe_allow_html=True)
|
|
@@ -153,29 +214,45 @@ def get_reconstruction(svd_model, img_flat):
|
|
| 153 |
return recons_tensor, clamp_ratio
|
| 154 |
|
| 155 |
|
| 156 |
-
# --- UI Sidebar ---
|
| 157 |
with st.sidebar:
|
| 158 |
-
st.markdown("## Coconut MNIST")
|
|
|
|
| 159 |
st.markdown("---")
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
# Global Noise Control for Hybrid Analysis
|
| 163 |
-
st.
|
| 164 |
noise_mode = st.toggle(
|
| 165 |
-
"Enable SVD Denoiser",
|
| 166 |
-
help="
|
| 167 |
)
|
| 168 |
if noise_mode:
|
| 169 |
-
st.success("SVD Denoising Active")
|
| 170 |
-
|
| 171 |
st.markdown("---")
|
| 172 |
-
st.
|
| 173 |
temp_scaling = st.slider(
|
| 174 |
-
"Softmax Temperature
|
| 175 |
0.1, 5.0, 1.0, 0.1,
|
| 176 |
-
help="Higher T =
|
| 177 |
)
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
# --- Initialization ---
|
| 181 |
X, y_orig, svd_model, cnn_model = get_app_resources()
|
|
@@ -183,8 +260,14 @@ X_flat = X.view(-1, 784)
|
|
| 183 |
|
| 184 |
|
| 185 |
# --- Main Page Header ---
|
| 186 |
-
st.title("Coconut MNIST:
|
| 187 |
-
st.markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
st.markdown("---")
|
| 189 |
|
| 190 |
|
|
@@ -199,8 +282,14 @@ tab1, tab2, tab3, tab4 = st.tabs([
|
|
| 199 |
|
| 200 |
# --- Tab 1: Interpolation (The Story of the 3 vs 8) ---
|
| 201 |
with tab1:
|
| 202 |
-
st.
|
| 203 |
-
st.markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
c1, c2, c3 = st.columns([1, 1, 2])
|
| 206 |
with c1:
|
|
@@ -265,8 +354,17 @@ with tab1:
|
|
| 265 |
|
| 266 |
# --- Tab 2: Robustness (The SVD Advantage) ---
|
| 267 |
with tab2:
|
| 268 |
-
st.
|
| 269 |
-
st.markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
col1, col2 = st.columns([1, 2])
|
| 272 |
with col1:
|
|
@@ -316,8 +414,16 @@ with tab2:
|
|
| 316 |
|
| 317 |
# --- Tab 3: Manifold Explorer (SVD vs UMAP comparison) ---
|
| 318 |
with tab3:
|
| 319 |
-
st.
|
| 320 |
-
st.markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
# Try loading cached data
|
| 323 |
emb_svd_cached, emb_umap_cached, y_sub_cached = load_embeddings()
|
|
@@ -436,8 +542,17 @@ with tab3:
|
|
| 436 |
|
| 437 |
# --- Tab 4: Live Lab ---
|
| 438 |
with tab4:
|
| 439 |
-
st.
|
| 440 |
-
st.markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
# Two modes: sample browser or upload
|
| 443 |
input_mode = st.radio("Input Mode", ["Browse Dataset", "Draw Digit"], horizontal=True)
|
|
|
|
| 24 |
initial_sidebar_state="expanded"
|
| 25 |
)
|
| 26 |
|
| 27 |
+
# --- Custom CSS for Clean Professional Theme (Nord Palette) ---
|
| 28 |
st.markdown("""
|
| 29 |
<style>
|
| 30 |
+
/* Main container */
|
| 31 |
.block-container {
|
| 32 |
+
max-width: 1200px;
|
| 33 |
+
padding-top: 2.5rem;
|
| 34 |
padding-bottom: 5rem;
|
| 35 |
margin: 0 auto;
|
| 36 |
}
|
| 37 |
+
|
| 38 |
/* Premium Typography */
|
| 39 |
+
h1, h2, h3 {
|
| 40 |
+
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
|
| 41 |
+
font-weight: 700;
|
| 42 |
+
letter-spacing: -0.5px;
|
| 43 |
}
|
| 44 |
+
|
| 45 |
+
h1 {
|
| 46 |
+
color: #2E3440;
|
| 47 |
+
margin-bottom: 0.5rem;
|
| 48 |
+
font-size: 2.2rem;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
h2 {
|
| 52 |
+
color: #3B4252;
|
| 53 |
+
margin-top: 1.5rem;
|
| 54 |
+
margin-bottom: 1rem;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
/* Tab styling with Nord colors */
|
| 58 |
.stTabs [data-baseweb="tab-list"] {
|
| 59 |
justify-content: center;
|
| 60 |
+
gap: 2.5rem;
|
| 61 |
+
background-color: transparent;
|
| 62 |
}
|
| 63 |
+
|
| 64 |
.stTabs [data-baseweb="tab"] {
|
| 65 |
height: 3rem;
|
| 66 |
white-space: pre-wrap;
|
| 67 |
background-color: transparent;
|
| 68 |
+
border-radius: 0px;
|
| 69 |
+
padding-top: 0.75rem;
|
| 70 |
+
padding-bottom: 0.75rem;
|
| 71 |
+
border-bottom: 2px solid transparent;
|
| 72 |
+
color: #4C566A;
|
| 73 |
+
font-weight: 600;
|
| 74 |
+
transition: all 0.3s ease;
|
| 75 |
}
|
| 76 |
+
|
| 77 |
+
.stTabs [aria-selected="true"] {
|
| 78 |
+
border-bottom-color: #5E81AC !important;
|
| 79 |
+
color: #5E81AC !important;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
/* Metric Cards with Nord styling */
|
| 83 |
.stMetric {
|
| 84 |
+
background-color: #ECEFF4;
|
| 85 |
+
padding: 1.25rem;
|
| 86 |
+
border-radius: 8px;
|
| 87 |
+
border-left: 4px solid #5E81AC;
|
| 88 |
+
box-shadow: 0 2px 4px rgba(46, 52, 64, 0.1);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
/* Section dividers */
|
| 92 |
+
hr {
|
| 93 |
+
border: 0;
|
| 94 |
+
height: 1px;
|
| 95 |
+
background: linear-gradient(to right, #ECEFF4, #D8DEE9, #ECEFF4);
|
| 96 |
+
margin: 2rem 0;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
/* Info/Warning/Error boxes */
|
| 100 |
+
.stAlert {
|
| 101 |
+
padding: 1.25rem;
|
| 102 |
+
border-radius: 8px;
|
| 103 |
+
border-left: 4px solid;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
/* Image container spacing */
|
| 107 |
+
.stImage {
|
| 108 |
+
margin: 0.75rem 0;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
/* Column spacing */
|
| 112 |
+
.stColumn {
|
| 113 |
+
padding: 0 0.75rem;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
/* Slider and input styling */
|
| 117 |
+
.stSlider > div > div > div {
|
| 118 |
+
background: linear-gradient(to right, #BF616A, #A3BE8C);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
/* Radio button styling */
|
| 122 |
+
.stRadio > label {
|
| 123 |
+
font-weight: 500;
|
| 124 |
+
color: #2E3440;
|
| 125 |
}
|
| 126 |
</style>
|
| 127 |
""", unsafe_allow_html=True)
|
|
|
|
| 214 |
return recons_tensor, clamp_ratio
|
| 215 |
|
| 216 |
|
| 217 |
+
# --- UI Sidebar (Nord Palette) ---
|
| 218 |
with st.sidebar:
|
| 219 |
+
st.markdown("## 🥥 Coconut MNIST")
|
| 220 |
+
st.markdown("**Linear vs. Non-Linear Representations**")
|
| 221 |
st.markdown("---")
|
| 222 |
+
|
| 223 |
+
st.info("""
|
| 224 |
+
**Key Finding:**\n
|
| 225 |
+
SVD optimizes global variance → fails at local topological features (the 3/8 gap).\n
|
| 226 |
+
CNN captures discriminative boundaries → sensitive to noise.
|
| 227 |
+
""")
|
| 228 |
+
|
| 229 |
# Global Noise Control for Hybrid Analysis
|
| 230 |
+
st.markdown("### ⚙️ Experimental Controls")
|
| 231 |
noise_mode = st.toggle(
|
| 232 |
+
"🔧 Enable SVD Denoiser",
|
| 233 |
+
help="Hybrid pipeline: SVD preprocessing → CNN classification"
|
| 234 |
)
|
| 235 |
if noise_mode:
|
| 236 |
+
st.success("✓ SVD Denoising Active", icon="✓")
|
| 237 |
+
|
| 238 |
st.markdown("---")
|
| 239 |
+
st.markdown("### 🎚️ Model Calibration")
|
| 240 |
temp_scaling = st.slider(
|
| 241 |
+
"Softmax Temperature",
|
| 242 |
0.1, 5.0, 1.0, 0.1,
|
| 243 |
+
help="Higher T = smooth probs | Lower T = sharp decision boundaries"
|
| 244 |
)
|
| 245 |
|
| 246 |
+
st.markdown("---")
|
| 247 |
+
st.markdown("### 📊 About This Tool")
|
| 248 |
+
st.caption("""
|
| 249 |
+
**Tabs:**
|
| 250 |
+
- Topology: Decision boundary snap analysis
|
| 251 |
+
- Robustness: Noise filtering comparison
|
| 252 |
+
- Manifold: Linear vs non-linear projections
|
| 253 |
+
- Lab: Interactive testing
|
| 254 |
+
""")
|
| 255 |
+
|
| 256 |
|
| 257 |
# --- Initialization ---
|
| 258 |
X, y_orig, svd_model, cnn_model = get_app_resources()
|
|
|
|
| 260 |
|
| 261 |
|
| 262 |
# --- Main Page Header ---
|
| 263 |
+
st.title("🥥 Coconut MNIST: Why SVD Misclassifies 3 as 8")
|
| 264 |
+
st.markdown("""
|
| 265 |
+
### Mechanistic Analysis: Linear vs. Non-Linear Representations
|
| 266 |
+
|
| 267 |
+
Explore how SVD's global variance optimization and CNN's local feature extraction lead to **complementary strengths and failure modes**.
|
| 268 |
+
|
| 269 |
+
**Start your exploration below →**
|
| 270 |
+
""")
|
| 271 |
st.markdown("---")
|
| 272 |
|
| 273 |
|
|
|
|
| 282 |
|
| 283 |
# --- Tab 1: Interpolation (The Story of the 3 vs 8) ---
|
| 284 |
with tab1:
|
| 285 |
+
st.markdown("### 🔍 Topological Decision Boundaries")
|
| 286 |
+
st.markdown("""
|
| 287 |
+
Smoothly interpolate between two digits and observe:
|
| 288 |
+
- **CNN's behavior**: Sharp phase transition at manifold boundary (topological snap)
|
| 289 |
+
- **SVD's behavior**: Gradual reconstruction error increase (tries to "bridge" manifolds)
|
| 290 |
+
|
| 291 |
+
This reveals the fundamental difference: CNN sees discrete topology, SVD sees continuous variance.
|
| 292 |
+
""")
|
| 293 |
|
| 294 |
c1, c2, c3 = st.columns([1, 1, 2])
|
| 295 |
with c1:
|
|
|
|
| 354 |
|
| 355 |
# --- Tab 2: Robustness (The SVD Advantage) ---
|
| 356 |
with tab2:
|
| 357 |
+
st.markdown("### 🛡️ SVD as Adaptive Denoiser")
|
| 358 |
+
st.markdown("""
|
| 359 |
+
**Key insight**: While SVD fails on clean MNIST (destroys the 3-8 gap), it becomes powerful under noise.
|
| 360 |
+
|
| 361 |
+
**Mechanism**: By keeping only top-20 variance directions, SVD acts as a low-pass filter that:
|
| 362 |
+
- ✓ Preserves class-relevant structure
|
| 363 |
+
- ✓ Suppresses high-frequency Gaussian noise
|
| 364 |
+
- ✗ Cannot recover from information already lost to noise
|
| 365 |
+
|
| 366 |
+
**Trade-off**: SVD + CNN maintains accuracy under moderate noise better than CNN alone.
|
| 367 |
+
""")
|
| 368 |
|
| 369 |
col1, col2 = st.columns([1, 2])
|
| 370 |
with col1:
|
|
|
|
| 414 |
|
| 415 |
# --- Tab 3: Manifold Explorer (SVD vs UMAP comparison) ---
|
| 416 |
with tab3:
|
| 417 |
+
st.markdown("### 📊 Manifold Projection Comparison")
|
| 418 |
+
st.markdown("""
|
| 419 |
+
**Question**: How different are linear (SVD) vs non-linear (UMAP) projections?
|
| 420 |
+
|
| 421 |
+
**Observations**:
|
| 422 |
+
- **SVD (Blue regions)**: Classes overlap → global variance loses local structure
|
| 423 |
+
- **UMAP (Colorful clusters)**: Classes separate → preserves topological neighborhoods
|
| 424 |
+
|
| 425 |
+
This visualizes why CNN (non-linear) works while SVD fails on the 3-8 pair.
|
| 426 |
+
""")
|
| 427 |
|
| 428 |
# Try loading cached data
|
| 429 |
emb_svd_cached, emb_umap_cached, y_sub_cached = load_embeddings()
|
|
|
|
| 542 |
|
| 543 |
# --- Tab 4: Live Lab ---
|
| 544 |
with tab4:
|
| 545 |
+
st.markdown("### 🧪 Interactive Testing")
|
| 546 |
+
st.markdown("""
|
| 547 |
+
**Experiment 1: Dataset Browser**
|
| 548 |
+
- Pick a digit and add Gaussian noise
|
| 549 |
+
- See how SVD denoises before CNN classification
|
| 550 |
+
- Compare predictions with/without SVD preprocessing
|
| 551 |
+
|
| 552 |
+
**Experiment 2: Draw Your Own**
|
| 553 |
+
- Sketch a digit and watch both methods analyze it in real-time
|
| 554 |
+
- Observe the difference between CNN's sharp boundary detection and SVD's smooth reconstruction
|
| 555 |
+
""")
|
| 556 |
|
| 557 |
# Two modes: sample browser or upload
|
| 558 |
input_mode = st.radio("Input Mode", ["Browse Dataset", "Draw Digit"], horizontal=True)
|
docs/REPORT.md
CHANGED
|
@@ -1,129 +1,204 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
---
|
| 8 |
|
| 9 |
-
##
|
| 10 |
|
| 11 |
-
|
| 12 |
-
- **The Mechanism:** Linear models see **Global Energy** (the overall silhouette), while CNNs see **Local Topology** (the gap). SVD literally "welds" the ends of a '3' together to minimize reconstruction error.
|
| 13 |
-
- **The Solution & Boundary: We built a Hybrid SVD→CNN pipeline.** While SVD fails as a standalone classifier, it works as a powerful **low-pass filter** and defensive shield against high noise ($\sigma=0.7$), provided the data isn't too texture-rich (like Fashion-MNIST).
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
Diagnosis Mechanism Solution & Boundary
|
| 21 |
-
───────────────────── ───────────────────── ─────────────────────
|
| 22 |
-
SVD fails on 3 vs 8 → Why? Grad-CAM + UMAP → Hybrid SVD→CNN pipeline
|
| 23 |
-
(The Variance Trap) (Global vs. Local) + Texture stress test
|
| 24 |
-
```
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
##
|
| 29 |
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
Linear dimensionality reduction (SVD) treats classification like a reconstruction problem. It looks for the directions of maximum variance (total pixel brightness). In the cluster of 3s and 8s, the shared "8-like" outline contains the most energy. The small gap that makes a '3' unique is mathematically ignored.
|
| 34 |
|
| 35 |
<p align="center">
|
| 36 |
-
<img src="research_results/
|
| 37 |
-
<
|
|
|
|
| 38 |
<br>
|
| 39 |
-
<em><strong>Figure
|
| 40 |
</p>
|
| 41 |
|
| 42 |
-
-
|
| 43 |
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
|
| 48 |
-
###
|
| 49 |
-
When we interpolate a '3' into an '8', the CNN shows a sharp "snap" in confidence—it recognizes a topological boundary. In contrast, SVD's reconstruction error peaks at the midpoint because it's trying to fit a "linear bridge" between two distinct manifolds.
|
| 50 |
|
| 51 |
<p align="center">
|
| 52 |
-
<img src="research_results/
|
| 53 |
<br>
|
| 54 |
-
<em><strong>Figure
|
| 55 |
</p>
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
<p align="center">
|
| 61 |
-
<img src="research_results/
|
| 62 |
<br>
|
| 63 |
-
<em><strong>Figure
|
| 64 |
</p>
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
- **
|
| 69 |
-
|
| 70 |
-
|
| 71 |
|
| 72 |
<p align="center">
|
| 73 |
-
<img src="research_results/
|
| 74 |
<br>
|
| 75 |
-
<em><strong>Figure
|
| 76 |
</p>
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
---
|
| 79 |
|
| 80 |
-
##
|
| 81 |
|
| 82 |
-
|
| 83 |
|
| 84 |
-
|
| 85 |
|
| 86 |
-
|
| 87 |
|
| 88 |
<p align="center">
|
| 89 |
-
<img src="research_results/fig_06_robustness_mnist_gaussian.png" alt="
|
| 90 |
-
<img src="research_results/fig_07_robustness_mnist_svd_aligned.png" alt="Subspace Risk" width="450" />
|
| 91 |
<br>
|
| 92 |
-
<em><strong>Figure
|
| 93 |
</p>
|
| 94 |
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
## 4.
|
| 98 |
|
| 99 |
-
|
|
|
|
| 100 |
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
<p align="center">
|
| 104 |
-
<img src="research_results/fig_08_robustness_fashion.png" alt="Fashion
|
| 105 |
<br>
|
| 106 |
-
<em><strong>Figure 8:</strong>
|
| 107 |
</p>
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
---
|
| 110 |
|
| 111 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
|
|
|
| 114 |
|
| 115 |
---
|
| 116 |
|
| 117 |
-
##
|
| 118 |
|
| 119 |
-
|
| 120 |
-
Convergence was typically reached within 5-8 epochs using the Adam optimizer.
|
| 121 |
-
<p align="center">
|
| 122 |
-
<img src="research_results/fig_09_learning_curves.png" alt="Learning Curves" width="450" />
|
| 123 |
-
</p>
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
| 1 |
+
# SVD vs CNN on MNIST: A Study of Complementary Representations
|
| 2 |
|
| 3 |
+
## 1. Initial Observation
|
| 4 |
|
| 5 |
+
While implementing SVD-based digit classification on MNIST, we observed systematic confusion between digits 3 and 8. The confusion matrix reveals:
|
| 6 |
+
- Digit 8 misclassified as 3: **3.4%**
|
| 7 |
+
- Digit 3 misclassified as 8: **2.5%**
|
| 8 |
+
|
| 9 |
+
This asymmetric but correlated failure pattern warranted investigation into the fundamental mechanisms driving the two methods' behaviors.
|
| 10 |
+
|
| 11 |
+
<p align="center">
|
| 12 |
+
<img src="research_results/fig_01_svd_confusion.png" alt="SVD Confusion Matrix" width="500" />
|
| 13 |
+
<br>
|
| 14 |
+
<em><strong>Figure 1:</strong> SVD Confusion Matrix (Accuracy: 88.13%). Errors concentrate on visually similar pairs: 3 ↔ 8, 5 ↔ 3, 4 ↔ 9.</em>
|
| 15 |
+
</p>
|
| 16 |
|
| 17 |
---
|
| 18 |
|
| 19 |
+
## 2. Diagnosis: The Variance Trap
|
| 20 |
|
| 21 |
+
### 2.1 Overall Performance (Clean Data)
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
<div align="center">
|
| 24 |
|
| 25 |
+
| Method | Accuracy |
|
| 26 |
+
|--------|----------|
|
| 27 |
+
| SVD | 88.13% |
|
| 28 |
+
| CNN | 98.55% |
|
| 29 |
|
| 30 |
+
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
SVD's 10% accuracy gap is not uniformly distributed. Confusion concentrates on visually ambiguous pairs (as shown in Figure 1):
|
| 33 |
+
- 3 ↔ 8 (2.5% + 3.4%)
|
| 34 |
+
- 5 ↔ 3 (6.4% + 0.9%)
|
| 35 |
+
- 4 ↔ 9 (1.7% + 5.8%)
|
| 36 |
+
|
| 37 |
+
Other digit pairs show error rates < 1.5%.
|
| 38 |
|
| 39 |
+
### 2.2 Root Cause: SVD Optimizes for Global Variance
|
| 40 |
|
| 41 |
+
SVD solves:
|
| 42 |
+
$$X = U \Sigma V^T$$
|
| 43 |
|
| 44 |
+
where $\Sigma$ contains singular values sorted in decreasing order. Truncation to rank $k=20$ retains only the $k$ dimensions with highest variance.
|
|
|
|
| 45 |
|
| 46 |
<p align="center">
|
| 47 |
+
<img src="research_results/fig_01_spectrum.png" alt="Singular Value Spectrum" width="400" />
|
| 48 |
+
<br>
|
| 49 |
+
<img src="research_results/fig_02_eigen_digits.png" alt="Eigen-digits" width="400" />
|
| 50 |
<br>
|
| 51 |
+
<em><strong>Figure 2 & 3:</strong> Left: Singular value decay showing rapid drop after k≈5. Right: First 10 eigen-digits (principal components). SVD reconstructs shared circular silhouettes, smoothing over discriminative gaps.</em>
|
| 52 |
</p>
|
| 53 |
|
| 54 |
+
**The problem**: The topological gap distinguishing 3 from 8 has low pixel variance (few pixels differ). SVD treats it as noise and discards it during dimensionality reduction. The reconstructed 3 appears closer to an 8-like silhouette.
|
| 55 |
|
| 56 |
+
---
|
| 57 |
|
| 58 |
+
## 3. Mechanistic Proof
|
| 59 |
|
| 60 |
+
### 3.1 Grad-CAM Visualization
|
|
|
|
| 61 |
|
| 62 |
<p align="center">
|
| 63 |
+
<img src="research_results/fig_04_explainability.png" alt="Grad-CAM vs SVD Reconstruction" width="700" />
|
| 64 |
<br>
|
| 65 |
+
<em><strong>Figure 4:</strong> Left: CNN Grad-CAM attention (red = high focus). Center: Original digit 3. Right: SVD reconstruction. CNN focuses on the gap; SVD hallucinates a closed loop to minimize reconstruction error.</em>
|
| 66 |
</p>
|
| 67 |
|
| 68 |
+
**CNN** attention heatmap: Focuses exclusively on the topological boundary (gap in digit 3).
|
| 69 |
+
|
| 70 |
+
**SVD** reconstruction: Smooth, closed loop at the 3-8 ambiguity zone, indicating the linear model reconstructs a phantom feature to minimize overall error.
|
| 71 |
+
|
| 72 |
+
### 3.2 UMAP Manifold Analysis
|
| 73 |
|
| 74 |
<p align="center">
|
| 75 |
+
<img src="research_results/fig_05_manifold_collapse.png" alt="Manifold Comparison: Raw vs SVD Subspace" width="600" />
|
| 76 |
<br>
|
| 77 |
+
<em><strong>Figure 5:</strong> Left: UMAP of raw pixel space (3 and 8 clearly separated). Right: UMAP of SVD 20-component subspace (clusters overlap significantly).</em>
|
| 78 |
</p>
|
| 79 |
|
| 80 |
+
- **Raw pixel space**: Digit 3 and 8 clusters are clearly separated (98.74% k-NN accuracy).
|
| 81 |
+
- **SVD 20-component subspace**: Clusters overlap significantly (96.98% k-NN accuracy, 1.76% loss).
|
| 82 |
+
- **Interpretation**: SVD projection collapses the manifold boundaries that discriminate these digits.
|
| 83 |
+
|
| 84 |
+
### 3.3 Interpolation Boundary
|
| 85 |
|
| 86 |
<p align="center">
|
| 87 |
+
<img src="research_results/fig_03_interpolation.png" alt="Decision Boundary Interpolation" width="700" />
|
| 88 |
<br>
|
| 89 |
+
<em><strong>Figure 6:</strong> Interpolating from digit 3 to 8. Top: CNN class probability (sharp transition at manifold boundary). Bottom: SVD reconstruction error (peaks at midpoint where linear model struggles to bridge two manifolds).</em>
|
| 90 |
</p>
|
| 91 |
|
| 92 |
+
Interpolating smoothly from digit 3 to digit 8:
|
| 93 |
+
- **CNN confidence**: Sharp phase transition at the midpoint (topological boundary detected).
|
| 94 |
+
- **SVD reconstruction error**: Peaks at midpoint (linear model struggles to bridge two distinct manifolds).
|
| 95 |
+
|
| 96 |
---
|
| 97 |
|
| 98 |
+
## 4. Complementarity: SVD as Denoising Filter
|
| 99 |
|
| 100 |
+
While SVD fails as a classifier on clean data, its low-pass filtering property reveals complementary benefits under realistic noise conditions.
|
| 101 |
|
| 102 |
+
### 4.1 Robustness Under Gaussian Noise (σ ∈ [0, 0.3])
|
| 103 |
|
| 104 |
+
Test regime: Add Gaussian noise $\mathcal{N}(0, \sigma^2)$ to test images (image range normalized to [0, 1]).
|
| 105 |
|
| 106 |
<p align="center">
|
| 107 |
+
<img src="research_results/fig_06_robustness_mnist_gaussian.png" alt="Robustness: Realistic Gaussian Noise on MNIST" width="500" />
|
|
|
|
| 108 |
<br>
|
| 109 |
+
<em><strong>Figure 7:</strong> Accuracy under Gaussian noise (σ ∈ [0, 0.3]). Hybrid (SVD→CNN) maintains stable performance, outperforming CNN at σ=0.2 and beyond.</em>
|
| 110 |
</p>
|
| 111 |
|
| 112 |
+
<div align="center">
|
| 113 |
+
|
| 114 |
+
| σ | CNN | SVD | Hybrid |
|
| 115 |
+
|---|-----|-----|--------|
|
| 116 |
+
| 0.0 | 98.55% | 88.13% | 91.98% |
|
| 117 |
+
| 0.1 | 98.48% | 87.18% | 91.84% |
|
| 118 |
+
| 0.2 | 97.94% | 86.37% | 91.24% |
|
| 119 |
+
| 0.3 | 95.67% | 80.64% | 90.02% |
|
| 120 |
+
|
| 121 |
+
</div>
|
| 122 |
+
|
| 123 |
+
**Key finding**:
|
| 124 |
+
- **Clean data**: CNN >> Hybrid >> SVD
|
| 125 |
+
- **At σ=0.3**: CNN drops to 95.67%, but Hybrid remains at 90.02%
|
| 126 |
+
- **Hybrid advantage**: Maintains relative stability by filtering noise before feature extraction
|
| 127 |
|
| 128 |
+
### 4.2 Mechanism: Selective Feature Preservation
|
| 129 |
|
| 130 |
+
SVD truncation to rank $k=20$ acts as an adaptive low-pass filter:
|
| 131 |
+
$$\text{Noisy Image} \xrightarrow{\text{SVD Project}} \text{Denoised} \xrightarrow{\text{CNN}} \text{Class}$$
|
| 132 |
|
| 133 |
+
By discarding low-variance dimensions, SVD naturally suppresses high-frequency Gaussian noise while preserving the primary class-discriminative structure. CNN then works with cleaner input.
|
| 134 |
+
|
| 135 |
+
### 4.3 Implication
|
| 136 |
+
|
| 137 |
+
SVD's complementary benefit is **narrowly applicable**: it helps when:
|
| 138 |
+
1. Noise is Gaussian (random, not aligned with data)
|
| 139 |
+
2. Noise level is moderate (σ ≤ 0.3, images still recognizable)
|
| 140 |
+
3. Data is simple/silhouette-based (MNIST works; texture-based data may not)
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
## 5. Boundary: Failure on Texture-Rich Data
|
| 145 |
+
|
| 146 |
+
On **Fashion-MNIST**, SVD's low-pass filtering becomes destructive.
|
| 147 |
|
| 148 |
<p align="center">
|
| 149 |
+
<img src="research_results/fig_08_robustness_fashion.png" alt="Fashion-MNIST: SVD Filter Destroys Textures" width="500" />
|
| 150 |
<br>
|
| 151 |
+
<em><strong>Figure 8:</strong> Performance on Fashion-MNIST under noise (σ ∈ [0, 0.3]). CNN performance degrades rapidly, but SVD (which preserves structure in MNIST) performs even worse than Hybrid here, revealing data-dependent behavior.</em>
|
| 152 |
</p>
|
| 153 |
|
| 154 |
+
**Clean data (σ=0)**:
|
| 155 |
+
|
| 156 |
+
<div align="center">
|
| 157 |
+
|
| 158 |
+
| Method | Accuracy |
|
| 159 |
+
|--------|----------|
|
| 160 |
+
| CNN | 89.79% |
|
| 161 |
+
| SVD | 80.30% |
|
| 162 |
+
| Hybrid | 71.78% |
|
| 163 |
+
|
| 164 |
+
</div>
|
| 165 |
+
|
| 166 |
+
**Why Hybrid fails worst (71.78%)**:
|
| 167 |
+
1. SVD destroys high-frequency textures (buttons, zippers, stitching) that distinguish clothing items
|
| 168 |
+
2. CNN receives a "simplified" image that has already lost class-relevant information
|
| 169 |
+
3. CNN cannot recover from this information loss, performing worse than SVD alone
|
| 170 |
+
|
| 171 |
+
**Implication**: SVD's denoising benefit is restricted to **silhouette-based datasets** where low-frequency structure dominates. On texture-rich data, the hybrid approach becomes a liability.
|
| 172 |
+
|
| 173 |
---
|
| 174 |
|
| 175 |
+
## 6. Summary: Method Applicability by Data Regime
|
| 176 |
+
|
| 177 |
+
<div align="center">
|
| 178 |
+
|
| 179 |
+
| Scenario | Best Choice | Why |
|
| 180 |
+
|----------|------------|-----|
|
| 181 |
+
| **Clean MNIST** | CNN (98.55%) | No noise; SVD's simplification is pure loss |
|
| 182 |
+
| **Noisy MNIST (σ=0.2-0.3)** | Hybrid (91.24%) | SVD filters Gaussian noise; CNN learns from cleaner input |
|
| 183 |
+
| **Clean Fashion-MNIST** | CNN (89.79%) | Textures require non-linear feature extraction |
|
| 184 |
+
| **Texture-rich + Noise** | CNN alone | SVD destroys high-freq features before noise filtering helps |
|
| 185 |
+
|
| 186 |
+
</div>
|
| 187 |
+
|
| 188 |
+
**No universal winner**: Methods succeed in different regimes based on their optimization objectives:
|
| 189 |
|
| 190 |
+
- **SVD** optimizes: Global variance preservation → low-pass filter → stable on silhouette-based data
|
| 191 |
+
- **CNN** optimizes: Discriminative feature learning → sensitive to noise, but powerful on complex data
|
| 192 |
|
| 193 |
---
|
| 194 |
|
| 195 |
+
## 7. Conclusion
|
| 196 |
|
| 197 |
+
This study demonstrates that methodological "limitations" are not flaws but **manifestations of optimization objectives**. SVD and CNN optimize different criteria—global reconstruction vs. local discrimination—leading to complementary failure modes and strengths.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
+
**Key insight**: Understanding a method's optimization target enables **predicting its applicability** rather than treating it as a black box. The choice of method should depend on:
|
| 200 |
+
1. **Data characteristics** (silhouette vs. texture)
|
| 201 |
+
2. **Noise conditions** (Gaussian vs. aligned; moderate vs. extreme)
|
| 202 |
+
3. **Accuracy requirements** (marginal vs. acceptable loss)
|
| 203 |
+
|
| 204 |
+
Rather than seeking universal solutions, practitioners should match methods to specific problem regimes.
|
docs/research_results/fig_01_spectrum.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
docs/research_results/fig_01_svd_confusion.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
docs/research_results/fig_02_eigen_digits.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
docs/research_results/fig_03_interpolation.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
docs/research_results/fig_04_cnn_confusion.png
DELETED
Git LFS Details
|
docs/research_results/fig_05_manifold_collapse.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
docs/research_results/fig_06_robustness_mnist_gaussian.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
docs/research_results/fig_07_robustness_mnist_svd_aligned.png
DELETED
Git LFS Details
|
docs/research_results/fig_08_robustness_fashion.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
docs/research_results/fig_09_learning_curves.png
DELETED
Git LFS Details
|
docs/research_results/fig_10_per_class_metrics_comparison.png
DELETED
Git LFS Details
|
experiments/03_operational_boundaries.py
CHANGED
|
@@ -33,8 +33,8 @@ def run_experiment(args):
|
|
| 33 |
svd_layer = SVDProjectionLayer(svd.components_, scaler.mean_)
|
| 34 |
hybrid = HybridSVDCNN(svd_layer, cnn).to(device)
|
| 35 |
|
| 36 |
-
# 3. Define Noise Levels
|
| 37 |
-
sigmas = [0.0, 0.
|
| 38 |
results = {'CNN': [], 'SVD': [], 'Hybrid': []}
|
| 39 |
|
| 40 |
# 4. Evaluation Loop
|
|
|
|
| 33 |
svd_layer = SVDProjectionLayer(svd.components_, scaler.mean_)
|
| 34 |
hybrid = HybridSVDCNN(svd_layer, cnn).to(device)
|
| 35 |
|
| 36 |
+
# 3. Define Noise Levels (realistic range: σ ∈ [0, 0.3])
|
| 37 |
+
sigmas = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
|
| 38 |
results = {'CNN': [], 'SVD': [], 'Hybrid': []}
|
| 39 |
|
| 40 |
# 4. Evaluation Loop
|
experiments/04_appendix_learning_curves.py
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Appendix A – Learning Curves
|
| 3 |
-
Refactored to use centralized viz utilities.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import pickle
|
| 7 |
-
import os
|
| 8 |
-
from src import config, viz
|
| 9 |
-
|
| 10 |
-
def main():
|
| 11 |
-
experiments = [
|
| 12 |
-
('cnn_10class_history.pkl', 'MNIST 10-class CNN Training', 'fig_09_learning_curves.png'),
|
| 13 |
-
('cnn_fashion_history.pkl', 'Fashion-MNIST CNN Training', 'fig_15_learning_curves_fashion.png')
|
| 14 |
-
]
|
| 15 |
-
|
| 16 |
-
for f_name, label, out_name in experiments:
|
| 17 |
-
path = os.path.join(config.MODELS_DIR, f_name)
|
| 18 |
-
if os.path.exists(path):
|
| 19 |
-
with open(path, 'rb') as f:
|
| 20 |
-
history = pickle.load(f)
|
| 21 |
-
viz.plot_learning_curves(history, label, out_name)
|
| 22 |
-
else:
|
| 23 |
-
print(f"Skipping {f_name}: Not found at {path}.")
|
| 24 |
-
|
| 25 |
-
if __name__ == "__main__":
|
| 26 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/05_appendix_per_class_metrics.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Appendix B – Per-Class Performance Metrics (MNIST)
|
| 3 |
-
Refactored to use centralized utility modules.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import numpy as np
|
| 8 |
-
from src import utils, viz, exp_utils
|
| 9 |
-
|
| 10 |
-
def main():
|
| 11 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 12 |
-
print("Loading Models and Test Data...")
|
| 13 |
-
|
| 14 |
-
# Load Models (MNIST default)
|
| 15 |
-
svd_pipe, cnn = utils.load_models(dataset_name="mnist")
|
| 16 |
-
if svd_pipe is None or cnn is None:
|
| 17 |
-
return
|
| 18 |
-
|
| 19 |
-
X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False)
|
| 20 |
-
X_test_flat = X_test.view(X_test.size(0), -1).numpy()
|
| 21 |
-
y_test_np = y_test.numpy()
|
| 22 |
-
|
| 23 |
-
# 1. Collect Predictions
|
| 24 |
-
print("Collecting Predictions...")
|
| 25 |
-
y_preds_dict = {}
|
| 26 |
-
|
| 27 |
-
# CNN Predictions
|
| 28 |
-
cnn.eval()
|
| 29 |
-
with torch.no_grad():
|
| 30 |
-
y_preds_dict['CNN'] = cnn(X_test.to(device)).argmax(dim=1).cpu().numpy()
|
| 31 |
-
|
| 32 |
-
# SVD+LR Predictions
|
| 33 |
-
print("Fitting SVD Baseline (10-class)...")
|
| 34 |
-
X_train_full, y_train_full = utils.load_data_split(dataset_name="mnist", train=True, flatten=True)
|
| 35 |
-
svd_pipe_fitted = exp_utils.fit_svd_baseline(X_train_full.numpy(), y_train_full.numpy(), n_components=20)
|
| 36 |
-
y_preds_dict['SVD+LR'] = svd_pipe_fitted.predict(X_test_flat)
|
| 37 |
-
|
| 38 |
-
# 2. Print Metrics Report
|
| 39 |
-
from sklearn.metrics import recall_score, precision_score, f1_score
|
| 40 |
-
for name, y_pred in y_preds_dict.items():
|
| 41 |
-
print(f"\n--- {name} Report (Average Metrics) ---")
|
| 42 |
-
p = precision_score(y_test_np, y_pred, average='macro')
|
| 43 |
-
r = recall_score(y_test_np, y_pred, average='macro')
|
| 44 |
-
f = f1_score(y_test_np, y_pred, average='macro')
|
| 45 |
-
print(f"Macro Average: Precision={p:.3f}, Recall={r:.3f}, F1={f:.3f}")
|
| 46 |
-
|
| 47 |
-
# 3. Visualization: Per-Class F1 Comparison
|
| 48 |
-
viz.plot_per_class_comparison(
|
| 49 |
-
y_test_np,
|
| 50 |
-
y_preds_dict,
|
| 51 |
-
'fig_10_per_class_metrics_comparison.png'
|
| 52 |
-
)
|
| 53 |
-
print("Appendix B Completed.")
|
| 54 |
-
|
| 55 |
-
if __name__ == "__main__":
|
| 56 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_migration.sh
DELETED
|
@@ -1,68 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
|
| 3 |
-
# This script performs the renaming of scripts and figures, and updates references in the code and report.
|
| 4 |
-
# Run this from the project root: /Users/ymlin/Downloads/003-Study/137-Projects/01-mnist-linear-vs-nonlinear
|
| 5 |
-
|
| 6 |
-
echo "Starting migration..."
|
| 7 |
-
|
| 8 |
-
# 1. Rename Scripts
|
| 9 |
-
echo "Renaming scripts..."
|
| 10 |
-
mv experiments/01_exp_diagnosis.py experiments/01_phenomenon_diagnosis.py
|
| 11 |
-
mv experiments/02_mechanistic_analysis.py experiments/02_mechanistic_proof.py
|
| 12 |
-
mv experiments/run_robustness_test.py experiments/03_operational_boundaries.py
|
| 13 |
-
mv experiments/appendix_learning_curves.py experiments/04_appendix_learning_curves.py
|
| 14 |
-
mv experiments/appendix_per_class_metrics.py experiments/05_appendix_per_class_metrics.py
|
| 15 |
-
|
| 16 |
-
# 2. Rename Figures
|
| 17 |
-
echo "Renaming figures..."
|
| 18 |
-
cd docs/research_results || exit
|
| 19 |
-
mv fig_02_svd_confusion.png fig_01_svd_confusion.png
|
| 20 |
-
mv fig_03_eigen_digits.png fig_02_eigen_digits.png
|
| 21 |
-
mv fig_05_interpolation.png fig_03_interpolation.png
|
| 22 |
-
mv fig_06_explainability.png fig_04_explainability.png
|
| 23 |
-
mv fig_08_manifold_collapse.png fig_05_manifold_collapse.png
|
| 24 |
-
mv fig_robustness_mnist_gaussian.png fig_06_robustness_mnist_gaussian.png
|
| 25 |
-
mv fig_robustness_mnist_svd_aligned.png fig_07_robustness_mnist_svd_aligned.png
|
| 26 |
-
mv fig_robustness_fashion.png fig_08_robustness_fashion.png
|
| 27 |
-
mv fig_14_learning_curves.png fig_09_learning_curves.png
|
| 28 |
-
mv fig_19_per_class_metrics_comparison.png fig_10_per_class_metrics_comparison.png
|
| 29 |
-
cd ../..
|
| 30 |
-
|
| 31 |
-
# 3. Update Python Scripts (Using sed for macOS)
|
| 32 |
-
echo "Updating Python scripts..."
|
| 33 |
-
|
| 34 |
-
# 01_phenomenon_diagnosis.py
|
| 35 |
-
sed -i '' 's/fig_02_svd_confusion.png/fig_01_svd_confusion.png/g' experiments/01_phenomenon_diagnosis.py
|
| 36 |
-
sed -i '' 's/fig_03_eigen_digits.png/fig_02_eigen_digits.png/g' experiments/01_phenomenon_diagnosis.py
|
| 37 |
-
sed -i '' 's/fig_04_cnn_confusion.png/fig_01b_cnn_confusion.png/g' experiments/01_phenomenon_diagnosis.py
|
| 38 |
-
|
| 39 |
-
# 02_mechanistic_proof.py
|
| 40 |
-
sed -i '' 's/fig_05_interpolation.png/fig_03_interpolation.png/g' experiments/02_mechanistic_proof.py
|
| 41 |
-
sed -i '' 's/fig_06_explainability.png/fig_04_explainability.png/g' experiments/02_mechanistic_proof.py
|
| 42 |
-
sed -i '' 's/fig_08_manifold_collapse.png/fig_05_manifold_collapse.png/g' experiments/02_mechanistic_proof.py
|
| 43 |
-
|
| 44 |
-
# 03_operational_boundaries.py
|
| 45 |
-
sed -i '' 's/fig_robustness_mnist_gaussian.png/fig_06_robustness_mnist_gaussian.png/g' experiments/03_operational_boundaries.py
|
| 46 |
-
sed -i '' 's/fig_robustness_mnist_svd_aligned.png/fig_07_robustness_mnist_svd_aligned.png/g' experiments/03_operational_boundaries.py
|
| 47 |
-
sed -i '' 's/fig_robustness_fashion.png/fig_08_robustness_fashion.png/g' experiments/03_operational_boundaries.py
|
| 48 |
-
|
| 49 |
-
# 04_appendix_learning_curves.py
|
| 50 |
-
sed -i '' 's/fig_14_learning_curves.png/fig_09_learning_curves.png/g' experiments/04_appendix_learning_curves.py
|
| 51 |
-
|
| 52 |
-
# 05_appendix_per_class_metrics.py
|
| 53 |
-
sed -i '' 's/fig_19_per_class_metrics_comparison.png/fig_10_per_class_metrics_comparison.png/g' experiments/05_appendix_per_class_metrics.py
|
| 54 |
-
|
| 55 |
-
# 4. Update Report (Using sed for macOS)
|
| 56 |
-
echo "Updating REPORT.md..."
|
| 57 |
-
sed -i '' 's/fig_02_svd_confusion.png/fig_01_svd_confusion.png/g' docs/REPORT.md
|
| 58 |
-
sed -i '' 's/fig_03_eigen_digits.png/fig_02_eigen_digits.png/g' docs/REPORT.md
|
| 59 |
-
sed -i '' 's/fig_05_interpolation.png/fig_03_interpolation.png/g' docs/REPORT.md
|
| 60 |
-
sed -i '' 's/fig_06_explainability.png/fig_04_explainability.png/g' docs/REPORT.md
|
| 61 |
-
sed -i '' 's/fig_08_manifold_collapse.png/fig_05_manifold_collapse.png/g' docs/REPORT.md
|
| 62 |
-
sed -i '' 's/fig_robustness_mnist_gaussian.png/fig_06_robustness_mnist_gaussian.png/g' docs/REPORT.md
|
| 63 |
-
sed -i '' 's/fig_robustness_mnist_svd_aligned.png/fig_07_robustness_mnist_svd_aligned.png/g' docs/REPORT.md
|
| 64 |
-
sed -i '' 's/fig_robustness_fashion.png/fig_08_robustness_fashion.png/g' docs/REPORT.md
|
| 65 |
-
sed -i '' 's/fig_14_learning_curves.png/fig_09_learning_curves.png/g' docs/REPORT.md
|
| 66 |
-
sed -i '' 's/fig_19_per_class_metrics_comparison.png/fig_10_per_class_metrics_comparison.png/g' docs/REPORT.md
|
| 67 |
-
|
| 68 |
-
echo "Migration completed successfully!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/viz.py
CHANGED
|
@@ -28,20 +28,29 @@ def save_fig(filename, dpi=300):
|
|
| 28 |
print(f"Figure saved to {path}")
|
| 29 |
|
| 30 |
def plot_robustness_curves(x_values, results_dict, x_label, title, filename):
|
| 31 |
-
"""Standardized robustness curve plotter."""
|
| 32 |
setup_style()
|
| 33 |
-
plt.
|
| 34 |
colors = {'CNN': COLOR_CNN, 'SVD': COLOR_SVD, 'Hybrid': COLOR_HYBRID}
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
for label, accs in results_dict.items():
|
| 37 |
-
|
| 38 |
-
color=colors.get(label, '#4C566A'), linewidth=2)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
save_fig(filename)
|
| 46 |
|
| 47 |
def plot_confusion_matrix(y_true, y_pred, labels, filename, title, color_end=COLOR_SVD):
|
|
|
|
| 28 |
print(f"Figure saved to {path}")
|
| 29 |
|
| 30 |
def plot_robustness_curves(x_values, results_dict, x_label, title, filename):
|
| 31 |
+
"""Standardized robustness curve plotter with consistent sizing and styling."""
|
| 32 |
setup_style()
|
| 33 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 34 |
colors = {'CNN': COLOR_CNN, 'SVD': COLOR_SVD, 'Hybrid': COLOR_HYBRID}
|
| 35 |
+
|
| 36 |
+
# Plot with consistent markers and linewidth
|
| 37 |
+
marker_map = {'CNN': 'o', 'SVD': 's', 'Hybrid': '^'}
|
| 38 |
for label, accs in results_dict.items():
|
| 39 |
+
ax.plot(x_values, accs, label=label, marker=marker_map.get(label, 'o'),
|
| 40 |
+
color=colors.get(label, '#4C566A'), linewidth=2.5, markersize=7, alpha=0.85)
|
| 41 |
+
|
| 42 |
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=15)
|
| 43 |
+
ax.set_xlabel(x_label, fontsize=12, fontweight='bold')
|
| 44 |
+
ax.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
|
| 45 |
+
|
| 46 |
+
# Auto-scale y-axis with padding based on data range
|
| 47 |
+
all_values = [val for vals in results_dict.values() for val in vals]
|
| 48 |
+
y_min, y_max = min(all_values), max(all_values)
|
| 49 |
+
y_padding = (y_max - y_min) * 0.1
|
| 50 |
+
ax.set_ylim([max(0, y_min - y_padding), min(1.0, y_max + y_padding)])
|
| 51 |
+
|
| 52 |
+
ax.legend(frameon=True, facecolor='white', framealpha=0.95, fontsize=11, loc='best')
|
| 53 |
+
ax.grid(True, alpha=0.3, linestyle='--')
|
| 54 |
save_fig(filename)
|
| 55 |
|
| 56 |
def plot_confusion_matrix(y_true, y_pred, labels, filename, title, color_end=COLOR_SVD):
|