ymlin105 commited on
Commit
58839b6
·
1 Parent(s): 6849757

feat: initial implementation of MNIST Hybrid SVD-CNN core

Browse files
README.md CHANGED
@@ -10,151 +10,94 @@ app_file: app.py
10
  pinned: false
11
  ---
12
 
13
- # Linear vs. Non-linear Manifold Geometry: A Robustness Analysis
14
 
15
- [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ymlin105/Coconut-MNIST)
16
 
17
- This started from a course assignment on SVD-based MNIST classification. I noticed digit **3 vs 8** was a failure mode, which led to a 9-experiment investigation into *why* linear projections fail and *when* they can still be useful.
18
 
19
- - **Problem** — Linear models like SVD struggle to distinguish handwritten **3s** and **8s** because they prioritize global pixel variance and overlook the "topological gap" that characterizes a 3.
20
- - **Main Question** When does a linear low-rank projection (SVD) improve robustness, and when does it destroy the features a non-linear model needs?
21
- - **What I did** — Diagnosed SVD failure on the **3 vs 8** pair, verified non-linear advantage with a CNN, then built and evaluated a **Hybrid SVD→CNN** pipeline under test-time corruptions.
22
- - **Key Takeaways** — SVD behaves like a low-pass filter; CNN attention is local/non-linear; Hybrid helps under high Gaussian noise on MNIST but fails on texture-heavy Fashion-MNIST.
23
- - **Full technical report →** [REPORT.md](./docs/REPORT.md)
24
 
25
- > **Beyond MNIST:** A denoiser layer is only useful if we can predict its failure modes. In medical imaging, satellite data, and any domain where a linear filter is tempting, knowing the failure boundary prevents accuracy collapse. This question generalizes beyond toy datasets.
26
 
27
- > **Key result:** Hybrid SVD→CNN achieves **+4.8 pp over standalone CNN** at σ=0.7 on MNIST; identifies a **24.6 pp failure boundary** on Fashion-MNIST — pinpointing when linear denoising helps vs. destroys accuracy.
28
 
29
- ![Hybrid SVDCNN Pipeline](docs/research_results/pipeline_diagram.png)
30
- *Figure: The Hybrid SVD→CNN pipeline. SVD reconstruction acts as a data-adapted low-pass filter before CNN feature extraction.*
31
 
32
- ![Geometric Analysis](docs/research_results/fig_06_explainability.png)
33
- *Figure: CNNs (center) focus on the local topological gap, while SVD (right) hallucinates a closed loop to satisfy global variance.*
 
 
 
 
34
 
35
- > **[Try the live demo →](https://huggingface.co/spaces/ymlin105/Coconut-MNIST)** Inject noise/blur in real time and compare SVD vs CNN vs Hybrid predictions side-by-side.
 
 
 
 
 
 
36
 
37
- <details>
38
- <summary><strong>Quick Start & Project Structure</strong></summary>
 
 
39
 
40
- **Tech stack:** Python 3.9 · PyTorch · scikit-learn · Streamlit · UMAP · Plotly
 
 
 
41
 
42
- ```bash
43
- # 1. Create and activate conda environment
44
- conda create -n hybrid-svd python=3.9
45
- conda activate hybrid-svd
46
-
47
- # 2. Install dependencies
48
- pip install -r requirements.txt
49
-
50
- # 3. Train SVD + CNN models (~1 min, now with validation & early stopping!)
51
- python src/train_models.py
52
-
53
- # 4. Launch interactive dashboard
54
- streamlit run app.py
55
-
56
- # Optional: Run additional analysis
57
- python experiments/10_ablation_study.py # Depth vs Non-linearity analysis
58
- python experiments/11_learning_curves.py # Training dynamics visualization
59
- python experiments/12_roc_analysis.py # ROC curves for 3 vs 8
60
- python experiments/13_per_class_metrics.py # Detailed per-class metrics
61
- ```
62
-
63
- > **What's New in v2.0**: Enhanced training with validation set splitting (80/20), early stopping (patience=3), reproducible random seeds (seed=42), and 4 new experiments. Key results: **Non-linearity alone provides +4.91 pp gain** (ablation study), **CNN achieves AUC=1.0 on 3vs8** (ROC analysis), **CNN reduces worst confusion from 6.5% to 2.2%** (per-class metrics). See [`docs/IMPROVEMENTS.md`](docs/IMPROVEMENTS.md) and [`docs/QUICKSTART.md`](docs/QUICKSTART.md) for details.
64
 
 
 
 
 
65
  ```
66
- Project Structure
67
- ├── src/ Core modules: CNN, SVD layer, hybrid pipeline, training
68
- ├── experiments/ 13 self-contained scripts (01–13), ordered by narrative
69
- ├── docs/ Full report (REPORT.md) + 20 figures and JSON metrics
70
- ├── models/ Pretrained checkpoints (CNN for MNIST & Fashion-MNIST)
71
- └── app.py Streamlit dashboard (live demo)
72
- ```
73
-
74
- </details>
75
-
76
- ## Approach
77
-
78
- ```
79
- Diagnosis Mechanism Solution & Boundary
80
- ───────────────────── ───────────────────── ─────────────────────
81
- SVD fails on 3 vs 8 → Why? Grad-CAM + UMAP → Hybrid SVD→CNN pipeline
82
- (Exp 1–3) (Exp 4–7) + Fashion-MNIST stress test
83
- (Exp 8–9)
84
- ```
85
-
86
- The Hybrid pipeline passes each input through a fixed SVD reconstruction (rank $k{=}20$) before the CNN classifier. SVD acts as a data-adapted low-pass filter — suppressing high-frequency noise while retaining structure aligned with the training manifold.
87
-
88
- ## Case Study 1: Success on Low-Rank Manifolds (MNIST)
89
 
90
- > *See [REPORT.md](./docs/REPORT.md#experiment-8-hybrid-architecture-validation-the-solution) for full details.*
 
91
 
92
- On shape-based data, the **Hybrid** architecture acts as a denoiser, filtering noise while preserving structure.
 
 
 
93
 
94
- | Model | Clean | σ=0.3 | σ=0.5 | σ=0.7 |
95
- | ---------------- | ------ | ------ | ------ | ---------------- |
96
- | CNN | 98.74% | 96.36% | 80.44% | 54.34% |
97
- | SVD | 88.12% | 80.87% | 64.60% | 51.30% |
98
- | Blur+CNN | 94.25% | 83.78% | 63.38% | 44.54% |
99
- | **Hybrid** | 91.82% | 88.57% | 79.26% | **59.10%** |
100
-
101
- *Results are mean over 3 random seeds. Bold indicates where Hybrid surpasses the standalone CNN.*
102
-
103
- **Result**: The Hybrid improves over the standalone CNN at high noise ($\sigma=0.7$: 59.10% vs 54.34%), consistent with SVD reconstruction suppressing high-frequency noise before CNN feature extraction. The crossover point where Hybrid surpasses the CNN occurs between $\sigma=0.5$ and $\sigma=0.7$. The Hybrid also outperforms the Gaussian blur baseline (59.10% vs 44.54%), confirming that SVD provides data-adapted denoising beyond generic smoothing.
104
-
105
- ![Robustness Curves](docs/research_results/fig_10_hybrid_robustness.png)
106
- *Figure: Accuracy vs. noise level (σ). The Hybrid (orange) crosses above the standalone CNN (blue) at high noise, confirming the denoising benefit on low-rank data.*
107
-
108
- ## Case Study 2: Failure on Texture-Rich Manifolds (Fashion-MNIST)
109
-
110
- > *See [REPORT.md](./docs/REPORT.md#experiment-9-boundary-analysis-fashion-mnist) for the full boundary analysis.*
111
- > On texture-dependent data, SVD filtering destroys high-frequency details (e.g., collar vs. no collar), causing the Hybrid model to collapse.
112
-
113
- | Model | MNIST (Clean) | Fashion-MNIST (Clean) |
114
- | ---------------- | ------------- | --------------------- |
115
- | **CNN** | 98.74% | **91.04%** |
116
- | **Hybrid** | 91.82% | **67.27%** |
117
-
118
- **Boundary Identified**: This method shows a **~24.6-point** clean-accuracy drop on Fashion-MNIST (texture-dependent), confirming it is primarily suitable for low-rank, shape-defined manifolds.
119
-
120
- <details>
121
- <summary><strong>Geometric Mechanics (why SVD fails and when it helps)</strong></summary>
122
-
123
- - **Feature energy paradox**: discriminative cues can be "low-energy" (gaps, textures) and get wiped out by low-rank projection.
124
- - **Manifold alignment check**: UMAP shows 3/8 are separable when local neighborhood structure is preserved.
125
- - **Subspace denoising**: an SVD reconstruction step can act as a data-adapted low-pass filter before the CNN.
126
-
127
- </details>
128
 
129
- <details>
130
- <summary><strong>Evidence & Reproducibility</strong></summary>
131
 
132
- All figures and metrics are in [`docs/research_results/`](docs/research_results). Each experiment has a single self-contained script in [`experiments/`](experiments/) numbered `01`–`13`, ordered to follow the narrative (diagnosis → mechanism → solution → boundary → validation).
 
 
133
 
134
- <details>
135
- <summary>Full experiment list</summary>
 
 
 
136
 
137
- | # | Script | What it produces |
138
- | ------------ | ---------------------------------- | ------------------------------------------------------ |
139
- | 01 | `phenomenon_discovery.py` | SVD failure analysis + spectrum |
140
- | 02 | `mnist_cnn_confusion.py` | MNIST CNN confusion matrix |
141
- | 03 | `mechanistic_investigation.py` | Interpolation + Grad-CAM vs reconstruction |
142
- | 04 | `robustness_limit.py` | SVD vs CNN degradation curves |
143
- | 05 | `manifold_learning.py` | SVD vs UMAP manifold comparison |
144
- | 06 | `fashion_mnist_baseline.py` | Fashion-MNIST SVD baseline |
145
- | 07 | `fashion_cnn_verification.py` | Fashion-MNIST CNN confusion |
146
- | 08 | `hybrid_robustness.py` | MNIST robustness +`robustness_mnist_noise.json` |
147
- | 09 | `fashion_hybrid_robustness.py` | Fashion robustness +`robustness_fashion_noise.json` |
148
- | **10** | **`ablation_study.py`** | **Depth vs Non-linearity contribution analysis** |
149
- | **11** | **`learning_curves.py`** | **Training/validation dynamics visualization** |
150
- | **12** | **`roc_analysis.py`** | **ROC curves for 3 vs 8 classification** |
151
- | **13** | **`per_class_metrics.py`** | **Precision/Recall/F1 per digit + CSV reports** |
152
 
153
- ## Limitations & Open Questions
 
 
154
 
155
- - **Clean-accuracy penalty**: The Hybrid trades ~7 pp of clean accuracy for high-noise robustness. Can adaptive rank selection ($k$ as a function of input noise estimate) eliminate this penalty?
156
- - **Texture-dependent failure**: The method collapses on Fashion-MNIST. Does this failure boundary generalize to other "texture vs shape" splits (e.g., CIFAR-10, medical imaging)?
157
- - **Fixed rank**: $k{=}20$ was chosen as a round number capturing ~70% variance — not tuned. A learned or input-dependent rank could improve the trade-off.
158
- - **Scope**: This is a mechanistic study on MNIST-scale data, not a production defense. Scaling to higher-resolution images would require rethinking the SVD layer.
 
 
 
 
 
159
 
160
  ---
 
10
  pinned: false
11
  ---
12
 
13
+ # SVD vs CNN: Mechanistic Analysis of Manifold Alignment on MNIST
14
 
15
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ymlin105/Coconut-MNIST) [![Full Report](https://img.shields.io/badge/📖_Read-Full_Report-blue)](./docs/REPORT.md)
16
 
17
+ While it is a known theoretical property that linear dimensionality reduction (SVD) acts as a low-pass filter, this project provides a **concrete, visual, and quantitative mechanistic explanation** of how this property manifests in neural network classification—specifically, why linear subspaces consistently force a "3" to collapse into an "8".
18
 
19
+ <p align="center">
20
+ <img src="./docs/research_results/fig_06_explainability.png" width="600" alt="Mechanistic Analysis of SVD Inductive Bias">
21
+ </p>
 
 
22
 
23
+ By mapping the exact decision boundaries where linear global variance models fail and non-linear topological models (CNNs) succeed, I empirically validate the **inherent trade-offs** of linear denoising in high-stakes domains like medical imaging or satellite data—where a linear filter might suppress critical diagnostic features to minimize noise variance.
24
 
25
+ ## The Solution: Hybrid SVD-CNN
26
 
27
+ I combine SVD's strength as a data-adapted low-pass filter with the CNN's robust feature extraction into a single pipeline.
 
28
 
29
+ ```mermaid
30
+ flowchart TD
31
+ subgraph S1 [I. Noisy Manifold]
32
+ direction LR
33
+ X["Input $X + \eta$"]
34
+ end
35
 
36
+ subgraph S2 [II. Adaptive Projection]
37
+ direction LR
38
+ node_SVD["SVD: $X = U \Sigma V^T$"]
39
+ node_Trunc["$k$-Rank Truncation"]
40
+ node_Recon["$\hat{X} = \sum \sigma_i u_i v_i^T$"]
41
+ node_SVD --> node_Trunc --> node_Recon
42
+ end
43
 
44
+ subgraph S3 [III. CNN Features]
45
+ direction LR
46
+ node_Conv["Conv Layers"] --> node_Pool["Pooling / ReLU"] --> node_Flat["Global Flatten"]
47
+ end
48
 
49
+ subgraph S4 [IV. Latent Mapping]
50
+ direction LR
51
+ node_Soft["Logits / Softmax"] --> node_Pred["Class Prediction"]
52
+ end
53
 
54
+ S1 --> S2
55
+ S2 --> S3
56
+ S3 --> S4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ style S2 fill:#f8f9ff,stroke:#0056b3,stroke-width:2px
59
+ style S3 fill:#f8fff9,stroke:#28a745,stroke-width:2px
60
+ style S1 fill:#fff,stroke:#333
61
+ style S4 fill:#fff,stroke:#333
62
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ ### Key Takeaways
65
+ For full analysis and detailed metrics, see the [Technical Report](./docs/REPORT.md).
66
 
67
+ 1. **The Variance Trap**: Important details (like the gap in a "3") have very little pixel variance. SVD-based linear projections clear them away as noise, forcing distinct digit manifolds to overlap and causing systematic "3-as-8" hallucinations.
68
+ 2. **Local Logic**: UMAP analysis demonstrates that manifolds are topologically distinct when local structure is preserved, but linear variance optimization destroys this neighborhood integrity.
69
+ 3. **Hybrid Advantage**: In high-noise environments ($\sigma=0.7$), a Hybrid architecture acts as a data-adapted denoiser, outperforming standalone CNNs by +4.8 pp.
70
+ 4. **The Boundary**: On texture-rich data (e.g., Fashion-MNIST), SVD reconstruction destroys critical high-frequency features, defining the physical limit of linear denoising.
71
 
72
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ ## Experience it Yourself
 
75
 
76
+ ### Online Demo
77
+ Try the live dashboard to inject noise, adjust SVD rank, and compare model predictions in real-time:
78
+ **[Launch Streamlit App](https://huggingface.co/spaces/ymlin105/Coconut-MNIST)**.
79
 
80
+ ### Local Installation
81
+ ```bash
82
+ # Clone the repository
83
+ git clone https://github.com/ymlin105/mnist-linear-vs-nonlinear.git
84
+ cd mnist-linear-vs-nonlinear
85
 
86
+ # Install dependencies
87
+ pip install -r requirements.txt
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ # Launch the interactive dashboard
90
+ streamlit run app.py
91
+ ```
92
 
93
+ ### Project Structure
94
+ ```
95
+ ├── src/ Core modules (CNN, SVD layer) + Experimental Utils
96
+ ├── experiments/ Theme-based scripts (01 Diagnosis, 02 Analysis, 03 Robustness)
97
+ ├── docs/ Full report (REPORT.md) + figures
98
+ ├── models/ Pretrained checkpoints
99
+ ├── run_all_experiments.sh One-click reproduction script
100
+ └── app.py Streamlit dashboard
101
+ ```
102
 
103
  ---
experiments/01_phenomenon_diagnosis.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Exp 01: Phenomenon Diagnosis
3
+ Combines Global/Focused SVD analysis with CNN baseline comparisons.
4
+ Refactored to use centralized utility modules.
5
+ """
6
+
7
+ import torch
8
+ import numpy as np
9
+ from sklearn.decomposition import TruncatedSVD
10
+ from sklearn.linear_model import LogisticRegression
11
+ from sklearn.metrics import accuracy_score
12
+
13
+ from src import config, utils, viz, exp_utils
14
+
15
+ def run_svd_analysis(X_train, y_train, X_test, y_test):
16
+ print("\n--- Running SVD Spectral Analysis ---")
17
+ mean = np.mean(X_train, axis=0)
18
+ X_centered = X_train - mean
19
+
20
+ n_view = 300
21
+ svd = TruncatedSVD(n_components=n_view, random_state=42)
22
+ svd.fit(X_centered)
23
+
24
+ # 1. Visualization: Spectrum
25
+ viz.plot_singular_spectrum(
26
+ svd.singular_values_,
27
+ np.cumsum(svd.explained_variance_ratio_),
28
+ 'fig_01_spectrum.png'
29
+ )
30
+
31
+ # 2. Classification with k=20
32
+ svd_20 = TruncatedSVD(n_components=20, random_state=42)
33
+ X_train_pca = svd_20.fit_transform(X_train - mean)
34
+ X_test_pca = svd_20.transform(X_test - mean)
35
+
36
+ clf = LogisticRegression(max_iter=1000)
37
+ clf.fit(X_train_pca, y_train)
38
+ y_pred = clf.predict(X_test_pca)
39
+
40
+ acc = accuracy_score(y_test, y_pred)
41
+ print(f"SVD (k=20) Accuracy: {acc*100:.2f}%")
42
+
43
+ # 3. Visualization: Confusion Matrix & Eigen-digits
44
+ viz.plot_confusion_matrix(
45
+ y_test, y_pred, list(range(10)),
46
+ 'fig_02_svd_confusion.png',
47
+ f'SVD Confusion Matrix (Acc={acc:.2f})',
48
+ viz.COLOR_SVD
49
+ )
50
+
51
+ component_titles = [f"Comp {i+1}" for i in range(10)]
52
+ viz.plot_multi_image_grid(
53
+ [c.reshape(28, 28) for c in svd_20.components_[:10]],
54
+ component_titles, 2, 5,
55
+ 'fig_03_eigen_digits.png',
56
+ 'Global SVD Eigen-digits'
57
+ )
58
+
59
+ def run_cnn_baseline(device):
60
+ print("\n--- Running CNN Baseline Diagnosis ---")
61
+ svd_p, cnn = utils.load_models(dataset_name="mnist")
62
+ X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False)
63
+
64
+ acc = exp_utils.evaluate_classifier(cnn, X_test, y_test, device=device, is_pytorch=True)
65
+ print(f"CNN Accuracy: {acc*100:.2f}%")
66
+
67
+ # Predict for confusion matrix
68
+ cnn.eval()
69
+ with torch.no_grad():
70
+ preds = cnn(X_test.to(device)).argmax(dim=1).cpu().numpy()
71
+
72
+ viz.plot_confusion_matrix(
73
+ y_test.numpy(), preds, list(range(10)),
74
+ 'fig_04_cnn_confusion.png',
75
+ f'CNN Confusion Matrix (Acc={acc:.2f})',
76
+ viz.COLOR_CNN
77
+ )
78
+
79
+ def main():
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+
82
+ # Load full MNIST (flattened for SVD)
83
+ X_train, y_train = utils.load_data_split(dataset_name="mnist", train=True, flatten=True)
84
+ X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False, flatten=True)
85
+
86
+ run_svd_analysis(X_train.numpy(), y_train.numpy(), X_test.numpy(), y_test.numpy())
87
+ run_cnn_baseline(device)
88
+ print("\nExperiment 01 Diagnosis Completed.")
89
+
90
+ if __name__ == "__main__":
91
+ main()
experiments/01_phenomenon_discovery.py DELETED
@@ -1,226 +0,0 @@
1
-
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- import seaborn as sns
5
- from sklearn.decomposition import TruncatedSVD
6
- from sklearn.metrics import confusion_matrix, accuracy_score
7
- from sklearn.linear_model import LogisticRegression
8
- import torchvision
9
- import torchvision.transforms as transforms
10
- import torch
11
- import os
12
- from matplotlib.colors import LinearSegmentedColormap
13
- from src import config
14
-
15
- # --- Configuration ---
16
- GRAY_LIGHT = "#D8DEE9"
17
- BLUE_DEEP = "#5E81AC"
18
- ORANGE = "#D08770"
19
-
20
- def load_mnist():
21
- """Load and flatten MNIST data."""
22
- print("Loading MNIST...")
23
- transform = transforms.Compose([transforms.ToTensor()])
24
- # Download to ./data/mnist
25
- trainset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
26
- testset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
27
-
28
- # Flatten: (N, 28, 28) -> (N, 784)
29
- X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
30
- y_train = trainset.targets.numpy()
31
-
32
- X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
33
- y_test = testset.targets.numpy()
34
-
35
- return X_train, y_train, X_test, y_test
36
-
37
- def plot_confusion_matrix(y_true, y_pred, labels, filename, title):
38
- """Draws and saves a confusion matrix (normalized)."""
39
- cm = confusion_matrix(y_true, y_pred, normalize='true') # Normalize by true class (rows)
40
- plt.figure(figsize=(10, 8))
41
- cmap = LinearSegmentedColormap.from_list("NBodyBlue", [GRAY_LIGHT, BLUE_DEEP])
42
- sns.heatmap(cm, annot=True, fmt='.1%', cmap=cmap, xticklabels=labels, yticklabels=labels)
43
- plt.title(title)
44
- plt.xlabel('Predicted')
45
- plt.ylabel('True')
46
- plt.savefig(os.path.join(config.RESULTS_DIR, filename))
47
- plt.close()
48
- print(f"Saved {filename}")
49
-
50
- def plot_eigen_digits(components, filename, title):
51
- """Visualizes the top eigen-digits."""
52
- plt.figure(figsize=(12, 4))
53
- for i in range(min(10, len(components))): # Show top 10
54
- plt.subplot(2, 5, i + 1)
55
- plt.imshow(components[i].reshape(28, 28), cmap='gray')
56
- plt.title(f"Comp {i+1}")
57
- plt.axis('off')
58
- plt.suptitle(title)
59
- plt.savefig(os.path.join(config.RESULTS_DIR, filename))
60
- plt.close()
61
- print(f"Saved {filename}")
62
-
63
- def analyze_spectrum(X, filename, title):
64
- """
65
- Computes and plots the singular value spectrum.
66
- Returns cumulative variance stats.
67
- """
68
- print(f"\nRunning Spectral Analysis on shape {X.shape}...")
69
- # Center the data
70
- X_mean = np.mean(X, axis=0)
71
- X_centered = X - X_mean
72
-
73
- # Compute full SVD (approximation with high k)
74
- n_view = 300 # Increased from 50 to capture >90% variance
75
- svd = TruncatedSVD(n_components=n_view, random_state=42)
76
- svd.fit(X_centered)
77
-
78
- singular_values = svd.singular_values_
79
- explained_variance_ratio = svd.explained_variance_ratio_
80
- cumulative_variance = np.cumsum(explained_variance_ratio)
81
-
82
- # Quantify stats
83
- var_k10 = cumulative_variance[9] * 100
84
-
85
- def get_k(threshold):
86
- idx = np.argmax(cumulative_variance >= threshold)
87
- if cumulative_variance[idx] < threshold: return f">{n_view}"
88
- return idx + 1
89
-
90
- k_90 = get_k(0.90)
91
- k_95 = get_k(0.95)
92
- k_99 = get_k(0.99)
93
-
94
- print(f"Spectral Stats:")
95
- print(f" Variance @ k=10: {var_k10:.2f}%")
96
- print(f" Components for 90% Var: k={k_90}")
97
- print(f" Components for 95% Var: k={k_95}")
98
- print(f" Components for 99% Var: k={k_99}")
99
-
100
- # Plot Scree
101
- fig, ax1 = plt.subplots(figsize=(10, 6))
102
-
103
- color = BLUE_DEEP
104
- ax1.set_xlabel('Principal Component (k)')
105
- ax1.set_ylabel('Singular Value (Log Scale)', color=color)
106
- ax1.semilogy(range(1, n_view+1), singular_values, marker='o', linestyle='-', color=color, markersize=4, label='Singular Values')
107
- ax1.tick_params(axis='y', labelcolor=color)
108
- ax1.grid(True, which="both", ls="-", alpha=0.3)
109
-
110
- ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
111
- color = ORANGE
112
- ax2.set_ylabel('Cumulative Explained Variance', color=color) # we already handled the x-label with ax1
113
- ax2.plot(range(1, n_view+1), cumulative_variance, color=color, linewidth=2, linestyle='--', label='Cumulative Variance')
114
- ax2.tick_params(axis='y', labelcolor=color)
115
- ax2.set_ylim(0, 1.0)
116
-
117
- # Annotate k=10
118
- ax2.axvline(x=10, color='gray', linestyle=':', alpha=0.5)
119
- ax2.text(10.5, 0.4, f'k=10\n({var_k10:.1f}%)', color='black')
120
-
121
- plt.title(title)
122
- plt.savefig(os.path.join(config.RESULTS_DIR, filename))
123
- plt.close()
124
- print(f"Saved {filename}")
125
-
126
- def run_experiment_0(X_train, y_train, X_test, y_test):
127
- """
128
- Experiment 0: Global SVD Analysis (10 Classes)
129
- Hypothesis: 3 and 8 shows significant confusion.
130
- """
131
- print("\n--- Running Experiment 0: Global SVD (10 Classes) ---")
132
-
133
- # 1. SVD Reduction
134
- n_components = 20 # low rank to force dependency on main variance directions
135
- print(f"Reducing dimension to {n_components} using SVD...")
136
- # Mean-center for consistency with hybrid model's SVD layer
137
- mean = np.mean(X_train, axis=0)
138
- X_train_centered = X_train - mean
139
- X_test_centered = X_test - mean
140
- svd = TruncatedSVD(n_components=n_components, random_state=42)
141
- X_train_pca = svd.fit_transform(X_train_centered)
142
- X_test_pca = svd.transform(X_test_centered)
143
-
144
- # 2. Classification (Simple Linear or KNN)
145
- # Using Logistic Regression to simulate linear classification on SVD features
146
- clf = LogisticRegression(max_iter=1000)
147
- clf.fit(X_train_pca, y_train)
148
- y_pred = clf.predict(X_test_pca)
149
-
150
- acc = accuracy_score(y_test, y_pred)
151
- print(f"Global SVD+LR Accuracy: {acc*100:.2f}%")
152
-
153
- # 3. Plot Confusion Matrix
154
- plot_confusion_matrix(y_test, y_pred, list(range(10)),
155
- 'fig_01_global_svd_confusion.png',
156
- f'Global SVD Confusion Matrix (k={n_components}, Acc={acc:.2f})')
157
-
158
- # Analyze specific confusion between 3 and 8
159
- idxs_3 = (y_test == 3)
160
- idxs_8 = (y_test == 8)
161
-
162
- # Confusion 3->8
163
- pred_3 = y_pred[idxs_3]
164
- confused_3_as_8 = np.sum(pred_3 == 8)
165
- print(f"Class 3 samples classified as 8: {confused_3_as_8} / {len(pred_3)} ({confused_3_as_8/len(pred_3)*100:.2f}%)")
166
-
167
- # Confusion 8->3
168
- pred_8 = y_pred[idxs_8]
169
- confused_8_as_3 = np.sum(pred_8 == 3)
170
- print(f"Class 8 samples classified as 3: {confused_8_as_3} / {len(pred_8)} ({confused_8_as_3/len(pred_8)*100:.2f}%)")
171
-
172
- def run_experiment_1(X_train, y_train, X_test, y_test):
173
- """
174
- Experiment 1: Focused 3 vs 8 SVD Analysis
175
- """
176
- print("\n--- Running Experiment 1: Focused SVD (3 vs 8) ---")
177
-
178
- # 1. Filter Data
179
- mask_train = np.logical_or(y_train == 3, y_train == 8)
180
- mask_test = np.logical_or(y_test == 3, y_test == 8)
181
-
182
- X_train_38 = X_train[mask_train]
183
- y_train_38 = y_train[mask_train]
184
- X_test_38 = X_test[mask_test]
185
- y_test_38 = y_test[mask_test]
186
-
187
- print(f"Train samples (3 vs 8): {len(y_train_38)}")
188
- print(f"Test samples (3 vs 8): {len(y_test_38)}")
189
-
190
- # 2. SVD on Subset
191
- n_components = 10
192
- # Mean-center for consistency with hybrid model's SVD layer
193
- mean_38 = np.mean(X_train_38, axis=0)
194
- X_train_38_centered = X_train_38 - mean_38
195
- X_test_38_centered = X_test_38 - mean_38
196
- svd = TruncatedSVD(n_components=n_components, random_state=42)
197
- X_train_pca = svd.fit_transform(X_train_38_centered)
198
- X_test_pca = svd.transform(X_test_38_centered)
199
-
200
- # 3. Classify
201
- clf = LogisticRegression()
202
- clf.fit(X_train_pca, y_train_38)
203
- y_pred = clf.predict(X_test_pca)
204
-
205
- acc = accuracy_score(y_test_38, y_pred)
206
- print(f"Focused SVD(k={n_components}) Accuracy (3 vs 8): {acc*100:.2f}%")
207
-
208
- # 4. Plots
209
- plot_eigen_digits(svd.components_,
210
- 'fig_03_eigen_digits.png',
211
- 'Top 10 Eigen-digits (Principal Components of 3&8)')
212
-
213
- # 5. Spectral Analysis (New)
214
- analyze_spectrum(X_train_38, 'fig_02_scree_plot.png', 'Singular Value Spectrum (3 vs 8 Subset)')
215
-
216
- def main():
217
- X_train, y_train, X_test, y_test = load_mnist()
218
-
219
- run_experiment_0(X_train, y_train, X_test, y_test)
220
- run_experiment_1(X_train, y_train, X_test, y_test)
221
-
222
- print("\nExperiments 0 & 1 Completed.")
223
- print(f"Results saved to {config.RESULTS_DIR}")
224
-
225
- if __name__ == "__main__":
226
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/02_mechanistic_analysis.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Exp 02: Mechanistic Analysis
3
+ Combines Interpolation, Explainability (Grad-CAM), and Quantifying Manifold Collapse.
4
+ Refactored for modularity.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+ from sklearn.decomposition import TruncatedSVD
11
+ from sklearn.neighbors import KNeighborsClassifier
12
+ from sklearn.metrics import accuracy_score
13
+
14
+ from src import config, utils, viz, exp_utils
15
+
16
+ def run_interpolation_analysis(device):
17
+ print("\n--- Running Mechanistic Proof: The Variance vs. Topology Conflict ---")
18
+ X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False, digits=[3, 8])
19
+ _, cnn = utils.load_models(dataset_name="mnist")
20
+
21
+ # Fit SVD baseline for reconstruction analysis
22
+ X_test_flat = X_test.view(X_test.size(0), -1).numpy()
23
+ svd_pipe = exp_utils.fit_svd_baseline(X_test_flat, y_test.numpy(), n_components=10)
24
+
25
+ svd = svd_pipe.named_steps['svd']
26
+ mean = svd_pipe.named_steps['scaler'].mean_
27
+
28
+ # Pick indices for digit 3 and 8
29
+ idx_3 = (y_test == 0).nonzero()[0][0]
30
+ idx_8 = (y_test == 1).nonzero()[0][0]
31
+ img_3, img_8 = X_test[idx_3], X_test[idx_8]
32
+
33
+ alphas = np.linspace(0, 1, 11)
34
+ probs_8, rec_errors = [], []
35
+
36
+ for alpha in alphas:
37
+ img_interp = (1 - alpha) * img_3 + alpha * img_8
38
+ # CNN Probability of class 1 (Digit 8)
39
+ with torch.no_grad():
40
+ logits = cnn(img_interp.unsqueeze(0).to(device))
41
+ # Note: We use index 8 from full model or index 1 if it was binary
42
+ # Here we assume full model but we load 3v8 subset.
43
+ # If model is 10-class, we need to pick actual digit indices.
44
+ # Let's check model output size.
45
+ out_dim = logits.shape[1]
46
+ if out_dim == 10:
47
+ p = torch.softmax(logits, dim=1)[0, 8].item()
48
+ else:
49
+ p = torch.softmax(logits, dim=1)[0, 1].item()
50
+ probs_8.append(p)
51
+
52
+ # SVD Reconstruction Error
53
+ flat = img_interp.view(1, -1).numpy()
54
+ rec = svd.inverse_transform(svd.transform(flat - mean)) + mean
55
+ rec_errors.append(np.linalg.norm(flat - rec))
56
+
57
+ viz.plot_interpolation_dynamics(alphas, probs_8, rec_errors, 'fig_05_interpolation.png')
58
+
59
+ def run_quantifying_manifold_collapse():
60
+ print("\n--- Running Experiment 7: Quantifying Manifold Collapse ---")
61
+ X_train, y_train = utils.load_data_split(dataset_name="mnist", train=True, digits=[3, 8], flatten=True)
62
+ X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False, digits=[3, 8], flatten=True)
63
+
64
+ X_train_np, y_train_np = X_train.numpy(), y_train.numpy()
65
+ X_test_np, y_test_np = X_test.numpy(), y_test.numpy()
66
+
67
+ # 1. k-NN on Raw Pixel Space (784D)
68
+ knn_raw = KNeighborsClassifier(n_neighbors=5)
69
+ knn_raw.fit(X_train_np, y_train_np)
70
+ acc_raw = accuracy_score(y_test_np, knn_raw.predict(X_test_np))
71
+
72
+ # 2. k-NN on SVD-reduced Space (10D)
73
+ svd = TruncatedSVD(n_components=10, random_state=42)
74
+ X_train_svd = svd.fit_transform(X_train_np)
75
+ X_test_svd = svd.transform(X_test_np)
76
+
77
+ knn_svd = KNeighborsClassifier(n_neighbors=5)
78
+ knn_svd.fit(X_train_svd, y_train_np)
79
+ acc_svd = accuracy_score(y_test_np, knn_svd.predict(X_test_svd))
80
+
81
+ # 3. Visualization: SVD vs UMAP
82
+ try:
83
+ import umap
84
+ reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
85
+ X_umap = reducer.fit_transform(X_test_np)
86
+ viz.plot_manifold_comparison(X_test_svd, X_umap, y_test_np, acc_svd, acc_raw, 'fig_08_manifold_collapse.png')
87
+ except Exception as e:
88
+ print(f"Warning: Manifold visualization failed: {e}")
89
+
90
+ print(f"Manifold Collapse Quantification Results:")
91
+ print(f" - Raw 784D k-NN Accuracy: {acc_raw:.4f}")
92
+ print(f" - SVD 10D k-NN Accuracy: {acc_svd:.4f}")
93
+ print(f" - Accuracy Loss: {acc_raw - acc_svd:.4f}")
94
+
95
+ def main():
96
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
+ run_interpolation_analysis(device)
98
+ run_quantifying_manifold_collapse()
99
+ print("\nExperiment 02 Mechanistic Analysis (Refined) Completed.")
100
+
101
+ if __name__ == "__main__":
102
+ main()
experiments/02_mnist_cnn_confusion.py DELETED
@@ -1,68 +0,0 @@
1
- # Exp 02 – MNIST 10-class CNN confusion matrix
2
-
3
- import os
4
-
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- import seaborn as sns
8
- import torch
9
- from sklearn.metrics import confusion_matrix, accuracy_score
10
- from torchvision import datasets, transforms
11
- from matplotlib.colors import LinearSegmentedColormap
12
-
13
- from src.hybrid_model import SimpleCNN
14
- from src import config
15
-
16
- GRAY_LIGHT = "#D8DEE9"
17
- BLUE_LIGHT = "#88C0D0"
18
-
19
-
20
- def load_mnist_test():
21
- transform = transforms.Compose([transforms.ToTensor()])
22
- testset = datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
23
- X_test = testset.data.float() / 255.0
24
- y_test = testset.targets
25
- X_test = X_test.unsqueeze(1) # (N, 1, 28, 28)
26
- return X_test, y_test
27
-
28
-
29
- def main():
30
- # Load model
31
- model = SimpleCNN(num_classes=10)
32
- model.load_state_dict(torch.load(config.CNN_MODEL_PATH, map_location="cpu"))
33
- model.eval()
34
-
35
- # Load data
36
- X_test, y_test = load_mnist_test()
37
-
38
- with torch.no_grad():
39
- logits = model(X_test)
40
- preds = torch.argmax(logits, dim=1)
41
-
42
- y_true = y_test.numpy()
43
- y_pred = preds.numpy()
44
-
45
- acc = accuracy_score(y_true, y_pred)
46
- print(f"MNIST CNN Accuracy: {acc*100:.2f}%")
47
-
48
- cm = confusion_matrix(y_true, y_pred, normalize="true")
49
-
50
- plt.figure(figsize=(10, 8))
51
- cmap = LinearSegmentedColormap.from_list(
52
- "NBodyCNN",
53
- [GRAY_LIGHT, BLUE_LIGHT],
54
- )
55
- sns.heatmap(cm, annot=True, fmt=".1%", cmap=cmap, xticklabels=list(range(10)), yticklabels=list(range(10)))
56
- plt.title(f"MNIST CNN Confusion Matrix (Acc={acc:.2%})")
57
- plt.xlabel("Predicted")
58
- plt.ylabel("True")
59
- plt.tight_layout()
60
-
61
- out_path = os.path.join(config.RESULTS_DIR, "fig_04_mnist_cnn_confusion.png")
62
- plt.savefig(out_path, dpi=300)
63
- plt.close()
64
- print(f"Saved {out_path}")
65
-
66
-
67
- if __name__ == "__main__":
68
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/03_mechanistic_investigation.py DELETED
@@ -1,244 +0,0 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- from torch.utils.data import TensorDataset, DataLoader
6
- import numpy as np
7
- import matplotlib.pyplot as plt
8
- from sklearn.decomposition import TruncatedSVD
9
- import cv2
10
- import os
11
- import ssl
12
- import torchvision
13
-
14
- from src.hybrid_model import SimpleCNN
15
- from src import config
16
-
17
- # --- Configuration ---
18
- BLUE_LIGHT = "#88C0D0"
19
- BLUE_DEEP = "#5E81AC"
20
- BATCH_SIZE = 64
21
- EPOCHS = 5
22
-
23
-
24
- def load_mnist_38():
25
- """Load MNIST and filter for 3 vs 8."""
26
- ssl._create_default_https_context = ssl._create_unverified_context
27
- transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
28
- trainset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
29
- testset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
30
-
31
- def filter_38(dataset):
32
- mask = (dataset.targets == 3) | (dataset.targets == 8)
33
- data = dataset.data[mask].unsqueeze(1).float() / 255.0
34
- targets = dataset.targets[mask]
35
- targets = torch.where(targets == 3, torch.tensor(0), torch.tensor(1))
36
- return data, targets
37
-
38
- X_train, y_train = filter_38(trainset)
39
- X_test, y_test = filter_38(testset)
40
- return X_train, y_train, X_test, y_test
41
-
42
-
43
- # --- Training Helper ---
44
- def train_model(X_train, y_train):
45
- model = SimpleCNN(num_classes=2)
46
- criterion = nn.CrossEntropyLoss()
47
- optimizer = optim.Adam(model.parameters(), lr=0.001)
48
- dataset = TensorDataset(X_train, y_train)
49
- loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
50
-
51
- print("Training CNN for Analysis...")
52
- model.train()
53
- for epoch in range(EPOCHS):
54
- for inputs, labels in loader:
55
- optimizer.zero_grad()
56
- outputs = model(inputs)
57
- loss = criterion(outputs, labels)
58
- loss.backward()
59
- optimizer.step()
60
- return model
61
-
62
- # --- Experiment 3: Interpolation ---
63
- def run_interpolation_analysis(model, svd, X_test, y_test, svd_mean=None):
64
- print("\n--- Running Exp 3: Interpolation Analysis ---")
65
-
66
- # 1. Find a good pairs of 3 and 8
67
- # We want a 'canonical' 3 and 8
68
- idx_3 = (y_test == 0).nonzero(as_tuple=True)[0][0]
69
- idx_8 = (y_test == 1).nonzero(as_tuple=True)[0][0]
70
-
71
- img_3 = X_test[idx_3] # (1, 28, 28)
72
- img_8 = X_test[idx_8]
73
-
74
- # 2. Generate Interpolation steps
75
- alphas = np.linspace(0, 1, 11)
76
- interpolated_imgs = []
77
- cnn_probs_8 = []
78
- svd_errors = []
79
-
80
- print("Computing metrics along interpolation path...")
81
- for alpha in alphas:
82
- # Linear Interpolation
83
- img_interp = (1 - alpha) * img_3 + alpha * img_8
84
- interpolated_imgs.append(img_interp.squeeze().numpy())
85
-
86
- # CNN Prediction
87
- with torch.no_grad():
88
- img_tensor = img_interp.unsqueeze(0) # (1, 1, 28, 28)
89
- logits = model(img_tensor)
90
- probs = torch.softmax(logits, dim=1)
91
- cnn_probs_8.append(probs[0, 1].item())
92
-
93
- # SVD Reconstruction (using the passed svd model)
94
- # SVD expects (N, 784), mean-centered to match training
95
- img_flat = img_interp.view(1, -1).numpy()
96
- img_centered = img_flat - svd_mean if svd_mean is not None else img_flat
97
- img_pca = svd.transform(img_centered)
98
- img_rec = svd.inverse_transform(img_pca)
99
- if svd_mean is not None:
100
- img_rec = img_rec + svd_mean
101
-
102
- # Reconstruction Error (L2)
103
- rec_err = np.linalg.norm(img_flat - img_rec)
104
- svd_errors.append(rec_err)
105
-
106
- # 3. Plotting
107
- plt.figure(figsize=(12, 6))
108
-
109
- # Plot Images
110
- for i, img in enumerate(interpolated_imgs):
111
- plt.subplot(3, 11, i + 1)
112
- plt.imshow(img, cmap='gray')
113
- plt.axis('off')
114
- if i == 0: plt.title("Start (3)")
115
- if i == 10: plt.title("End (8)")
116
-
117
- # Plot Curves
118
- plt.subplot(3, 1, 2)
119
- plt.plot(alphas, cnn_probs_8, marker='o', color=BLUE_LIGHT, label='CNN Prob(Class=8)')
120
- plt.ylabel('CNN Probability')
121
- plt.grid(True)
122
- plt.legend()
123
-
124
- plt.subplot(3, 1, 3)
125
- plt.plot(alphas, svd_errors, marker='s', color=BLUE_DEEP, label='SVD Rec. Error')
126
- plt.xlabel('Interpolation Alpha (0=3, 1=8)')
127
- plt.ylabel('SVD Error')
128
- plt.grid(True)
129
- plt.legend()
130
-
131
- plt.tight_layout()
132
- plt.savefig(os.path.join(config.RESULTS_DIR, 'fig_05_interpolation_analysis.png'))
133
- plt.close()
134
- print("Saved fig_05_interpolation_analysis.png")
135
-
136
- return interpolated_imgs # Return for Exp 4 use
137
-
138
- # --- Experiment 4: Explainability (Grad-CAM vs SVD) ---
139
- def run_explainability_analysis(model, svd, interpolated_imgs, svd_mean=None):
140
- print("\n--- Running Exp 4: Explainability Analysis ---")
141
-
142
- # Pick the middle ambiguous image (alpha=0.5)
143
- middle_idx = 5
144
- img_ambiguous = torch.tensor(interpolated_imgs[middle_idx]).unsqueeze(0).unsqueeze(0) # (1, 1, 28, 28)
145
-
146
- # 1. CNN Grad-CAM
147
- # Hook into last conv layer
148
- gradients = []
149
- activations = []
150
-
151
- def backward_hook(module, grad_input, grad_output):
152
- gradients.append(grad_output[0])
153
-
154
- def forward_hook(module, input, output):
155
- activations.append(output)
156
-
157
- # Register hooks on conv2
158
- handle_b = model.conv2.register_full_backward_hook(backward_hook)
159
- handle_f = model.conv2.register_forward_hook(forward_hook)
160
-
161
- # Forward & Backward
162
- model.eval()
163
- logits = model(img_ambiguous)
164
- # Target class 8 (index 1) for visualization
165
- logits[0, 1].backward()
166
-
167
- # Generate Heatmap
168
- grads = gradients[0].cpu().data.numpy()[0] # (32, 7, 7)
169
- fmaps = activations[0].cpu().data.numpy()[0] # (32, 7, 7)
170
-
171
- weights = np.mean(grads, axis=(1, 2)) # Global Average Pooling
172
- cam = np.zeros(fmaps.shape[1:], dtype=np.float32)
173
-
174
- for i, w in enumerate(weights):
175
- cam += w * fmaps[i]
176
-
177
- cam = np.maximum(cam, 0)
178
- cam = cv2.resize(cam, (28, 28))
179
- cam = cam - np.min(cam)
180
- cam = cam / np.max(cam)
181
-
182
- # 2. SVD Reconstruction
183
- img_flat = img_ambiguous.view(1, -1).numpy()
184
- img_centered = img_flat - svd_mean if svd_mean is not None else img_flat
185
- img_pca = svd.transform(img_centered)
186
- img_rec = svd.inverse_transform(img_pca)
187
- if svd_mean is not None:
188
- img_rec = img_rec + svd_mean
189
- img_rec = img_rec.reshape(28, 28)
190
-
191
- # 3. Plot Comparison
192
- plt.figure(figsize=(10, 4))
193
-
194
- # Original Ambiguous
195
- plt.subplot(1, 3, 1)
196
- plt.imshow(img_ambiguous.squeeze(), cmap='gray')
197
- plt.title("Ambiguous Input (Alpha=0.5)")
198
- plt.axis('off')
199
-
200
- # CNN Attention
201
- plt.subplot(1, 3, 2)
202
- plt.imshow(img_ambiguous.squeeze(), cmap='gray')
203
- plt.imshow(cam, cmap='jet', alpha=0.5) # Overlay
204
- plt.title("CNN Attention (Grad-CAM)")
205
- plt.axis('off')
206
-
207
- # SVD Reconstruction
208
- plt.subplot(1, 3, 3)
209
- plt.imshow(img_rec, cmap='gray')
210
- plt.title("SVD Reconstruction")
211
- plt.axis('off')
212
-
213
- plt.tight_layout()
214
- plt.savefig(os.path.join(config.RESULTS_DIR, 'fig_06_explainability.png'))
215
- plt.close()
216
- print("Saved fig_06_explainability.png")
217
-
218
- handle_b.remove()
219
- handle_f.remove()
220
-
221
- def main():
222
- # Load Data
223
- X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor = load_mnist_38()
224
-
225
- # Train CNN
226
- cnn_model = train_model(X_train_tensor, y_train_tensor)
227
-
228
- # Fit SVD (on train data 3 vs 8)
229
- print("Fitting SVD on 3 vs 8...")
230
- X_train_np = X_train_tensor.view(-1, 784).numpy()
231
- # Mean-center for consistency with hybrid model's SVD layer
232
- svd_mean = np.mean(X_train_np, axis=0)
233
- X_train_centered = X_train_np - svd_mean
234
- svd = TruncatedSVD(n_components=10, random_state=42)
235
- svd.fit(X_train_centered)
236
-
237
- # Run Experiments
238
- interp_imgs = run_interpolation_analysis(cnn_model, svd, X_test_tensor, y_test_tensor, svd_mean)
239
- run_explainability_analysis(cnn_model, svd, interp_imgs, svd_mean)
240
-
241
- print("\nDeep Dive Analysis Completed.")
242
-
243
- if __name__ == "__main__":
244
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/04_robustness_limit.py DELETED
@@ -1,187 +0,0 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- from torch.utils.data import TensorDataset, DataLoader
6
- import numpy as np
7
- import matplotlib.pyplot as plt
8
- from sklearn.decomposition import TruncatedSVD
9
- from sklearn.linear_model import LogisticRegression
10
- import torchvision
11
- import torchvision.transforms as transforms
12
- import os
13
-
14
- from src import config
15
-
16
- # --- Configuration ---
17
- BLUE_LIGHT = "#88C0D0"
18
- BLUE_DEEP = "#5E81AC"
19
- BATCH_SIZE = 64
20
-
21
- # --- Model (Same as before) ---
22
- from src.hybrid_model import SimpleCNN
23
-
24
- # --- Comparison Evaluation ---
25
-
26
- def load_data():
27
- transform = transforms.Compose([transforms.ToTensor()])
28
- trainset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
29
- testset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
30
-
31
- def filter_38(dataset):
32
- mask = (dataset.targets == 3) | (dataset.targets == 8)
33
- data = dataset.data[mask].unsqueeze(1).float() / 255.0
34
- targets = dataset.targets[mask]
35
- targets = torch.where(targets == 3, torch.tensor(0), torch.tensor(1))
36
- return data, targets
37
-
38
- X_train, y_train = filter_38(trainset)
39
- X_test, y_test = filter_38(testset)
40
- return X_train, y_train, X_test, y_test
41
-
42
- def add_noise(images, noise_level):
43
- """Add Gaussian noise."""
44
- noise = torch.randn_like(images) * noise_level
45
- noisy_imgs = images + noise
46
- return torch.clamp(noisy_imgs, 0, 1)
47
-
48
- def add_blur(images, kernel_size):
49
- """Add Gaussian blur."""
50
- if kernel_size <= 1: return images
51
- blur_fn = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(0.1 + 0.3 * (kernel_size//2)))
52
- return blur_fn(images)
53
-
54
- def evaluate_models(cnn, svd_clf, svd_transform, X_test, y_test, svd_mean=None):
55
- # CNN Eval
56
- cnn.eval()
57
- with torch.no_grad():
58
- logits = cnn(X_test)
59
- preds = torch.argmax(logits, dim=1)
60
- cnn_acc = (preds == y_test).float().mean().item()
61
-
62
- # SVD Eval
63
- X_flat = X_test.view(X_test.size(0), -1).numpy()
64
- if svd_mean is not None:
65
- X_flat = X_flat - svd_mean
66
- X_pca = svd_transform.transform(X_flat)
67
- y_pred_svd = svd_clf.predict(X_pca)
68
- svd_acc = np.mean(y_pred_svd == y_test.numpy())
69
-
70
- return cnn_acc, svd_acc
71
-
72
- def main():
73
- print("Loading Data...")
74
- X_train, y_train, X_test, y_test = load_data()
75
-
76
- # 1. Train Models (Quickly)
77
- print("Training CNN Baseline...")
78
- cnn = SimpleCNN(num_classes=2)
79
- optimizer = optim.Adam(cnn.parameters(), lr=0.001)
80
- criterion = nn.CrossEntropyLoss()
81
- dataset = TensorDataset(X_train, y_train)
82
- loader = DataLoader(dataset, batch_size=64, shuffle=True)
83
-
84
- cnn.train()
85
- for _ in range(3): # 3 Epochs enough for 99%
86
- for x, y in loader:
87
- optimizer.zero_grad()
88
- loss = criterion(cnn(x), y)
89
- loss.backward()
90
- optimizer.step()
91
-
92
- print("Training SVD Baseline...")
93
- X_train_flat = X_train.view(X_train.size(0), -1).numpy()
94
- # Mean-center for consistency with hybrid model's SVD layer
95
- svd_mean = np.mean(X_train_flat, axis=0)
96
- X_train_centered = X_train_flat - svd_mean
97
- svd = TruncatedSVD(n_components=10, random_state=42)
98
- X_train_pca = svd.fit_transform(X_train_centered)
99
-
100
- clf = LogisticRegression(max_iter=500)
101
- clf.fit(X_train_pca, y_train.numpy())
102
-
103
- # 2. Noise Experiment
104
- noise_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
105
- cnn_accs_noise = []
106
- svd_accs_noise = []
107
-
108
- print("Running Noise Experiments...")
109
- for nl in noise_levels:
110
- X_test_noisy = add_noise(X_test, nl)
111
- ca, sa = evaluate_models(cnn, clf, svd, X_test_noisy, y_test, svd_mean)
112
- cnn_accs_noise.append(ca)
113
- svd_accs_noise.append(sa)
114
- print(f"Noise {nl}: CNN={ca:.2f}, SVD={sa:.2f}")
115
-
116
- # 3. Blur Experiment
117
- blur_kernels = [1, 3, 5, 7, 9, 11]
118
- cnn_accs_blur = []
119
- svd_accs_blur = []
120
-
121
- print("Running Blur Experiments...")
122
- for k in blur_kernels:
123
- X_test_blur = add_blur(X_test, k)
124
- ca, sa = evaluate_models(cnn, clf, svd, X_test_blur, y_test, svd_mean)
125
- cnn_accs_blur.append(ca)
126
- svd_accs_blur.append(sa)
127
- print(f"Blur K={k}: CNN={ca:.2f}, SVD={sa:.2f}")
128
-
129
- # 4. Plots
130
- plt.figure(figsize=(12, 5))
131
-
132
- # Noise Plot
133
- plt.subplot(1, 2, 1)
134
- plt.plot(noise_levels, cnn_accs_noise, marker='o', color=BLUE_LIGHT, label='CNN')
135
- plt.plot(noise_levels, svd_accs_noise, marker='s', color=BLUE_DEEP, label='SVD')
136
- plt.xlabel(r'Gaussian Noise ($\sigma$)')
137
- plt.ylabel('Accuracy')
138
- plt.title('Robustness vs Noise')
139
- plt.legend()
140
- plt.grid(True)
141
-
142
- # Blur Plot
143
- plt.subplot(1, 2, 2)
144
- plt.plot(blur_kernels, cnn_accs_blur, marker='o', color=BLUE_LIGHT, label='CNN')
145
- plt.plot(blur_kernels, svd_accs_blur, marker='s', color=BLUE_DEEP, label='SVD')
146
- plt.xlabel('Blur Kernel Size')
147
- plt.ylabel('Accuracy')
148
- plt.title('Robustness vs Blur')
149
- plt.legend()
150
- plt.grid(True)
151
-
152
- plt.savefig(os.path.join(config.RESULTS_DIR, 'fig_07_degradation_curves.png'))
153
- plt.close()
154
-
155
- # 5. Visualizing Breakdown
156
- # Find a sample that CNN gets right at noise=0.1 but wrong at noise=0.5
157
- print("Generating Breakdown Visuals...")
158
- noise_high = 0.6
159
- X_test_high_noise = add_noise(X_test, noise_high)
160
-
161
- cnn.eval()
162
- logits = cnn(X_test_high_noise)
163
- preds = torch.argmax(logits, dim=1)
164
-
165
- # Find failures
166
- failures = (preds != y_test).nonzero(as_tuple=True)[0]
167
- if len(failures) > 0:
168
- idx = failures[0]
169
- img_clean = X_test[idx].squeeze()
170
- img_noisy = X_test_high_noise[idx].squeeze()
171
-
172
- plt.figure(figsize=(8, 4))
173
- plt.subplot(1, 2, 1)
174
- plt.imshow(img_clean, cmap='gray')
175
- plt.title(f"Clean (True: {y_test[idx]})")
176
-
177
- plt.subplot(1, 2, 2)
178
- plt.imshow(img_noisy, cmap='gray')
179
- plt.title(f"Noisy $\\sigma$={noise_high}\nCNN Pred: {preds[idx]}")
180
-
181
- plt.savefig(os.path.join(config.RESULTS_DIR, 'fig_08_breakdown_point.png'))
182
- plt.close()
183
-
184
- print("Experiment 5 Completed.")
185
-
186
- if __name__ == "__main__":
187
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/05_manifold_learning.py DELETED
@@ -1,103 +0,0 @@
1
-
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- import seaborn as sns
5
- from sklearn.decomposition import TruncatedSVD
6
- from sklearn.metrics import silhouette_score
7
- import umap
8
- import torch
9
- import torchvision
10
- from torchvision import transforms
11
- import os
12
-
13
- from matplotlib.colors import ListedColormap
14
- from src import config
15
-
16
- BLUE_DEEP = "#5E81AC"
17
- ORANGE = "#D08770"
18
-
19
- # Configure styling
20
- sns.set_style("whitegrid")
21
- plt.rcParams.update({'font.size': 12})
22
-
23
- def load_data():
24
- """Load and filter MNIST for digits 3 and 8."""
25
- print("Loading MNIST data...")
26
- # Fix for Mac SSL certificate issue
27
- import ssl
28
- ssl._create_default_https_context = ssl._create_unverified_context
29
-
30
- transform = transforms.Compose([transforms.ToTensor()])
31
- # Try loading without download first
32
- try:
33
- dataset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=True, download=False, transform=transform)
34
- except:
35
- dataset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
36
-
37
- # Filter for 3 and 8
38
- idx = (dataset.targets == 3) | (dataset.targets == 8)
39
- dataset.targets = dataset.targets[idx]
40
- dataset.data = dataset.data[idx]
41
-
42
- # Flatten images (N, 784)
43
- X = dataset.data.numpy().reshape(-1, 28*28).astype(np.float32) / 255.0
44
- y = dataset.targets.numpy()
45
-
46
- print(f"Dataset loaded: {X.shape} samples (Classes: {np.unique(y)})")
47
- return X, y
48
-
49
- def run_experiment():
50
- X, y = load_data()
51
-
52
- # Subsample for UMAP speed if necessary (though MNIST 3/8 is small enough ~12k samples)
53
- # We'll use full set for accurate density
54
-
55
- print("\n--- Running SVD Projection (Linear) ---")
56
- # Mean-center for consistency with project convention
57
- X_centered = X - X.mean(axis=0)
58
- svd = TruncatedSVD(n_components=2, random_state=42)
59
- X_svd = svd.fit_transform(X_centered)
60
- sil_svd = silhouette_score(X_svd, y)
61
- print(f"SVD Silhouette Score: {sil_svd:.4f}")
62
-
63
- print("\n--- Running UMAP Projection (Non-linear) ---")
64
- reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
65
- X_umap = reducer.fit_transform(X)
66
- sil_umap = silhouette_score(X_umap, y)
67
- print(f"UMAP Silhouette Score: {sil_umap:.4f}")
68
-
69
- # Plotting
70
- print("\nGenerating Comparison Plot...")
71
- fig, axes = plt.subplots(1, 2, figsize=(16, 7))
72
-
73
- # Custom cmap for 3 and 8
74
- # 3 is lower index, 8 is higher. We map 3 -> blue_deep, 8 -> orange
75
- cmap = ListedColormap([BLUE_DEEP, ORANGE])
76
-
77
- # Plot SVD
78
- scatter_svd = axes[0].scatter(X_svd[:, 0], X_svd[:, 1], c=y, cmap=cmap, alpha=0.5, s=2)
79
- axes[0].set_title(f"Linear Projection (SVD)\nSilhouette Score: {sil_svd:.3f}", fontsize=14)
80
- axes[0].set_xlabel("PC 1")
81
- axes[0].set_ylabel("PC 2")
82
-
83
- # Plot UMAP
84
- scatter_umap = axes[1].scatter(X_umap[:, 0], X_umap[:, 1], c=y, cmap=cmap, alpha=0.5, s=2)
85
- axes[1].set_title(f"Manifold Learning (UMAP)\nSilhouette Score: {sil_umap:.3f}", fontsize=14)
86
- axes[1].set_xlabel("UMAP 1")
87
- axes[1].set_ylabel("UMAP 2")
88
-
89
- # Add legend
90
- legend1 = axes[0].legend(*scatter_svd.legend_elements(), title="Digits")
91
- axes[0].add_artist(legend1)
92
- legend2 = axes[1].legend(*scatter_umap.legend_elements(), title="Digits")
93
- axes[1].add_artist(legend2)
94
-
95
- plt.tight_layout()
96
- save_path = os.path.join(config.RESULTS_DIR, "fig_09_manifold_comparison.png")
97
- plt.savefig(save_path, dpi=150)
98
- print(f"Plot saved to: {save_path}")
99
-
100
- return sil_svd, sil_umap
101
-
102
- if __name__ == "__main__":
103
- run_experiment()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/06_fashion_mnist_baseline.py DELETED
@@ -1,115 +0,0 @@
1
- # Exp 06 – Fashion-MNIST SVD baseline (replicating MNIST findings)
2
-
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import seaborn as sns
6
- from sklearn.decomposition import TruncatedSVD
7
- from sklearn.linear_model import LogisticRegression
8
- from sklearn.metrics import confusion_matrix, accuracy_score
9
- from matplotlib.colors import LinearSegmentedColormap
10
- import torchvision
11
- import torchvision.transforms as transforms
12
- import os
13
-
14
- from src import config
15
-
16
- # --- Configuration ---
17
- GRAY_LIGHT = "#D8DEE9"
18
- BLUE_DEEP = "#5E81AC"
19
-
20
- # Fashion-MNIST class names
21
- CLASS_NAMES = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
22
- 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
23
-
24
- def load_fashion_mnist():
25
- """Load and flatten Fashion-MNIST data."""
26
- print("Loading Fashion-MNIST...")
27
- transform = transforms.Compose([transforms.ToTensor()])
28
-
29
- trainset = torchvision.datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=True, download=True, transform=transform)
30
- testset = torchvision.datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=False, download=True, transform=transform)
31
-
32
- X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
33
- y_train = trainset.targets.numpy()
34
-
35
- X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
36
- y_test = testset.targets.numpy()
37
-
38
- return X_train, y_train, X_test, y_test
39
-
40
- def plot_confusion_matrix(y_true, y_pred, labels, filename, title):
41
- """Draws and saves a confusion matrix (normalized by row = recall)."""
42
- cm = confusion_matrix(y_true, y_pred, normalize='true')
43
- plt.figure(figsize=(12, 10))
44
- cmap = LinearSegmentedColormap.from_list("NBodyBlue", [GRAY_LIGHT, BLUE_DEEP])
45
- sns.heatmap(cm, annot=True, fmt='.1%', cmap=cmap,
46
- xticklabels=labels, yticklabels=labels)
47
- plt.title(title)
48
- plt.xlabel('Predicted')
49
- plt.ylabel('True')
50
- plt.tight_layout()
51
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300)
52
- plt.close()
53
- print(f"Saved {filename}")
54
-
55
- def analyze_confusion_pairs(cm, class_names, top_k=5):
56
- """Identify the most confused class pairs."""
57
- n = len(class_names)
58
- confusions = []
59
-
60
- for i in range(n):
61
- for j in range(n):
62
- if i != j:
63
- confusions.append((cm[i, j], class_names[i], class_names[j]))
64
-
65
- confusions.sort(reverse=True)
66
-
67
- print(f"\nTop {top_k} Confused Pairs:")
68
- for rate, true_class, pred_class in confusions[:top_k]:
69
- print(f" {true_class} → {pred_class}: {rate*100:.2f}%")
70
-
71
- return confusions[:top_k]
72
-
73
- def run_svd_baseline(X_train, y_train, X_test, y_test):
74
- """Run SVD + Logistic Regression baseline."""
75
- print("\n--- Running SVD Baseline (Fashion-MNIST) ---")
76
-
77
- n_components = 20
78
- print(f"Reducing dimension to {n_components} using SVD...")
79
-
80
- # Mean-center for consistency with hybrid model's SVD layer
81
- mean = np.mean(X_train, axis=0)
82
- X_train_centered = X_train - mean
83
- X_test_centered = X_test - mean
84
- svd = TruncatedSVD(n_components=n_components, random_state=42)
85
- X_train_svd = svd.fit_transform(X_train_centered)
86
- X_test_svd = svd.transform(X_test_centered)
87
-
88
- clf = LogisticRegression(max_iter=1000)
89
- clf.fit(X_train_svd, y_train)
90
- y_pred = clf.predict(X_test_svd)
91
-
92
- acc = accuracy_score(y_test, y_pred)
93
- print(f"SVD+LR Accuracy: {acc*100:.2f}%")
94
-
95
- # Confusion Matrix
96
- cm = confusion_matrix(y_test, y_pred, normalize='true')
97
- plot_confusion_matrix(y_test, y_pred, CLASS_NAMES,
98
- 'fig_11_fashion_svd_confusion.png',
99
- f'Fashion-MNIST SVD Confusion (k={n_components}, Acc={acc:.2%})')
100
-
101
- # Analyze confusions
102
- analyze_confusion_pairs(cm, CLASS_NAMES)
103
-
104
- return svd, clf
105
-
106
- def main():
107
- X_train, y_train, X_test, y_test = load_fashion_mnist()
108
-
109
- svd, clf = run_svd_baseline(X_train, y_train, X_test, y_test)
110
-
111
- print("\nExperiment 06 Complete.")
112
- print(f"Results saved to {config.RESULTS_DIR}")
113
-
114
- if __name__ == "__main__":
115
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/07_fashion_cnn_verification.py DELETED
@@ -1,145 +0,0 @@
1
- # Exp 07 – Fashion-MNIST CNN vs SVD confusion comparison
2
-
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import seaborn as sns
6
- import torch
7
- import torch.nn as nn
8
- import torch.optim as optim
9
- from torch.utils.data import DataLoader
10
- from torchvision import datasets, transforms
11
- from sklearn.metrics import confusion_matrix, accuracy_score
12
- from matplotlib.colors import LinearSegmentedColormap
13
- import os
14
-
15
- from src import config
16
- from src.hybrid_model import SimpleCNN
17
-
18
- # --- Configuration ---
19
- GRAY_LIGHT = "#D8DEE9"
20
- BLUE_DEEP = "#5E81AC"
21
-
22
- CLASS_NAMES = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
23
- 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
24
-
25
- def train_cnn(train_loader, epochs=10):
26
- """Train CNN on Fashion-MNIST."""
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- model = SimpleCNN(num_classes=10).to(device)
29
-
30
- criterion = nn.CrossEntropyLoss()
31
- optimizer = optim.Adam(model.parameters(), lr=0.001)
32
-
33
- model.train()
34
- for epoch in range(epochs):
35
- running_loss = 0.0
36
- for inputs, labels in train_loader:
37
- inputs, labels = inputs.to(device), labels.to(device)
38
-
39
- optimizer.zero_grad()
40
- outputs = model(inputs)
41
- loss = criterion(outputs, labels)
42
- loss.backward()
43
- optimizer.step()
44
-
45
- running_loss += loss.item()
46
-
47
- print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}")
48
-
49
- return model
50
-
51
- def evaluate_cnn(model, test_loader):
52
- """Evaluate CNN and return predictions."""
53
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
- model.eval()
55
-
56
- all_preds = []
57
- all_targets = []
58
-
59
- with torch.no_grad():
60
- for inputs, labels in test_loader:
61
- inputs = inputs.to(device)
62
- outputs = model(inputs)
63
- preds = outputs.argmax(dim=1)
64
-
65
- all_preds.extend(preds.cpu().numpy())
66
- all_targets.extend(labels.numpy())
67
-
68
- return np.array(all_preds), np.array(all_targets)
69
-
70
- def plot_confusion_matrix(y_true, y_pred, labels, filename, title):
71
- """Draws and saves a confusion matrix (normalized by row = recall)."""
72
- cm = confusion_matrix(y_true, y_pred, normalize='true')
73
- plt.figure(figsize=(12, 10))
74
- cmap = LinearSegmentedColormap.from_list("NBodyBlue", [GRAY_LIGHT, BLUE_DEEP])
75
- sns.heatmap(cm, annot=True, fmt='.1%', cmap=cmap,
76
- xticklabels=labels, yticklabels=labels)
77
- plt.title(title)
78
- plt.xlabel('Predicted')
79
- plt.ylabel('True')
80
- plt.tight_layout()
81
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300)
82
- plt.close()
83
- print(f"Saved {filename}")
84
-
85
- def analyze_confusion_improvement(svd_confusions, cnn_cm, class_names):
86
- """Compare SVD vs CNN on the most confused pairs."""
87
- print("\n--- Confusion Comparison: SVD vs CNN ---")
88
- print(f"{'Pair':<25} {'SVD Error':<12} {'CNN Error':<12} {'Improvement':<12}")
89
- print("-" * 60)
90
-
91
- for svd_rate, true_class, pred_class in svd_confusions:
92
- i = class_names.index(true_class)
93
- j = class_names.index(pred_class)
94
- cnn_rate = cnn_cm[i, j]
95
- improvement = (svd_rate - cnn_rate) / svd_rate * 100 if svd_rate > 0 else 0
96
-
97
- print(f"{true_class} → {pred_class:<10} {svd_rate*100:>8.2f}% {cnn_rate*100:>8.2f}% {improvement:>8.1f}%")
98
-
99
- def main():
100
- print("Loading Fashion-MNIST...")
101
- transform = transforms.Compose([transforms.ToTensor()])
102
-
103
- trainset = datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=True, download=True, transform=transform)
104
- testset = datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=False, download=True, transform=transform)
105
-
106
- train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
107
- test_loader = DataLoader(testset, batch_size=1000, shuffle=False)
108
-
109
- print("\n--- Training CNN on Fashion-MNIST ---")
110
- model = train_cnn(train_loader, epochs=10)
111
-
112
- print("\n--- Evaluating CNN ---")
113
- y_pred, y_true = evaluate_cnn(model, test_loader)
114
-
115
- acc = accuracy_score(y_true, y_pred)
116
- print(f"CNN Accuracy: {acc*100:.2f}%")
117
-
118
- # Confusion Matrix
119
- cm = confusion_matrix(y_true, y_pred, normalize='true')
120
- plot_confusion_matrix(y_true, y_pred, CLASS_NAMES,
121
- 'fig_12_fashion_cnn_confusion.png',
122
- f'Fashion-MNIST CNN Confusion (Acc={acc:.2%})')
123
-
124
- # Save model for later use
125
- model_path = os.path.join("models", "cnn_fashion.pth")
126
- os.makedirs("models", exist_ok=True)
127
- torch.save(model.state_dict(), model_path)
128
- print(f"Model saved to {model_path}")
129
-
130
- # Compare with SVD (hardcoded top confusions from experiment 06)
131
- # These will be updated after running experiment 06
132
- svd_confusions = [
133
- (0.15, 'Shirt', 'T-shirt'),
134
- (0.12, 'Shirt', 'Coat'),
135
- (0.10, 'Pullover', 'Coat'),
136
- (0.08, 'T-shirt', 'Shirt'),
137
- (0.06, 'Coat', 'Pullover'),
138
- ]
139
-
140
- analyze_confusion_improvement(svd_confusions, cm, CLASS_NAMES)
141
-
142
- print("\nExperiment 07 Complete.")
143
-
144
- if __name__ == "__main__":
145
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/08_hybrid_robustness.py DELETED
@@ -1,253 +0,0 @@
1
- # Exp 08 – MNIST 10-class hybrid robustness under Gaussian noise (multi-seed)
2
-
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import torch
6
- from torchvision import datasets, transforms
7
- from sklearn.decomposition import TruncatedSVD
8
- from sklearn.linear_model import LogisticRegression
9
- from sklearn.metrics import accuracy_score
10
- import os
11
- import json
12
- from scipy.ndimage import gaussian_filter
13
-
14
- from src.hybrid_model import SimpleCNN, HybridSVDCNN, create_svd_layer
15
- from src import config
16
-
17
- # --- Configuration ---
18
- BLUE_LIGHT = "#88C0D0"
19
- BLUE_DEEP = "#5E81AC"
20
- ORANGE = "#D08770"
21
- RED = "#BF616A"
22
-
23
- SVD_K = 20
24
- NOISE_LEVELS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
25
- SEEDS = [42, 123, 456]
26
- BLUR_SIGMA = 1.5
27
-
28
-
29
- def load_mnist():
30
- """Load MNIST test data."""
31
- transform = transforms.Compose([transforms.ToTensor()])
32
- testset = datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
33
- trainset = datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
34
-
35
- X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
36
- y_train = trainset.targets.numpy()
37
-
38
- X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
39
- y_test = testset.targets.numpy()
40
-
41
- return X_train, y_train, X_test, y_test
42
-
43
-
44
- def add_gaussian_noise(X, sigma):
45
- """Add Gaussian noise to images."""
46
- noise = np.random.randn(*X.shape) * sigma
47
- X_noisy = X + noise
48
- return np.clip(X_noisy, 0, 1)
49
-
50
-
51
- def set_seeds(seed: int) -> None:
52
- np.random.seed(seed)
53
- torch.manual_seed(seed)
54
-
55
-
56
- def evaluate_svd(svd, clf, X_test, y_test, mean):
57
- """Evaluate SVD+LR model (with mean-centering)."""
58
- X_test_centered = X_test - mean
59
- X_test_svd = svd.transform(X_test_centered)
60
- y_pred = clf.predict(X_test_svd)
61
- return accuracy_score(y_test, y_pred)
62
-
63
-
64
- def evaluate_cnn(model, X_test, y_test, device):
65
- """Evaluate CNN model."""
66
- model.eval()
67
- X_tensor = torch.tensor(X_test.reshape(-1, 1, 28, 28), dtype=torch.float32).to(device)
68
-
69
- with torch.no_grad():
70
- outputs = model(X_tensor)
71
- preds = outputs.argmax(dim=1).cpu().numpy()
72
-
73
- return accuracy_score(y_test, preds)
74
-
75
-
76
- def evaluate_blur_cnn(cnn, X_test, y_test, device, blur_sigma=BLUR_SIGMA):
77
- """Sanity baseline: if SVD is just smoothing, blur should do equally well."""
78
- X_blurred = np.array([
79
- gaussian_filter(img.reshape(28, 28), sigma=blur_sigma).flatten()
80
- for img in X_test
81
- ])
82
- X_blurred = np.clip(X_blurred, 0, 1)
83
-
84
- return evaluate_cnn(cnn, X_blurred, y_test, device)
85
-
86
-
87
- def load_pretrained_cnn(device) -> SimpleCNN:
88
- """Load the pretrained CNN from models/ for stable, reproducible evaluation."""
89
- model = SimpleCNN(num_classes=10).to(device)
90
- model.load_state_dict(torch.load(config.CNN_MODEL_PATH, map_location=device))
91
- model.eval()
92
- return model
93
-
94
-
95
- def train_svd_model(X_train, y_train, n_components=SVD_K):
96
- """SVD+LR baseline. Mean-centered to match hybrid model's SVD layer."""
97
- mean = np.mean(X_train, axis=0)
98
- X_centered = X_train - mean
99
-
100
- svd = TruncatedSVD(n_components=n_components, random_state=42)
101
- X_train_svd = svd.fit_transform(X_centered)
102
-
103
- clf = LogisticRegression(max_iter=1000)
104
- clf.fit(X_train_svd, y_train)
105
-
106
- return svd, clf, mean
107
-
108
-
109
- def plot_robustness_comparison(results, filename, std_results=None):
110
- """Plot robustness curves for all models, optionally with error bands."""
111
- plt.figure(figsize=(10, 6))
112
-
113
- colors = {
114
- 'CNN': BLUE_LIGHT,
115
- 'SVD': BLUE_DEEP,
116
- 'Hybrid': RED,
117
- 'Blur+CNN': ORANGE,
118
- }
119
- markers = {'CNN': 'o', 'SVD': 's', 'Hybrid': '^', 'Blur+CNN': 'D'}
120
-
121
- for model_name, accuracies in results.items():
122
- plt.plot(NOISE_LEVELS, accuracies,
123
- color=colors[model_name],
124
- marker=markers[model_name],
125
- linewidth=2, markersize=8,
126
- label=f'{model_name}')
127
- # Add shaded error band if std is available
128
- if std_results and model_name in std_results:
129
- mean = np.array(accuracies)
130
- std = np.array(std_results[model_name])
131
- plt.fill_between(NOISE_LEVELS, mean - std, mean + std,
132
- color=colors[model_name], alpha=0.15)
133
-
134
- plt.xlabel('Noise Level (σ)', fontsize=12)
135
- plt.ylabel('Accuracy', fontsize=12)
136
- plt.title('Model Robustness Under Gaussian Noise', fontsize=14)
137
- plt.legend(fontsize=11)
138
- plt.grid(True, alpha=0.3)
139
- plt.ylim(0.0, 1.05)
140
-
141
- # Add annotations
142
- plt.axhline(y=0.9, color='gray', linestyle='--', alpha=0.5)
143
- plt.text(0.65, 0.91, '90% threshold', fontsize=10, color='gray')
144
-
145
- plt.tight_layout()
146
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300)
147
- plt.close()
148
- print(f"Saved {filename}")
149
-
150
-
151
- def run_single_seed(seed, X_train, y_train, X_test, y_test, device, cnn):
152
- """Run one full evaluation pass with a given random seed."""
153
- set_seeds(seed)
154
- print(f"\n{'='*50}")
155
- print(f" Seed = {seed}")
156
- print(f"{'='*50}")
157
-
158
- # Train SVD+LR (with mean-centering, consistent with hybrid)
159
- svd, svd_clf, svd_mean = train_svd_model(X_train, y_train)
160
-
161
- # Create Hybrid Model
162
- svd_layer = create_svd_layer(X_train, n_components=SVD_K)
163
- hybrid = HybridSVDCNN(svd_layer, cnn).to(device)
164
-
165
- results = {'CNN': [], 'SVD': [], 'Hybrid': [], 'Blur+CNN': []}
166
-
167
- for sigma in NOISE_LEVELS:
168
- X_test_noisy = add_gaussian_noise(X_test, sigma)
169
-
170
- acc_svd = evaluate_svd(svd, svd_clf, X_test_noisy, y_test, svd_mean)
171
- acc_cnn = evaluate_cnn(cnn, X_test_noisy, y_test, device)
172
- acc_hybrid = evaluate_cnn(hybrid, X_test_noisy, y_test, device)
173
- acc_blur = evaluate_blur_cnn(cnn, X_test_noisy, y_test, device)
174
-
175
- results['SVD'].append(acc_svd)
176
- results['CNN'].append(acc_cnn)
177
- results['Hybrid'].append(acc_hybrid)
178
- results['Blur+CNN'].append(acc_blur)
179
-
180
- print(f" σ={sigma:.1f} SVD={acc_svd*100:5.2f}% CNN={acc_cnn*100:5.2f}% "
181
- f"Hybrid={acc_hybrid*100:5.2f}% Blur+CNN={acc_blur*100:5.2f}%")
182
-
183
- return results
184
-
185
-
186
- def main():
187
- print("Loading MNIST...")
188
- X_train, y_train, X_test, y_test = load_mnist()
189
-
190
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
- print(f"Using device: {device}")
192
-
193
- # Load pretrained CNN (shared across all seeds)
194
- print("\nLoading pretrained CNN...")
195
- cnn = load_pretrained_cnn(device)
196
-
197
- # Run over multiple seeds
198
- all_runs = [] # list of {model_name: [acc_per_sigma]}
199
- for seed in SEEDS:
200
- run_results = run_single_seed(seed, X_train, y_train, X_test, y_test, device, cnn)
201
- all_runs.append(run_results)
202
-
203
- # Aggregate: compute mean ± std across seeds
204
- model_names = ['CNN', 'SVD', 'Hybrid', 'Blur+CNN']
205
- mean_results = {}
206
- std_results = {}
207
- for name in model_names:
208
- all_accs = np.array([run[name] for run in all_runs]) # (n_seeds, n_noise_levels)
209
- mean_results[name] = all_accs.mean(axis=0).tolist()
210
- std_results[name] = all_accs.std(axis=0).tolist()
211
-
212
- # Plot comparison (mean with error band)
213
- plot_robustness_comparison(mean_results, 'fig_10_hybrid_robustness.png', std_results)
214
-
215
- # Save raw numbers for reproducibility / app usage
216
- out_json = {
217
- "dataset": "MNIST",
218
- "task": "10-class",
219
- "noise_levels": NOISE_LEVELS,
220
- "seeds": SEEDS,
221
- "results_mean": {k: [round(x, 4) for x in v] for k, v in mean_results.items()},
222
- "results_std": {k: [round(x, 4) for x in v] for k, v in std_results.items()},
223
- "results": {k: [round(x, 4) for x in v] for k, v in mean_results.items()}, # backward compat
224
- "svd_components": SVD_K,
225
- "svd_centering": True,
226
- "blur_sigma": BLUR_SIGMA,
227
- "cnn_epochs": 5,
228
- "notes": "Mean over 3 seeds. SVD baseline uses explicit mean-centering for consistency with hybrid layer.",
229
- }
230
- json_path = os.path.join(config.RESULTS_DIR, "robustness_mnist_noise.json")
231
- with open(json_path, "w", encoding="utf-8") as f:
232
- json.dump(out_json, f, indent=2)
233
- print(f"\nSaved robustness JSON to {json_path}")
234
-
235
- # Summary table
236
- print("\n--- Summary (mean ± std across seeds) ---")
237
- header = f"{'Model':<12} {'Clean':<16} {'σ=0.3':<16} {'σ=0.5':<16} {'σ=0.7':<16}"
238
- print(header)
239
- print("-" * len(header))
240
- for name in model_names:
241
- m = mean_results[name]
242
- s = std_results[name]
243
- # indices: 0=clean, 3=0.3, 5=0.5, 7=0.7
244
- print(f"{name:<12} "
245
- f"{m[0]*100:5.2f}±{s[0]*100:.2f}% "
246
- f"{m[3]*100:5.2f}±{s[3]*100:.2f}% "
247
- f"{m[5]*100:5.2f}±{s[5]*100:.2f}% "
248
- f"{m[7]*100:5.2f}±{s[7]*100:.2f}%")
249
-
250
- print("\nExperiment 08 Complete.")
251
-
252
- if __name__ == "__main__":
253
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/09_fashion_hybrid_robustness.py DELETED
@@ -1,189 +0,0 @@
1
- # Exp 09 – Fashion-MNIST hybrid robustness under Gaussian noise
2
-
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import torch
6
- from torchvision import transforms
7
- import torchvision
8
- from sklearn.decomposition import TruncatedSVD
9
- from sklearn.linear_model import LogisticRegression
10
- from sklearn.metrics import accuracy_score
11
- import os
12
- import sys
13
- import json
14
-
15
- from src.hybrid_model import SimpleCNN, HybridSVDCNN, create_svd_layer
16
- from src import config
17
-
18
- # --- Configuration ---
19
- BLUE_LIGHT = "#88C0D0"
20
- BLUE_DEEP = "#5E81AC"
21
- RED = "#BF616A"
22
-
23
- SVD_K = 20
24
- NOISE_LEVELS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
25
-
26
- def load_fashion_mnist():
27
- """Load Fashion-MNIST test data."""
28
- print("Loading Fashion-MNIST...")
29
- transform = transforms.Compose([transforms.ToTensor()])
30
-
31
- # Use ./data/fashion to match other scripts
32
- trainset = torchvision.datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=True, download=True, transform=transform)
33
- testset = torchvision.datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=False, download=True, transform=transform)
34
-
35
- X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
36
- y_train = trainset.targets.numpy()
37
-
38
- X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
39
- y_test = testset.targets.numpy()
40
-
41
- return X_train, y_train, X_test, y_test
42
-
43
- def add_gaussian_noise(X, sigma):
44
- """Add Gaussian noise to images."""
45
- noise = np.random.randn(*X.shape) * sigma
46
- X_noisy = X + noise
47
- return np.clip(X_noisy, 0, 1)
48
-
49
-
50
- def set_seeds(seed: int) -> None:
51
- np.random.seed(seed)
52
- torch.manual_seed(seed)
53
-
54
-
55
- def load_pretrained_cnn(device) -> SimpleCNN:
56
- """Load the pretrained Fashion-MNIST CNN from models/ for stable evaluation."""
57
- model = SimpleCNN(num_classes=10).to(device)
58
- model.load_state_dict(torch.load(config.CNN_FASHION_MODEL_PATH, map_location=device))
59
- model.eval()
60
- return model
61
-
62
- def evaluate_svd(svd, clf, X_test, y_test, mean=None):
63
- """Evaluate SVD+LR model (with optional mean-centering)."""
64
- X = X_test - mean if mean is not None else X_test
65
- X_test_svd = svd.transform(X)
66
- y_pred = clf.predict(X_test_svd)
67
- return accuracy_score(y_test, y_pred)
68
-
69
- def evaluate_cnn(model, X_test, y_test, device):
70
- """Evaluate CNN model."""
71
- model.eval()
72
- X_tensor = torch.tensor(X_test.reshape(-1, 1, 28, 28), dtype=torch.float32).to(device)
73
-
74
- with torch.no_grad():
75
- outputs = model(X_tensor)
76
- preds = outputs.argmax(dim=1).cpu().numpy()
77
-
78
- return accuracy_score(y_test, preds)
79
-
80
- def train_svd_model(X_train, y_train, n_components=SVD_K):
81
- """Train SVD + Logistic Regression with explicit mean-centering."""
82
- print(f"Training SVD (k={n_components}) on shape {X_train.shape}...")
83
- sys.stdout.flush()
84
- # Mean-center for consistency with hybrid model's SVD layer
85
- mean = np.mean(X_train, axis=0)
86
- X_centered = X_train - mean
87
- svd = TruncatedSVD(n_components=n_components, algorithm='randomized', n_iter=5, random_state=42)
88
- X_train_svd = svd.fit_transform(X_centered)
89
- print("SVD fitted. Training Logistic Regression...")
90
-
91
- # Increase iterations to avoid premature convergence warnings on 60k samples
92
- clf = LogisticRegression(max_iter=1000)
93
- clf.fit(X_train_svd, y_train)
94
-
95
- print(f"SVD Explained Variance: {svd.explained_variance_ratio_.sum()*100:.2f}%")
96
- return svd, clf, mean
97
-
98
- def plot_robustness_comparison(results, filename):
99
- """Plot robustness curves for all models."""
100
- plt.figure(figsize=(10, 6))
101
-
102
- colors = {'CNN': BLUE_LIGHT, 'SVD': BLUE_DEEP, 'Hybrid': RED}
103
- markers = {'CNN': 'o', 'SVD': 's', 'Hybrid': '^'}
104
-
105
- for model_name, accuracies in results.items():
106
- plt.plot(NOISE_LEVELS, accuracies,
107
- color=colors[model_name],
108
- marker=markers[model_name],
109
- linewidth=2, markersize=8,
110
- label=f'{model_name}')
111
-
112
- plt.xlabel('Noise Level (σ)', fontsize=12)
113
- plt.ylabel('Accuracy', fontsize=12)
114
- plt.title('Fashion-MNIST: Model Robustness Under Gaussian Noise', fontsize=14)
115
- plt.legend(fontsize=11)
116
- plt.grid(True, alpha=0.3)
117
- plt.ylim(0.0, 1.05)
118
-
119
- plt.tight_layout()
120
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300)
121
- plt.close()
122
- print(f"Saved {filename}")
123
-
124
- def main():
125
- set_seeds(42)
126
- # 1. Load Data
127
- X_train, y_train, X_test, y_test = load_fashion_mnist()
128
-
129
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
-
131
- # 2. Train Models
132
- # SVD
133
- svd, svd_clf, svd_mean = train_svd_model(X_train, y_train)
134
-
135
- # CNN (pretrained)
136
- cnn = load_pretrained_cnn(device)
137
-
138
- # Hybrid
139
- svd_layer = create_svd_layer(X_train, n_components=SVD_K)
140
- hybrid = HybridSVDCNN(svd_layer, cnn).to(device)
141
-
142
- # 3. Evaluate
143
- print("\n--- Evaluating Robustness on Fashion-MNIST ---")
144
- results = {'CNN': [], 'SVD': [], 'Hybrid': []}
145
-
146
- for sigma in NOISE_LEVELS:
147
- print(f"\nNoise σ = {sigma}")
148
- X_test_noisy = add_gaussian_noise(X_test, sigma)
149
-
150
- # SVD
151
- acc_svd = evaluate_svd(svd, svd_clf, X_test_noisy, y_test, svd_mean)
152
- results['SVD'].append(acc_svd)
153
- print(f" SVD: {acc_svd*100:.2f}%")
154
-
155
- # CNN
156
- acc_cnn = evaluate_cnn(cnn, X_test_noisy, y_test, device)
157
- results['CNN'].append(acc_cnn)
158
- print(f" CNN: {acc_cnn*100:.2f}%")
159
-
160
- # Hybrid
161
- acc_hybrid = evaluate_cnn(hybrid, X_test_noisy, y_test, device)
162
- results['Hybrid'].append(acc_hybrid)
163
- print(f" Hybrid: {acc_hybrid*100:.2f}%")
164
-
165
- # Save raw numbers for reproducibility / app usage
166
- out_json = {
167
- "dataset": "Fashion-MNIST",
168
- "task": "10-class",
169
- "noise_levels": NOISE_LEVELS,
170
- "results": {k: [float(x) for x in v] for k, v in results.items()},
171
- "svd_components": 20,
172
- "cnn_epochs": 5,
173
- "notes": "Numbers are evaluated on Fashion-MNIST test set with test-time Gaussian noise.",
174
- }
175
- json_path = os.path.join(config.RESULTS_DIR, "robustness_fashion_noise.json")
176
- with open(json_path, "w", encoding="utf-8") as f:
177
- json.dump(out_json, f, indent=2)
178
- print(f"Saved robustness JSON to {json_path}")
179
-
180
- # 5. Summary Table
181
- print("\n--- Summary (Fashion-MNIST) ---")
182
- print(f"{'Model':<10} {'Clean':<10} {'σ=0.3':<10} {'σ=0.5':<10} {'σ=0.7':<10}")
183
- print("-" * 50)
184
- for model_name in ['CNN', 'SVD', 'Hybrid']:
185
- accs = results[model_name]
186
- print(f"{model_name:<10} {accs[0]*100:>6.2f}% {accs[3]*100:>6.2f}% {accs[5]*100:>6.2f}% {accs[7]*100:>6.2f}%")
187
-
188
- if __name__ == "__main__":
189
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/10_ablation_study.py DELETED
@@ -1,344 +0,0 @@
1
- # Exp 10 – Ablation Study: Depth vs Non-linearity Contribution
2
- # Systematically test the independent contributions of depth and non-linearity
3
-
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- import torch
7
- import torch.nn as nn
8
- import torch.optim as optim
9
- from torch.utils.data import TensorDataset, DataLoader
10
- from torchvision import datasets, transforms
11
- from sklearn.linear_model import LogisticRegression
12
- from sklearn.metrics import accuracy_score
13
- from sklearn.model_selection import train_test_split
14
- import os
15
-
16
- from src import config
17
-
18
- # --- Configuration ---
19
- BLUE_DEEP = "#5E81AC"
20
- ORANGE = "#D08770"
21
- GREEN = "#A3BE8C"
22
- RED = "#BF616A"
23
- PURPLE = "#B48EAD"
24
-
25
- SEED = 42
26
- BATCH_SIZE = 64
27
- EPOCHS = 10
28
- LR = 0.001
29
-
30
-
31
- def set_seeds(seed):
32
- """Set all random seeds for reproducibility."""
33
- np.random.seed(seed)
34
- torch.manual_seed(seed)
35
- torch.cuda.manual_seed_all(seed)
36
-
37
-
38
- class ShallowLinear(nn.Module):
39
- """Single layer linear model (no activation)."""
40
- def __init__(self, num_classes=10):
41
- super().__init__()
42
- self.fc = nn.Linear(784, num_classes)
43
-
44
- def forward(self, x):
45
- x = x.view(-1, 784)
46
- return self.fc(x)
47
-
48
-
49
- class ShallowNonLinear(nn.Module):
50
- """Single layer with ReLU activation."""
51
- def __init__(self, num_classes=10, hidden_size=128):
52
- super().__init__()
53
- self.fc1 = nn.Linear(784, hidden_size)
54
- self.fc2 = nn.Linear(hidden_size, num_classes)
55
-
56
- def forward(self, x):
57
- x = x.view(-1, 784)
58
- x = torch.relu(self.fc1(x))
59
- return self.fc2(x)
60
-
61
-
62
- class DeepLinear(nn.Module):
63
- """Two hidden layers without activation (identity mapping)."""
64
- def __init__(self, num_classes=10, hidden_size=128):
65
- super().__init__()
66
- self.fc1 = nn.Linear(784, hidden_size)
67
- self.fc2 = nn.Linear(hidden_size, hidden_size)
68
- self.fc3 = nn.Linear(hidden_size, num_classes)
69
-
70
- def forward(self, x):
71
- x = x.view(-1, 784)
72
- x = self.fc1(x) # No activation
73
- x = self.fc2(x) # No activation
74
- return self.fc3(x)
75
-
76
-
77
- class DeepNonLinear(nn.Module):
78
- """Two hidden layers with ReLU activation (similar to CNN complexity)."""
79
- def __init__(self, num_classes=10, hidden_size=128):
80
- super().__init__()
81
- self.fc1 = nn.Linear(784, hidden_size)
82
- self.fc2 = nn.Linear(hidden_size, hidden_size)
83
- self.fc3 = nn.Linear(hidden_size, num_classes)
84
-
85
- def forward(self, x):
86
- x = x.view(-1, 784)
87
- x = torch.relu(self.fc1(x))
88
- x = torch.relu(self.fc2(x))
89
- return self.fc3(x)
90
-
91
-
92
- class SimpleCNN(nn.Module):
93
- """2-conv CNN for comparison (from hybrid_model)."""
94
- def __init__(self, num_classes=10):
95
- super().__init__()
96
- self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
97
- self.pool = nn.MaxPool2d(2, 2)
98
- self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
99
- self.fc1 = nn.Linear(32 * 7 * 7, 128)
100
- self.fc2 = nn.Linear(128, num_classes)
101
-
102
- def forward(self, x):
103
- x = self.pool(torch.relu(self.conv1(x)))
104
- x = self.pool(torch.relu(self.conv2(x)))
105
- x = x.view(-1, 32 * 7 * 7)
106
- x = torch.relu(self.fc1(x))
107
- x = self.fc2(x)
108
- return x
109
-
110
-
111
- def load_mnist():
112
- """Load MNIST train/test data."""
113
- transform = transforms.Compose([transforms.ToTensor()])
114
- trainset = datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
115
- testset = datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
116
-
117
- X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
118
- y_train = trainset.targets.numpy()
119
-
120
- X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
121
- y_test = testset.targets.numpy()
122
-
123
- return X_train, y_train, X_test, y_test
124
-
125
-
126
- def train_model(model, X_train, y_train, X_val, y_val, device, epochs=EPOCHS):
127
- """Train a PyTorch model with validation tracking."""
128
- model = model.to(device)
129
- criterion = nn.CrossEntropyLoss()
130
- optimizer = optim.Adam(model.parameters(), lr=LR)
131
-
132
- # Convert to tensors
133
- X_train_t = torch.tensor(X_train, dtype=torch.float32)
134
- y_train_t = torch.tensor(y_train, dtype=torch.long)
135
- X_val_t = torch.tensor(X_val, dtype=torch.float32)
136
- y_val_t = torch.tensor(y_val, dtype=torch.long)
137
-
138
- # Reshape for CNN if needed
139
- if isinstance(model, SimpleCNN):
140
- X_train_t = X_train_t.view(-1, 1, 28, 28)
141
- X_val_t = X_val_t.view(-1, 1, 28, 28)
142
-
143
- train_dataset = TensorDataset(X_train_t, y_train_t)
144
- train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
145
-
146
- history = {'train_acc': [], 'val_acc': []}
147
-
148
- for epoch in range(epochs):
149
- model.train()
150
- train_correct = 0
151
- train_total = 0
152
-
153
- for inputs, labels in train_loader:
154
- inputs, labels = inputs.to(device), labels.to(device)
155
- optimizer.zero_grad()
156
- outputs = model(inputs)
157
- loss = criterion(outputs, labels)
158
- loss.backward()
159
- optimizer.step()
160
-
161
- _, predicted = outputs.max(1)
162
- train_total += labels.size(0)
163
- train_correct += predicted.eq(labels).sum().item()
164
-
165
- train_acc = 100.0 * train_correct / train_total
166
-
167
- # Validation
168
- model.eval()
169
- with torch.no_grad():
170
- X_val_batch = X_val_t.to(device)
171
- y_val_batch = y_val_t.to(device)
172
- outputs = model(X_val_batch)
173
- _, predicted = outputs.max(1)
174
- val_acc = 100.0 * predicted.eq(y_val_batch).sum().item() / len(y_val_batch)
175
-
176
- history['train_acc'].append(train_acc)
177
- history['val_acc'].append(val_acc)
178
-
179
- if (epoch + 1) % 2 == 0:
180
- print(f" Epoch {epoch+1}/{epochs}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
181
-
182
- return model, history
183
-
184
-
185
- def evaluate_model(model, X_test, y_test, device):
186
- """Evaluate model on test set."""
187
- model.eval()
188
- X_test_t = torch.tensor(X_test, dtype=torch.float32)
189
-
190
- if isinstance(model, SimpleCNN):
191
- X_test_t = X_test_t.view(-1, 1, 28, 28)
192
-
193
- with torch.no_grad():
194
- X_test_batch = X_test_t.to(device)
195
- outputs = model(X_test_batch)
196
- _, predicted = outputs.max(1)
197
- accuracy = 100.0 * predicted.eq(torch.tensor(y_test).to(device)).sum().item() / len(y_test)
198
-
199
- return accuracy
200
-
201
-
202
- def plot_ablation_results(results, filename):
203
- """Plot ablation study results."""
204
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
205
-
206
- # Bar chart: Test accuracy
207
- models = list(results.keys())
208
- test_accs = [results[m]['test_acc'] for m in models]
209
- colors = [BLUE_DEEP, GREEN, PURPLE, ORANGE, RED]
210
-
211
- ax1.bar(range(len(models)), test_accs, color=colors, alpha=0.8)
212
- ax1.set_xticks(range(len(models)))
213
- ax1.set_xticklabels(models, rotation=30, ha='right')
214
- ax1.set_ylabel('Test Accuracy (%)')
215
- ax1.set_title('Ablation Study: Architecture Comparison')
216
- ax1.grid(axis='y', alpha=0.3)
217
- ax1.set_ylim([85, 100])
218
-
219
- # Add value labels
220
- for i, acc in enumerate(test_accs):
221
- ax1.text(i, acc + 0.5, f'{acc:.2f}%', ha='center', va='bottom', fontsize=9)
222
-
223
- # Learning curves
224
- for i, (model_name, data) in enumerate(results.items()):
225
- if 'val_acc' in data:
226
- epochs = range(1, len(data['val_acc']) + 1)
227
- ax2.plot(epochs, data['val_acc'], label=model_name, color=colors[i], linewidth=2)
228
-
229
- ax2.set_xlabel('Epoch')
230
- ax2.set_ylabel('Validation Accuracy (%)')
231
- ax2.set_title('Training Dynamics')
232
- ax2.legend()
233
- ax2.grid(alpha=0.3)
234
-
235
- plt.tight_layout()
236
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
237
- plt.close()
238
- print(f"Saved {filename}")
239
-
240
-
241
- def main():
242
- set_seeds(SEED)
243
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
244
- print(f"Using device: {device}\n")
245
-
246
- # Load data
247
- print("Loading MNIST...")
248
- X_train_full, y_train_full, X_test, y_test = load_mnist()
249
-
250
- # Split train into train/val
251
- X_train, X_val, y_train, y_val = train_test_split(
252
- X_train_full, y_train_full, test_size=0.2, random_state=SEED, stratify=y_train_full
253
- )
254
- print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}\n")
255
-
256
- results = {}
257
-
258
- # 1. Shallow Linear
259
- print("="*60)
260
- print("Training: Shallow Linear (baseline)")
261
- print("="*60)
262
- model = ShallowLinear(num_classes=10)
263
- model, history = train_model(model, X_train, y_train, X_val, y_val, device)
264
- test_acc = evaluate_model(model, X_test, y_test, device)
265
- results['Shallow Linear'] = {'test_acc': test_acc, 'val_acc': history['val_acc']}
266
- print(f"Test Accuracy: {test_acc:.2f}%\n")
267
-
268
- # 2. Shallow Non-Linear
269
- print("="*60)
270
- print("Training: Shallow Non-Linear (+ ReLU)")
271
- print("="*60)
272
- model = ShallowNonLinear(num_classes=10)
273
- model, history = train_model(model, X_train, y_train, X_val, y_val, device)
274
- test_acc = evaluate_model(model, X_test, y_test, device)
275
- results['Shallow NonLinear'] = {'test_acc': test_acc, 'val_acc': history['val_acc']}
276
- print(f"Test Accuracy: {test_acc:.2f}%\n")
277
-
278
- # 3. Deep Linear
279
- print("="*60)
280
- print("Training: Deep Linear (+ Depth, no ReLU)")
281
- print("="*60)
282
- model = DeepLinear(num_classes=10)
283
- model, history = train_model(model, X_train, y_train, X_val, y_val, device)
284
- test_acc = evaluate_model(model, X_test, y_test, device)
285
- results['Deep Linear'] = {'test_acc': test_acc, 'val_acc': history['val_acc']}
286
- print(f"Test Accuracy: {test_acc:.2f}%\n")
287
-
288
- # 4. Deep Non-Linear
289
- print("="*60)
290
- print("Training: Deep Non-Linear (+ Depth + ReLU)")
291
- print("="*60)
292
- model = DeepNonLinear(num_classes=10)
293
- model, history = train_model(model, X_train, y_train, X_val, y_val, device)
294
- test_acc = evaluate_model(model, X_test, y_test, device)
295
- results['Deep NonLinear'] = {'test_acc': test_acc, 'val_acc': history['val_acc']}
296
- print(f"Test Accuracy: {test_acc:.2f}%\n")
297
-
298
- # 5. CNN (for reference)
299
- print("="*60)
300
- print("Training: CNN (convolutional + non-linear)")
301
- print("="*60)
302
- model = SimpleCNN(num_classes=10)
303
- # Reshape data for CNN
304
- X_train_cnn = X_train.reshape(-1, 28, 28)
305
- X_val_cnn = X_val.reshape(-1, 28, 28)
306
- X_test_cnn = X_test.reshape(-1, 28, 28)
307
- model, history = train_model(model, X_train_cnn, y_train, X_val_cnn, y_val, device)
308
- test_acc = evaluate_model(model, X_test_cnn, y_test, device)
309
- results['CNN'] = {'test_acc': test_acc, 'val_acc': history['val_acc']}
310
- print(f"Test Accuracy: {test_acc:.2f}%\n")
311
-
312
- # Summary
313
- print("="*60)
314
- print("ABLATION STUDY SUMMARY")
315
- print("="*60)
316
- for model_name, data in results.items():
317
- print(f"{model_name:20s}: {data['test_acc']:.2f}%")
318
-
319
- # Analysis
320
- print("\n" + "="*60)
321
- print("KEY INSIGHTS")
322
- print("="*60)
323
- shallow_linear = results['Shallow Linear']['test_acc']
324
- shallow_nonlinear = results['Shallow NonLinear']['test_acc']
325
- deep_linear = results['Deep Linear']['test_acc']
326
- deep_nonlinear = results['Deep NonLinear']['test_acc']
327
-
328
- nonlinearity_gain = shallow_nonlinear - shallow_linear
329
- depth_gain = deep_linear - shallow_linear
330
- combined_gain = deep_nonlinear - shallow_linear
331
-
332
- print(f"Non-linearity alone (shallow): +{nonlinearity_gain:.2f} pp")
333
- print(f"Depth alone (linear): +{depth_gain:.2f} pp")
334
- print(f"Depth + Non-linearity (combined): +{combined_gain:.2f} pp")
335
- print(f"CNN (convolutional structure): +{results['CNN']['test_acc'] - shallow_linear:.2f} pp")
336
-
337
- # Plot results
338
- plot_ablation_results(results, 'fig_13_ablation_study.png')
339
-
340
- print("\n✓ Ablation study complete!")
341
-
342
-
343
- if __name__ == "__main__":
344
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/11_learning_curves.py DELETED
@@ -1,228 +0,0 @@
1
- # Exp 11 – Learning Curves Visualization
2
- # Generate training/validation loss and accuracy curves from saved training history
3
-
4
- import matplotlib.pyplot as plt
5
- import pickle
6
- import os
7
- import numpy as np
8
-
9
- from src import config
10
-
11
- # --- Configuration ---
12
- BLUE_DEEP = "#5E81AC"
13
- ORANGE = "#D08770"
14
- GRAY_LIGHT = "#D8DEE9"
15
-
16
-
17
- def plot_learning_curves(history, title, filename):
18
- """
19
- Plot training and validation curves for loss and accuracy.
20
-
21
- Args:
22
- history: Dictionary with keys 'train_loss', 'val_loss', 'train_acc', 'val_acc'
23
- title: Plot title
24
- filename: Output filename
25
- """
26
- epochs = range(1, len(history['train_loss']) + 1)
27
-
28
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
29
-
30
- # Loss curves
31
- ax1.plot(epochs, history['train_loss'], label='Training Loss',
32
- color=BLUE_DEEP, linewidth=2, marker='o', markersize=4)
33
- ax1.plot(epochs, history['val_loss'], label='Validation Loss',
34
- color=ORANGE, linewidth=2, marker='s', markersize=4)
35
- ax1.set_xlabel('Epoch', fontsize=12)
36
- ax1.set_ylabel('Loss', fontsize=12)
37
- ax1.set_title('Training and Validation Loss', fontsize=13, fontweight='bold')
38
- ax1.legend(loc='best', fontsize=10)
39
- ax1.grid(alpha=0.3)
40
-
41
- # Highlight best validation loss
42
- best_val_epoch = np.argmin(history['val_loss'])
43
- best_val_loss = history['val_loss'][best_val_epoch]
44
- ax1.axvline(x=best_val_epoch + 1, color='red', linestyle='--', alpha=0.5, linewidth=1)
45
- ax1.plot(best_val_epoch + 1, best_val_loss, 'r*', markersize=15,
46
- label=f'Best Val Loss: {best_val_loss:.4f} @ Epoch {best_val_epoch + 1}')
47
- ax1.legend(loc='best', fontsize=9)
48
-
49
- # Accuracy curves
50
- ax2.plot(epochs, history['train_acc'], label='Training Accuracy',
51
- color=BLUE_DEEP, linewidth=2, marker='o', markersize=4)
52
- ax2.plot(epochs, history['val_acc'], label='Validation Accuracy',
53
- color=ORANGE, linewidth=2, marker='s', markersize=4)
54
- ax2.set_xlabel('Epoch', fontsize=12)
55
- ax2.set_ylabel('Accuracy (%)', fontsize=12)
56
- ax2.set_title('Training and Validation Accuracy', fontsize=13, fontweight='bold')
57
- ax2.legend(loc='best', fontsize=10)
58
- ax2.grid(alpha=0.3)
59
-
60
- # Highlight best validation accuracy
61
- best_val_epoch = np.argmax(history['val_acc'])
62
- best_val_acc = history['val_acc'][best_val_epoch]
63
- ax2.axvline(x=best_val_epoch + 1, color='red', linestyle='--', alpha=0.5, linewidth=1)
64
- ax2.plot(best_val_epoch + 1, best_val_acc, 'r*', markersize=15,
65
- label=f'Best Val Acc: {best_val_acc:.2f}% @ Epoch {best_val_epoch + 1}')
66
- ax2.legend(loc='best', fontsize=9)
67
-
68
- plt.suptitle(title, fontsize=15, fontweight='bold', y=1.02)
69
- plt.tight_layout()
70
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
71
- plt.close()
72
- print(f"✓ Saved {filename}")
73
-
74
-
75
- def analyze_overfitting(history):
76
- """Analyze training dynamics for overfitting indicators."""
77
- train_acc = history['train_acc']
78
- val_acc = history['val_acc']
79
- train_loss = history['train_loss']
80
- val_loss = history['val_loss']
81
-
82
- # Calculate gaps
83
- final_acc_gap = train_acc[-1] - val_acc[-1]
84
- final_loss_gap = val_loss[-1] - train_loss[-1]
85
-
86
- # Check for divergence (sign of overfitting)
87
- mid_point = len(train_acc) // 2
88
- early_acc_gap = np.mean(train_acc[:mid_point]) - np.mean(val_acc[:mid_point])
89
- late_acc_gap = np.mean(train_acc[mid_point:]) - np.mean(val_acc[mid_point:])
90
- gap_increase = late_acc_gap - early_acc_gap
91
-
92
- print("\n" + "="*60)
93
- print("OVERFITTING ANALYSIS")
94
- print("="*60)
95
- print(f"Final Train Accuracy: {train_acc[-1]:.2f}%")
96
- print(f"Final Validation Accuracy: {val_acc[-1]:.2f}%")
97
- print(f"Accuracy Gap: {final_acc_gap:.2f} pp")
98
- print(f"Loss Gap: {final_loss_gap:.4f}")
99
- print(f"Gap Increase (early→late): {gap_increase:.2f} pp")
100
-
101
- if final_acc_gap > 5.0:
102
- print("\n⚠️ WARNING: Significant train-val accuracy gap detected (>5 pp)")
103
- print(" Consider: regularization, dropout, or early stopping")
104
- elif gap_increase > 2.0:
105
- print("\n⚠️ WARNING: Train-val gap widening over time")
106
- print(" Model may be starting to overfit")
107
- else:
108
- print("\n✓ No significant overfitting detected")
109
-
110
- # Best epoch analysis
111
- best_epoch = np.argmax(val_acc)
112
- total_epochs = len(val_acc)
113
- print(f"\nBest validation accuracy achieved at epoch {best_epoch + 1}/{total_epochs}")
114
- if best_epoch < total_epochs - 2:
115
- print(f"⚠️ Training continued for {total_epochs - best_epoch - 1} epochs after best model")
116
- print(" Early stopping could have saved training time")
117
-
118
- return {
119
- 'final_acc_gap': final_acc_gap,
120
- 'final_loss_gap': final_loss_gap,
121
- 'gap_increase': gap_increase,
122
- 'best_epoch': best_epoch + 1,
123
- 'total_epochs': total_epochs
124
- }
125
-
126
-
127
- def plot_comparative_curves(histories, labels, filename):
128
- """
129
- Plot multiple models' learning curves for comparison.
130
-
131
- Args:
132
- histories: List of history dictionaries
133
- labels: List of model names
134
- filename: Output filename
135
- """
136
- colors = [BLUE_DEEP, ORANGE, "#A3BE8C", "#BF616A", "#B48EAD"]
137
-
138
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
139
-
140
- for i, (history, label) in enumerate(zip(histories, labels)):
141
- epochs = range(1, len(history['val_loss']) + 1)
142
- color = colors[i % len(colors)]
143
-
144
- # Validation loss
145
- ax1.plot(epochs, history['val_loss'], label=label,
146
- color=color, linewidth=2, marker='o', markersize=3)
147
-
148
- # Validation accuracy
149
- ax2.plot(epochs, history['val_acc'], label=label,
150
- color=color, linewidth=2, marker='o', markersize=3)
151
-
152
- ax1.set_xlabel('Epoch', fontsize=12)
153
- ax1.set_ylabel('Validation Loss', fontsize=12)
154
- ax1.set_title('Validation Loss Comparison', fontsize=13, fontweight='bold')
155
- ax1.legend(loc='best', fontsize=10)
156
- ax1.grid(alpha=0.3)
157
-
158
- ax2.set_xlabel('Epoch', fontsize=12)
159
- ax2.set_ylabel('Validation Accuracy (%)', fontsize=12)
160
- ax2.set_title('Validation Accuracy Comparison', fontsize=13, fontweight='bold')
161
- ax2.legend(loc='best', fontsize=10)
162
- ax2.grid(alpha=0.3)
163
-
164
- plt.tight_layout()
165
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
166
- plt.close()
167
- print(f"✓ Saved {filename}")
168
-
169
-
170
- def main():
171
- print("="*60)
172
- print("LEARNING CURVES VISUALIZATION")
173
- print("="*60)
174
-
175
- # Load CNN training history
176
- history_path = os.path.join(config.MODELS_DIR, 'cnn_10class_history.pkl')
177
-
178
- if os.path.exists(history_path):
179
- print(f"\nLoading training history from {history_path}")
180
- with open(history_path, 'rb') as f:
181
- history = pickle.load(f)
182
-
183
- print(f"Loaded {len(history['train_loss'])} epochs of training data")
184
-
185
- # Plot learning curves
186
- plot_learning_curves(
187
- history,
188
- 'CNN Training Dynamics (MNIST 10-class)',
189
- 'fig_14_learning_curves.png'
190
- )
191
-
192
- # Analyze for overfitting
193
- analysis = analyze_overfitting(history)
194
-
195
- else:
196
- print(f"\n⚠️ Training history not found at {history_path}")
197
- print("Please run 'python src/train_models.py' first to generate training history")
198
- return
199
-
200
- # Check for Fashion-MNIST history
201
- fashion_history_path = os.path.join(config.MODELS_DIR, 'cnn_fashion_history.pkl')
202
- if os.path.exists(fashion_history_path):
203
- print(f"\nLoading Fashion-MNIST training history...")
204
- with open(fashion_history_path, 'rb') as f:
205
- fashion_history = pickle.load(f)
206
-
207
- plot_learning_curves(
208
- fashion_history,
209
- 'CNN Training Dynamics (Fashion-MNIST)',
210
- 'fig_15_learning_curves_fashion.png'
211
- )
212
-
213
- # Comparative plot
214
- print("\nGenerating comparative learning curves...")
215
- plot_comparative_curves(
216
- [history, fashion_history],
217
- ['MNIST', 'Fashion-MNIST'],
218
- 'fig_16_learning_curves_comparison.png'
219
- )
220
-
221
- print("\n" + "="*60)
222
- print("✓ Learning curves visualization complete!")
223
- print("="*60)
224
- print(f"Results saved to: {config.RESULTS_DIR}")
225
-
226
-
227
- if __name__ == "__main__":
228
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/12_roc_analysis.py DELETED
@@ -1,291 +0,0 @@
1
- # Exp 12 – ROC Curve Analysis for Hard Classification Pairs
2
- # Particularly focused on the challenging 3 vs 8 classification
3
-
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- import torch
7
- from torchvision import datasets, transforms
8
- from sklearn.decomposition import TruncatedSVD
9
- from sklearn.linear_model import LogisticRegression
10
- from sklearn.metrics import roc_curve, auc, roc_auc_score
11
- import pickle
12
- import os
13
-
14
- from src.hybrid_model import SimpleCNN
15
- from src import config
16
-
17
- # --- Configuration ---
18
- BLUE_DEEP = "#5E81AC"
19
- ORANGE = "#D08770"
20
- GREEN = "#A3BE8C"
21
- RED = "#BF616A"
22
-
23
- SEED = 42
24
-
25
-
26
- def set_seeds(seed):
27
- """Set random seeds for reproducibility."""
28
- np.random.seed(seed)
29
- torch.manual_seed(seed)
30
-
31
-
32
- def load_mnist_subset(digit_a=3, digit_b=8):
33
- """Load MNIST subset with only two specified digits."""
34
- transform = transforms.Compose([transforms.ToTensor()])
35
- testset = datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
36
-
37
- # Filter for specific digits
38
- mask = (testset.targets == digit_a) | (testset.targets == digit_b)
39
- X = testset.data[mask].numpy().astype(np.float32) / 255.0
40
- y = testset.targets[mask].numpy()
41
-
42
- # Binary labels: 0 for digit_a, 1 for digit_b
43
- y_binary = (y == digit_b).astype(int)
44
-
45
- print(f"Loaded {len(X)} samples: {np.sum(y_binary == 0)} digit-{digit_a}, {np.sum(y_binary == 1)} digit-{digit_b}")
46
-
47
- return X, y_binary, digit_a, digit_b
48
-
49
-
50
- def get_svd_probabilities(X, svd_path=config.SVD_MODEL_PATH):
51
- """Get probability scores from SVD+LR model."""
52
- with open(svd_path, 'rb') as f:
53
- svd = pickle.load(f)
54
-
55
- # Get mean from saved model
56
- if hasattr(svd, '_train_mean'):
57
- mean = svd._train_mean
58
- else:
59
- mean = np.zeros(784)
60
-
61
- X_flat = X.reshape(-1, 784)
62
- X_centered = X_flat - mean
63
- X_svd = svd.transform(X_centered)
64
-
65
- # Train a binary classifier on 3 vs 8
66
- print("Training binary SVD+LR classifier for ROC analysis...")
67
- X_train_full, y_train_full = load_mnist_binary_train()
68
- X_train_centered = X_train_full - mean
69
- X_train_svd = svd.transform(X_train_centered)
70
-
71
- clf = LogisticRegression(random_state=SEED, max_iter=1000)
72
- clf.fit(X_train_svd, y_train_full)
73
-
74
- # Get probability scores (for positive class)
75
- probs = clf.predict_proba(X_svd)[:, 1]
76
-
77
- return probs
78
-
79
-
80
- def load_mnist_binary_train(digit_a=3, digit_b=8):
81
- """Load training data for binary classification."""
82
- transform = transforms.Compose([transforms.ToTensor()])
83
- trainset = datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
84
-
85
- mask = (trainset.targets == digit_a) | (trainset.targets == digit_b)
86
- X = trainset.data[mask].numpy().reshape(-1, 784).astype(np.float32) / 255.0
87
- y = trainset.targets[mask].numpy()
88
- y_binary = (y == digit_b).astype(int)
89
-
90
- return X, y_binary
91
-
92
-
93
- def get_cnn_probabilities(X, cnn_path=config.CNN_MODEL_PATH, digit_a=3, digit_b=8):
94
- """Get probability scores from CNN model."""
95
- device = torch.device('cpu')
96
- cnn = SimpleCNN(num_classes=10)
97
- cnn.load_state_dict(torch.load(cnn_path, map_location=device))
98
- cnn.eval()
99
-
100
- X_tensor = torch.tensor(X, dtype=torch.float32).view(-1, 1, 28, 28)
101
-
102
- with torch.no_grad():
103
- outputs = cnn(X_tensor)
104
- probs = torch.softmax(outputs, dim=1).numpy()
105
-
106
- # Extract probabilities for the two digits of interest
107
- prob_a = probs[:, digit_a]
108
- prob_b = probs[:, digit_b]
109
-
110
- # Normalize to binary probability (probability of digit_b given only these two options)
111
- binary_probs = prob_b / (prob_a + prob_b + 1e-10)
112
-
113
- return binary_probs
114
-
115
-
116
- def plot_roc_curves(results, digit_a, digit_b, filename):
117
- """
118
- Plot ROC curves for different models.
119
-
120
- Args:
121
- results: Dictionary with model names as keys and (fpr, tpr, auc) as values
122
- digit_a, digit_b: The two digits being classified
123
- filename: Output filename
124
- """
125
- plt.figure(figsize=(10, 8))
126
-
127
- colors = [BLUE_DEEP, ORANGE, GREEN, RED]
128
-
129
- for i, (model_name, (fpr, tpr, roc_auc)) in enumerate(results.items()):
130
- plt.plot(fpr, tpr, color=colors[i % len(colors)], linewidth=2.5,
131
- label=f'{model_name} (AUC = {roc_auc:.4f})', marker='o', markersize=4, markevery=20)
132
-
133
- # Random classifier baseline
134
- plt.plot([0, 1], [0, 1], color='gray', linestyle='--', linewidth=2, label='Random Classifier (AUC = 0.5000)')
135
-
136
- plt.xlim([0.0, 1.0])
137
- plt.ylim([0.0, 1.05])
138
- plt.xlabel('False Positive Rate', fontsize=13)
139
- plt.ylabel('True Positive Rate', fontsize=13)
140
- plt.title(f'ROC Curves: Digit {digit_a} vs {digit_b} Classification', fontsize=14, fontweight='bold')
141
- plt.legend(loc='lower right', fontsize=11)
142
- plt.grid(alpha=0.3)
143
-
144
- # Highlight perfect classification region
145
- plt.fill_between([0, 0, 0.1], [0.9, 1, 1], alpha=0.1, color='green', label='_nolegend_')
146
- plt.text(0.02, 0.95, 'Ideal Region', fontsize=9, color='green', alpha=0.7)
147
-
148
- plt.tight_layout()
149
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
150
- plt.close()
151
- print(f"✓ Saved {filename}")
152
-
153
-
154
- def plot_roc_zoom(results, digit_a, digit_b, filename):
155
- """Plot zoomed-in ROC curves focusing on high-sensitivity region."""
156
- plt.figure(figsize=(10, 8))
157
-
158
- colors = [BLUE_DEEP, ORANGE, GREEN, RED]
159
-
160
- for i, (model_name, (fpr, tpr, roc_auc)) in enumerate(results.items()):
161
- plt.plot(fpr, tpr, color=colors[i % len(colors)], linewidth=2.5,
162
- label=f'{model_name} (AUC = {roc_auc:.4f})', marker='o', markersize=5, markevery=10)
163
-
164
- plt.plot([0, 1], [0, 1], color='gray', linestyle='--', linewidth=2, alpha=0.5)
165
-
166
- # Zoom to interesting region
167
- plt.xlim([0.0, 0.2])
168
- plt.ylim([0.8, 1.0])
169
- plt.xlabel('False Positive Rate', fontsize=13)
170
- plt.ylabel('True Positive Rate', fontsize=13)
171
- plt.title(f'ROC Curves (Zoomed): Digit {digit_a} vs {digit_b}', fontsize=14, fontweight='bold')
172
- plt.legend(loc='lower right', fontsize=11)
173
- plt.grid(alpha=0.3)
174
-
175
- plt.tight_layout()
176
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
177
- plt.close()
178
- print(f"✓ Saved {filename}")
179
-
180
-
181
- def analyze_threshold_performance(y_true, y_probs, thresholds=[0.3, 0.5, 0.7, 0.9]):
182
- """Analyze model performance at different decision thresholds."""
183
- print("\n" + "="*60)
184
- print("THRESHOLD SENSITIVITY ANALYSIS")
185
- print("="*60)
186
- print(f"{'Threshold':<12} {'Accuracy':<12} {'TPR':<12} {'FPR':<12} {'Precision':<12}")
187
- print("-"*60)
188
-
189
- for threshold in thresholds:
190
- y_pred = (y_probs >= threshold).astype(int)
191
-
192
- # Calculate metrics
193
- tp = np.sum((y_pred == 1) & (y_true == 1))
194
- tn = np.sum((y_pred == 0) & (y_true == 0))
195
- fp = np.sum((y_pred == 1) & (y_true == 0))
196
- fn = np.sum((y_pred == 0) & (y_true == 1))
197
-
198
- accuracy = (tp + tn) / len(y_true)
199
- tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
200
- fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
201
- precision = tp / (tp + fp) if (tp + fp) > 0 else 0
202
-
203
- print(f"{threshold:<12.1f} {accuracy:<12.4f} {tpr:<12.4f} {fpr:<12.4f} {precision:<12.4f}")
204
-
205
-
206
- def main():
207
- set_seeds(SEED)
208
-
209
- print("="*60)
210
- print("ROC CURVE ANALYSIS: Digit 3 vs 8")
211
- print("="*60)
212
-
213
- # Load test data
214
- print("\nLoading test data...")
215
- X_test, y_test, digit_a, digit_b = load_mnist_subset(digit_a=3, digit_b=8)
216
-
217
- results = {}
218
-
219
- # SVD+LR model
220
- print("\n" + "-"*60)
221
- print("Evaluating SVD+LR model...")
222
- print("-"*60)
223
- try:
224
- svd_probs = get_svd_probabilities(X_test)
225
- fpr_svd, tpr_svd, _ = roc_curve(y_test, svd_probs)
226
- auc_svd = auc(fpr_svd, tpr_svd)
227
- results['SVD+LR'] = (fpr_svd, tpr_svd, auc_svd)
228
- print(f"✓ SVD+LR AUC: {auc_svd:.4f}")
229
-
230
- analyze_threshold_performance(y_test, svd_probs)
231
- except Exception as e:
232
- print(f"⚠️ Could not evaluate SVD model: {e}")
233
-
234
- # CNN model
235
- print("\n" + "-"*60)
236
- print("Evaluating CNN model...")
237
- print("-"*60)
238
- try:
239
- cnn_probs = get_cnn_probabilities(X_test, digit_a=digit_a, digit_b=digit_b)
240
- fpr_cnn, tpr_cnn, _ = roc_curve(y_test, cnn_probs)
241
- auc_cnn = auc(fpr_cnn, tpr_cnn)
242
- results['CNN'] = (fpr_cnn, tpr_cnn, auc_cnn)
243
- print(f"✓ CNN AUC: {auc_cnn:.4f}")
244
-
245
- analyze_threshold_performance(y_test, cnn_probs)
246
- except Exception as e:
247
- print(f"⚠️ Could not evaluate CNN model: {e}")
248
-
249
- # Plot results
250
- if len(results) > 0:
251
- print("\n" + "="*60)
252
- print("Generating ROC visualizations...")
253
- print("="*60)
254
-
255
- plot_roc_curves(results, digit_a, digit_b, 'fig_17_roc_curves.png')
256
- plot_roc_zoom(results, digit_a, digit_b, 'fig_18_roc_curves_zoom.png')
257
-
258
- # Summary
259
- print("\n" + "="*60)
260
- print("SUMMARY")
261
- print("="*60)
262
- for model_name, (_, _, roc_auc) in results.items():
263
- print(f"{model_name:15s}: AUC = {roc_auc:.4f}")
264
-
265
- # Interpretation
266
- print("\n" + "="*60)
267
- print("INTERPRETATION")
268
- print("="*60)
269
- print("• AUC = 1.0: Perfect classifier")
270
- print("• AUC = 0.9-1.0: Excellent")
271
- print("• AUC = 0.8-0.9: Good")
272
- print("• AUC = 0.7-0.8: Fair")
273
- print("• AUC = 0.5: Random classifier")
274
-
275
- if 'CNN' in results and 'SVD+LR' in results:
276
- auc_diff = results['CNN'][2] - results['SVD+LR'][2]
277
- print(f"\nCNN advantage over SVD: {auc_diff:.4f} AUC points")
278
- if auc_diff > 0.05:
279
- print("→ CNN shows substantially better discrimination ability")
280
- elif auc_diff > 0.02:
281
- print("→ CNN shows moderately better discrimination")
282
- else:
283
- print("→ Models show similar discrimination ability")
284
-
285
- print("\n✓ ROC analysis complete!")
286
- else:
287
- print("\n⚠️ No models could be evaluated. Please train models first.")
288
-
289
-
290
- if __name__ == "__main__":
291
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/13_per_class_metrics.py DELETED
@@ -1,366 +0,0 @@
1
- # Exp 13 – Per-Class Performance Metrics
2
- # Generate detailed classification report with Precision, Recall, F1-Score for each digit
3
-
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- import pandas as pd
7
- import seaborn as sns
8
- import torch
9
- from torchvision import datasets, transforms
10
- from sklearn.decomposition import TruncatedSVD
11
- from sklearn.linear_model import LogisticRegression
12
- from sklearn.metrics import classification_report, precision_recall_fscore_support, confusion_matrix
13
- import pickle
14
- import os
15
-
16
- from src.hybrid_model import SimpleCNN
17
- from src import config
18
-
19
- # --- Configuration ---
20
- BLUE_DEEP = "#5E81AC"
21
- ORANGE = "#D08770"
22
- GREEN = "#A3BE8C"
23
- RED = "#BF616A"
24
-
25
- SEED = 42
26
-
27
-
28
- def set_seeds(seed):
29
- """Set random seeds for reproducibility."""
30
- np.random.seed(seed)
31
- torch.manual_seed(seed)
32
-
33
-
34
- def load_mnist():
35
- """Load MNIST test data."""
36
- transform = transforms.Compose([transforms.ToTensor()])
37
- testset = datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
38
-
39
- X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
40
- y_test = testset.targets.numpy()
41
-
42
- return X_test, y_test
43
-
44
-
45
- def load_mnist_train():
46
- """Load MNIST training data for SVD."""
47
- transform = transforms.Compose([transforms.ToTensor()])
48
- trainset = datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
49
-
50
- X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
51
- y_train = trainset.targets.numpy()
52
-
53
- return X_train, y_train
54
-
55
-
56
- def evaluate_svd(X_test, y_test, svd_path=config.SVD_MODEL_PATH):
57
- """Evaluate SVD+LR model."""
58
- print("Evaluating SVD+LR model...")
59
-
60
- # Load SVD model
61
- with open(svd_path, 'rb') as f:
62
- svd = pickle.load(f)
63
-
64
- if hasattr(svd, '_train_mean'):
65
- mean = svd._train_mean
66
- else:
67
- mean = np.zeros(784)
68
-
69
- # Transform test data
70
- X_test_centered = X_test - mean
71
- X_test_svd = svd.transform(X_test_centered)
72
-
73
- # Train classifier
74
- print(" Training LogisticRegression classifier...")
75
- X_train, y_train = load_mnist_train()
76
- X_train_centered = X_train - mean
77
- X_train_svd = svd.transform(X_train_centered)
78
-
79
- clf = LogisticRegression(random_state=SEED, max_iter=1000)
80
- clf.fit(X_train_svd, y_train)
81
-
82
- # Predict
83
- y_pred = clf.predict(X_test_svd)
84
-
85
- return y_pred
86
-
87
-
88
- def evaluate_cnn(X_test, y_test, cnn_path=config.CNN_MODEL_PATH):
89
- """Evaluate CNN model."""
90
- print("Evaluating CNN model...")
91
-
92
- device = torch.device('cpu')
93
- cnn = SimpleCNN(num_classes=10)
94
- cnn.load_state_dict(torch.load(cnn_path, map_location=device))
95
- cnn.eval()
96
-
97
- X_tensor = torch.tensor(X_test, dtype=torch.float32).view(-1, 1, 28, 28)
98
-
99
- with torch.no_grad():
100
- outputs = cnn(X_tensor)
101
- y_pred = outputs.argmax(dim=1).numpy()
102
-
103
- return y_pred
104
-
105
-
106
- def create_metrics_table(y_true, y_pred, model_name):
107
- """Create per-class metrics table."""
108
- precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None)
109
-
110
- # Create DataFrame
111
- df = pd.DataFrame({
112
- 'Class': [f'Digit {i}' for i in range(10)],
113
- 'Precision': precision,
114
- 'Recall': recall,
115
- 'F1-Score': f1,
116
- 'Support': support
117
- })
118
-
119
- # Add overall metrics
120
- precision_avg, recall_avg, f1_avg, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
121
-
122
- overall = pd.DataFrame({
123
- 'Class': ['Overall (weighted)'],
124
- 'Precision': [precision_avg],
125
- 'Recall': [recall_avg],
126
- 'F1-Score': [f1_avg],
127
- 'Support': [len(y_true)]
128
- })
129
-
130
- df = pd.concat([df, overall], ignore_index=True)
131
-
132
- return df
133
-
134
-
135
- def plot_metrics_comparison(df_svd, df_cnn, filename):
136
- """Plot side-by-side comparison of metrics."""
137
- fig, axes = plt.subplots(1, 3, figsize=(18, 6))
138
-
139
- metrics = ['Precision', 'Recall', 'F1-Score']
140
- colors = [BLUE_DEEP, ORANGE]
141
-
142
- # Exclude overall row for bar chart
143
- df_svd_plot = df_svd.iloc[:-1]
144
- df_cnn_plot = df_cnn.iloc[:-1]
145
-
146
- x = np.arange(10)
147
- width = 0.35
148
-
149
- for i, metric in enumerate(metrics):
150
- ax = axes[i]
151
-
152
- svd_values = df_svd_plot[metric].values
153
- cnn_values = df_cnn_plot[metric].values
154
-
155
- ax.bar(x - width/2, svd_values, width, label='SVD+LR', color=colors[0], alpha=0.8)
156
- ax.bar(x + width/2, cnn_values, width, label='CNN', color=colors[1], alpha=0.8)
157
-
158
- ax.set_xlabel('Digit Class', fontsize=12)
159
- ax.set_ylabel(metric, fontsize=12)
160
- ax.set_title(f'{metric} by Digit', fontsize=13, fontweight='bold')
161
- ax.set_xticks(x)
162
- ax.set_xticklabels([str(i) for i in range(10)])
163
- ax.legend()
164
- ax.grid(axis='y', alpha=0.3)
165
- ax.set_ylim([0.7, 1.0])
166
-
167
- plt.tight_layout()
168
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
169
- plt.close()
170
- print(f"✓ Saved {filename}")
171
-
172
-
173
- def plot_metrics_heatmap(df_svd, df_cnn, filename):
174
- """Create heatmap showing per-class metrics."""
175
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))
176
-
177
- # Prepare data (exclude overall row)
178
- svd_data = df_svd.iloc[:-1][['Precision', 'Recall', 'F1-Score']].T
179
- cnn_data = df_cnn.iloc[:-1][['Precision', 'Recall', 'F1-Score']].T
180
-
181
- svd_data.columns = [str(i) for i in range(10)]
182
- cnn_data.columns = [str(i) for i in range(10)]
183
-
184
- # SVD heatmap
185
- sns.heatmap(svd_data, annot=True, fmt='.3f', cmap='Blues',
186
- vmin=0.7, vmax=1.0, ax=ax1, cbar_kws={'label': 'Score'})
187
- ax1.set_title('SVD+LR Per-Class Metrics', fontsize=13, fontweight='bold')
188
- ax1.set_xlabel('Digit Class', fontsize=12)
189
- ax1.set_ylabel('Metric', fontsize=12)
190
-
191
- # CNN heatmap
192
- sns.heatmap(cnn_data, annot=True, fmt='.3f', cmap='Oranges',
193
- vmin=0.7, vmax=1.0, ax=ax2, cbar_kws={'label': 'Score'})
194
- ax2.set_title('CNN Per-Class Metrics', fontsize=13, fontweight='bold')
195
- ax2.set_xlabel('Digit Class', fontsize=12)
196
- ax2.set_ylabel('Metric', fontsize=12)
197
-
198
- plt.tight_layout()
199
- plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
200
- plt.close()
201
- print(f"✓ Saved {filename}")
202
-
203
-
204
- def identify_hard_pairs(y_true, y_pred_svd, y_pred_cnn):
205
- """Identify digit pairs that are frequently confused."""
206
- print("\n" + "="*60)
207
- print("HARD CLASSIFICATION PAIRS")
208
- print("="*60)
209
-
210
- cm_svd = confusion_matrix(y_true, y_pred_svd, normalize='true')
211
- cm_cnn = confusion_matrix(y_true, y_pred_cnn, normalize='true')
212
-
213
- # Find top confusions (excluding diagonal)
214
- np.fill_diagonal(cm_svd, 0)
215
- np.fill_diagonal(cm_cnn, 0)
216
-
217
- print("\nTop 5 SVD+LR Confusions:")
218
- svd_confusions = []
219
- for i in range(10):
220
- for j in range(10):
221
- if i != j and cm_svd[i, j] > 0.01:
222
- svd_confusions.append((i, j, cm_svd[i, j]))
223
-
224
- svd_confusions.sort(key=lambda x: x[2], reverse=True)
225
- for i, (true_class, pred_class, rate) in enumerate(svd_confusions[:5], 1):
226
- print(f" {i}. {true_class} → {pred_class}: {rate*100:.2f}%")
227
-
228
- print("\nTop 5 CNN Confusions:")
229
- cnn_confusions = []
230
- for i in range(10):
231
- for j in range(10):
232
- if i != j and cm_cnn[i, j] > 0.01:
233
- cnn_confusions.append((i, j, cm_cnn[i, j]))
234
-
235
- cnn_confusions.sort(key=lambda x: x[2], reverse=True)
236
- for i, (true_class, pred_class, rate) in enumerate(cnn_confusions[:5], 1):
237
- print(f" {i}. {true_class} → {pred_class}: {rate*100:.2f}%")
238
-
239
- # Compare improvements
240
- print("\n" + "="*60)
241
- print("CNN IMPROVEMENTS OVER SVD+LR")
242
- print("="*60)
243
-
244
- improvements = []
245
- for i in range(10):
246
- for j in range(10):
247
- if i != j:
248
- improvement = cm_svd[i, j] - cm_cnn[i, j]
249
- if improvement > 0.01: # More than 1% improvement
250
- improvements.append((i, j, improvement))
251
-
252
- improvements.sort(key=lambda x: x[2], reverse=True)
253
-
254
- print("Top pairs where CNN reduced confusion:")
255
- for i, (true_class, pred_class, improvement) in enumerate(improvements[:5], 1):
256
- svd_rate = cm_svd[true_class, pred_class]
257
- cnn_rate = cm_cnn[true_class, pred_class]
258
- print(f" {i}. {true_class} → {pred_class}: {svd_rate*100:.2f}% → {cnn_rate*100:.2f}% "
259
- f"(Δ = -{improvement*100:.2f} pp)")
260
-
261
-
262
- def save_reports_to_csv(df_svd, df_cnn):
263
- """Save detailed reports to CSV files."""
264
- svd_path = os.path.join(config.RESULTS_DIR, 'per_class_metrics_svd.csv')
265
- cnn_path = os.path.join(config.RESULTS_DIR, 'per_class_metrics_cnn.csv')
266
-
267
- df_svd.to_csv(svd_path, index=False, float_format='%.4f')
268
- df_cnn.to_csv(cnn_path, index=False, float_format='%.4f')
269
-
270
- print(f"\n✓ Saved detailed metrics to:")
271
- print(f" - {svd_path}")
272
- print(f" - {cnn_path}")
273
-
274
-
275
- def main():
276
- set_seeds(SEED)
277
-
278
- print("="*60)
279
- print("PER-CLASS PERFORMANCE METRICS ANALYSIS")
280
- print("="*60)
281
-
282
- # Load test data
283
- print("\nLoading MNIST test data...")
284
- X_test, y_test = load_mnist()
285
- print(f"Loaded {len(X_test)} test samples")
286
-
287
- # Evaluate models
288
- try:
289
- print("\n" + "-"*60)
290
- y_pred_svd = evaluate_svd(X_test, y_test)
291
- print("✓ SVD+LR evaluation complete")
292
- except Exception as e:
293
- print(f"⚠️ Could not evaluate SVD model: {e}")
294
- return
295
-
296
- try:
297
- print("\n" + "-"*60)
298
- y_pred_cnn = evaluate_cnn(X_test, y_test)
299
- print("✓ CNN evaluation complete")
300
- except Exception as e:
301
- print(f"⚠️ Could not evaluate CNN model: {e}")
302
- return
303
-
304
- # Generate metrics tables
305
- print("\n" + "="*60)
306
- print("GENERATING METRICS TABLES")
307
- print("="*60)
308
-
309
- df_svd = create_metrics_table(y_test, y_pred_svd, 'SVD+LR')
310
- df_cnn = create_metrics_table(y_test, y_pred_cnn, 'CNN')
311
-
312
- # Display tables
313
- print("\n" + "-"*60)
314
- print("SVD+LR Per-Class Metrics")
315
- print("-"*60)
316
- print(df_svd.to_string(index=False, float_format='%.4f'))
317
-
318
- print("\n" + "-"*60)
319
- print("CNN Per-Class Metrics")
320
- print("-"*60)
321
- print(df_cnn.to_string(index=False, float_format='%.4f'))
322
-
323
- # Identify hard pairs
324
- identify_hard_pairs(y_test, y_pred_svd, y_pred_cnn)
325
-
326
- # Generate visualizations
327
- print("\n" + "="*60)
328
- print("GENERATING VISUALIZATIONS")
329
- print("="*60)
330
-
331
- plot_metrics_comparison(df_svd, df_cnn, 'fig_19_per_class_metrics_comparison.png')
332
- plot_metrics_heatmap(df_svd, df_cnn, 'fig_20_per_class_metrics_heatmap.png')
333
-
334
- # Save to CSV
335
- save_reports_to_csv(df_svd, df_cnn)
336
-
337
- # Summary statistics
338
- print("\n" + "="*60)
339
- print("SUMMARY STATISTICS")
340
- print("="*60)
341
-
342
- svd_overall = df_svd.iloc[-1]
343
- cnn_overall = df_cnn.iloc[-1]
344
-
345
- print(f"\nSVD+LR Overall:")
346
- print(f" Precision: {svd_overall['Precision']:.4f}")
347
- print(f" Recall: {svd_overall['Recall']:.4f}")
348
- print(f" F1-Score: {svd_overall['F1-Score']:.4f}")
349
-
350
- print(f"\nCNN Overall:")
351
- print(f" Precision: {cnn_overall['Precision']:.4f}")
352
- print(f" Recall: {cnn_overall['Recall']:.4f}")
353
- print(f" F1-Score: {cnn_overall['F1-Score']:.4f}")
354
-
355
- print(f"\nCNN Improvement:")
356
- print(f" Precision: +{(cnn_overall['Precision'] - svd_overall['Precision'])*100:.2f} pp")
357
- print(f" Recall: +{(cnn_overall['Recall'] - svd_overall['Recall'])*100:.2f} pp")
358
- print(f" F1-Score: +{(cnn_overall['F1-Score'] - svd_overall['F1-Score'])*100:.2f} pp")
359
-
360
- print("\n" + "="*60)
361
- print("✓ Per-class metrics analysis complete!")
362
- print("="*60)
363
-
364
-
365
- if __name__ == "__main__":
366
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/appendix_learning_curves.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Appendix A – Learning Curves
3
+ Refactored to use centralized viz utilities.
4
+ """
5
+
6
+ import pickle
7
+ import os
8
+ from src import config, viz
9
+
10
+ def main():
11
+ experiments = [
12
+ ('cnn_10class_history.pkl', 'MNIST 10-class CNN Training', 'fig_14_learning_curves.png'),
13
+ ('cnn_fashion_history.pkl', 'Fashion-MNIST CNN Training', 'fig_15_learning_curves_fashion.png')
14
+ ]
15
+
16
+ for f_name, label, out_name in experiments:
17
+ path = os.path.join(config.MODELS_DIR, f_name)
18
+ if os.path.exists(path):
19
+ with open(path, 'rb') as f:
20
+ history = pickle.load(f)
21
+ viz.plot_learning_curves(history, label, out_name)
22
+ else:
23
+ print(f"Skipping {f_name}: Not found at {path}.")
24
+
25
+ if __name__ == "__main__":
26
+ main()
experiments/appendix_per_class_metrics.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Appendix B – Per-Class Performance Metrics (MNIST)
3
+ Refactored to use centralized utility modules.
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from src import utils, viz, exp_utils
9
+
10
+ def main():
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ print("Loading Models and Test Data...")
13
+
14
+ # Load Models (MNIST default)
15
+ svd_pipe, cnn = utils.load_models(dataset_name="mnist")
16
+ if svd_pipe is None or cnn is None:
17
+ return
18
+
19
+ X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False)
20
+ X_test_flat = X_test.view(X_test.size(0), -1).numpy()
21
+ y_test_np = y_test.numpy()
22
+
23
+ # 1. Collect Predictions
24
+ print("Collecting Predictions...")
25
+ y_preds_dict = {}
26
+
27
+ # CNN Predictions
28
+ cnn.eval()
29
+ with torch.no_grad():
30
+ y_preds_dict['CNN'] = cnn(X_test.to(device)).argmax(dim=1).cpu().numpy()
31
+
32
+ # SVD+LR Predictions
33
+ y_preds_dict['SVD+LR'] = svd_pipe.predict(X_test_flat)
34
+
35
+ # 2. Print Metrics Report
36
+ from sklearn.metrics import recall_score, precision_score, f1_score
37
+ for name, y_pred in y_preds_dict.items():
38
+ print(f"\n--- {name} Report (Average Metrics) ---")
39
+ p = precision_score(y_test_np, y_pred, average='macro')
40
+ r = recall_score(y_test_np, y_pred, average='macro')
41
+ f = f1_score(y_test_np, y_pred, average='macro')
42
+ print(f"Macro Average: Precision={p:.3f}, Recall={r:.3f}, F1={f:.3f}")
43
+
44
+ # 3. Visualization: Per-Class F1 Comparison
45
+ viz.plot_per_class_comparison(
46
+ y_test_np,
47
+ y_preds_dict,
48
+ 'fig_19_per_class_metrics_comparison.png'
49
+ )
50
+ print("Appendix B Completed.")
51
+
52
+ if __name__ == "__main__":
53
+ main()
experiments/run_robustness_test.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Robustness Test Script
3
+ Evaluates CNN, SVD, and Hybrid model performance under Gaussian noise.
4
+ Refactored to use centralized src utilities.
5
+ """
6
+
7
+ import argparse
8
+ import torch
9
+ import numpy as np
10
+ from src import config, utils, viz, exp_utils
11
+ from src.hybrid_model import HybridSVDCNN, SVDProjectionLayer
12
+
13
+ def run_experiment(args):
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ print(f"\n--- Running Robustness Test: {args.dataset.upper()} ---")
16
+
17
+ # 1. Load Data and Models
18
+ X_test, y_test = utils.load_data_split(dataset_name=args.dataset, train=False)
19
+ _, cnn = utils.load_models(dataset_name=args.dataset)
20
+
21
+ if cnn is None:
22
+ return
23
+
24
+ # 2. Fit SVD Baseline and Build Hybrid Model
25
+ print("Fitting SVD Baseline...")
26
+ X_test_flat = X_test.view(X_test.size(0), -1).numpy()
27
+ svd_pipe = exp_utils.fit_svd_baseline(X_test_flat, y_test.numpy(), n_components=20)
28
+
29
+ svd = svd_pipe.named_steps['svd']
30
+ scaler = svd_pipe.named_steps['scaler']
31
+ # Hybrid model expects mean from scaler if available
32
+ svd_layer = SVDProjectionLayer(svd.components_, scaler.mean_)
33
+ hybrid = HybridSVDCNN(svd_layer, cnn).to(device)
34
+
35
+ # 3. Define Noise Levels
36
+ sigmas = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
37
+ results = {'CNN': [], 'SVD': [], 'Hybrid': []}
38
+
39
+ # 4. Evaluation Loop
40
+ for sigma in sigmas:
41
+ X_noisy = exp_utils.add_gaussian_noise(X_test, sigma)
42
+
43
+ results['CNN'].append(exp_utils.evaluate_classifier(cnn, X_noisy, y_test, device))
44
+ results['SVD'].append(exp_utils.evaluate_classifier(svd_pipe, X_noisy, y_test, is_pytorch=False))
45
+ results['Hybrid'].append(exp_utils.evaluate_classifier(hybrid, X_noisy, y_test, device))
46
+
47
+ print(f"σ={sigma:.1f} | CNN: {results['CNN'][-1]:.4f} | SVD: {results['SVD'][-1]:.4f} | Hybrid: {results['Hybrid'][-1]:.4f}")
48
+
49
+ # 5. Visualization
50
+ viz.plot_robustness_curves(
51
+ x_values=sigmas,
52
+ results_dict=results,
53
+ x_label='Gaussian Noise Level (σ)',
54
+ title=f'Robustness Analysis: {args.dataset.upper()}',
55
+ filename=f'fig_robustness_{args.dataset}.png'
56
+ )
57
+
58
+ def main():
59
+ parser = argparse.ArgumentParser(description="Unified Robustness Evaluation")
60
+ parser.add_argument("--dataset", choices=["mnist", "fashion"], default="mnist", help="Dataset to evaluate.")
61
+ args = parser.parse_args()
62
+ run_experiment(args)
63
+
64
+ if __name__ == "__main__":
65
+ main()
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # --- Paths ---
4
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
5
+ DATA_DIR = os.path.join(BASE_DIR, "data")
6
+ MODELS_DIR = os.path.join(BASE_DIR, "models")
7
+ RESULTS_DIR = os.path.join(BASE_DIR, "docs", "research_results")
8
+
9
+ for d in [DATA_DIR, MODELS_DIR, RESULTS_DIR]:
10
+ os.makedirs(d, exist_ok=True)
11
+
12
+ SVD_MODEL_PATH = os.path.join(MODELS_DIR, "svd_10class.pkl")
13
+ CNN_MODEL_PATH = os.path.join(MODELS_DIR, "cnn_10class.pth")
14
+ FASHION_SVD_PATH = os.path.join(MODELS_DIR, "svd_fashion.pkl") # Placeholder if not exists
15
+ FASHION_CNN_PATH = os.path.join(MODELS_DIR, "cnn_fashion.pth")
16
+ APP_CACHE_PATH = os.path.join(DATA_DIR, "app_cache.npz")
17
+
src/exp_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torchvision.transforms as transforms
4
+ from sklearn.metrics import accuracy_score
5
+ from sklearn.decomposition import TruncatedSVD
6
+ from sklearn.linear_model import LogisticRegression
7
+ from sklearn.pipeline import Pipeline
8
+ from sklearn.preprocessing import StandardScaler
9
+
10
+ def fit_svd_baseline(X_train, y_train, n_components=20):
11
+ """Fits a linear baseline (SVD + Logistic Regression) on the fly."""
12
+ pipeline = Pipeline([
13
+ ('scaler', StandardScaler()),
14
+ ('svd', TruncatedSVD(n_components=n_components, random_state=42)),
15
+ ('logistic', LogisticRegression(max_iter=1000))
16
+ ])
17
+ pipeline.fit(X_train, y_train)
18
+ return pipeline
19
+
20
+ def add_gaussian_noise(X, sigma):
21
+ """
22
+ Uniform noise addition for both torch Tensors and numpy arrays.
23
+ Returns the same type as input.
24
+ """
25
+ if sigma <= 0: return X
26
+ if torch.is_tensor(X):
27
+ noise = torch.randn_like(X) * sigma
28
+ return torch.clamp(X + noise, 0, 1)
29
+ else:
30
+ noise = np.random.randn(*X.shape) * sigma
31
+ return np.clip(X + noise, 0, 1)
32
+
33
+ def add_blur(X, kernel_size):
34
+ """Unified blur for torch Tensors (4D: B, C, H, W)."""
35
+ if kernel_size <= 1:
36
+ return X
37
+ sigma = 0.1 + 0.3 * (kernel_size // 2)
38
+ blur_fn = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma))
39
+ return blur_fn(X)
40
+
41
+ def evaluate_classifier(model, X, y, device="cpu", is_pytorch=True):
42
+ """
43
+ Unified evaluation function.
44
+ Handles PyTorch models (CNN, Hybrid) and Sklearn pipelines (SVD+LR).
45
+ """
46
+ if is_pytorch:
47
+ model.eval()
48
+ model.to(device)
49
+ # Ensure X is 4D for CNN (B, 1, 28, 28)
50
+ if len(X.shape) == 2:
51
+ X_t = torch.as_tensor(X.reshape(-1, 1, 28, 28), dtype=torch.float32).to(device)
52
+ else:
53
+ X_t = torch.as_tensor(X, dtype=torch.float32).to(device)
54
+
55
+ y_t = torch.as_tensor(y, dtype=torch.long).to(device)
56
+
57
+ with torch.no_grad():
58
+ logits = model(X_t)
59
+ preds = torch.argmax(logits, dim=1).cpu().numpy()
60
+ return accuracy_score(y, preds)
61
+ else:
62
+ # Sklearn pipeline - Ensure X is flattened 2D numpy
63
+ if torch.is_tensor(X):
64
+ X_np = X.view(X.size(0), -1).cpu().numpy()
65
+ else:
66
+ X_np = X.reshape(X.shape[0], -1)
67
+ preds = model.predict(X_np)
68
+ return accuracy_score(y, preds)
src/hybrid_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hybrid SVD-CNN Model
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from sklearn.decomposition import TruncatedSVD
7
+
8
+ class SimpleCNN(nn.Module):
9
+ def __init__(self, num_classes=10):
10
+ super().__init__()
11
+ self.features = nn.Sequential(
12
+ nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
13
+ nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
14
+ )
15
+ self.classifier = nn.Sequential(
16
+ nn.Linear(32 * 7 * 7, 128), nn.ReLU(),
17
+ nn.Linear(128, num_classes)
18
+ )
19
+
20
+ def forward(self, x):
21
+ return self.classifier(self.features(x).view(x.size(0), -1))
22
+
23
+ class SVDProjectionLayer(nn.Module):
24
+ def __init__(self, V_k, mean=None):
25
+ super().__init__()
26
+ self.register_buffer('V_k', torch.tensor(V_k, dtype=torch.float32))
27
+ self.register_buffer('mean', torch.tensor(mean, dtype=torch.float32) if mean is not None else torch.zeros(V_k.shape[1]))
28
+
29
+ def forward(self, x):
30
+ b = x.size(0)
31
+ x_rec = (x.view(b, -1) - self.mean) @ self.V_k.T @ self.V_k + self.mean
32
+ return torch.clamp(x_rec, 0, 1).view(b, 1, 28, 28)
33
+
34
+ class HybridSVDCNN(nn.Module):
35
+ def __init__(self, svd_layer, cnn):
36
+ super().__init__()
37
+ self.svd_layer, self.cnn = svd_layer, cnn
38
+
39
+ def forward(self, x):
40
+ return self.cnn(self.svd_layer(x))
41
+
42
+ def create_svd_layer(X_train, n_components=20):
43
+ mean = np.mean(X_train, axis=0)
44
+ svd = TruncatedSVD(n_components=n_components, random_state=42).fit(X_train - mean)
45
+ return SVDProjectionLayer(svd.components_, mean)
src/train_fashion.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+ from src.hybrid_model import SimpleCNN
5
+ from src import config
6
+ from src.train_models import train_cnn, set_seed
7
+
8
+ if __name__ == "__main__":
9
+ set_seed()
10
+ print("Loading Fashion-MNIST for training...")
11
+ transform = transforms.Compose([transforms.ToTensor()])
12
+ train_dataset = torchvision.datasets.FashionMNIST(root=config.DATA_DIR, train=True, download=True, transform=transform)
13
+
14
+ # Extract data to tensors for train_cnn
15
+ X = train_dataset.data.float() / 255.0
16
+ y = train_dataset.targets
17
+
18
+ # Temporarily override CNN_MODEL_PATH for fashion
19
+ original_path = config.CNN_MODEL_PATH
20
+ config.CNN_MODEL_PATH = config.CNN_FASHION_MODEL_PATH
21
+
22
+ print(f"Retraining Fashion-MNIST model to {config.CNN_FASHION_MODEL_PATH}...")
23
+ train_cnn(X, y)
24
+
25
+ # Restore (optional but good practice)
26
+ config.CNN_MODEL_PATH = original_path
27
+ print("Fashion-MNIST training completed.")
src/train_models.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import TensorDataset, DataLoader
5
+ from sklearn.decomposition import TruncatedSVD
6
+ from sklearn.model_selection import train_test_split
7
+ import pickle
8
+ import os
9
+ import numpy as np
10
+ import random
11
+ from src.hybrid_model import SimpleCNN
12
+ from src.utils import load_data_split
13
+ from src import config
14
+
15
+ def set_seed(seed=42):
16
+ random.seed(seed); np.random.seed(seed)
17
+ torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
18
+ torch.backends.cudnn.deterministic = True
19
+
20
+ def train_svd(X_flat, n_components=20):
21
+ print(f"Training SVD (k={n_components})...")
22
+ X_np = X_flat.numpy()
23
+ mean = X_np.mean(axis=0)
24
+ svd = TruncatedSVD(n_components=n_components, random_state=42).fit(X_np - mean)
25
+ svd._train_mean = mean
26
+ with open(config.SVD_MODEL_PATH, "wb") as f: pickle.dump(svd, f)
27
+ return svd
28
+
29
+ def train_cnn(X_flat, y, batch_size=64, epochs=5):
30
+ X_train, X_val, y_train, y_val = train_test_split(X_flat.numpy(), y.numpy(), test_size=0.2, random_state=42, stratify=y.numpy())
31
+
32
+ def to_loader(X, y, shuffle=True):
33
+ return DataLoader(TensorDataset(torch.tensor(X).view(-1, 1, 28, 28), torch.tensor(y, dtype=torch.long)), batch_size=batch_size, shuffle=shuffle)
34
+
35
+ train_loader, val_loader = to_loader(X_train, y_train), to_loader(X_val, y_val, False)
36
+ model = SimpleCNN().to("cuda" if torch.cuda.is_available() else "cpu")
37
+ opt = optim.Adam(model.parameters(), lr=0.001)
38
+ crit = nn.CrossEntropyLoss()
39
+
40
+ history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
41
+ best_acc, best_state = 0, None
42
+
43
+ for epoch in range(epochs):
44
+ model.train()
45
+ t_loss, t_corr = 0, 0
46
+ for x, labels in train_loader:
47
+ x, labels = x.to(next(model.parameters()).device), labels.to(next(model.parameters()).device)
48
+ opt.zero_grad(); out = model(x); loss = crit(out, labels); loss.backward(); opt.step()
49
+ t_loss += loss.item(); t_corr += (out.argmax(1) == labels).sum().item()
50
+
51
+ model.eval(); v_loss, v_corr = 0, 0
52
+ with torch.no_grad():
53
+ for x, labels in val_loader:
54
+ x, labels = x.to(next(model.parameters()).device), labels.to(next(model.parameters()).device)
55
+ out = model(x); v_loss += crit(out, labels).item(); v_corr += (out.argmax(1) == labels).sum().item()
56
+
57
+ history['train_acc'].append(100 * t_corr / len(X_train)); history['val_acc'].append(100 * v_corr / len(X_val))
58
+ print(f"Epoch {epoch+1}: Train Acc {history['train_acc'][-1]:.2f}%, Val Acc {history['val_acc'][-1]:.2f}%")
59
+
60
+ if history['val_acc'][-1] > best_acc:
61
+ best_acc, best_state = history['val_acc'][-1], model.state_dict().copy()
62
+
63
+ model.load_state_dict(best_state)
64
+ torch.save(model.cpu().state_dict(), config.CNN_MODEL_PATH)
65
+ with open(config.CNN_MODEL_PATH.replace('.pth', '_history.pkl'), 'wb') as f: pickle.dump(history, f)
66
+ return model, history
67
+
68
+ if __name__ == "__main__":
69
+ set_seed()
70
+ X, y = load_data_split()
71
+ train_svd(X.view(-1, 784))
72
+ train_cnn(X.view(-1, 784), y)
src/utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ import torchvision
4
+ import numpy as np
5
+ import os
6
+ import pickle
7
+ import ssl
8
+ from src.hybrid_model import SimpleCNN
9
+ from src import config
10
+ import cv2
11
+
12
+ def load_data_split(dataset_name="mnist", train=True, digits=None, flatten=False):
13
+ """
14
+ Unified entry point for data loading.
15
+ Supports: MNIST, Fashion-MNIST, and custom digit filtering (e.g., [3, 8]).
16
+ """
17
+ # Bypass SSL verification issues for dataset downloads
18
+ ssl._create_default_https_context = ssl._create_unverified_context
19
+
20
+ transform = T.Compose([T.ToTensor()])
21
+
22
+ if dataset_name.lower() == "mnist":
23
+ dataset = torchvision.datasets.MNIST(config.DATA_DIR, train=train, download=True, transform=transform)
24
+ elif dataset_name.lower() == "fashion":
25
+ dataset = torchvision.datasets.FashionMNIST(config.DATA_DIR, train=train, download=True, transform=transform)
26
+ else:
27
+ raise ValueError(f"Unknown dataset: {dataset_name}")
28
+
29
+ X = dataset.data.float() / 255.0
30
+ y = dataset.targets
31
+
32
+ # Filter for specific digits if requested (e.g., [3, 8] for binary analysis)
33
+ if digits is not None:
34
+ mask = torch.zeros(len(y), dtype=torch.bool)
35
+ for d in digits:
36
+ mask |= (y == d)
37
+ X = X[mask]
38
+ y = y[mask]
39
+
40
+ # Remap labels to 0, 1... for binary tasks
41
+ if len(digits) == 2:
42
+ y = torch.where(y == digits[0], torch.tensor(0), torch.tensor(1))
43
+
44
+ # Add channel dimension if not flattened (B, 1, 28, 28)
45
+ if not flatten:
46
+ X = X.unsqueeze(1)
47
+ else:
48
+ X = X.view(X.size(0), -1)
49
+
50
+ return X, y
51
+
52
+ def load_models(dataset_name="mnist"):
53
+ """
54
+ Loads pre-trained SVD transformer and CNN model for a specific dataset.
55
+ Returns (svd, cnn). Either can be None if the file is missing.
56
+ """
57
+ svd_path = config.SVD_MODEL_PATH if dataset_name == "mnist" else config.FASHION_SVD_PATH
58
+ cnn_path = config.CNN_MODEL_PATH if dataset_name == "mnist" else config.FASHION_CNN_PATH
59
+
60
+ svd, cnn = None, None
61
+
62
+ if os.path.exists(svd_path):
63
+ with open(svd_path, "rb") as f:
64
+ svd = pickle.load(f)
65
+ else:
66
+ print(f"Note: SVD model for {dataset_name} not found at {svd_path}")
67
+
68
+ if os.path.exists(cnn_path):
69
+ cnn = SimpleCNN()
70
+ cnn.load_state_dict(torch.load(cnn_path, map_location="cpu"))
71
+ cnn.eval()
72
+ else:
73
+ print(f"Note: CNN model for {dataset_name} not found at {cnn_path}")
74
+
75
+ return svd, cnn
76
+
77
+ # --- Backward Compatibility Aliases ---
78
+ load_data = load_data_split
79
+
80
+ def preprocess_digit(img):
81
+ """
82
+ Original preprocessing logic used by the Streamlit app.
83
+ Crops, resizes (20x20), and pads to 28x28.
84
+ """
85
+ if isinstance(img, torch.Tensor):
86
+ img = img.numpy().astype(np.uint8)
87
+
88
+ # 1. Threshold & Find Bounding Box
89
+ _, thresh = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
90
+ coords = cv2.findNonZero(thresh)
91
+ if coords is None:
92
+ return torch.zeros((28, 28))
93
+ x, y, w, h = cv2.boundingRect(coords)
94
+ img_crop = img[y:y+h, x:x+w]
95
+
96
+ # 2. Resize to fit 20px
97
+ if w > h:
98
+ new_w = 20
99
+ new_h = int(h * (20 / w))
100
+ else:
101
+ new_h = 20
102
+ new_w = int(w * (20 / h))
103
+
104
+ if new_w == 0 or new_h == 0:
105
+ return torch.zeros((28, 28))
106
+
107
+ img_resize = cv2.resize(img_crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
108
+
109
+ # 3. Center in 28x28
110
+ final_img = np.zeros((28, 28), dtype=np.uint8)
111
+ pad_y = (28 - new_h) // 2
112
+ pad_x = (28 - new_w) // 2
113
+ final_img[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = img_resize
114
+
115
+ # 4. Normalize
116
+ return torch.tensor(final_img).float() / 255.0
src/viz.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ import numpy as np
4
+ import os
5
+ from matplotlib.colors import LinearSegmentedColormap
6
+ from src import config
7
+
8
+ # --- Nord Palette Colors ---
9
+ COLOR_SVD = "#5E81AC" # Nord 10 (Blue)
10
+ COLOR_CNN = "#BF616A" # Nord 11 (Red)
11
+ COLOR_HYBRID = "#A3BE8C" # Nord 14 (Green)
12
+ COLOR_TEXT = "#2E3440" # Nord 0 (Dark)
13
+ COLOR_GRID = "#D8DEE9" # Nord 4
14
+
15
+ def setup_style():
16
+ """Standardize matplotlib plots."""
17
+ plt.rcParams['font.family'] = 'sans-serif'
18
+ plt.rcParams['axes.edgecolor'] = COLOR_GRID
19
+ plt.rcParams['grid.alpha'] = 0.3
20
+ plt.rcParams['axes.labelcolor'] = COLOR_TEXT
21
+
22
+ def save_fig(filename, dpi=300):
23
+ """Save plot to results directory."""
24
+ path = os.path.join(config.RESULTS_DIR, filename)
25
+ plt.tight_layout()
26
+ plt.savefig(path, dpi=dpi)
27
+ plt.close()
28
+ print(f"Figure saved to {path}")
29
+
30
+ def plot_robustness_curves(x_values, results_dict, x_label, title, filename):
31
+ """Standardized robustness curve plotter."""
32
+ setup_style()
33
+ plt.figure(figsize=(10, 6))
34
+ colors = {'CNN': COLOR_CNN, 'SVD': COLOR_SVD, 'Hybrid': COLOR_HYBRID}
35
+
36
+ for label, accs in results_dict.items():
37
+ plt.plot(x_values, accs, label=label, marker='o',
38
+ color=colors.get(label, '#4C566A'), linewidth=2)
39
+
40
+ plt.title(title, fontsize=14, fontweight='bold', pad=15)
41
+ plt.xlabel(x_label, fontsize=12)
42
+ plt.ylabel('Accuracy', fontsize=12)
43
+ plt.legend(frameon=True, facecolor='white', framealpha=0.8)
44
+ plt.grid(True)
45
+ save_fig(filename)
46
+
47
+ def plot_confusion_matrix(y_true, y_pred, labels, filename, title, color_end=COLOR_SVD):
48
+ """Normalized confusion matrix with Nord-consistent coloring."""
49
+ from sklearn.metrics import confusion_matrix
50
+ setup_style()
51
+ cm = confusion_matrix(y_true, y_pred, normalize='true')
52
+ plt.figure(figsize=(10, 8))
53
+
54
+ # Custom cmap from Light Gray to Nord Color
55
+ cmap = LinearSegmentedColormap.from_list("NordCustom", ["#ECEFF4", color_end])
56
+
57
+ sns.heatmap(cm, annot=True, fmt='.1%', cmap=cmap, xticklabels=labels, yticklabels=labels)
58
+ plt.title(title, fontsize=14, fontweight='bold', pad=15)
59
+ plt.xlabel('Predicted', fontsize=12)
60
+ plt.ylabel('True', fontsize=12)
61
+ save_fig(filename)
62
+
63
+ def plot_singular_spectrum(singular_values, cumulative_variance, filename):
64
+ """Visualizes singular values and explained variance."""
65
+ setup_style()
66
+ fig, ax1 = plt.subplots(figsize=(10, 6))
67
+
68
+ n = len(singular_values)
69
+ ax1.semilogy(range(1, n+1), singular_values, color=COLOR_SVD, label='Singular Values', linewidth=2)
70
+ ax1.set_xlabel('Principal Component (k)', fontsize=12)
71
+ ax1.set_ylabel('Singular Value (Log)', color=COLOR_SVD, fontsize=12)
72
+ ax1.tick_params(axis='y', labelcolor=COLOR_SVD)
73
+
74
+ ax2 = ax1.twinx()
75
+ ax2.plot(range(1, n+1), cumulative_variance, color=COLOR_CNN, linestyle='--', label='Cum. Var', linewidth=2)
76
+ ax2.set_ylabel('Cumulative Explained Variance', color=COLOR_CNN, fontsize=12)
77
+ ax2.tick_params(axis='y', labelcolor=COLOR_CNN)
78
+ ax2.set_ylim(0, 1.05)
79
+
80
+ plt.title('Singular Value Spectrum & Explained Variance', fontsize=14, fontweight='bold', pad=15)
81
+ fig.legend(loc="upper right", bbox_to_anchor=(1,1), bbox_transform=ax1.transAxes)
82
+ save_fig(filename)
83
+
84
+ def plot_interpolation_dynamics(alphas, probs_8, rec_errors, filename):
85
+ """Visualizes the CNN response vs SVD reconstruction error during interpolation."""
86
+ setup_style()
87
+ plt.figure(figsize=(10, 6))
88
+
89
+ plt.plot(alphas, probs_8, color=COLOR_CNN, label='CNN Prob(8) [Topology]', marker='o', linewidth=2)
90
+ plt.plot(alphas, rec_errors, color=COLOR_SVD, label='SVD Rec Error [Global Variance]', marker='s', linewidth=2)
91
+
92
+ plt.axvline(x=0.5, color='#4C566A', linestyle='--', alpha=0.5, label='Ambiguity Mid-point')
93
+ plt.title('Mechanistic Dynamics: Interpolation vs. SVD Error', fontsize=14, fontweight='bold', pad=15)
94
+ plt.xlabel('Alpha (0=Digit 3, 1=Digit 8)', fontsize=12)
95
+ plt.ylabel('Metric Value', fontsize=12)
96
+ plt.legend()
97
+ plt.grid(True)
98
+ save_fig(filename)
99
+
100
+ def plot_manifold_comparison(X_svd, X_umap, y, acc_svd, acc_raw, filename):
101
+ """Side-by-side comparison of SVD (linear) vs UMAP (non-linear) projections."""
102
+ setup_style()
103
+ fig, axes = plt.subplots(1, 2, figsize=(15, 6))
104
+
105
+ colors = [COLOR_SVD, COLOR_CNN] # 3 vs 8
106
+ labels = ['Digit 3', 'Digit 8']
107
+
108
+ for i in range(2):
109
+ # SVD Plane
110
+ axes[0].scatter(X_svd[y==i, 0], X_svd[y==i, 1], label=labels[i], alpha=0.5, s=15, color=colors[i])
111
+ # UMAP Manifold
112
+ axes[1].scatter(X_umap[y==i, 0], X_umap[y==i, 1], label=labels[i], alpha=0.5, s=15, color=colors[i])
113
+
114
+ axes[0].set_title(f"SVD Projection (2D Subspace)\nk-NN Accuracy: {acc_svd:.2%}", fontsize=12)
115
+ axes[1].set_title(f"UMAP Manifold (Non-linear)\nRaw k-NN Accuracy: {acc_raw:.2%}", fontsize=12)
116
+
117
+ for ax in axes:
118
+ ax.legend()
119
+ ax.set_xticks([])
120
+ ax.set_yticks([])
121
+
122
+ plt.suptitle("Manifold Collapse: Linear SVD Overlap vs. Non-linear Topological Separation",
123
+ fontsize=14, fontweight='bold', y=1.02)
124
+ save_fig(filename)
125
+
126
+ def plot_learning_curves(history, title, filename):
127
+ """Standardized plotter for training history (loss and accuracy)."""
128
+ setup_style()
129
+ epochs = range(1, len(history['train_loss']) + 1)
130
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
131
+
132
+ # Nord palette for curves
133
+ COLOR_TRAIN = COLOR_SVD
134
+ COLOR_VAL = "#D08770" # Nord 12 (Orange)
135
+
136
+ # Loss Plot
137
+ ax1.plot(epochs, history['train_loss'], label='Train', color=COLOR_TRAIN, marker='o', markersize=4, linewidth=1.5)
138
+ ax1.plot(epochs, history['val_loss'], label='Val', color=COLOR_VAL, marker='s', markersize=4, linewidth=1.5)
139
+ ax1.set_title('Loss Dynamics', fontsize=12, fontweight='bold')
140
+ ax1.set_xlabel('Epoch')
141
+ ax1.set_ylabel('Loss')
142
+ ax1.legend()
143
+ ax1.grid(True)
144
+
145
+ # Accuracy Plot
146
+ ax2.plot(epochs, history['train_acc'], label='Train', color=COLOR_TRAIN, marker='o', markersize=4, linewidth=1.5)
147
+ ax2.plot(epochs, history['val_acc'], label='Val', color=COLOR_VAL, marker='s', markersize=4, linewidth=1.5)
148
+ ax2.set_title('Accuracy Dynamics', fontsize=12, fontweight='bold')
149
+ ax2.set_xlabel('Epoch')
150
+ ax2.set_ylabel('Accuracy')
151
+ ax2.legend()
152
+ ax2.grid(True)
153
+
154
+ plt.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
155
+ save_fig(filename)
156
+
157
+ def plot_per_class_comparison(y_test, y_preds_dict, filename):
158
+ """Grouped bar chart comparing F1-scores per class for multiple models."""
159
+ from sklearn.metrics import f1_score
160
+ setup_style()
161
+ plt.figure(figsize=(10, 6))
162
+
163
+ x = np.arange(10)
164
+ width = 0.8 / len(y_preds_dict)
165
+
166
+ colors = {
167
+ 'SVD+LR': COLOR_SVD,
168
+ 'CNN': COLOR_CNN,
169
+ 'Hybrid': COLOR_HYBRID
170
+ }
171
+
172
+ for i, (label, y_pred) in enumerate(y_preds_dict.items()):
173
+ f1s = f1_score(y_test, y_pred, average=None)
174
+ plt.bar(x + (i - len(y_preds_dict)/2 + 0.5) * width, f1s, width,
175
+ label=label, color=colors.get(label, '#4C566A'), alpha=0.8)
176
+
177
+ plt.xticks(x)
178
+ plt.xlabel('Digit Class', fontsize=12)
179
+ plt.ylabel('F1-Score', fontsize=12)
180
+ plt.title('Per-Class Performance Comparison (F1-Score)', fontsize=14, fontweight='bold', pad=15)
181
+ plt.legend()
182
+ plt.grid(True, axis='y')
183
+ save_fig(filename)
184
+
185
+ def plot_multi_image_grid(images, titles, rows, cols, filename, suptitle=None):
186
+ """Generic grid plotter for images (e.g., eigen-digits)."""
187
+ plt.figure(figsize=(cols * 2.5, rows * 2.5))
188
+ for i, (img, title) in enumerate(zip(images, titles)):
189
+ plt.subplot(rows, cols, i + 1)
190
+ plt.imshow(img, cmap='gray')
191
+ plt.title(title, fontsize=10)
192
+ plt.axis('off')
193
+ if suptitle:
194
+ plt.suptitle(suptitle, fontsize=14, fontweight='bold')
195
+ save_fig(filename)