Spaces:
Sleeping
Sleeping
feat: initial implementation of MNIST Hybrid SVD-CNN core
Browse files- README.md +67 -124
- experiments/01_phenomenon_diagnosis.py +91 -0
- experiments/01_phenomenon_discovery.py +0 -226
- experiments/02_mechanistic_analysis.py +102 -0
- experiments/02_mnist_cnn_confusion.py +0 -68
- experiments/03_mechanistic_investigation.py +0 -244
- experiments/04_robustness_limit.py +0 -187
- experiments/05_manifold_learning.py +0 -103
- experiments/06_fashion_mnist_baseline.py +0 -115
- experiments/07_fashion_cnn_verification.py +0 -145
- experiments/08_hybrid_robustness.py +0 -253
- experiments/09_fashion_hybrid_robustness.py +0 -189
- experiments/10_ablation_study.py +0 -344
- experiments/11_learning_curves.py +0 -228
- experiments/12_roc_analysis.py +0 -291
- experiments/13_per_class_metrics.py +0 -366
- experiments/appendix_learning_curves.py +26 -0
- experiments/appendix_per_class_metrics.py +53 -0
- experiments/run_robustness_test.py +65 -0
- src/__init__.py +0 -0
- src/config.py +17 -0
- src/exp_utils.py +68 -0
- src/hybrid_model.py +45 -0
- src/train_fashion.py +27 -0
- src/train_models.py +72 -0
- src/utils.py +116 -0
- src/viz.py +195 -0
README.md
CHANGED
|
@@ -10,151 +10,94 @@ app_file: app.py
|
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
| 13 |
-
#
|
| 14 |
|
| 15 |
-
[](https://huggingface.co/spaces/ymlin105/Coconut-MNIST)
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
- **Key Takeaways** — SVD behaves like a low-pass filter; CNN attention is local/non-linear; Hybrid helps under high Gaussian noise on MNIST but fails on texture-heavy Fashion-MNIST.
|
| 23 |
-
- **Full technical report →** [REPORT.md](./docs/REPORT.md)
|
| 24 |
|
| 25 |
-
|
| 26 |
|
| 27 |
-
|
| 28 |
|
| 29 |
-
|
| 30 |
-
*Figure: The Hybrid SVD→CNN pipeline. SVD reconstruction acts as a data-adapted low-pass filter before CNN feature extraction.*
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
conda activate hybrid-svd
|
| 46 |
-
|
| 47 |
-
# 2. Install dependencies
|
| 48 |
-
pip install -r requirements.txt
|
| 49 |
-
|
| 50 |
-
# 3. Train SVD + CNN models (~1 min, now with validation & early stopping!)
|
| 51 |
-
python src/train_models.py
|
| 52 |
-
|
| 53 |
-
# 4. Launch interactive dashboard
|
| 54 |
-
streamlit run app.py
|
| 55 |
-
|
| 56 |
-
# Optional: Run additional analysis
|
| 57 |
-
python experiments/10_ablation_study.py # Depth vs Non-linearity analysis
|
| 58 |
-
python experiments/11_learning_curves.py # Training dynamics visualization
|
| 59 |
-
python experiments/12_roc_analysis.py # ROC curves for 3 vs 8
|
| 60 |
-
python experiments/13_per_class_metrics.py # Detailed per-class metrics
|
| 61 |
-
```
|
| 62 |
-
|
| 63 |
-
> **What's New in v2.0**: Enhanced training with validation set splitting (80/20), early stopping (patience=3), reproducible random seeds (seed=42), and 4 new experiments. Key results: **Non-linearity alone provides +4.91 pp gain** (ablation study), **CNN achieves AUC=1.0 on 3vs8** (ROC analysis), **CNN reduces worst confusion from 6.5% to 2.2%** (per-class metrics). See [`docs/IMPROVEMENTS.md`](docs/IMPROVEMENTS.md) and [`docs/QUICKSTART.md`](docs/QUICKSTART.md) for details.
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
```
|
| 66 |
-
Project Structure
|
| 67 |
-
├── src/ Core modules: CNN, SVD layer, hybrid pipeline, training
|
| 68 |
-
├── experiments/ 13 self-contained scripts (01–13), ordered by narrative
|
| 69 |
-
├── docs/ Full report (REPORT.md) + 20 figures and JSON metrics
|
| 70 |
-
├── models/ Pretrained checkpoints (CNN for MNIST & Fashion-MNIST)
|
| 71 |
-
└── app.py Streamlit dashboard (live demo)
|
| 72 |
-
```
|
| 73 |
-
|
| 74 |
-
</details>
|
| 75 |
-
|
| 76 |
-
## Approach
|
| 77 |
-
|
| 78 |
-
```
|
| 79 |
-
Diagnosis Mechanism Solution & Boundary
|
| 80 |
-
───────────────────── ───────────────────── ─────────────────────
|
| 81 |
-
SVD fails on 3 vs 8 → Why? Grad-CAM + UMAP → Hybrid SVD→CNN pipeline
|
| 82 |
-
(Exp 1–3) (Exp 4–7) + Fashion-MNIST stress test
|
| 83 |
-
(Exp 8–9)
|
| 84 |
-
```
|
| 85 |
-
|
| 86 |
-
The Hybrid pipeline passes each input through a fixed SVD reconstruction (rank $k{=}20$) before the CNN classifier. SVD acts as a data-adapted low-pass filter — suppressing high-frequency noise while retaining structure aligned with the training manifold.
|
| 87 |
-
|
| 88 |
-
## Case Study 1: Success on Low-Rank Manifolds (MNIST)
|
| 89 |
|
| 90 |
-
|
|
|
|
| 91 |
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
|
| 95 |
-
| ---------------- | ------ | ------ | ------ | ---------------- |
|
| 96 |
-
| CNN | 98.74% | 96.36% | 80.44% | 54.34% |
|
| 97 |
-
| SVD | 88.12% | 80.87% | 64.60% | 51.30% |
|
| 98 |
-
| Blur+CNN | 94.25% | 83.78% | 63.38% | 44.54% |
|
| 99 |
-
| **Hybrid** | 91.82% | 88.57% | 79.26% | **59.10%** |
|
| 100 |
-
|
| 101 |
-
*Results are mean over 3 random seeds. Bold indicates where Hybrid surpasses the standalone CNN.*
|
| 102 |
-
|
| 103 |
-
**Result**: The Hybrid improves over the standalone CNN at high noise ($\sigma=0.7$: 59.10% vs 54.34%), consistent with SVD reconstruction suppressing high-frequency noise before CNN feature extraction. The crossover point where Hybrid surpasses the CNN occurs between $\sigma=0.5$ and $\sigma=0.7$. The Hybrid also outperforms the Gaussian blur baseline (59.10% vs 44.54%), confirming that SVD provides data-adapted denoising beyond generic smoothing.
|
| 104 |
-
|
| 105 |
-

|
| 106 |
-
*Figure: Accuracy vs. noise level (σ). The Hybrid (orange) crosses above the standalone CNN (blue) at high noise, confirming the denoising benefit on low-rank data.*
|
| 107 |
-
|
| 108 |
-
## Case Study 2: Failure on Texture-Rich Manifolds (Fashion-MNIST)
|
| 109 |
-
|
| 110 |
-
> *See [REPORT.md](./docs/REPORT.md#experiment-9-boundary-analysis-fashion-mnist) for the full boundary analysis.*
|
| 111 |
-
> On texture-dependent data, SVD filtering destroys high-frequency details (e.g., collar vs. no collar), causing the Hybrid model to collapse.
|
| 112 |
-
|
| 113 |
-
| Model | MNIST (Clean) | Fashion-MNIST (Clean) |
|
| 114 |
-
| ---------------- | ------------- | --------------------- |
|
| 115 |
-
| **CNN** | 98.74% | **91.04%** |
|
| 116 |
-
| **Hybrid** | 91.82% | **67.27%** |
|
| 117 |
-
|
| 118 |
-
**Boundary Identified**: This method shows a **~24.6-point** clean-accuracy drop on Fashion-MNIST (texture-dependent), confirming it is primarily suitable for low-rank, shape-defined manifolds.
|
| 119 |
-
|
| 120 |
-
<details>
|
| 121 |
-
<summary><strong>Geometric Mechanics (why SVD fails and when it helps)</strong></summary>
|
| 122 |
-
|
| 123 |
-
- **Feature energy paradox**: discriminative cues can be "low-energy" (gaps, textures) and get wiped out by low-rank projection.
|
| 124 |
-
- **Manifold alignment check**: UMAP shows 3/8 are separable when local neighborhood structure is preserved.
|
| 125 |
-
- **Subspace denoising**: an SVD reconstruction step can act as a data-adapted low-pass filter before the CNN.
|
| 126 |
-
|
| 127 |
-
</details>
|
| 128 |
|
| 129 |
-
|
| 130 |
-
<summary><strong>Evidence & Reproducibility</strong></summary>
|
| 131 |
|
| 132 |
-
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
| 01 | `phenomenon_discovery.py` | SVD failure analysis + spectrum |
|
| 140 |
-
| 02 | `mnist_cnn_confusion.py` | MNIST CNN confusion matrix |
|
| 141 |
-
| 03 | `mechanistic_investigation.py` | Interpolation + Grad-CAM vs reconstruction |
|
| 142 |
-
| 04 | `robustness_limit.py` | SVD vs CNN degradation curves |
|
| 143 |
-
| 05 | `manifold_learning.py` | SVD vs UMAP manifold comparison |
|
| 144 |
-
| 06 | `fashion_mnist_baseline.py` | Fashion-MNIST SVD baseline |
|
| 145 |
-
| 07 | `fashion_cnn_verification.py` | Fashion-MNIST CNN confusion |
|
| 146 |
-
| 08 | `hybrid_robustness.py` | MNIST robustness +`robustness_mnist_noise.json` |
|
| 147 |
-
| 09 | `fashion_hybrid_robustness.py` | Fashion robustness +`robustness_fashion_noise.json` |
|
| 148 |
-
| **10** | **`ablation_study.py`** | **Depth vs Non-linearity contribution analysis** |
|
| 149 |
-
| **11** | **`learning_curves.py`** | **Training/validation dynamics visualization** |
|
| 150 |
-
| **12** | **`roc_analysis.py`** | **ROC curves for 3 vs 8 classification** |
|
| 151 |
-
| **13** | **`per_class_metrics.py`** | **Precision/Recall/F1 per digit + CSV reports** |
|
| 152 |
|
| 153 |
-
#
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
---
|
|
|
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# SVD vs CNN: Mechanistic Analysis of Manifold Alignment on MNIST
|
| 14 |
|
| 15 |
+
[](https://huggingface.co/spaces/ymlin105/Coconut-MNIST) [](./docs/REPORT.md)
|
| 16 |
|
| 17 |
+
While it is a known theoretical property that linear dimensionality reduction (SVD) acts as a low-pass filter, this project provides a **concrete, visual, and quantitative mechanistic explanation** of how this property manifests in neural network classification—specifically, why linear subspaces consistently force a "3" to collapse into an "8".
|
| 18 |
|
| 19 |
+
<p align="center">
|
| 20 |
+
<img src="./docs/research_results/fig_06_explainability.png" width="600" alt="Mechanistic Analysis of SVD Inductive Bias">
|
| 21 |
+
</p>
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
By mapping the exact decision boundaries where linear global variance models fail and non-linear topological models (CNNs) succeed, I empirically validate the **inherent trade-offs** of linear denoising in high-stakes domains like medical imaging or satellite data—where a linear filter might suppress critical diagnostic features to minimize noise variance.
|
| 24 |
|
| 25 |
+
## The Solution: Hybrid SVD-CNN
|
| 26 |
|
| 27 |
+
I combine SVD's strength as a data-adapted low-pass filter with the CNN's robust feature extraction into a single pipeline.
|
|
|
|
| 28 |
|
| 29 |
+
```mermaid
|
| 30 |
+
flowchart TD
|
| 31 |
+
subgraph S1 [I. Noisy Manifold]
|
| 32 |
+
direction LR
|
| 33 |
+
X["Input $X + \eta$"]
|
| 34 |
+
end
|
| 35 |
|
| 36 |
+
subgraph S2 [II. Adaptive Projection]
|
| 37 |
+
direction LR
|
| 38 |
+
node_SVD["SVD: $X = U \Sigma V^T$"]
|
| 39 |
+
node_Trunc["$k$-Rank Truncation"]
|
| 40 |
+
node_Recon["$\hat{X} = \sum \sigma_i u_i v_i^T$"]
|
| 41 |
+
node_SVD --> node_Trunc --> node_Recon
|
| 42 |
+
end
|
| 43 |
|
| 44 |
+
subgraph S3 [III. CNN Features]
|
| 45 |
+
direction LR
|
| 46 |
+
node_Conv["Conv Layers"] --> node_Pool["Pooling / ReLU"] --> node_Flat["Global Flatten"]
|
| 47 |
+
end
|
| 48 |
|
| 49 |
+
subgraph S4 [IV. Latent Mapping]
|
| 50 |
+
direction LR
|
| 51 |
+
node_Soft["Logits / Softmax"] --> node_Pred["Class Prediction"]
|
| 52 |
+
end
|
| 53 |
|
| 54 |
+
S1 --> S2
|
| 55 |
+
S2 --> S3
|
| 56 |
+
S3 --> S4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
style S2 fill:#f8f9ff,stroke:#0056b3,stroke-width:2px
|
| 59 |
+
style S3 fill:#f8fff9,stroke:#28a745,stroke-width:2px
|
| 60 |
+
style S1 fill:#fff,stroke:#333
|
| 61 |
+
style S4 fill:#fff,stroke:#333
|
| 62 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
### Key Takeaways
|
| 65 |
+
For full analysis and detailed metrics, see the [Technical Report](./docs/REPORT.md).
|
| 66 |
|
| 67 |
+
1. **The Variance Trap**: Important details (like the gap in a "3") have very little pixel variance. SVD-based linear projections clear them away as noise, forcing distinct digit manifolds to overlap and causing systematic "3-as-8" hallucinations.
|
| 68 |
+
2. **Local Logic**: UMAP analysis demonstrates that manifolds are topologically distinct when local structure is preserved, but linear variance optimization destroys this neighborhood integrity.
|
| 69 |
+
3. **Hybrid Advantage**: In high-noise environments ($\sigma=0.7$), a Hybrid architecture acts as a data-adapted denoiser, outperforming standalone CNNs by +4.8 pp.
|
| 70 |
+
4. **The Boundary**: On texture-rich data (e.g., Fashion-MNIST), SVD reconstruction destroys critical high-frequency features, defining the physical limit of linear denoising.
|
| 71 |
|
| 72 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
## Experience it Yourself
|
|
|
|
| 75 |
|
| 76 |
+
### Online Demo
|
| 77 |
+
Try the live dashboard to inject noise, adjust SVD rank, and compare model predictions in real-time:
|
| 78 |
+
**[Launch Streamlit App](https://huggingface.co/spaces/ymlin105/Coconut-MNIST)**.
|
| 79 |
|
| 80 |
+
### Local Installation
|
| 81 |
+
```bash
|
| 82 |
+
# Clone the repository
|
| 83 |
+
git clone https://github.com/ymlin105/mnist-linear-vs-nonlinear.git
|
| 84 |
+
cd mnist-linear-vs-nonlinear
|
| 85 |
|
| 86 |
+
# Install dependencies
|
| 87 |
+
pip install -r requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
# Launch the interactive dashboard
|
| 90 |
+
streamlit run app.py
|
| 91 |
+
```
|
| 92 |
|
| 93 |
+
### Project Structure
|
| 94 |
+
```
|
| 95 |
+
├── src/ Core modules (CNN, SVD layer) + Experimental Utils
|
| 96 |
+
├── experiments/ Theme-based scripts (01 Diagnosis, 02 Analysis, 03 Robustness)
|
| 97 |
+
├── docs/ Full report (REPORT.md) + figures
|
| 98 |
+
├── models/ Pretrained checkpoints
|
| 99 |
+
├── run_all_experiments.sh One-click reproduction script
|
| 100 |
+
└── app.py Streamlit dashboard
|
| 101 |
+
```
|
| 102 |
|
| 103 |
---
|
experiments/01_phenomenon_diagnosis.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Exp 01: Phenomenon Diagnosis
|
| 3 |
+
Combines Global/Focused SVD analysis with CNN baseline comparisons.
|
| 4 |
+
Refactored to use centralized utility modules.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from sklearn.decomposition import TruncatedSVD
|
| 10 |
+
from sklearn.linear_model import LogisticRegression
|
| 11 |
+
from sklearn.metrics import accuracy_score
|
| 12 |
+
|
| 13 |
+
from src import config, utils, viz, exp_utils
|
| 14 |
+
|
| 15 |
+
def run_svd_analysis(X_train, y_train, X_test, y_test):
|
| 16 |
+
print("\n--- Running SVD Spectral Analysis ---")
|
| 17 |
+
mean = np.mean(X_train, axis=0)
|
| 18 |
+
X_centered = X_train - mean
|
| 19 |
+
|
| 20 |
+
n_view = 300
|
| 21 |
+
svd = TruncatedSVD(n_components=n_view, random_state=42)
|
| 22 |
+
svd.fit(X_centered)
|
| 23 |
+
|
| 24 |
+
# 1. Visualization: Spectrum
|
| 25 |
+
viz.plot_singular_spectrum(
|
| 26 |
+
svd.singular_values_,
|
| 27 |
+
np.cumsum(svd.explained_variance_ratio_),
|
| 28 |
+
'fig_01_spectrum.png'
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# 2. Classification with k=20
|
| 32 |
+
svd_20 = TruncatedSVD(n_components=20, random_state=42)
|
| 33 |
+
X_train_pca = svd_20.fit_transform(X_train - mean)
|
| 34 |
+
X_test_pca = svd_20.transform(X_test - mean)
|
| 35 |
+
|
| 36 |
+
clf = LogisticRegression(max_iter=1000)
|
| 37 |
+
clf.fit(X_train_pca, y_train)
|
| 38 |
+
y_pred = clf.predict(X_test_pca)
|
| 39 |
+
|
| 40 |
+
acc = accuracy_score(y_test, y_pred)
|
| 41 |
+
print(f"SVD (k=20) Accuracy: {acc*100:.2f}%")
|
| 42 |
+
|
| 43 |
+
# 3. Visualization: Confusion Matrix & Eigen-digits
|
| 44 |
+
viz.plot_confusion_matrix(
|
| 45 |
+
y_test, y_pred, list(range(10)),
|
| 46 |
+
'fig_02_svd_confusion.png',
|
| 47 |
+
f'SVD Confusion Matrix (Acc={acc:.2f})',
|
| 48 |
+
viz.COLOR_SVD
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
component_titles = [f"Comp {i+1}" for i in range(10)]
|
| 52 |
+
viz.plot_multi_image_grid(
|
| 53 |
+
[c.reshape(28, 28) for c in svd_20.components_[:10]],
|
| 54 |
+
component_titles, 2, 5,
|
| 55 |
+
'fig_03_eigen_digits.png',
|
| 56 |
+
'Global SVD Eigen-digits'
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def run_cnn_baseline(device):
|
| 60 |
+
print("\n--- Running CNN Baseline Diagnosis ---")
|
| 61 |
+
svd_p, cnn = utils.load_models(dataset_name="mnist")
|
| 62 |
+
X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False)
|
| 63 |
+
|
| 64 |
+
acc = exp_utils.evaluate_classifier(cnn, X_test, y_test, device=device, is_pytorch=True)
|
| 65 |
+
print(f"CNN Accuracy: {acc*100:.2f}%")
|
| 66 |
+
|
| 67 |
+
# Predict for confusion matrix
|
| 68 |
+
cnn.eval()
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
preds = cnn(X_test.to(device)).argmax(dim=1).cpu().numpy()
|
| 71 |
+
|
| 72 |
+
viz.plot_confusion_matrix(
|
| 73 |
+
y_test.numpy(), preds, list(range(10)),
|
| 74 |
+
'fig_04_cnn_confusion.png',
|
| 75 |
+
f'CNN Confusion Matrix (Acc={acc:.2f})',
|
| 76 |
+
viz.COLOR_CNN
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def main():
|
| 80 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
+
|
| 82 |
+
# Load full MNIST (flattened for SVD)
|
| 83 |
+
X_train, y_train = utils.load_data_split(dataset_name="mnist", train=True, flatten=True)
|
| 84 |
+
X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False, flatten=True)
|
| 85 |
+
|
| 86 |
+
run_svd_analysis(X_train.numpy(), y_train.numpy(), X_test.numpy(), y_test.numpy())
|
| 87 |
+
run_cnn_baseline(device)
|
| 88 |
+
print("\nExperiment 01 Diagnosis Completed.")
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
main()
|
experiments/01_phenomenon_discovery.py
DELETED
|
@@ -1,226 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import numpy as np
|
| 3 |
-
import matplotlib.pyplot as plt
|
| 4 |
-
import seaborn as sns
|
| 5 |
-
from sklearn.decomposition import TruncatedSVD
|
| 6 |
-
from sklearn.metrics import confusion_matrix, accuracy_score
|
| 7 |
-
from sklearn.linear_model import LogisticRegression
|
| 8 |
-
import torchvision
|
| 9 |
-
import torchvision.transforms as transforms
|
| 10 |
-
import torch
|
| 11 |
-
import os
|
| 12 |
-
from matplotlib.colors import LinearSegmentedColormap
|
| 13 |
-
from src import config
|
| 14 |
-
|
| 15 |
-
# --- Configuration ---
|
| 16 |
-
GRAY_LIGHT = "#D8DEE9"
|
| 17 |
-
BLUE_DEEP = "#5E81AC"
|
| 18 |
-
ORANGE = "#D08770"
|
| 19 |
-
|
| 20 |
-
def load_mnist():
|
| 21 |
-
"""Load and flatten MNIST data."""
|
| 22 |
-
print("Loading MNIST...")
|
| 23 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 24 |
-
# Download to ./data/mnist
|
| 25 |
-
trainset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
|
| 26 |
-
testset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
|
| 27 |
-
|
| 28 |
-
# Flatten: (N, 28, 28) -> (N, 784)
|
| 29 |
-
X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 30 |
-
y_train = trainset.targets.numpy()
|
| 31 |
-
|
| 32 |
-
X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 33 |
-
y_test = testset.targets.numpy()
|
| 34 |
-
|
| 35 |
-
return X_train, y_train, X_test, y_test
|
| 36 |
-
|
| 37 |
-
def plot_confusion_matrix(y_true, y_pred, labels, filename, title):
|
| 38 |
-
"""Draws and saves a confusion matrix (normalized)."""
|
| 39 |
-
cm = confusion_matrix(y_true, y_pred, normalize='true') # Normalize by true class (rows)
|
| 40 |
-
plt.figure(figsize=(10, 8))
|
| 41 |
-
cmap = LinearSegmentedColormap.from_list("NBodyBlue", [GRAY_LIGHT, BLUE_DEEP])
|
| 42 |
-
sns.heatmap(cm, annot=True, fmt='.1%', cmap=cmap, xticklabels=labels, yticklabels=labels)
|
| 43 |
-
plt.title(title)
|
| 44 |
-
plt.xlabel('Predicted')
|
| 45 |
-
plt.ylabel('True')
|
| 46 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename))
|
| 47 |
-
plt.close()
|
| 48 |
-
print(f"Saved {filename}")
|
| 49 |
-
|
| 50 |
-
def plot_eigen_digits(components, filename, title):
|
| 51 |
-
"""Visualizes the top eigen-digits."""
|
| 52 |
-
plt.figure(figsize=(12, 4))
|
| 53 |
-
for i in range(min(10, len(components))): # Show top 10
|
| 54 |
-
plt.subplot(2, 5, i + 1)
|
| 55 |
-
plt.imshow(components[i].reshape(28, 28), cmap='gray')
|
| 56 |
-
plt.title(f"Comp {i+1}")
|
| 57 |
-
plt.axis('off')
|
| 58 |
-
plt.suptitle(title)
|
| 59 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename))
|
| 60 |
-
plt.close()
|
| 61 |
-
print(f"Saved {filename}")
|
| 62 |
-
|
| 63 |
-
def analyze_spectrum(X, filename, title):
|
| 64 |
-
"""
|
| 65 |
-
Computes and plots the singular value spectrum.
|
| 66 |
-
Returns cumulative variance stats.
|
| 67 |
-
"""
|
| 68 |
-
print(f"\nRunning Spectral Analysis on shape {X.shape}...")
|
| 69 |
-
# Center the data
|
| 70 |
-
X_mean = np.mean(X, axis=0)
|
| 71 |
-
X_centered = X - X_mean
|
| 72 |
-
|
| 73 |
-
# Compute full SVD (approximation with high k)
|
| 74 |
-
n_view = 300 # Increased from 50 to capture >90% variance
|
| 75 |
-
svd = TruncatedSVD(n_components=n_view, random_state=42)
|
| 76 |
-
svd.fit(X_centered)
|
| 77 |
-
|
| 78 |
-
singular_values = svd.singular_values_
|
| 79 |
-
explained_variance_ratio = svd.explained_variance_ratio_
|
| 80 |
-
cumulative_variance = np.cumsum(explained_variance_ratio)
|
| 81 |
-
|
| 82 |
-
# Quantify stats
|
| 83 |
-
var_k10 = cumulative_variance[9] * 100
|
| 84 |
-
|
| 85 |
-
def get_k(threshold):
|
| 86 |
-
idx = np.argmax(cumulative_variance >= threshold)
|
| 87 |
-
if cumulative_variance[idx] < threshold: return f">{n_view}"
|
| 88 |
-
return idx + 1
|
| 89 |
-
|
| 90 |
-
k_90 = get_k(0.90)
|
| 91 |
-
k_95 = get_k(0.95)
|
| 92 |
-
k_99 = get_k(0.99)
|
| 93 |
-
|
| 94 |
-
print(f"Spectral Stats:")
|
| 95 |
-
print(f" Variance @ k=10: {var_k10:.2f}%")
|
| 96 |
-
print(f" Components for 90% Var: k={k_90}")
|
| 97 |
-
print(f" Components for 95% Var: k={k_95}")
|
| 98 |
-
print(f" Components for 99% Var: k={k_99}")
|
| 99 |
-
|
| 100 |
-
# Plot Scree
|
| 101 |
-
fig, ax1 = plt.subplots(figsize=(10, 6))
|
| 102 |
-
|
| 103 |
-
color = BLUE_DEEP
|
| 104 |
-
ax1.set_xlabel('Principal Component (k)')
|
| 105 |
-
ax1.set_ylabel('Singular Value (Log Scale)', color=color)
|
| 106 |
-
ax1.semilogy(range(1, n_view+1), singular_values, marker='o', linestyle='-', color=color, markersize=4, label='Singular Values')
|
| 107 |
-
ax1.tick_params(axis='y', labelcolor=color)
|
| 108 |
-
ax1.grid(True, which="both", ls="-", alpha=0.3)
|
| 109 |
-
|
| 110 |
-
ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
|
| 111 |
-
color = ORANGE
|
| 112 |
-
ax2.set_ylabel('Cumulative Explained Variance', color=color) # we already handled the x-label with ax1
|
| 113 |
-
ax2.plot(range(1, n_view+1), cumulative_variance, color=color, linewidth=2, linestyle='--', label='Cumulative Variance')
|
| 114 |
-
ax2.tick_params(axis='y', labelcolor=color)
|
| 115 |
-
ax2.set_ylim(0, 1.0)
|
| 116 |
-
|
| 117 |
-
# Annotate k=10
|
| 118 |
-
ax2.axvline(x=10, color='gray', linestyle=':', alpha=0.5)
|
| 119 |
-
ax2.text(10.5, 0.4, f'k=10\n({var_k10:.1f}%)', color='black')
|
| 120 |
-
|
| 121 |
-
plt.title(title)
|
| 122 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename))
|
| 123 |
-
plt.close()
|
| 124 |
-
print(f"Saved {filename}")
|
| 125 |
-
|
| 126 |
-
def run_experiment_0(X_train, y_train, X_test, y_test):
|
| 127 |
-
"""
|
| 128 |
-
Experiment 0: Global SVD Analysis (10 Classes)
|
| 129 |
-
Hypothesis: 3 and 8 shows significant confusion.
|
| 130 |
-
"""
|
| 131 |
-
print("\n--- Running Experiment 0: Global SVD (10 Classes) ---")
|
| 132 |
-
|
| 133 |
-
# 1. SVD Reduction
|
| 134 |
-
n_components = 20 # low rank to force dependency on main variance directions
|
| 135 |
-
print(f"Reducing dimension to {n_components} using SVD...")
|
| 136 |
-
# Mean-center for consistency with hybrid model's SVD layer
|
| 137 |
-
mean = np.mean(X_train, axis=0)
|
| 138 |
-
X_train_centered = X_train - mean
|
| 139 |
-
X_test_centered = X_test - mean
|
| 140 |
-
svd = TruncatedSVD(n_components=n_components, random_state=42)
|
| 141 |
-
X_train_pca = svd.fit_transform(X_train_centered)
|
| 142 |
-
X_test_pca = svd.transform(X_test_centered)
|
| 143 |
-
|
| 144 |
-
# 2. Classification (Simple Linear or KNN)
|
| 145 |
-
# Using Logistic Regression to simulate linear classification on SVD features
|
| 146 |
-
clf = LogisticRegression(max_iter=1000)
|
| 147 |
-
clf.fit(X_train_pca, y_train)
|
| 148 |
-
y_pred = clf.predict(X_test_pca)
|
| 149 |
-
|
| 150 |
-
acc = accuracy_score(y_test, y_pred)
|
| 151 |
-
print(f"Global SVD+LR Accuracy: {acc*100:.2f}%")
|
| 152 |
-
|
| 153 |
-
# 3. Plot Confusion Matrix
|
| 154 |
-
plot_confusion_matrix(y_test, y_pred, list(range(10)),
|
| 155 |
-
'fig_01_global_svd_confusion.png',
|
| 156 |
-
f'Global SVD Confusion Matrix (k={n_components}, Acc={acc:.2f})')
|
| 157 |
-
|
| 158 |
-
# Analyze specific confusion between 3 and 8
|
| 159 |
-
idxs_3 = (y_test == 3)
|
| 160 |
-
idxs_8 = (y_test == 8)
|
| 161 |
-
|
| 162 |
-
# Confusion 3->8
|
| 163 |
-
pred_3 = y_pred[idxs_3]
|
| 164 |
-
confused_3_as_8 = np.sum(pred_3 == 8)
|
| 165 |
-
print(f"Class 3 samples classified as 8: {confused_3_as_8} / {len(pred_3)} ({confused_3_as_8/len(pred_3)*100:.2f}%)")
|
| 166 |
-
|
| 167 |
-
# Confusion 8->3
|
| 168 |
-
pred_8 = y_pred[idxs_8]
|
| 169 |
-
confused_8_as_3 = np.sum(pred_8 == 3)
|
| 170 |
-
print(f"Class 8 samples classified as 3: {confused_8_as_3} / {len(pred_8)} ({confused_8_as_3/len(pred_8)*100:.2f}%)")
|
| 171 |
-
|
| 172 |
-
def run_experiment_1(X_train, y_train, X_test, y_test):
|
| 173 |
-
"""
|
| 174 |
-
Experiment 1: Focused 3 vs 8 SVD Analysis
|
| 175 |
-
"""
|
| 176 |
-
print("\n--- Running Experiment 1: Focused SVD (3 vs 8) ---")
|
| 177 |
-
|
| 178 |
-
# 1. Filter Data
|
| 179 |
-
mask_train = np.logical_or(y_train == 3, y_train == 8)
|
| 180 |
-
mask_test = np.logical_or(y_test == 3, y_test == 8)
|
| 181 |
-
|
| 182 |
-
X_train_38 = X_train[mask_train]
|
| 183 |
-
y_train_38 = y_train[mask_train]
|
| 184 |
-
X_test_38 = X_test[mask_test]
|
| 185 |
-
y_test_38 = y_test[mask_test]
|
| 186 |
-
|
| 187 |
-
print(f"Train samples (3 vs 8): {len(y_train_38)}")
|
| 188 |
-
print(f"Test samples (3 vs 8): {len(y_test_38)}")
|
| 189 |
-
|
| 190 |
-
# 2. SVD on Subset
|
| 191 |
-
n_components = 10
|
| 192 |
-
# Mean-center for consistency with hybrid model's SVD layer
|
| 193 |
-
mean_38 = np.mean(X_train_38, axis=0)
|
| 194 |
-
X_train_38_centered = X_train_38 - mean_38
|
| 195 |
-
X_test_38_centered = X_test_38 - mean_38
|
| 196 |
-
svd = TruncatedSVD(n_components=n_components, random_state=42)
|
| 197 |
-
X_train_pca = svd.fit_transform(X_train_38_centered)
|
| 198 |
-
X_test_pca = svd.transform(X_test_38_centered)
|
| 199 |
-
|
| 200 |
-
# 3. Classify
|
| 201 |
-
clf = LogisticRegression()
|
| 202 |
-
clf.fit(X_train_pca, y_train_38)
|
| 203 |
-
y_pred = clf.predict(X_test_pca)
|
| 204 |
-
|
| 205 |
-
acc = accuracy_score(y_test_38, y_pred)
|
| 206 |
-
print(f"Focused SVD(k={n_components}) Accuracy (3 vs 8): {acc*100:.2f}%")
|
| 207 |
-
|
| 208 |
-
# 4. Plots
|
| 209 |
-
plot_eigen_digits(svd.components_,
|
| 210 |
-
'fig_03_eigen_digits.png',
|
| 211 |
-
'Top 10 Eigen-digits (Principal Components of 3&8)')
|
| 212 |
-
|
| 213 |
-
# 5. Spectral Analysis (New)
|
| 214 |
-
analyze_spectrum(X_train_38, 'fig_02_scree_plot.png', 'Singular Value Spectrum (3 vs 8 Subset)')
|
| 215 |
-
|
| 216 |
-
def main():
|
| 217 |
-
X_train, y_train, X_test, y_test = load_mnist()
|
| 218 |
-
|
| 219 |
-
run_experiment_0(X_train, y_train, X_test, y_test)
|
| 220 |
-
run_experiment_1(X_train, y_train, X_test, y_test)
|
| 221 |
-
|
| 222 |
-
print("\nExperiments 0 & 1 Completed.")
|
| 223 |
-
print(f"Results saved to {config.RESULTS_DIR}")
|
| 224 |
-
|
| 225 |
-
if __name__ == "__main__":
|
| 226 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/02_mechanistic_analysis.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Exp 02: Mechanistic Analysis
|
| 3 |
+
Combines Interpolation, Explainability (Grad-CAM), and Quantifying Manifold Collapse.
|
| 4 |
+
Refactored for modularity.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import numpy as np
|
| 10 |
+
from sklearn.decomposition import TruncatedSVD
|
| 11 |
+
from sklearn.neighbors import KNeighborsClassifier
|
| 12 |
+
from sklearn.metrics import accuracy_score
|
| 13 |
+
|
| 14 |
+
from src import config, utils, viz, exp_utils
|
| 15 |
+
|
| 16 |
+
def run_interpolation_analysis(device):
|
| 17 |
+
print("\n--- Running Mechanistic Proof: The Variance vs. Topology Conflict ---")
|
| 18 |
+
X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False, digits=[3, 8])
|
| 19 |
+
_, cnn = utils.load_models(dataset_name="mnist")
|
| 20 |
+
|
| 21 |
+
# Fit SVD baseline for reconstruction analysis
|
| 22 |
+
X_test_flat = X_test.view(X_test.size(0), -1).numpy()
|
| 23 |
+
svd_pipe = exp_utils.fit_svd_baseline(X_test_flat, y_test.numpy(), n_components=10)
|
| 24 |
+
|
| 25 |
+
svd = svd_pipe.named_steps['svd']
|
| 26 |
+
mean = svd_pipe.named_steps['scaler'].mean_
|
| 27 |
+
|
| 28 |
+
# Pick indices for digit 3 and 8
|
| 29 |
+
idx_3 = (y_test == 0).nonzero()[0][0]
|
| 30 |
+
idx_8 = (y_test == 1).nonzero()[0][0]
|
| 31 |
+
img_3, img_8 = X_test[idx_3], X_test[idx_8]
|
| 32 |
+
|
| 33 |
+
alphas = np.linspace(0, 1, 11)
|
| 34 |
+
probs_8, rec_errors = [], []
|
| 35 |
+
|
| 36 |
+
for alpha in alphas:
|
| 37 |
+
img_interp = (1 - alpha) * img_3 + alpha * img_8
|
| 38 |
+
# CNN Probability of class 1 (Digit 8)
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
logits = cnn(img_interp.unsqueeze(0).to(device))
|
| 41 |
+
# Note: We use index 8 from full model or index 1 if it was binary
|
| 42 |
+
# Here we assume full model but we load 3v8 subset.
|
| 43 |
+
# If model is 10-class, we need to pick actual digit indices.
|
| 44 |
+
# Let's check model output size.
|
| 45 |
+
out_dim = logits.shape[1]
|
| 46 |
+
if out_dim == 10:
|
| 47 |
+
p = torch.softmax(logits, dim=1)[0, 8].item()
|
| 48 |
+
else:
|
| 49 |
+
p = torch.softmax(logits, dim=1)[0, 1].item()
|
| 50 |
+
probs_8.append(p)
|
| 51 |
+
|
| 52 |
+
# SVD Reconstruction Error
|
| 53 |
+
flat = img_interp.view(1, -1).numpy()
|
| 54 |
+
rec = svd.inverse_transform(svd.transform(flat - mean)) + mean
|
| 55 |
+
rec_errors.append(np.linalg.norm(flat - rec))
|
| 56 |
+
|
| 57 |
+
viz.plot_interpolation_dynamics(alphas, probs_8, rec_errors, 'fig_05_interpolation.png')
|
| 58 |
+
|
| 59 |
+
def run_quantifying_manifold_collapse():
|
| 60 |
+
print("\n--- Running Experiment 7: Quantifying Manifold Collapse ---")
|
| 61 |
+
X_train, y_train = utils.load_data_split(dataset_name="mnist", train=True, digits=[3, 8], flatten=True)
|
| 62 |
+
X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False, digits=[3, 8], flatten=True)
|
| 63 |
+
|
| 64 |
+
X_train_np, y_train_np = X_train.numpy(), y_train.numpy()
|
| 65 |
+
X_test_np, y_test_np = X_test.numpy(), y_test.numpy()
|
| 66 |
+
|
| 67 |
+
# 1. k-NN on Raw Pixel Space (784D)
|
| 68 |
+
knn_raw = KNeighborsClassifier(n_neighbors=5)
|
| 69 |
+
knn_raw.fit(X_train_np, y_train_np)
|
| 70 |
+
acc_raw = accuracy_score(y_test_np, knn_raw.predict(X_test_np))
|
| 71 |
+
|
| 72 |
+
# 2. k-NN on SVD-reduced Space (10D)
|
| 73 |
+
svd = TruncatedSVD(n_components=10, random_state=42)
|
| 74 |
+
X_train_svd = svd.fit_transform(X_train_np)
|
| 75 |
+
X_test_svd = svd.transform(X_test_np)
|
| 76 |
+
|
| 77 |
+
knn_svd = KNeighborsClassifier(n_neighbors=5)
|
| 78 |
+
knn_svd.fit(X_train_svd, y_train_np)
|
| 79 |
+
acc_svd = accuracy_score(y_test_np, knn_svd.predict(X_test_svd))
|
| 80 |
+
|
| 81 |
+
# 3. Visualization: SVD vs UMAP
|
| 82 |
+
try:
|
| 83 |
+
import umap
|
| 84 |
+
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
|
| 85 |
+
X_umap = reducer.fit_transform(X_test_np)
|
| 86 |
+
viz.plot_manifold_comparison(X_test_svd, X_umap, y_test_np, acc_svd, acc_raw, 'fig_08_manifold_collapse.png')
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Warning: Manifold visualization failed: {e}")
|
| 89 |
+
|
| 90 |
+
print(f"Manifold Collapse Quantification Results:")
|
| 91 |
+
print(f" - Raw 784D k-NN Accuracy: {acc_raw:.4f}")
|
| 92 |
+
print(f" - SVD 10D k-NN Accuracy: {acc_svd:.4f}")
|
| 93 |
+
print(f" - Accuracy Loss: {acc_raw - acc_svd:.4f}")
|
| 94 |
+
|
| 95 |
+
def main():
|
| 96 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 97 |
+
run_interpolation_analysis(device)
|
| 98 |
+
run_quantifying_manifold_collapse()
|
| 99 |
+
print("\nExperiment 02 Mechanistic Analysis (Refined) Completed.")
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
main()
|
experiments/02_mnist_cnn_confusion.py
DELETED
|
@@ -1,68 +0,0 @@
|
|
| 1 |
-
# Exp 02 – MNIST 10-class CNN confusion matrix
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import matplotlib.pyplot as plt
|
| 7 |
-
import seaborn as sns
|
| 8 |
-
import torch
|
| 9 |
-
from sklearn.metrics import confusion_matrix, accuracy_score
|
| 10 |
-
from torchvision import datasets, transforms
|
| 11 |
-
from matplotlib.colors import LinearSegmentedColormap
|
| 12 |
-
|
| 13 |
-
from src.hybrid_model import SimpleCNN
|
| 14 |
-
from src import config
|
| 15 |
-
|
| 16 |
-
GRAY_LIGHT = "#D8DEE9"
|
| 17 |
-
BLUE_LIGHT = "#88C0D0"
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def load_mnist_test():
|
| 21 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 22 |
-
testset = datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
|
| 23 |
-
X_test = testset.data.float() / 255.0
|
| 24 |
-
y_test = testset.targets
|
| 25 |
-
X_test = X_test.unsqueeze(1) # (N, 1, 28, 28)
|
| 26 |
-
return X_test, y_test
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def main():
|
| 30 |
-
# Load model
|
| 31 |
-
model = SimpleCNN(num_classes=10)
|
| 32 |
-
model.load_state_dict(torch.load(config.CNN_MODEL_PATH, map_location="cpu"))
|
| 33 |
-
model.eval()
|
| 34 |
-
|
| 35 |
-
# Load data
|
| 36 |
-
X_test, y_test = load_mnist_test()
|
| 37 |
-
|
| 38 |
-
with torch.no_grad():
|
| 39 |
-
logits = model(X_test)
|
| 40 |
-
preds = torch.argmax(logits, dim=1)
|
| 41 |
-
|
| 42 |
-
y_true = y_test.numpy()
|
| 43 |
-
y_pred = preds.numpy()
|
| 44 |
-
|
| 45 |
-
acc = accuracy_score(y_true, y_pred)
|
| 46 |
-
print(f"MNIST CNN Accuracy: {acc*100:.2f}%")
|
| 47 |
-
|
| 48 |
-
cm = confusion_matrix(y_true, y_pred, normalize="true")
|
| 49 |
-
|
| 50 |
-
plt.figure(figsize=(10, 8))
|
| 51 |
-
cmap = LinearSegmentedColormap.from_list(
|
| 52 |
-
"NBodyCNN",
|
| 53 |
-
[GRAY_LIGHT, BLUE_LIGHT],
|
| 54 |
-
)
|
| 55 |
-
sns.heatmap(cm, annot=True, fmt=".1%", cmap=cmap, xticklabels=list(range(10)), yticklabels=list(range(10)))
|
| 56 |
-
plt.title(f"MNIST CNN Confusion Matrix (Acc={acc:.2%})")
|
| 57 |
-
plt.xlabel("Predicted")
|
| 58 |
-
plt.ylabel("True")
|
| 59 |
-
plt.tight_layout()
|
| 60 |
-
|
| 61 |
-
out_path = os.path.join(config.RESULTS_DIR, "fig_04_mnist_cnn_confusion.png")
|
| 62 |
-
plt.savefig(out_path, dpi=300)
|
| 63 |
-
plt.close()
|
| 64 |
-
print(f"Saved {out_path}")
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
if __name__ == "__main__":
|
| 68 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/03_mechanistic_investigation.py
DELETED
|
@@ -1,244 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.optim as optim
|
| 5 |
-
from torch.utils.data import TensorDataset, DataLoader
|
| 6 |
-
import numpy as np
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
from sklearn.decomposition import TruncatedSVD
|
| 9 |
-
import cv2
|
| 10 |
-
import os
|
| 11 |
-
import ssl
|
| 12 |
-
import torchvision
|
| 13 |
-
|
| 14 |
-
from src.hybrid_model import SimpleCNN
|
| 15 |
-
from src import config
|
| 16 |
-
|
| 17 |
-
# --- Configuration ---
|
| 18 |
-
BLUE_LIGHT = "#88C0D0"
|
| 19 |
-
BLUE_DEEP = "#5E81AC"
|
| 20 |
-
BATCH_SIZE = 64
|
| 21 |
-
EPOCHS = 5
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def load_mnist_38():
|
| 25 |
-
"""Load MNIST and filter for 3 vs 8."""
|
| 26 |
-
ssl._create_default_https_context = ssl._create_unverified_context
|
| 27 |
-
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
|
| 28 |
-
trainset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
|
| 29 |
-
testset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
|
| 30 |
-
|
| 31 |
-
def filter_38(dataset):
|
| 32 |
-
mask = (dataset.targets == 3) | (dataset.targets == 8)
|
| 33 |
-
data = dataset.data[mask].unsqueeze(1).float() / 255.0
|
| 34 |
-
targets = dataset.targets[mask]
|
| 35 |
-
targets = torch.where(targets == 3, torch.tensor(0), torch.tensor(1))
|
| 36 |
-
return data, targets
|
| 37 |
-
|
| 38 |
-
X_train, y_train = filter_38(trainset)
|
| 39 |
-
X_test, y_test = filter_38(testset)
|
| 40 |
-
return X_train, y_train, X_test, y_test
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
# --- Training Helper ---
|
| 44 |
-
def train_model(X_train, y_train):
|
| 45 |
-
model = SimpleCNN(num_classes=2)
|
| 46 |
-
criterion = nn.CrossEntropyLoss()
|
| 47 |
-
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
| 48 |
-
dataset = TensorDataset(X_train, y_train)
|
| 49 |
-
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 50 |
-
|
| 51 |
-
print("Training CNN for Analysis...")
|
| 52 |
-
model.train()
|
| 53 |
-
for epoch in range(EPOCHS):
|
| 54 |
-
for inputs, labels in loader:
|
| 55 |
-
optimizer.zero_grad()
|
| 56 |
-
outputs = model(inputs)
|
| 57 |
-
loss = criterion(outputs, labels)
|
| 58 |
-
loss.backward()
|
| 59 |
-
optimizer.step()
|
| 60 |
-
return model
|
| 61 |
-
|
| 62 |
-
# --- Experiment 3: Interpolation ---
|
| 63 |
-
def run_interpolation_analysis(model, svd, X_test, y_test, svd_mean=None):
|
| 64 |
-
print("\n--- Running Exp 3: Interpolation Analysis ---")
|
| 65 |
-
|
| 66 |
-
# 1. Find a good pairs of 3 and 8
|
| 67 |
-
# We want a 'canonical' 3 and 8
|
| 68 |
-
idx_3 = (y_test == 0).nonzero(as_tuple=True)[0][0]
|
| 69 |
-
idx_8 = (y_test == 1).nonzero(as_tuple=True)[0][0]
|
| 70 |
-
|
| 71 |
-
img_3 = X_test[idx_3] # (1, 28, 28)
|
| 72 |
-
img_8 = X_test[idx_8]
|
| 73 |
-
|
| 74 |
-
# 2. Generate Interpolation steps
|
| 75 |
-
alphas = np.linspace(0, 1, 11)
|
| 76 |
-
interpolated_imgs = []
|
| 77 |
-
cnn_probs_8 = []
|
| 78 |
-
svd_errors = []
|
| 79 |
-
|
| 80 |
-
print("Computing metrics along interpolation path...")
|
| 81 |
-
for alpha in alphas:
|
| 82 |
-
# Linear Interpolation
|
| 83 |
-
img_interp = (1 - alpha) * img_3 + alpha * img_8
|
| 84 |
-
interpolated_imgs.append(img_interp.squeeze().numpy())
|
| 85 |
-
|
| 86 |
-
# CNN Prediction
|
| 87 |
-
with torch.no_grad():
|
| 88 |
-
img_tensor = img_interp.unsqueeze(0) # (1, 1, 28, 28)
|
| 89 |
-
logits = model(img_tensor)
|
| 90 |
-
probs = torch.softmax(logits, dim=1)
|
| 91 |
-
cnn_probs_8.append(probs[0, 1].item())
|
| 92 |
-
|
| 93 |
-
# SVD Reconstruction (using the passed svd model)
|
| 94 |
-
# SVD expects (N, 784), mean-centered to match training
|
| 95 |
-
img_flat = img_interp.view(1, -1).numpy()
|
| 96 |
-
img_centered = img_flat - svd_mean if svd_mean is not None else img_flat
|
| 97 |
-
img_pca = svd.transform(img_centered)
|
| 98 |
-
img_rec = svd.inverse_transform(img_pca)
|
| 99 |
-
if svd_mean is not None:
|
| 100 |
-
img_rec = img_rec + svd_mean
|
| 101 |
-
|
| 102 |
-
# Reconstruction Error (L2)
|
| 103 |
-
rec_err = np.linalg.norm(img_flat - img_rec)
|
| 104 |
-
svd_errors.append(rec_err)
|
| 105 |
-
|
| 106 |
-
# 3. Plotting
|
| 107 |
-
plt.figure(figsize=(12, 6))
|
| 108 |
-
|
| 109 |
-
# Plot Images
|
| 110 |
-
for i, img in enumerate(interpolated_imgs):
|
| 111 |
-
plt.subplot(3, 11, i + 1)
|
| 112 |
-
plt.imshow(img, cmap='gray')
|
| 113 |
-
plt.axis('off')
|
| 114 |
-
if i == 0: plt.title("Start (3)")
|
| 115 |
-
if i == 10: plt.title("End (8)")
|
| 116 |
-
|
| 117 |
-
# Plot Curves
|
| 118 |
-
plt.subplot(3, 1, 2)
|
| 119 |
-
plt.plot(alphas, cnn_probs_8, marker='o', color=BLUE_LIGHT, label='CNN Prob(Class=8)')
|
| 120 |
-
plt.ylabel('CNN Probability')
|
| 121 |
-
plt.grid(True)
|
| 122 |
-
plt.legend()
|
| 123 |
-
|
| 124 |
-
plt.subplot(3, 1, 3)
|
| 125 |
-
plt.plot(alphas, svd_errors, marker='s', color=BLUE_DEEP, label='SVD Rec. Error')
|
| 126 |
-
plt.xlabel('Interpolation Alpha (0=3, 1=8)')
|
| 127 |
-
plt.ylabel('SVD Error')
|
| 128 |
-
plt.grid(True)
|
| 129 |
-
plt.legend()
|
| 130 |
-
|
| 131 |
-
plt.tight_layout()
|
| 132 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, 'fig_05_interpolation_analysis.png'))
|
| 133 |
-
plt.close()
|
| 134 |
-
print("Saved fig_05_interpolation_analysis.png")
|
| 135 |
-
|
| 136 |
-
return interpolated_imgs # Return for Exp 4 use
|
| 137 |
-
|
| 138 |
-
# --- Experiment 4: Explainability (Grad-CAM vs SVD) ---
|
| 139 |
-
def run_explainability_analysis(model, svd, interpolated_imgs, svd_mean=None):
|
| 140 |
-
print("\n--- Running Exp 4: Explainability Analysis ---")
|
| 141 |
-
|
| 142 |
-
# Pick the middle ambiguous image (alpha=0.5)
|
| 143 |
-
middle_idx = 5
|
| 144 |
-
img_ambiguous = torch.tensor(interpolated_imgs[middle_idx]).unsqueeze(0).unsqueeze(0) # (1, 1, 28, 28)
|
| 145 |
-
|
| 146 |
-
# 1. CNN Grad-CAM
|
| 147 |
-
# Hook into last conv layer
|
| 148 |
-
gradients = []
|
| 149 |
-
activations = []
|
| 150 |
-
|
| 151 |
-
def backward_hook(module, grad_input, grad_output):
|
| 152 |
-
gradients.append(grad_output[0])
|
| 153 |
-
|
| 154 |
-
def forward_hook(module, input, output):
|
| 155 |
-
activations.append(output)
|
| 156 |
-
|
| 157 |
-
# Register hooks on conv2
|
| 158 |
-
handle_b = model.conv2.register_full_backward_hook(backward_hook)
|
| 159 |
-
handle_f = model.conv2.register_forward_hook(forward_hook)
|
| 160 |
-
|
| 161 |
-
# Forward & Backward
|
| 162 |
-
model.eval()
|
| 163 |
-
logits = model(img_ambiguous)
|
| 164 |
-
# Target class 8 (index 1) for visualization
|
| 165 |
-
logits[0, 1].backward()
|
| 166 |
-
|
| 167 |
-
# Generate Heatmap
|
| 168 |
-
grads = gradients[0].cpu().data.numpy()[0] # (32, 7, 7)
|
| 169 |
-
fmaps = activations[0].cpu().data.numpy()[0] # (32, 7, 7)
|
| 170 |
-
|
| 171 |
-
weights = np.mean(grads, axis=(1, 2)) # Global Average Pooling
|
| 172 |
-
cam = np.zeros(fmaps.shape[1:], dtype=np.float32)
|
| 173 |
-
|
| 174 |
-
for i, w in enumerate(weights):
|
| 175 |
-
cam += w * fmaps[i]
|
| 176 |
-
|
| 177 |
-
cam = np.maximum(cam, 0)
|
| 178 |
-
cam = cv2.resize(cam, (28, 28))
|
| 179 |
-
cam = cam - np.min(cam)
|
| 180 |
-
cam = cam / np.max(cam)
|
| 181 |
-
|
| 182 |
-
# 2. SVD Reconstruction
|
| 183 |
-
img_flat = img_ambiguous.view(1, -1).numpy()
|
| 184 |
-
img_centered = img_flat - svd_mean if svd_mean is not None else img_flat
|
| 185 |
-
img_pca = svd.transform(img_centered)
|
| 186 |
-
img_rec = svd.inverse_transform(img_pca)
|
| 187 |
-
if svd_mean is not None:
|
| 188 |
-
img_rec = img_rec + svd_mean
|
| 189 |
-
img_rec = img_rec.reshape(28, 28)
|
| 190 |
-
|
| 191 |
-
# 3. Plot Comparison
|
| 192 |
-
plt.figure(figsize=(10, 4))
|
| 193 |
-
|
| 194 |
-
# Original Ambiguous
|
| 195 |
-
plt.subplot(1, 3, 1)
|
| 196 |
-
plt.imshow(img_ambiguous.squeeze(), cmap='gray')
|
| 197 |
-
plt.title("Ambiguous Input (Alpha=0.5)")
|
| 198 |
-
plt.axis('off')
|
| 199 |
-
|
| 200 |
-
# CNN Attention
|
| 201 |
-
plt.subplot(1, 3, 2)
|
| 202 |
-
plt.imshow(img_ambiguous.squeeze(), cmap='gray')
|
| 203 |
-
plt.imshow(cam, cmap='jet', alpha=0.5) # Overlay
|
| 204 |
-
plt.title("CNN Attention (Grad-CAM)")
|
| 205 |
-
plt.axis('off')
|
| 206 |
-
|
| 207 |
-
# SVD Reconstruction
|
| 208 |
-
plt.subplot(1, 3, 3)
|
| 209 |
-
plt.imshow(img_rec, cmap='gray')
|
| 210 |
-
plt.title("SVD Reconstruction")
|
| 211 |
-
plt.axis('off')
|
| 212 |
-
|
| 213 |
-
plt.tight_layout()
|
| 214 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, 'fig_06_explainability.png'))
|
| 215 |
-
plt.close()
|
| 216 |
-
print("Saved fig_06_explainability.png")
|
| 217 |
-
|
| 218 |
-
handle_b.remove()
|
| 219 |
-
handle_f.remove()
|
| 220 |
-
|
| 221 |
-
def main():
|
| 222 |
-
# Load Data
|
| 223 |
-
X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor = load_mnist_38()
|
| 224 |
-
|
| 225 |
-
# Train CNN
|
| 226 |
-
cnn_model = train_model(X_train_tensor, y_train_tensor)
|
| 227 |
-
|
| 228 |
-
# Fit SVD (on train data 3 vs 8)
|
| 229 |
-
print("Fitting SVD on 3 vs 8...")
|
| 230 |
-
X_train_np = X_train_tensor.view(-1, 784).numpy()
|
| 231 |
-
# Mean-center for consistency with hybrid model's SVD layer
|
| 232 |
-
svd_mean = np.mean(X_train_np, axis=0)
|
| 233 |
-
X_train_centered = X_train_np - svd_mean
|
| 234 |
-
svd = TruncatedSVD(n_components=10, random_state=42)
|
| 235 |
-
svd.fit(X_train_centered)
|
| 236 |
-
|
| 237 |
-
# Run Experiments
|
| 238 |
-
interp_imgs = run_interpolation_analysis(cnn_model, svd, X_test_tensor, y_test_tensor, svd_mean)
|
| 239 |
-
run_explainability_analysis(cnn_model, svd, interp_imgs, svd_mean)
|
| 240 |
-
|
| 241 |
-
print("\nDeep Dive Analysis Completed.")
|
| 242 |
-
|
| 243 |
-
if __name__ == "__main__":
|
| 244 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/04_robustness_limit.py
DELETED
|
@@ -1,187 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.optim as optim
|
| 5 |
-
from torch.utils.data import TensorDataset, DataLoader
|
| 6 |
-
import numpy as np
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
from sklearn.decomposition import TruncatedSVD
|
| 9 |
-
from sklearn.linear_model import LogisticRegression
|
| 10 |
-
import torchvision
|
| 11 |
-
import torchvision.transforms as transforms
|
| 12 |
-
import os
|
| 13 |
-
|
| 14 |
-
from src import config
|
| 15 |
-
|
| 16 |
-
# --- Configuration ---
|
| 17 |
-
BLUE_LIGHT = "#88C0D0"
|
| 18 |
-
BLUE_DEEP = "#5E81AC"
|
| 19 |
-
BATCH_SIZE = 64
|
| 20 |
-
|
| 21 |
-
# --- Model (Same as before) ---
|
| 22 |
-
from src.hybrid_model import SimpleCNN
|
| 23 |
-
|
| 24 |
-
# --- Comparison Evaluation ---
|
| 25 |
-
|
| 26 |
-
def load_data():
|
| 27 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 28 |
-
trainset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
|
| 29 |
-
testset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
|
| 30 |
-
|
| 31 |
-
def filter_38(dataset):
|
| 32 |
-
mask = (dataset.targets == 3) | (dataset.targets == 8)
|
| 33 |
-
data = dataset.data[mask].unsqueeze(1).float() / 255.0
|
| 34 |
-
targets = dataset.targets[mask]
|
| 35 |
-
targets = torch.where(targets == 3, torch.tensor(0), torch.tensor(1))
|
| 36 |
-
return data, targets
|
| 37 |
-
|
| 38 |
-
X_train, y_train = filter_38(trainset)
|
| 39 |
-
X_test, y_test = filter_38(testset)
|
| 40 |
-
return X_train, y_train, X_test, y_test
|
| 41 |
-
|
| 42 |
-
def add_noise(images, noise_level):
|
| 43 |
-
"""Add Gaussian noise."""
|
| 44 |
-
noise = torch.randn_like(images) * noise_level
|
| 45 |
-
noisy_imgs = images + noise
|
| 46 |
-
return torch.clamp(noisy_imgs, 0, 1)
|
| 47 |
-
|
| 48 |
-
def add_blur(images, kernel_size):
|
| 49 |
-
"""Add Gaussian blur."""
|
| 50 |
-
if kernel_size <= 1: return images
|
| 51 |
-
blur_fn = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(0.1 + 0.3 * (kernel_size//2)))
|
| 52 |
-
return blur_fn(images)
|
| 53 |
-
|
| 54 |
-
def evaluate_models(cnn, svd_clf, svd_transform, X_test, y_test, svd_mean=None):
|
| 55 |
-
# CNN Eval
|
| 56 |
-
cnn.eval()
|
| 57 |
-
with torch.no_grad():
|
| 58 |
-
logits = cnn(X_test)
|
| 59 |
-
preds = torch.argmax(logits, dim=1)
|
| 60 |
-
cnn_acc = (preds == y_test).float().mean().item()
|
| 61 |
-
|
| 62 |
-
# SVD Eval
|
| 63 |
-
X_flat = X_test.view(X_test.size(0), -1).numpy()
|
| 64 |
-
if svd_mean is not None:
|
| 65 |
-
X_flat = X_flat - svd_mean
|
| 66 |
-
X_pca = svd_transform.transform(X_flat)
|
| 67 |
-
y_pred_svd = svd_clf.predict(X_pca)
|
| 68 |
-
svd_acc = np.mean(y_pred_svd == y_test.numpy())
|
| 69 |
-
|
| 70 |
-
return cnn_acc, svd_acc
|
| 71 |
-
|
| 72 |
-
def main():
|
| 73 |
-
print("Loading Data...")
|
| 74 |
-
X_train, y_train, X_test, y_test = load_data()
|
| 75 |
-
|
| 76 |
-
# 1. Train Models (Quickly)
|
| 77 |
-
print("Training CNN Baseline...")
|
| 78 |
-
cnn = SimpleCNN(num_classes=2)
|
| 79 |
-
optimizer = optim.Adam(cnn.parameters(), lr=0.001)
|
| 80 |
-
criterion = nn.CrossEntropyLoss()
|
| 81 |
-
dataset = TensorDataset(X_train, y_train)
|
| 82 |
-
loader = DataLoader(dataset, batch_size=64, shuffle=True)
|
| 83 |
-
|
| 84 |
-
cnn.train()
|
| 85 |
-
for _ in range(3): # 3 Epochs enough for 99%
|
| 86 |
-
for x, y in loader:
|
| 87 |
-
optimizer.zero_grad()
|
| 88 |
-
loss = criterion(cnn(x), y)
|
| 89 |
-
loss.backward()
|
| 90 |
-
optimizer.step()
|
| 91 |
-
|
| 92 |
-
print("Training SVD Baseline...")
|
| 93 |
-
X_train_flat = X_train.view(X_train.size(0), -1).numpy()
|
| 94 |
-
# Mean-center for consistency with hybrid model's SVD layer
|
| 95 |
-
svd_mean = np.mean(X_train_flat, axis=0)
|
| 96 |
-
X_train_centered = X_train_flat - svd_mean
|
| 97 |
-
svd = TruncatedSVD(n_components=10, random_state=42)
|
| 98 |
-
X_train_pca = svd.fit_transform(X_train_centered)
|
| 99 |
-
|
| 100 |
-
clf = LogisticRegression(max_iter=500)
|
| 101 |
-
clf.fit(X_train_pca, y_train.numpy())
|
| 102 |
-
|
| 103 |
-
# 2. Noise Experiment
|
| 104 |
-
noise_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
| 105 |
-
cnn_accs_noise = []
|
| 106 |
-
svd_accs_noise = []
|
| 107 |
-
|
| 108 |
-
print("Running Noise Experiments...")
|
| 109 |
-
for nl in noise_levels:
|
| 110 |
-
X_test_noisy = add_noise(X_test, nl)
|
| 111 |
-
ca, sa = evaluate_models(cnn, clf, svd, X_test_noisy, y_test, svd_mean)
|
| 112 |
-
cnn_accs_noise.append(ca)
|
| 113 |
-
svd_accs_noise.append(sa)
|
| 114 |
-
print(f"Noise {nl}: CNN={ca:.2f}, SVD={sa:.2f}")
|
| 115 |
-
|
| 116 |
-
# 3. Blur Experiment
|
| 117 |
-
blur_kernels = [1, 3, 5, 7, 9, 11]
|
| 118 |
-
cnn_accs_blur = []
|
| 119 |
-
svd_accs_blur = []
|
| 120 |
-
|
| 121 |
-
print("Running Blur Experiments...")
|
| 122 |
-
for k in blur_kernels:
|
| 123 |
-
X_test_blur = add_blur(X_test, k)
|
| 124 |
-
ca, sa = evaluate_models(cnn, clf, svd, X_test_blur, y_test, svd_mean)
|
| 125 |
-
cnn_accs_blur.append(ca)
|
| 126 |
-
svd_accs_blur.append(sa)
|
| 127 |
-
print(f"Blur K={k}: CNN={ca:.2f}, SVD={sa:.2f}")
|
| 128 |
-
|
| 129 |
-
# 4. Plots
|
| 130 |
-
plt.figure(figsize=(12, 5))
|
| 131 |
-
|
| 132 |
-
# Noise Plot
|
| 133 |
-
plt.subplot(1, 2, 1)
|
| 134 |
-
plt.plot(noise_levels, cnn_accs_noise, marker='o', color=BLUE_LIGHT, label='CNN')
|
| 135 |
-
plt.plot(noise_levels, svd_accs_noise, marker='s', color=BLUE_DEEP, label='SVD')
|
| 136 |
-
plt.xlabel(r'Gaussian Noise ($\sigma$)')
|
| 137 |
-
plt.ylabel('Accuracy')
|
| 138 |
-
plt.title('Robustness vs Noise')
|
| 139 |
-
plt.legend()
|
| 140 |
-
plt.grid(True)
|
| 141 |
-
|
| 142 |
-
# Blur Plot
|
| 143 |
-
plt.subplot(1, 2, 2)
|
| 144 |
-
plt.plot(blur_kernels, cnn_accs_blur, marker='o', color=BLUE_LIGHT, label='CNN')
|
| 145 |
-
plt.plot(blur_kernels, svd_accs_blur, marker='s', color=BLUE_DEEP, label='SVD')
|
| 146 |
-
plt.xlabel('Blur Kernel Size')
|
| 147 |
-
plt.ylabel('Accuracy')
|
| 148 |
-
plt.title('Robustness vs Blur')
|
| 149 |
-
plt.legend()
|
| 150 |
-
plt.grid(True)
|
| 151 |
-
|
| 152 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, 'fig_07_degradation_curves.png'))
|
| 153 |
-
plt.close()
|
| 154 |
-
|
| 155 |
-
# 5. Visualizing Breakdown
|
| 156 |
-
# Find a sample that CNN gets right at noise=0.1 but wrong at noise=0.5
|
| 157 |
-
print("Generating Breakdown Visuals...")
|
| 158 |
-
noise_high = 0.6
|
| 159 |
-
X_test_high_noise = add_noise(X_test, noise_high)
|
| 160 |
-
|
| 161 |
-
cnn.eval()
|
| 162 |
-
logits = cnn(X_test_high_noise)
|
| 163 |
-
preds = torch.argmax(logits, dim=1)
|
| 164 |
-
|
| 165 |
-
# Find failures
|
| 166 |
-
failures = (preds != y_test).nonzero(as_tuple=True)[0]
|
| 167 |
-
if len(failures) > 0:
|
| 168 |
-
idx = failures[0]
|
| 169 |
-
img_clean = X_test[idx].squeeze()
|
| 170 |
-
img_noisy = X_test_high_noise[idx].squeeze()
|
| 171 |
-
|
| 172 |
-
plt.figure(figsize=(8, 4))
|
| 173 |
-
plt.subplot(1, 2, 1)
|
| 174 |
-
plt.imshow(img_clean, cmap='gray')
|
| 175 |
-
plt.title(f"Clean (True: {y_test[idx]})")
|
| 176 |
-
|
| 177 |
-
plt.subplot(1, 2, 2)
|
| 178 |
-
plt.imshow(img_noisy, cmap='gray')
|
| 179 |
-
plt.title(f"Noisy $\\sigma$={noise_high}\nCNN Pred: {preds[idx]}")
|
| 180 |
-
|
| 181 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, 'fig_08_breakdown_point.png'))
|
| 182 |
-
plt.close()
|
| 183 |
-
|
| 184 |
-
print("Experiment 5 Completed.")
|
| 185 |
-
|
| 186 |
-
if __name__ == "__main__":
|
| 187 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/05_manifold_learning.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import numpy as np
|
| 3 |
-
import matplotlib.pyplot as plt
|
| 4 |
-
import seaborn as sns
|
| 5 |
-
from sklearn.decomposition import TruncatedSVD
|
| 6 |
-
from sklearn.metrics import silhouette_score
|
| 7 |
-
import umap
|
| 8 |
-
import torch
|
| 9 |
-
import torchvision
|
| 10 |
-
from torchvision import transforms
|
| 11 |
-
import os
|
| 12 |
-
|
| 13 |
-
from matplotlib.colors import ListedColormap
|
| 14 |
-
from src import config
|
| 15 |
-
|
| 16 |
-
BLUE_DEEP = "#5E81AC"
|
| 17 |
-
ORANGE = "#D08770"
|
| 18 |
-
|
| 19 |
-
# Configure styling
|
| 20 |
-
sns.set_style("whitegrid")
|
| 21 |
-
plt.rcParams.update({'font.size': 12})
|
| 22 |
-
|
| 23 |
-
def load_data():
|
| 24 |
-
"""Load and filter MNIST for digits 3 and 8."""
|
| 25 |
-
print("Loading MNIST data...")
|
| 26 |
-
# Fix for Mac SSL certificate issue
|
| 27 |
-
import ssl
|
| 28 |
-
ssl._create_default_https_context = ssl._create_unverified_context
|
| 29 |
-
|
| 30 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 31 |
-
# Try loading without download first
|
| 32 |
-
try:
|
| 33 |
-
dataset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=True, download=False, transform=transform)
|
| 34 |
-
except:
|
| 35 |
-
dataset = torchvision.datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
|
| 36 |
-
|
| 37 |
-
# Filter for 3 and 8
|
| 38 |
-
idx = (dataset.targets == 3) | (dataset.targets == 8)
|
| 39 |
-
dataset.targets = dataset.targets[idx]
|
| 40 |
-
dataset.data = dataset.data[idx]
|
| 41 |
-
|
| 42 |
-
# Flatten images (N, 784)
|
| 43 |
-
X = dataset.data.numpy().reshape(-1, 28*28).astype(np.float32) / 255.0
|
| 44 |
-
y = dataset.targets.numpy()
|
| 45 |
-
|
| 46 |
-
print(f"Dataset loaded: {X.shape} samples (Classes: {np.unique(y)})")
|
| 47 |
-
return X, y
|
| 48 |
-
|
| 49 |
-
def run_experiment():
|
| 50 |
-
X, y = load_data()
|
| 51 |
-
|
| 52 |
-
# Subsample for UMAP speed if necessary (though MNIST 3/8 is small enough ~12k samples)
|
| 53 |
-
# We'll use full set for accurate density
|
| 54 |
-
|
| 55 |
-
print("\n--- Running SVD Projection (Linear) ---")
|
| 56 |
-
# Mean-center for consistency with project convention
|
| 57 |
-
X_centered = X - X.mean(axis=0)
|
| 58 |
-
svd = TruncatedSVD(n_components=2, random_state=42)
|
| 59 |
-
X_svd = svd.fit_transform(X_centered)
|
| 60 |
-
sil_svd = silhouette_score(X_svd, y)
|
| 61 |
-
print(f"SVD Silhouette Score: {sil_svd:.4f}")
|
| 62 |
-
|
| 63 |
-
print("\n--- Running UMAP Projection (Non-linear) ---")
|
| 64 |
-
reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
|
| 65 |
-
X_umap = reducer.fit_transform(X)
|
| 66 |
-
sil_umap = silhouette_score(X_umap, y)
|
| 67 |
-
print(f"UMAP Silhouette Score: {sil_umap:.4f}")
|
| 68 |
-
|
| 69 |
-
# Plotting
|
| 70 |
-
print("\nGenerating Comparison Plot...")
|
| 71 |
-
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
|
| 72 |
-
|
| 73 |
-
# Custom cmap for 3 and 8
|
| 74 |
-
# 3 is lower index, 8 is higher. We map 3 -> blue_deep, 8 -> orange
|
| 75 |
-
cmap = ListedColormap([BLUE_DEEP, ORANGE])
|
| 76 |
-
|
| 77 |
-
# Plot SVD
|
| 78 |
-
scatter_svd = axes[0].scatter(X_svd[:, 0], X_svd[:, 1], c=y, cmap=cmap, alpha=0.5, s=2)
|
| 79 |
-
axes[0].set_title(f"Linear Projection (SVD)\nSilhouette Score: {sil_svd:.3f}", fontsize=14)
|
| 80 |
-
axes[0].set_xlabel("PC 1")
|
| 81 |
-
axes[0].set_ylabel("PC 2")
|
| 82 |
-
|
| 83 |
-
# Plot UMAP
|
| 84 |
-
scatter_umap = axes[1].scatter(X_umap[:, 0], X_umap[:, 1], c=y, cmap=cmap, alpha=0.5, s=2)
|
| 85 |
-
axes[1].set_title(f"Manifold Learning (UMAP)\nSilhouette Score: {sil_umap:.3f}", fontsize=14)
|
| 86 |
-
axes[1].set_xlabel("UMAP 1")
|
| 87 |
-
axes[1].set_ylabel("UMAP 2")
|
| 88 |
-
|
| 89 |
-
# Add legend
|
| 90 |
-
legend1 = axes[0].legend(*scatter_svd.legend_elements(), title="Digits")
|
| 91 |
-
axes[0].add_artist(legend1)
|
| 92 |
-
legend2 = axes[1].legend(*scatter_umap.legend_elements(), title="Digits")
|
| 93 |
-
axes[1].add_artist(legend2)
|
| 94 |
-
|
| 95 |
-
plt.tight_layout()
|
| 96 |
-
save_path = os.path.join(config.RESULTS_DIR, "fig_09_manifold_comparison.png")
|
| 97 |
-
plt.savefig(save_path, dpi=150)
|
| 98 |
-
print(f"Plot saved to: {save_path}")
|
| 99 |
-
|
| 100 |
-
return sil_svd, sil_umap
|
| 101 |
-
|
| 102 |
-
if __name__ == "__main__":
|
| 103 |
-
run_experiment()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/06_fashion_mnist_baseline.py
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
# Exp 06 – Fashion-MNIST SVD baseline (replicating MNIST findings)
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import seaborn as sns
|
| 6 |
-
from sklearn.decomposition import TruncatedSVD
|
| 7 |
-
from sklearn.linear_model import LogisticRegression
|
| 8 |
-
from sklearn.metrics import confusion_matrix, accuracy_score
|
| 9 |
-
from matplotlib.colors import LinearSegmentedColormap
|
| 10 |
-
import torchvision
|
| 11 |
-
import torchvision.transforms as transforms
|
| 12 |
-
import os
|
| 13 |
-
|
| 14 |
-
from src import config
|
| 15 |
-
|
| 16 |
-
# --- Configuration ---
|
| 17 |
-
GRAY_LIGHT = "#D8DEE9"
|
| 18 |
-
BLUE_DEEP = "#5E81AC"
|
| 19 |
-
|
| 20 |
-
# Fashion-MNIST class names
|
| 21 |
-
CLASS_NAMES = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
|
| 22 |
-
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
|
| 23 |
-
|
| 24 |
-
def load_fashion_mnist():
|
| 25 |
-
"""Load and flatten Fashion-MNIST data."""
|
| 26 |
-
print("Loading Fashion-MNIST...")
|
| 27 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 28 |
-
|
| 29 |
-
trainset = torchvision.datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=True, download=True, transform=transform)
|
| 30 |
-
testset = torchvision.datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=False, download=True, transform=transform)
|
| 31 |
-
|
| 32 |
-
X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 33 |
-
y_train = trainset.targets.numpy()
|
| 34 |
-
|
| 35 |
-
X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 36 |
-
y_test = testset.targets.numpy()
|
| 37 |
-
|
| 38 |
-
return X_train, y_train, X_test, y_test
|
| 39 |
-
|
| 40 |
-
def plot_confusion_matrix(y_true, y_pred, labels, filename, title):
|
| 41 |
-
"""Draws and saves a confusion matrix (normalized by row = recall)."""
|
| 42 |
-
cm = confusion_matrix(y_true, y_pred, normalize='true')
|
| 43 |
-
plt.figure(figsize=(12, 10))
|
| 44 |
-
cmap = LinearSegmentedColormap.from_list("NBodyBlue", [GRAY_LIGHT, BLUE_DEEP])
|
| 45 |
-
sns.heatmap(cm, annot=True, fmt='.1%', cmap=cmap,
|
| 46 |
-
xticklabels=labels, yticklabels=labels)
|
| 47 |
-
plt.title(title)
|
| 48 |
-
plt.xlabel('Predicted')
|
| 49 |
-
plt.ylabel('True')
|
| 50 |
-
plt.tight_layout()
|
| 51 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300)
|
| 52 |
-
plt.close()
|
| 53 |
-
print(f"Saved {filename}")
|
| 54 |
-
|
| 55 |
-
def analyze_confusion_pairs(cm, class_names, top_k=5):
|
| 56 |
-
"""Identify the most confused class pairs."""
|
| 57 |
-
n = len(class_names)
|
| 58 |
-
confusions = []
|
| 59 |
-
|
| 60 |
-
for i in range(n):
|
| 61 |
-
for j in range(n):
|
| 62 |
-
if i != j:
|
| 63 |
-
confusions.append((cm[i, j], class_names[i], class_names[j]))
|
| 64 |
-
|
| 65 |
-
confusions.sort(reverse=True)
|
| 66 |
-
|
| 67 |
-
print(f"\nTop {top_k} Confused Pairs:")
|
| 68 |
-
for rate, true_class, pred_class in confusions[:top_k]:
|
| 69 |
-
print(f" {true_class} → {pred_class}: {rate*100:.2f}%")
|
| 70 |
-
|
| 71 |
-
return confusions[:top_k]
|
| 72 |
-
|
| 73 |
-
def run_svd_baseline(X_train, y_train, X_test, y_test):
|
| 74 |
-
"""Run SVD + Logistic Regression baseline."""
|
| 75 |
-
print("\n--- Running SVD Baseline (Fashion-MNIST) ---")
|
| 76 |
-
|
| 77 |
-
n_components = 20
|
| 78 |
-
print(f"Reducing dimension to {n_components} using SVD...")
|
| 79 |
-
|
| 80 |
-
# Mean-center for consistency with hybrid model's SVD layer
|
| 81 |
-
mean = np.mean(X_train, axis=0)
|
| 82 |
-
X_train_centered = X_train - mean
|
| 83 |
-
X_test_centered = X_test - mean
|
| 84 |
-
svd = TruncatedSVD(n_components=n_components, random_state=42)
|
| 85 |
-
X_train_svd = svd.fit_transform(X_train_centered)
|
| 86 |
-
X_test_svd = svd.transform(X_test_centered)
|
| 87 |
-
|
| 88 |
-
clf = LogisticRegression(max_iter=1000)
|
| 89 |
-
clf.fit(X_train_svd, y_train)
|
| 90 |
-
y_pred = clf.predict(X_test_svd)
|
| 91 |
-
|
| 92 |
-
acc = accuracy_score(y_test, y_pred)
|
| 93 |
-
print(f"SVD+LR Accuracy: {acc*100:.2f}%")
|
| 94 |
-
|
| 95 |
-
# Confusion Matrix
|
| 96 |
-
cm = confusion_matrix(y_test, y_pred, normalize='true')
|
| 97 |
-
plot_confusion_matrix(y_test, y_pred, CLASS_NAMES,
|
| 98 |
-
'fig_11_fashion_svd_confusion.png',
|
| 99 |
-
f'Fashion-MNIST SVD Confusion (k={n_components}, Acc={acc:.2%})')
|
| 100 |
-
|
| 101 |
-
# Analyze confusions
|
| 102 |
-
analyze_confusion_pairs(cm, CLASS_NAMES)
|
| 103 |
-
|
| 104 |
-
return svd, clf
|
| 105 |
-
|
| 106 |
-
def main():
|
| 107 |
-
X_train, y_train, X_test, y_test = load_fashion_mnist()
|
| 108 |
-
|
| 109 |
-
svd, clf = run_svd_baseline(X_train, y_train, X_test, y_test)
|
| 110 |
-
|
| 111 |
-
print("\nExperiment 06 Complete.")
|
| 112 |
-
print(f"Results saved to {config.RESULTS_DIR}")
|
| 113 |
-
|
| 114 |
-
if __name__ == "__main__":
|
| 115 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/07_fashion_cnn_verification.py
DELETED
|
@@ -1,145 +0,0 @@
|
|
| 1 |
-
# Exp 07 – Fashion-MNIST CNN vs SVD confusion comparison
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import seaborn as sns
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.optim as optim
|
| 9 |
-
from torch.utils.data import DataLoader
|
| 10 |
-
from torchvision import datasets, transforms
|
| 11 |
-
from sklearn.metrics import confusion_matrix, accuracy_score
|
| 12 |
-
from matplotlib.colors import LinearSegmentedColormap
|
| 13 |
-
import os
|
| 14 |
-
|
| 15 |
-
from src import config
|
| 16 |
-
from src.hybrid_model import SimpleCNN
|
| 17 |
-
|
| 18 |
-
# --- Configuration ---
|
| 19 |
-
GRAY_LIGHT = "#D8DEE9"
|
| 20 |
-
BLUE_DEEP = "#5E81AC"
|
| 21 |
-
|
| 22 |
-
CLASS_NAMES = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
|
| 23 |
-
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
|
| 24 |
-
|
| 25 |
-
def train_cnn(train_loader, epochs=10):
|
| 26 |
-
"""Train CNN on Fashion-MNIST."""
|
| 27 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
-
model = SimpleCNN(num_classes=10).to(device)
|
| 29 |
-
|
| 30 |
-
criterion = nn.CrossEntropyLoss()
|
| 31 |
-
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
| 32 |
-
|
| 33 |
-
model.train()
|
| 34 |
-
for epoch in range(epochs):
|
| 35 |
-
running_loss = 0.0
|
| 36 |
-
for inputs, labels in train_loader:
|
| 37 |
-
inputs, labels = inputs.to(device), labels.to(device)
|
| 38 |
-
|
| 39 |
-
optimizer.zero_grad()
|
| 40 |
-
outputs = model(inputs)
|
| 41 |
-
loss = criterion(outputs, labels)
|
| 42 |
-
loss.backward()
|
| 43 |
-
optimizer.step()
|
| 44 |
-
|
| 45 |
-
running_loss += loss.item()
|
| 46 |
-
|
| 47 |
-
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}")
|
| 48 |
-
|
| 49 |
-
return model
|
| 50 |
-
|
| 51 |
-
def evaluate_cnn(model, test_loader):
|
| 52 |
-
"""Evaluate CNN and return predictions."""
|
| 53 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 54 |
-
model.eval()
|
| 55 |
-
|
| 56 |
-
all_preds = []
|
| 57 |
-
all_targets = []
|
| 58 |
-
|
| 59 |
-
with torch.no_grad():
|
| 60 |
-
for inputs, labels in test_loader:
|
| 61 |
-
inputs = inputs.to(device)
|
| 62 |
-
outputs = model(inputs)
|
| 63 |
-
preds = outputs.argmax(dim=1)
|
| 64 |
-
|
| 65 |
-
all_preds.extend(preds.cpu().numpy())
|
| 66 |
-
all_targets.extend(labels.numpy())
|
| 67 |
-
|
| 68 |
-
return np.array(all_preds), np.array(all_targets)
|
| 69 |
-
|
| 70 |
-
def plot_confusion_matrix(y_true, y_pred, labels, filename, title):
|
| 71 |
-
"""Draws and saves a confusion matrix (normalized by row = recall)."""
|
| 72 |
-
cm = confusion_matrix(y_true, y_pred, normalize='true')
|
| 73 |
-
plt.figure(figsize=(12, 10))
|
| 74 |
-
cmap = LinearSegmentedColormap.from_list("NBodyBlue", [GRAY_LIGHT, BLUE_DEEP])
|
| 75 |
-
sns.heatmap(cm, annot=True, fmt='.1%', cmap=cmap,
|
| 76 |
-
xticklabels=labels, yticklabels=labels)
|
| 77 |
-
plt.title(title)
|
| 78 |
-
plt.xlabel('Predicted')
|
| 79 |
-
plt.ylabel('True')
|
| 80 |
-
plt.tight_layout()
|
| 81 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300)
|
| 82 |
-
plt.close()
|
| 83 |
-
print(f"Saved {filename}")
|
| 84 |
-
|
| 85 |
-
def analyze_confusion_improvement(svd_confusions, cnn_cm, class_names):
|
| 86 |
-
"""Compare SVD vs CNN on the most confused pairs."""
|
| 87 |
-
print("\n--- Confusion Comparison: SVD vs CNN ---")
|
| 88 |
-
print(f"{'Pair':<25} {'SVD Error':<12} {'CNN Error':<12} {'Improvement':<12}")
|
| 89 |
-
print("-" * 60)
|
| 90 |
-
|
| 91 |
-
for svd_rate, true_class, pred_class in svd_confusions:
|
| 92 |
-
i = class_names.index(true_class)
|
| 93 |
-
j = class_names.index(pred_class)
|
| 94 |
-
cnn_rate = cnn_cm[i, j]
|
| 95 |
-
improvement = (svd_rate - cnn_rate) / svd_rate * 100 if svd_rate > 0 else 0
|
| 96 |
-
|
| 97 |
-
print(f"{true_class} → {pred_class:<10} {svd_rate*100:>8.2f}% {cnn_rate*100:>8.2f}% {improvement:>8.1f}%")
|
| 98 |
-
|
| 99 |
-
def main():
|
| 100 |
-
print("Loading Fashion-MNIST...")
|
| 101 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 102 |
-
|
| 103 |
-
trainset = datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=True, download=True, transform=transform)
|
| 104 |
-
testset = datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=False, download=True, transform=transform)
|
| 105 |
-
|
| 106 |
-
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
|
| 107 |
-
test_loader = DataLoader(testset, batch_size=1000, shuffle=False)
|
| 108 |
-
|
| 109 |
-
print("\n--- Training CNN on Fashion-MNIST ---")
|
| 110 |
-
model = train_cnn(train_loader, epochs=10)
|
| 111 |
-
|
| 112 |
-
print("\n--- Evaluating CNN ---")
|
| 113 |
-
y_pred, y_true = evaluate_cnn(model, test_loader)
|
| 114 |
-
|
| 115 |
-
acc = accuracy_score(y_true, y_pred)
|
| 116 |
-
print(f"CNN Accuracy: {acc*100:.2f}%")
|
| 117 |
-
|
| 118 |
-
# Confusion Matrix
|
| 119 |
-
cm = confusion_matrix(y_true, y_pred, normalize='true')
|
| 120 |
-
plot_confusion_matrix(y_true, y_pred, CLASS_NAMES,
|
| 121 |
-
'fig_12_fashion_cnn_confusion.png',
|
| 122 |
-
f'Fashion-MNIST CNN Confusion (Acc={acc:.2%})')
|
| 123 |
-
|
| 124 |
-
# Save model for later use
|
| 125 |
-
model_path = os.path.join("models", "cnn_fashion.pth")
|
| 126 |
-
os.makedirs("models", exist_ok=True)
|
| 127 |
-
torch.save(model.state_dict(), model_path)
|
| 128 |
-
print(f"Model saved to {model_path}")
|
| 129 |
-
|
| 130 |
-
# Compare with SVD (hardcoded top confusions from experiment 06)
|
| 131 |
-
# These will be updated after running experiment 06
|
| 132 |
-
svd_confusions = [
|
| 133 |
-
(0.15, 'Shirt', 'T-shirt'),
|
| 134 |
-
(0.12, 'Shirt', 'Coat'),
|
| 135 |
-
(0.10, 'Pullover', 'Coat'),
|
| 136 |
-
(0.08, 'T-shirt', 'Shirt'),
|
| 137 |
-
(0.06, 'Coat', 'Pullover'),
|
| 138 |
-
]
|
| 139 |
-
|
| 140 |
-
analyze_confusion_improvement(svd_confusions, cm, CLASS_NAMES)
|
| 141 |
-
|
| 142 |
-
print("\nExperiment 07 Complete.")
|
| 143 |
-
|
| 144 |
-
if __name__ == "__main__":
|
| 145 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/08_hybrid_robustness.py
DELETED
|
@@ -1,253 +0,0 @@
|
|
| 1 |
-
# Exp 08 – MNIST 10-class hybrid robustness under Gaussian noise (multi-seed)
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import torch
|
| 6 |
-
from torchvision import datasets, transforms
|
| 7 |
-
from sklearn.decomposition import TruncatedSVD
|
| 8 |
-
from sklearn.linear_model import LogisticRegression
|
| 9 |
-
from sklearn.metrics import accuracy_score
|
| 10 |
-
import os
|
| 11 |
-
import json
|
| 12 |
-
from scipy.ndimage import gaussian_filter
|
| 13 |
-
|
| 14 |
-
from src.hybrid_model import SimpleCNN, HybridSVDCNN, create_svd_layer
|
| 15 |
-
from src import config
|
| 16 |
-
|
| 17 |
-
# --- Configuration ---
|
| 18 |
-
BLUE_LIGHT = "#88C0D0"
|
| 19 |
-
BLUE_DEEP = "#5E81AC"
|
| 20 |
-
ORANGE = "#D08770"
|
| 21 |
-
RED = "#BF616A"
|
| 22 |
-
|
| 23 |
-
SVD_K = 20
|
| 24 |
-
NOISE_LEVELS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
| 25 |
-
SEEDS = [42, 123, 456]
|
| 26 |
-
BLUR_SIGMA = 1.5
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def load_mnist():
|
| 30 |
-
"""Load MNIST test data."""
|
| 31 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 32 |
-
testset = datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
|
| 33 |
-
trainset = datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
|
| 34 |
-
|
| 35 |
-
X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 36 |
-
y_train = trainset.targets.numpy()
|
| 37 |
-
|
| 38 |
-
X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 39 |
-
y_test = testset.targets.numpy()
|
| 40 |
-
|
| 41 |
-
return X_train, y_train, X_test, y_test
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def add_gaussian_noise(X, sigma):
|
| 45 |
-
"""Add Gaussian noise to images."""
|
| 46 |
-
noise = np.random.randn(*X.shape) * sigma
|
| 47 |
-
X_noisy = X + noise
|
| 48 |
-
return np.clip(X_noisy, 0, 1)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def set_seeds(seed: int) -> None:
|
| 52 |
-
np.random.seed(seed)
|
| 53 |
-
torch.manual_seed(seed)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def evaluate_svd(svd, clf, X_test, y_test, mean):
|
| 57 |
-
"""Evaluate SVD+LR model (with mean-centering)."""
|
| 58 |
-
X_test_centered = X_test - mean
|
| 59 |
-
X_test_svd = svd.transform(X_test_centered)
|
| 60 |
-
y_pred = clf.predict(X_test_svd)
|
| 61 |
-
return accuracy_score(y_test, y_pred)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def evaluate_cnn(model, X_test, y_test, device):
|
| 65 |
-
"""Evaluate CNN model."""
|
| 66 |
-
model.eval()
|
| 67 |
-
X_tensor = torch.tensor(X_test.reshape(-1, 1, 28, 28), dtype=torch.float32).to(device)
|
| 68 |
-
|
| 69 |
-
with torch.no_grad():
|
| 70 |
-
outputs = model(X_tensor)
|
| 71 |
-
preds = outputs.argmax(dim=1).cpu().numpy()
|
| 72 |
-
|
| 73 |
-
return accuracy_score(y_test, preds)
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def evaluate_blur_cnn(cnn, X_test, y_test, device, blur_sigma=BLUR_SIGMA):
|
| 77 |
-
"""Sanity baseline: if SVD is just smoothing, blur should do equally well."""
|
| 78 |
-
X_blurred = np.array([
|
| 79 |
-
gaussian_filter(img.reshape(28, 28), sigma=blur_sigma).flatten()
|
| 80 |
-
for img in X_test
|
| 81 |
-
])
|
| 82 |
-
X_blurred = np.clip(X_blurred, 0, 1)
|
| 83 |
-
|
| 84 |
-
return evaluate_cnn(cnn, X_blurred, y_test, device)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def load_pretrained_cnn(device) -> SimpleCNN:
|
| 88 |
-
"""Load the pretrained CNN from models/ for stable, reproducible evaluation."""
|
| 89 |
-
model = SimpleCNN(num_classes=10).to(device)
|
| 90 |
-
model.load_state_dict(torch.load(config.CNN_MODEL_PATH, map_location=device))
|
| 91 |
-
model.eval()
|
| 92 |
-
return model
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def train_svd_model(X_train, y_train, n_components=SVD_K):
|
| 96 |
-
"""SVD+LR baseline. Mean-centered to match hybrid model's SVD layer."""
|
| 97 |
-
mean = np.mean(X_train, axis=0)
|
| 98 |
-
X_centered = X_train - mean
|
| 99 |
-
|
| 100 |
-
svd = TruncatedSVD(n_components=n_components, random_state=42)
|
| 101 |
-
X_train_svd = svd.fit_transform(X_centered)
|
| 102 |
-
|
| 103 |
-
clf = LogisticRegression(max_iter=1000)
|
| 104 |
-
clf.fit(X_train_svd, y_train)
|
| 105 |
-
|
| 106 |
-
return svd, clf, mean
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def plot_robustness_comparison(results, filename, std_results=None):
|
| 110 |
-
"""Plot robustness curves for all models, optionally with error bands."""
|
| 111 |
-
plt.figure(figsize=(10, 6))
|
| 112 |
-
|
| 113 |
-
colors = {
|
| 114 |
-
'CNN': BLUE_LIGHT,
|
| 115 |
-
'SVD': BLUE_DEEP,
|
| 116 |
-
'Hybrid': RED,
|
| 117 |
-
'Blur+CNN': ORANGE,
|
| 118 |
-
}
|
| 119 |
-
markers = {'CNN': 'o', 'SVD': 's', 'Hybrid': '^', 'Blur+CNN': 'D'}
|
| 120 |
-
|
| 121 |
-
for model_name, accuracies in results.items():
|
| 122 |
-
plt.plot(NOISE_LEVELS, accuracies,
|
| 123 |
-
color=colors[model_name],
|
| 124 |
-
marker=markers[model_name],
|
| 125 |
-
linewidth=2, markersize=8,
|
| 126 |
-
label=f'{model_name}')
|
| 127 |
-
# Add shaded error band if std is available
|
| 128 |
-
if std_results and model_name in std_results:
|
| 129 |
-
mean = np.array(accuracies)
|
| 130 |
-
std = np.array(std_results[model_name])
|
| 131 |
-
plt.fill_between(NOISE_LEVELS, mean - std, mean + std,
|
| 132 |
-
color=colors[model_name], alpha=0.15)
|
| 133 |
-
|
| 134 |
-
plt.xlabel('Noise Level (σ)', fontsize=12)
|
| 135 |
-
plt.ylabel('Accuracy', fontsize=12)
|
| 136 |
-
plt.title('Model Robustness Under Gaussian Noise', fontsize=14)
|
| 137 |
-
plt.legend(fontsize=11)
|
| 138 |
-
plt.grid(True, alpha=0.3)
|
| 139 |
-
plt.ylim(0.0, 1.05)
|
| 140 |
-
|
| 141 |
-
# Add annotations
|
| 142 |
-
plt.axhline(y=0.9, color='gray', linestyle='--', alpha=0.5)
|
| 143 |
-
plt.text(0.65, 0.91, '90% threshold', fontsize=10, color='gray')
|
| 144 |
-
|
| 145 |
-
plt.tight_layout()
|
| 146 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300)
|
| 147 |
-
plt.close()
|
| 148 |
-
print(f"Saved {filename}")
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def run_single_seed(seed, X_train, y_train, X_test, y_test, device, cnn):
|
| 152 |
-
"""Run one full evaluation pass with a given random seed."""
|
| 153 |
-
set_seeds(seed)
|
| 154 |
-
print(f"\n{'='*50}")
|
| 155 |
-
print(f" Seed = {seed}")
|
| 156 |
-
print(f"{'='*50}")
|
| 157 |
-
|
| 158 |
-
# Train SVD+LR (with mean-centering, consistent with hybrid)
|
| 159 |
-
svd, svd_clf, svd_mean = train_svd_model(X_train, y_train)
|
| 160 |
-
|
| 161 |
-
# Create Hybrid Model
|
| 162 |
-
svd_layer = create_svd_layer(X_train, n_components=SVD_K)
|
| 163 |
-
hybrid = HybridSVDCNN(svd_layer, cnn).to(device)
|
| 164 |
-
|
| 165 |
-
results = {'CNN': [], 'SVD': [], 'Hybrid': [], 'Blur+CNN': []}
|
| 166 |
-
|
| 167 |
-
for sigma in NOISE_LEVELS:
|
| 168 |
-
X_test_noisy = add_gaussian_noise(X_test, sigma)
|
| 169 |
-
|
| 170 |
-
acc_svd = evaluate_svd(svd, svd_clf, X_test_noisy, y_test, svd_mean)
|
| 171 |
-
acc_cnn = evaluate_cnn(cnn, X_test_noisy, y_test, device)
|
| 172 |
-
acc_hybrid = evaluate_cnn(hybrid, X_test_noisy, y_test, device)
|
| 173 |
-
acc_blur = evaluate_blur_cnn(cnn, X_test_noisy, y_test, device)
|
| 174 |
-
|
| 175 |
-
results['SVD'].append(acc_svd)
|
| 176 |
-
results['CNN'].append(acc_cnn)
|
| 177 |
-
results['Hybrid'].append(acc_hybrid)
|
| 178 |
-
results['Blur+CNN'].append(acc_blur)
|
| 179 |
-
|
| 180 |
-
print(f" σ={sigma:.1f} SVD={acc_svd*100:5.2f}% CNN={acc_cnn*100:5.2f}% "
|
| 181 |
-
f"Hybrid={acc_hybrid*100:5.2f}% Blur+CNN={acc_blur*100:5.2f}%")
|
| 182 |
-
|
| 183 |
-
return results
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def main():
|
| 187 |
-
print("Loading MNIST...")
|
| 188 |
-
X_train, y_train, X_test, y_test = load_mnist()
|
| 189 |
-
|
| 190 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 191 |
-
print(f"Using device: {device}")
|
| 192 |
-
|
| 193 |
-
# Load pretrained CNN (shared across all seeds)
|
| 194 |
-
print("\nLoading pretrained CNN...")
|
| 195 |
-
cnn = load_pretrained_cnn(device)
|
| 196 |
-
|
| 197 |
-
# Run over multiple seeds
|
| 198 |
-
all_runs = [] # list of {model_name: [acc_per_sigma]}
|
| 199 |
-
for seed in SEEDS:
|
| 200 |
-
run_results = run_single_seed(seed, X_train, y_train, X_test, y_test, device, cnn)
|
| 201 |
-
all_runs.append(run_results)
|
| 202 |
-
|
| 203 |
-
# Aggregate: compute mean ± std across seeds
|
| 204 |
-
model_names = ['CNN', 'SVD', 'Hybrid', 'Blur+CNN']
|
| 205 |
-
mean_results = {}
|
| 206 |
-
std_results = {}
|
| 207 |
-
for name in model_names:
|
| 208 |
-
all_accs = np.array([run[name] for run in all_runs]) # (n_seeds, n_noise_levels)
|
| 209 |
-
mean_results[name] = all_accs.mean(axis=0).tolist()
|
| 210 |
-
std_results[name] = all_accs.std(axis=0).tolist()
|
| 211 |
-
|
| 212 |
-
# Plot comparison (mean with error band)
|
| 213 |
-
plot_robustness_comparison(mean_results, 'fig_10_hybrid_robustness.png', std_results)
|
| 214 |
-
|
| 215 |
-
# Save raw numbers for reproducibility / app usage
|
| 216 |
-
out_json = {
|
| 217 |
-
"dataset": "MNIST",
|
| 218 |
-
"task": "10-class",
|
| 219 |
-
"noise_levels": NOISE_LEVELS,
|
| 220 |
-
"seeds": SEEDS,
|
| 221 |
-
"results_mean": {k: [round(x, 4) for x in v] for k, v in mean_results.items()},
|
| 222 |
-
"results_std": {k: [round(x, 4) for x in v] for k, v in std_results.items()},
|
| 223 |
-
"results": {k: [round(x, 4) for x in v] for k, v in mean_results.items()}, # backward compat
|
| 224 |
-
"svd_components": SVD_K,
|
| 225 |
-
"svd_centering": True,
|
| 226 |
-
"blur_sigma": BLUR_SIGMA,
|
| 227 |
-
"cnn_epochs": 5,
|
| 228 |
-
"notes": "Mean over 3 seeds. SVD baseline uses explicit mean-centering for consistency with hybrid layer.",
|
| 229 |
-
}
|
| 230 |
-
json_path = os.path.join(config.RESULTS_DIR, "robustness_mnist_noise.json")
|
| 231 |
-
with open(json_path, "w", encoding="utf-8") as f:
|
| 232 |
-
json.dump(out_json, f, indent=2)
|
| 233 |
-
print(f"\nSaved robustness JSON to {json_path}")
|
| 234 |
-
|
| 235 |
-
# Summary table
|
| 236 |
-
print("\n--- Summary (mean ± std across seeds) ---")
|
| 237 |
-
header = f"{'Model':<12} {'Clean':<16} {'σ=0.3':<16} {'σ=0.5':<16} {'σ=0.7':<16}"
|
| 238 |
-
print(header)
|
| 239 |
-
print("-" * len(header))
|
| 240 |
-
for name in model_names:
|
| 241 |
-
m = mean_results[name]
|
| 242 |
-
s = std_results[name]
|
| 243 |
-
# indices: 0=clean, 3=0.3, 5=0.5, 7=0.7
|
| 244 |
-
print(f"{name:<12} "
|
| 245 |
-
f"{m[0]*100:5.2f}±{s[0]*100:.2f}% "
|
| 246 |
-
f"{m[3]*100:5.2f}±{s[3]*100:.2f}% "
|
| 247 |
-
f"{m[5]*100:5.2f}±{s[5]*100:.2f}% "
|
| 248 |
-
f"{m[7]*100:5.2f}±{s[7]*100:.2f}%")
|
| 249 |
-
|
| 250 |
-
print("\nExperiment 08 Complete.")
|
| 251 |
-
|
| 252 |
-
if __name__ == "__main__":
|
| 253 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/09_fashion_hybrid_robustness.py
DELETED
|
@@ -1,189 +0,0 @@
|
|
| 1 |
-
# Exp 09 – Fashion-MNIST hybrid robustness under Gaussian noise
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import torch
|
| 6 |
-
from torchvision import transforms
|
| 7 |
-
import torchvision
|
| 8 |
-
from sklearn.decomposition import TruncatedSVD
|
| 9 |
-
from sklearn.linear_model import LogisticRegression
|
| 10 |
-
from sklearn.metrics import accuracy_score
|
| 11 |
-
import os
|
| 12 |
-
import sys
|
| 13 |
-
import json
|
| 14 |
-
|
| 15 |
-
from src.hybrid_model import SimpleCNN, HybridSVDCNN, create_svd_layer
|
| 16 |
-
from src import config
|
| 17 |
-
|
| 18 |
-
# --- Configuration ---
|
| 19 |
-
BLUE_LIGHT = "#88C0D0"
|
| 20 |
-
BLUE_DEEP = "#5E81AC"
|
| 21 |
-
RED = "#BF616A"
|
| 22 |
-
|
| 23 |
-
SVD_K = 20
|
| 24 |
-
NOISE_LEVELS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
| 25 |
-
|
| 26 |
-
def load_fashion_mnist():
|
| 27 |
-
"""Load Fashion-MNIST test data."""
|
| 28 |
-
print("Loading Fashion-MNIST...")
|
| 29 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 30 |
-
|
| 31 |
-
# Use ./data/fashion to match other scripts
|
| 32 |
-
trainset = torchvision.datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=True, download=True, transform=transform)
|
| 33 |
-
testset = torchvision.datasets.FashionMNIST(root=config.FASHION_MNIST_DIR, train=False, download=True, transform=transform)
|
| 34 |
-
|
| 35 |
-
X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 36 |
-
y_train = trainset.targets.numpy()
|
| 37 |
-
|
| 38 |
-
X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 39 |
-
y_test = testset.targets.numpy()
|
| 40 |
-
|
| 41 |
-
return X_train, y_train, X_test, y_test
|
| 42 |
-
|
| 43 |
-
def add_gaussian_noise(X, sigma):
|
| 44 |
-
"""Add Gaussian noise to images."""
|
| 45 |
-
noise = np.random.randn(*X.shape) * sigma
|
| 46 |
-
X_noisy = X + noise
|
| 47 |
-
return np.clip(X_noisy, 0, 1)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def set_seeds(seed: int) -> None:
|
| 51 |
-
np.random.seed(seed)
|
| 52 |
-
torch.manual_seed(seed)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def load_pretrained_cnn(device) -> SimpleCNN:
|
| 56 |
-
"""Load the pretrained Fashion-MNIST CNN from models/ for stable evaluation."""
|
| 57 |
-
model = SimpleCNN(num_classes=10).to(device)
|
| 58 |
-
model.load_state_dict(torch.load(config.CNN_FASHION_MODEL_PATH, map_location=device))
|
| 59 |
-
model.eval()
|
| 60 |
-
return model
|
| 61 |
-
|
| 62 |
-
def evaluate_svd(svd, clf, X_test, y_test, mean=None):
|
| 63 |
-
"""Evaluate SVD+LR model (with optional mean-centering)."""
|
| 64 |
-
X = X_test - mean if mean is not None else X_test
|
| 65 |
-
X_test_svd = svd.transform(X)
|
| 66 |
-
y_pred = clf.predict(X_test_svd)
|
| 67 |
-
return accuracy_score(y_test, y_pred)
|
| 68 |
-
|
| 69 |
-
def evaluate_cnn(model, X_test, y_test, device):
|
| 70 |
-
"""Evaluate CNN model."""
|
| 71 |
-
model.eval()
|
| 72 |
-
X_tensor = torch.tensor(X_test.reshape(-1, 1, 28, 28), dtype=torch.float32).to(device)
|
| 73 |
-
|
| 74 |
-
with torch.no_grad():
|
| 75 |
-
outputs = model(X_tensor)
|
| 76 |
-
preds = outputs.argmax(dim=1).cpu().numpy()
|
| 77 |
-
|
| 78 |
-
return accuracy_score(y_test, preds)
|
| 79 |
-
|
| 80 |
-
def train_svd_model(X_train, y_train, n_components=SVD_K):
|
| 81 |
-
"""Train SVD + Logistic Regression with explicit mean-centering."""
|
| 82 |
-
print(f"Training SVD (k={n_components}) on shape {X_train.shape}...")
|
| 83 |
-
sys.stdout.flush()
|
| 84 |
-
# Mean-center for consistency with hybrid model's SVD layer
|
| 85 |
-
mean = np.mean(X_train, axis=0)
|
| 86 |
-
X_centered = X_train - mean
|
| 87 |
-
svd = TruncatedSVD(n_components=n_components, algorithm='randomized', n_iter=5, random_state=42)
|
| 88 |
-
X_train_svd = svd.fit_transform(X_centered)
|
| 89 |
-
print("SVD fitted. Training Logistic Regression...")
|
| 90 |
-
|
| 91 |
-
# Increase iterations to avoid premature convergence warnings on 60k samples
|
| 92 |
-
clf = LogisticRegression(max_iter=1000)
|
| 93 |
-
clf.fit(X_train_svd, y_train)
|
| 94 |
-
|
| 95 |
-
print(f"SVD Explained Variance: {svd.explained_variance_ratio_.sum()*100:.2f}%")
|
| 96 |
-
return svd, clf, mean
|
| 97 |
-
|
| 98 |
-
def plot_robustness_comparison(results, filename):
|
| 99 |
-
"""Plot robustness curves for all models."""
|
| 100 |
-
plt.figure(figsize=(10, 6))
|
| 101 |
-
|
| 102 |
-
colors = {'CNN': BLUE_LIGHT, 'SVD': BLUE_DEEP, 'Hybrid': RED}
|
| 103 |
-
markers = {'CNN': 'o', 'SVD': 's', 'Hybrid': '^'}
|
| 104 |
-
|
| 105 |
-
for model_name, accuracies in results.items():
|
| 106 |
-
plt.plot(NOISE_LEVELS, accuracies,
|
| 107 |
-
color=colors[model_name],
|
| 108 |
-
marker=markers[model_name],
|
| 109 |
-
linewidth=2, markersize=8,
|
| 110 |
-
label=f'{model_name}')
|
| 111 |
-
|
| 112 |
-
plt.xlabel('Noise Level (σ)', fontsize=12)
|
| 113 |
-
plt.ylabel('Accuracy', fontsize=12)
|
| 114 |
-
plt.title('Fashion-MNIST: Model Robustness Under Gaussian Noise', fontsize=14)
|
| 115 |
-
plt.legend(fontsize=11)
|
| 116 |
-
plt.grid(True, alpha=0.3)
|
| 117 |
-
plt.ylim(0.0, 1.05)
|
| 118 |
-
|
| 119 |
-
plt.tight_layout()
|
| 120 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300)
|
| 121 |
-
plt.close()
|
| 122 |
-
print(f"Saved {filename}")
|
| 123 |
-
|
| 124 |
-
def main():
|
| 125 |
-
set_seeds(42)
|
| 126 |
-
# 1. Load Data
|
| 127 |
-
X_train, y_train, X_test, y_test = load_fashion_mnist()
|
| 128 |
-
|
| 129 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 130 |
-
|
| 131 |
-
# 2. Train Models
|
| 132 |
-
# SVD
|
| 133 |
-
svd, svd_clf, svd_mean = train_svd_model(X_train, y_train)
|
| 134 |
-
|
| 135 |
-
# CNN (pretrained)
|
| 136 |
-
cnn = load_pretrained_cnn(device)
|
| 137 |
-
|
| 138 |
-
# Hybrid
|
| 139 |
-
svd_layer = create_svd_layer(X_train, n_components=SVD_K)
|
| 140 |
-
hybrid = HybridSVDCNN(svd_layer, cnn).to(device)
|
| 141 |
-
|
| 142 |
-
# 3. Evaluate
|
| 143 |
-
print("\n--- Evaluating Robustness on Fashion-MNIST ---")
|
| 144 |
-
results = {'CNN': [], 'SVD': [], 'Hybrid': []}
|
| 145 |
-
|
| 146 |
-
for sigma in NOISE_LEVELS:
|
| 147 |
-
print(f"\nNoise σ = {sigma}")
|
| 148 |
-
X_test_noisy = add_gaussian_noise(X_test, sigma)
|
| 149 |
-
|
| 150 |
-
# SVD
|
| 151 |
-
acc_svd = evaluate_svd(svd, svd_clf, X_test_noisy, y_test, svd_mean)
|
| 152 |
-
results['SVD'].append(acc_svd)
|
| 153 |
-
print(f" SVD: {acc_svd*100:.2f}%")
|
| 154 |
-
|
| 155 |
-
# CNN
|
| 156 |
-
acc_cnn = evaluate_cnn(cnn, X_test_noisy, y_test, device)
|
| 157 |
-
results['CNN'].append(acc_cnn)
|
| 158 |
-
print(f" CNN: {acc_cnn*100:.2f}%")
|
| 159 |
-
|
| 160 |
-
# Hybrid
|
| 161 |
-
acc_hybrid = evaluate_cnn(hybrid, X_test_noisy, y_test, device)
|
| 162 |
-
results['Hybrid'].append(acc_hybrid)
|
| 163 |
-
print(f" Hybrid: {acc_hybrid*100:.2f}%")
|
| 164 |
-
|
| 165 |
-
# Save raw numbers for reproducibility / app usage
|
| 166 |
-
out_json = {
|
| 167 |
-
"dataset": "Fashion-MNIST",
|
| 168 |
-
"task": "10-class",
|
| 169 |
-
"noise_levels": NOISE_LEVELS,
|
| 170 |
-
"results": {k: [float(x) for x in v] for k, v in results.items()},
|
| 171 |
-
"svd_components": 20,
|
| 172 |
-
"cnn_epochs": 5,
|
| 173 |
-
"notes": "Numbers are evaluated on Fashion-MNIST test set with test-time Gaussian noise.",
|
| 174 |
-
}
|
| 175 |
-
json_path = os.path.join(config.RESULTS_DIR, "robustness_fashion_noise.json")
|
| 176 |
-
with open(json_path, "w", encoding="utf-8") as f:
|
| 177 |
-
json.dump(out_json, f, indent=2)
|
| 178 |
-
print(f"Saved robustness JSON to {json_path}")
|
| 179 |
-
|
| 180 |
-
# 5. Summary Table
|
| 181 |
-
print("\n--- Summary (Fashion-MNIST) ---")
|
| 182 |
-
print(f"{'Model':<10} {'Clean':<10} {'σ=0.3':<10} {'σ=0.5':<10} {'σ=0.7':<10}")
|
| 183 |
-
print("-" * 50)
|
| 184 |
-
for model_name in ['CNN', 'SVD', 'Hybrid']:
|
| 185 |
-
accs = results[model_name]
|
| 186 |
-
print(f"{model_name:<10} {accs[0]*100:>6.2f}% {accs[3]*100:>6.2f}% {accs[5]*100:>6.2f}% {accs[7]*100:>6.2f}%")
|
| 187 |
-
|
| 188 |
-
if __name__ == "__main__":
|
| 189 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/10_ablation_study.py
DELETED
|
@@ -1,344 +0,0 @@
|
|
| 1 |
-
# Exp 10 – Ablation Study: Depth vs Non-linearity Contribution
|
| 2 |
-
# Systematically test the independent contributions of depth and non-linearity
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import matplotlib.pyplot as plt
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.optim as optim
|
| 9 |
-
from torch.utils.data import TensorDataset, DataLoader
|
| 10 |
-
from torchvision import datasets, transforms
|
| 11 |
-
from sklearn.linear_model import LogisticRegression
|
| 12 |
-
from sklearn.metrics import accuracy_score
|
| 13 |
-
from sklearn.model_selection import train_test_split
|
| 14 |
-
import os
|
| 15 |
-
|
| 16 |
-
from src import config
|
| 17 |
-
|
| 18 |
-
# --- Configuration ---
|
| 19 |
-
BLUE_DEEP = "#5E81AC"
|
| 20 |
-
ORANGE = "#D08770"
|
| 21 |
-
GREEN = "#A3BE8C"
|
| 22 |
-
RED = "#BF616A"
|
| 23 |
-
PURPLE = "#B48EAD"
|
| 24 |
-
|
| 25 |
-
SEED = 42
|
| 26 |
-
BATCH_SIZE = 64
|
| 27 |
-
EPOCHS = 10
|
| 28 |
-
LR = 0.001
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def set_seeds(seed):
|
| 32 |
-
"""Set all random seeds for reproducibility."""
|
| 33 |
-
np.random.seed(seed)
|
| 34 |
-
torch.manual_seed(seed)
|
| 35 |
-
torch.cuda.manual_seed_all(seed)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class ShallowLinear(nn.Module):
|
| 39 |
-
"""Single layer linear model (no activation)."""
|
| 40 |
-
def __init__(self, num_classes=10):
|
| 41 |
-
super().__init__()
|
| 42 |
-
self.fc = nn.Linear(784, num_classes)
|
| 43 |
-
|
| 44 |
-
def forward(self, x):
|
| 45 |
-
x = x.view(-1, 784)
|
| 46 |
-
return self.fc(x)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class ShallowNonLinear(nn.Module):
|
| 50 |
-
"""Single layer with ReLU activation."""
|
| 51 |
-
def __init__(self, num_classes=10, hidden_size=128):
|
| 52 |
-
super().__init__()
|
| 53 |
-
self.fc1 = nn.Linear(784, hidden_size)
|
| 54 |
-
self.fc2 = nn.Linear(hidden_size, num_classes)
|
| 55 |
-
|
| 56 |
-
def forward(self, x):
|
| 57 |
-
x = x.view(-1, 784)
|
| 58 |
-
x = torch.relu(self.fc1(x))
|
| 59 |
-
return self.fc2(x)
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
class DeepLinear(nn.Module):
|
| 63 |
-
"""Two hidden layers without activation (identity mapping)."""
|
| 64 |
-
def __init__(self, num_classes=10, hidden_size=128):
|
| 65 |
-
super().__init__()
|
| 66 |
-
self.fc1 = nn.Linear(784, hidden_size)
|
| 67 |
-
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
| 68 |
-
self.fc3 = nn.Linear(hidden_size, num_classes)
|
| 69 |
-
|
| 70 |
-
def forward(self, x):
|
| 71 |
-
x = x.view(-1, 784)
|
| 72 |
-
x = self.fc1(x) # No activation
|
| 73 |
-
x = self.fc2(x) # No activation
|
| 74 |
-
return self.fc3(x)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class DeepNonLinear(nn.Module):
|
| 78 |
-
"""Two hidden layers with ReLU activation (similar to CNN complexity)."""
|
| 79 |
-
def __init__(self, num_classes=10, hidden_size=128):
|
| 80 |
-
super().__init__()
|
| 81 |
-
self.fc1 = nn.Linear(784, hidden_size)
|
| 82 |
-
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
| 83 |
-
self.fc3 = nn.Linear(hidden_size, num_classes)
|
| 84 |
-
|
| 85 |
-
def forward(self, x):
|
| 86 |
-
x = x.view(-1, 784)
|
| 87 |
-
x = torch.relu(self.fc1(x))
|
| 88 |
-
x = torch.relu(self.fc2(x))
|
| 89 |
-
return self.fc3(x)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
class SimpleCNN(nn.Module):
|
| 93 |
-
"""2-conv CNN for comparison (from hybrid_model)."""
|
| 94 |
-
def __init__(self, num_classes=10):
|
| 95 |
-
super().__init__()
|
| 96 |
-
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
|
| 97 |
-
self.pool = nn.MaxPool2d(2, 2)
|
| 98 |
-
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
| 99 |
-
self.fc1 = nn.Linear(32 * 7 * 7, 128)
|
| 100 |
-
self.fc2 = nn.Linear(128, num_classes)
|
| 101 |
-
|
| 102 |
-
def forward(self, x):
|
| 103 |
-
x = self.pool(torch.relu(self.conv1(x)))
|
| 104 |
-
x = self.pool(torch.relu(self.conv2(x)))
|
| 105 |
-
x = x.view(-1, 32 * 7 * 7)
|
| 106 |
-
x = torch.relu(self.fc1(x))
|
| 107 |
-
x = self.fc2(x)
|
| 108 |
-
return x
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def load_mnist():
|
| 112 |
-
"""Load MNIST train/test data."""
|
| 113 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 114 |
-
trainset = datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
|
| 115 |
-
testset = datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
|
| 116 |
-
|
| 117 |
-
X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 118 |
-
y_train = trainset.targets.numpy()
|
| 119 |
-
|
| 120 |
-
X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 121 |
-
y_test = testset.targets.numpy()
|
| 122 |
-
|
| 123 |
-
return X_train, y_train, X_test, y_test
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def train_model(model, X_train, y_train, X_val, y_val, device, epochs=EPOCHS):
|
| 127 |
-
"""Train a PyTorch model with validation tracking."""
|
| 128 |
-
model = model.to(device)
|
| 129 |
-
criterion = nn.CrossEntropyLoss()
|
| 130 |
-
optimizer = optim.Adam(model.parameters(), lr=LR)
|
| 131 |
-
|
| 132 |
-
# Convert to tensors
|
| 133 |
-
X_train_t = torch.tensor(X_train, dtype=torch.float32)
|
| 134 |
-
y_train_t = torch.tensor(y_train, dtype=torch.long)
|
| 135 |
-
X_val_t = torch.tensor(X_val, dtype=torch.float32)
|
| 136 |
-
y_val_t = torch.tensor(y_val, dtype=torch.long)
|
| 137 |
-
|
| 138 |
-
# Reshape for CNN if needed
|
| 139 |
-
if isinstance(model, SimpleCNN):
|
| 140 |
-
X_train_t = X_train_t.view(-1, 1, 28, 28)
|
| 141 |
-
X_val_t = X_val_t.view(-1, 1, 28, 28)
|
| 142 |
-
|
| 143 |
-
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 144 |
-
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 145 |
-
|
| 146 |
-
history = {'train_acc': [], 'val_acc': []}
|
| 147 |
-
|
| 148 |
-
for epoch in range(epochs):
|
| 149 |
-
model.train()
|
| 150 |
-
train_correct = 0
|
| 151 |
-
train_total = 0
|
| 152 |
-
|
| 153 |
-
for inputs, labels in train_loader:
|
| 154 |
-
inputs, labels = inputs.to(device), labels.to(device)
|
| 155 |
-
optimizer.zero_grad()
|
| 156 |
-
outputs = model(inputs)
|
| 157 |
-
loss = criterion(outputs, labels)
|
| 158 |
-
loss.backward()
|
| 159 |
-
optimizer.step()
|
| 160 |
-
|
| 161 |
-
_, predicted = outputs.max(1)
|
| 162 |
-
train_total += labels.size(0)
|
| 163 |
-
train_correct += predicted.eq(labels).sum().item()
|
| 164 |
-
|
| 165 |
-
train_acc = 100.0 * train_correct / train_total
|
| 166 |
-
|
| 167 |
-
# Validation
|
| 168 |
-
model.eval()
|
| 169 |
-
with torch.no_grad():
|
| 170 |
-
X_val_batch = X_val_t.to(device)
|
| 171 |
-
y_val_batch = y_val_t.to(device)
|
| 172 |
-
outputs = model(X_val_batch)
|
| 173 |
-
_, predicted = outputs.max(1)
|
| 174 |
-
val_acc = 100.0 * predicted.eq(y_val_batch).sum().item() / len(y_val_batch)
|
| 175 |
-
|
| 176 |
-
history['train_acc'].append(train_acc)
|
| 177 |
-
history['val_acc'].append(val_acc)
|
| 178 |
-
|
| 179 |
-
if (epoch + 1) % 2 == 0:
|
| 180 |
-
print(f" Epoch {epoch+1}/{epochs}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
|
| 181 |
-
|
| 182 |
-
return model, history
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def evaluate_model(model, X_test, y_test, device):
|
| 186 |
-
"""Evaluate model on test set."""
|
| 187 |
-
model.eval()
|
| 188 |
-
X_test_t = torch.tensor(X_test, dtype=torch.float32)
|
| 189 |
-
|
| 190 |
-
if isinstance(model, SimpleCNN):
|
| 191 |
-
X_test_t = X_test_t.view(-1, 1, 28, 28)
|
| 192 |
-
|
| 193 |
-
with torch.no_grad():
|
| 194 |
-
X_test_batch = X_test_t.to(device)
|
| 195 |
-
outputs = model(X_test_batch)
|
| 196 |
-
_, predicted = outputs.max(1)
|
| 197 |
-
accuracy = 100.0 * predicted.eq(torch.tensor(y_test).to(device)).sum().item() / len(y_test)
|
| 198 |
-
|
| 199 |
-
return accuracy
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def plot_ablation_results(results, filename):
|
| 203 |
-
"""Plot ablation study results."""
|
| 204 |
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
| 205 |
-
|
| 206 |
-
# Bar chart: Test accuracy
|
| 207 |
-
models = list(results.keys())
|
| 208 |
-
test_accs = [results[m]['test_acc'] for m in models]
|
| 209 |
-
colors = [BLUE_DEEP, GREEN, PURPLE, ORANGE, RED]
|
| 210 |
-
|
| 211 |
-
ax1.bar(range(len(models)), test_accs, color=colors, alpha=0.8)
|
| 212 |
-
ax1.set_xticks(range(len(models)))
|
| 213 |
-
ax1.set_xticklabels(models, rotation=30, ha='right')
|
| 214 |
-
ax1.set_ylabel('Test Accuracy (%)')
|
| 215 |
-
ax1.set_title('Ablation Study: Architecture Comparison')
|
| 216 |
-
ax1.grid(axis='y', alpha=0.3)
|
| 217 |
-
ax1.set_ylim([85, 100])
|
| 218 |
-
|
| 219 |
-
# Add value labels
|
| 220 |
-
for i, acc in enumerate(test_accs):
|
| 221 |
-
ax1.text(i, acc + 0.5, f'{acc:.2f}%', ha='center', va='bottom', fontsize=9)
|
| 222 |
-
|
| 223 |
-
# Learning curves
|
| 224 |
-
for i, (model_name, data) in enumerate(results.items()):
|
| 225 |
-
if 'val_acc' in data:
|
| 226 |
-
epochs = range(1, len(data['val_acc']) + 1)
|
| 227 |
-
ax2.plot(epochs, data['val_acc'], label=model_name, color=colors[i], linewidth=2)
|
| 228 |
-
|
| 229 |
-
ax2.set_xlabel('Epoch')
|
| 230 |
-
ax2.set_ylabel('Validation Accuracy (%)')
|
| 231 |
-
ax2.set_title('Training Dynamics')
|
| 232 |
-
ax2.legend()
|
| 233 |
-
ax2.grid(alpha=0.3)
|
| 234 |
-
|
| 235 |
-
plt.tight_layout()
|
| 236 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
|
| 237 |
-
plt.close()
|
| 238 |
-
print(f"Saved {filename}")
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
def main():
|
| 242 |
-
set_seeds(SEED)
|
| 243 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 244 |
-
print(f"Using device: {device}\n")
|
| 245 |
-
|
| 246 |
-
# Load data
|
| 247 |
-
print("Loading MNIST...")
|
| 248 |
-
X_train_full, y_train_full, X_test, y_test = load_mnist()
|
| 249 |
-
|
| 250 |
-
# Split train into train/val
|
| 251 |
-
X_train, X_val, y_train, y_val = train_test_split(
|
| 252 |
-
X_train_full, y_train_full, test_size=0.2, random_state=SEED, stratify=y_train_full
|
| 253 |
-
)
|
| 254 |
-
print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}\n")
|
| 255 |
-
|
| 256 |
-
results = {}
|
| 257 |
-
|
| 258 |
-
# 1. Shallow Linear
|
| 259 |
-
print("="*60)
|
| 260 |
-
print("Training: Shallow Linear (baseline)")
|
| 261 |
-
print("="*60)
|
| 262 |
-
model = ShallowLinear(num_classes=10)
|
| 263 |
-
model, history = train_model(model, X_train, y_train, X_val, y_val, device)
|
| 264 |
-
test_acc = evaluate_model(model, X_test, y_test, device)
|
| 265 |
-
results['Shallow Linear'] = {'test_acc': test_acc, 'val_acc': history['val_acc']}
|
| 266 |
-
print(f"Test Accuracy: {test_acc:.2f}%\n")
|
| 267 |
-
|
| 268 |
-
# 2. Shallow Non-Linear
|
| 269 |
-
print("="*60)
|
| 270 |
-
print("Training: Shallow Non-Linear (+ ReLU)")
|
| 271 |
-
print("="*60)
|
| 272 |
-
model = ShallowNonLinear(num_classes=10)
|
| 273 |
-
model, history = train_model(model, X_train, y_train, X_val, y_val, device)
|
| 274 |
-
test_acc = evaluate_model(model, X_test, y_test, device)
|
| 275 |
-
results['Shallow NonLinear'] = {'test_acc': test_acc, 'val_acc': history['val_acc']}
|
| 276 |
-
print(f"Test Accuracy: {test_acc:.2f}%\n")
|
| 277 |
-
|
| 278 |
-
# 3. Deep Linear
|
| 279 |
-
print("="*60)
|
| 280 |
-
print("Training: Deep Linear (+ Depth, no ReLU)")
|
| 281 |
-
print("="*60)
|
| 282 |
-
model = DeepLinear(num_classes=10)
|
| 283 |
-
model, history = train_model(model, X_train, y_train, X_val, y_val, device)
|
| 284 |
-
test_acc = evaluate_model(model, X_test, y_test, device)
|
| 285 |
-
results['Deep Linear'] = {'test_acc': test_acc, 'val_acc': history['val_acc']}
|
| 286 |
-
print(f"Test Accuracy: {test_acc:.2f}%\n")
|
| 287 |
-
|
| 288 |
-
# 4. Deep Non-Linear
|
| 289 |
-
print("="*60)
|
| 290 |
-
print("Training: Deep Non-Linear (+ Depth + ReLU)")
|
| 291 |
-
print("="*60)
|
| 292 |
-
model = DeepNonLinear(num_classes=10)
|
| 293 |
-
model, history = train_model(model, X_train, y_train, X_val, y_val, device)
|
| 294 |
-
test_acc = evaluate_model(model, X_test, y_test, device)
|
| 295 |
-
results['Deep NonLinear'] = {'test_acc': test_acc, 'val_acc': history['val_acc']}
|
| 296 |
-
print(f"Test Accuracy: {test_acc:.2f}%\n")
|
| 297 |
-
|
| 298 |
-
# 5. CNN (for reference)
|
| 299 |
-
print("="*60)
|
| 300 |
-
print("Training: CNN (convolutional + non-linear)")
|
| 301 |
-
print("="*60)
|
| 302 |
-
model = SimpleCNN(num_classes=10)
|
| 303 |
-
# Reshape data for CNN
|
| 304 |
-
X_train_cnn = X_train.reshape(-1, 28, 28)
|
| 305 |
-
X_val_cnn = X_val.reshape(-1, 28, 28)
|
| 306 |
-
X_test_cnn = X_test.reshape(-1, 28, 28)
|
| 307 |
-
model, history = train_model(model, X_train_cnn, y_train, X_val_cnn, y_val, device)
|
| 308 |
-
test_acc = evaluate_model(model, X_test_cnn, y_test, device)
|
| 309 |
-
results['CNN'] = {'test_acc': test_acc, 'val_acc': history['val_acc']}
|
| 310 |
-
print(f"Test Accuracy: {test_acc:.2f}%\n")
|
| 311 |
-
|
| 312 |
-
# Summary
|
| 313 |
-
print("="*60)
|
| 314 |
-
print("ABLATION STUDY SUMMARY")
|
| 315 |
-
print("="*60)
|
| 316 |
-
for model_name, data in results.items():
|
| 317 |
-
print(f"{model_name:20s}: {data['test_acc']:.2f}%")
|
| 318 |
-
|
| 319 |
-
# Analysis
|
| 320 |
-
print("\n" + "="*60)
|
| 321 |
-
print("KEY INSIGHTS")
|
| 322 |
-
print("="*60)
|
| 323 |
-
shallow_linear = results['Shallow Linear']['test_acc']
|
| 324 |
-
shallow_nonlinear = results['Shallow NonLinear']['test_acc']
|
| 325 |
-
deep_linear = results['Deep Linear']['test_acc']
|
| 326 |
-
deep_nonlinear = results['Deep NonLinear']['test_acc']
|
| 327 |
-
|
| 328 |
-
nonlinearity_gain = shallow_nonlinear - shallow_linear
|
| 329 |
-
depth_gain = deep_linear - shallow_linear
|
| 330 |
-
combined_gain = deep_nonlinear - shallow_linear
|
| 331 |
-
|
| 332 |
-
print(f"Non-linearity alone (shallow): +{nonlinearity_gain:.2f} pp")
|
| 333 |
-
print(f"Depth alone (linear): +{depth_gain:.2f} pp")
|
| 334 |
-
print(f"Depth + Non-linearity (combined): +{combined_gain:.2f} pp")
|
| 335 |
-
print(f"CNN (convolutional structure): +{results['CNN']['test_acc'] - shallow_linear:.2f} pp")
|
| 336 |
-
|
| 337 |
-
# Plot results
|
| 338 |
-
plot_ablation_results(results, 'fig_13_ablation_study.png')
|
| 339 |
-
|
| 340 |
-
print("\n✓ Ablation study complete!")
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
if __name__ == "__main__":
|
| 344 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/11_learning_curves.py
DELETED
|
@@ -1,228 +0,0 @@
|
|
| 1 |
-
# Exp 11 – Learning Curves Visualization
|
| 2 |
-
# Generate training/validation loss and accuracy curves from saved training history
|
| 3 |
-
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import pickle
|
| 6 |
-
import os
|
| 7 |
-
import numpy as np
|
| 8 |
-
|
| 9 |
-
from src import config
|
| 10 |
-
|
| 11 |
-
# --- Configuration ---
|
| 12 |
-
BLUE_DEEP = "#5E81AC"
|
| 13 |
-
ORANGE = "#D08770"
|
| 14 |
-
GRAY_LIGHT = "#D8DEE9"
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def plot_learning_curves(history, title, filename):
|
| 18 |
-
"""
|
| 19 |
-
Plot training and validation curves for loss and accuracy.
|
| 20 |
-
|
| 21 |
-
Args:
|
| 22 |
-
history: Dictionary with keys 'train_loss', 'val_loss', 'train_acc', 'val_acc'
|
| 23 |
-
title: Plot title
|
| 24 |
-
filename: Output filename
|
| 25 |
-
"""
|
| 26 |
-
epochs = range(1, len(history['train_loss']) + 1)
|
| 27 |
-
|
| 28 |
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
| 29 |
-
|
| 30 |
-
# Loss curves
|
| 31 |
-
ax1.plot(epochs, history['train_loss'], label='Training Loss',
|
| 32 |
-
color=BLUE_DEEP, linewidth=2, marker='o', markersize=4)
|
| 33 |
-
ax1.plot(epochs, history['val_loss'], label='Validation Loss',
|
| 34 |
-
color=ORANGE, linewidth=2, marker='s', markersize=4)
|
| 35 |
-
ax1.set_xlabel('Epoch', fontsize=12)
|
| 36 |
-
ax1.set_ylabel('Loss', fontsize=12)
|
| 37 |
-
ax1.set_title('Training and Validation Loss', fontsize=13, fontweight='bold')
|
| 38 |
-
ax1.legend(loc='best', fontsize=10)
|
| 39 |
-
ax1.grid(alpha=0.3)
|
| 40 |
-
|
| 41 |
-
# Highlight best validation loss
|
| 42 |
-
best_val_epoch = np.argmin(history['val_loss'])
|
| 43 |
-
best_val_loss = history['val_loss'][best_val_epoch]
|
| 44 |
-
ax1.axvline(x=best_val_epoch + 1, color='red', linestyle='--', alpha=0.5, linewidth=1)
|
| 45 |
-
ax1.plot(best_val_epoch + 1, best_val_loss, 'r*', markersize=15,
|
| 46 |
-
label=f'Best Val Loss: {best_val_loss:.4f} @ Epoch {best_val_epoch + 1}')
|
| 47 |
-
ax1.legend(loc='best', fontsize=9)
|
| 48 |
-
|
| 49 |
-
# Accuracy curves
|
| 50 |
-
ax2.plot(epochs, history['train_acc'], label='Training Accuracy',
|
| 51 |
-
color=BLUE_DEEP, linewidth=2, marker='o', markersize=4)
|
| 52 |
-
ax2.plot(epochs, history['val_acc'], label='Validation Accuracy',
|
| 53 |
-
color=ORANGE, linewidth=2, marker='s', markersize=4)
|
| 54 |
-
ax2.set_xlabel('Epoch', fontsize=12)
|
| 55 |
-
ax2.set_ylabel('Accuracy (%)', fontsize=12)
|
| 56 |
-
ax2.set_title('Training and Validation Accuracy', fontsize=13, fontweight='bold')
|
| 57 |
-
ax2.legend(loc='best', fontsize=10)
|
| 58 |
-
ax2.grid(alpha=0.3)
|
| 59 |
-
|
| 60 |
-
# Highlight best validation accuracy
|
| 61 |
-
best_val_epoch = np.argmax(history['val_acc'])
|
| 62 |
-
best_val_acc = history['val_acc'][best_val_epoch]
|
| 63 |
-
ax2.axvline(x=best_val_epoch + 1, color='red', linestyle='--', alpha=0.5, linewidth=1)
|
| 64 |
-
ax2.plot(best_val_epoch + 1, best_val_acc, 'r*', markersize=15,
|
| 65 |
-
label=f'Best Val Acc: {best_val_acc:.2f}% @ Epoch {best_val_epoch + 1}')
|
| 66 |
-
ax2.legend(loc='best', fontsize=9)
|
| 67 |
-
|
| 68 |
-
plt.suptitle(title, fontsize=15, fontweight='bold', y=1.02)
|
| 69 |
-
plt.tight_layout()
|
| 70 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
|
| 71 |
-
plt.close()
|
| 72 |
-
print(f"✓ Saved {filename}")
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def analyze_overfitting(history):
|
| 76 |
-
"""Analyze training dynamics for overfitting indicators."""
|
| 77 |
-
train_acc = history['train_acc']
|
| 78 |
-
val_acc = history['val_acc']
|
| 79 |
-
train_loss = history['train_loss']
|
| 80 |
-
val_loss = history['val_loss']
|
| 81 |
-
|
| 82 |
-
# Calculate gaps
|
| 83 |
-
final_acc_gap = train_acc[-1] - val_acc[-1]
|
| 84 |
-
final_loss_gap = val_loss[-1] - train_loss[-1]
|
| 85 |
-
|
| 86 |
-
# Check for divergence (sign of overfitting)
|
| 87 |
-
mid_point = len(train_acc) // 2
|
| 88 |
-
early_acc_gap = np.mean(train_acc[:mid_point]) - np.mean(val_acc[:mid_point])
|
| 89 |
-
late_acc_gap = np.mean(train_acc[mid_point:]) - np.mean(val_acc[mid_point:])
|
| 90 |
-
gap_increase = late_acc_gap - early_acc_gap
|
| 91 |
-
|
| 92 |
-
print("\n" + "="*60)
|
| 93 |
-
print("OVERFITTING ANALYSIS")
|
| 94 |
-
print("="*60)
|
| 95 |
-
print(f"Final Train Accuracy: {train_acc[-1]:.2f}%")
|
| 96 |
-
print(f"Final Validation Accuracy: {val_acc[-1]:.2f}%")
|
| 97 |
-
print(f"Accuracy Gap: {final_acc_gap:.2f} pp")
|
| 98 |
-
print(f"Loss Gap: {final_loss_gap:.4f}")
|
| 99 |
-
print(f"Gap Increase (early→late): {gap_increase:.2f} pp")
|
| 100 |
-
|
| 101 |
-
if final_acc_gap > 5.0:
|
| 102 |
-
print("\n⚠️ WARNING: Significant train-val accuracy gap detected (>5 pp)")
|
| 103 |
-
print(" Consider: regularization, dropout, or early stopping")
|
| 104 |
-
elif gap_increase > 2.0:
|
| 105 |
-
print("\n⚠️ WARNING: Train-val gap widening over time")
|
| 106 |
-
print(" Model may be starting to overfit")
|
| 107 |
-
else:
|
| 108 |
-
print("\n✓ No significant overfitting detected")
|
| 109 |
-
|
| 110 |
-
# Best epoch analysis
|
| 111 |
-
best_epoch = np.argmax(val_acc)
|
| 112 |
-
total_epochs = len(val_acc)
|
| 113 |
-
print(f"\nBest validation accuracy achieved at epoch {best_epoch + 1}/{total_epochs}")
|
| 114 |
-
if best_epoch < total_epochs - 2:
|
| 115 |
-
print(f"⚠️ Training continued for {total_epochs - best_epoch - 1} epochs after best model")
|
| 116 |
-
print(" Early stopping could have saved training time")
|
| 117 |
-
|
| 118 |
-
return {
|
| 119 |
-
'final_acc_gap': final_acc_gap,
|
| 120 |
-
'final_loss_gap': final_loss_gap,
|
| 121 |
-
'gap_increase': gap_increase,
|
| 122 |
-
'best_epoch': best_epoch + 1,
|
| 123 |
-
'total_epochs': total_epochs
|
| 124 |
-
}
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def plot_comparative_curves(histories, labels, filename):
|
| 128 |
-
"""
|
| 129 |
-
Plot multiple models' learning curves for comparison.
|
| 130 |
-
|
| 131 |
-
Args:
|
| 132 |
-
histories: List of history dictionaries
|
| 133 |
-
labels: List of model names
|
| 134 |
-
filename: Output filename
|
| 135 |
-
"""
|
| 136 |
-
colors = [BLUE_DEEP, ORANGE, "#A3BE8C", "#BF616A", "#B48EAD"]
|
| 137 |
-
|
| 138 |
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
| 139 |
-
|
| 140 |
-
for i, (history, label) in enumerate(zip(histories, labels)):
|
| 141 |
-
epochs = range(1, len(history['val_loss']) + 1)
|
| 142 |
-
color = colors[i % len(colors)]
|
| 143 |
-
|
| 144 |
-
# Validation loss
|
| 145 |
-
ax1.plot(epochs, history['val_loss'], label=label,
|
| 146 |
-
color=color, linewidth=2, marker='o', markersize=3)
|
| 147 |
-
|
| 148 |
-
# Validation accuracy
|
| 149 |
-
ax2.plot(epochs, history['val_acc'], label=label,
|
| 150 |
-
color=color, linewidth=2, marker='o', markersize=3)
|
| 151 |
-
|
| 152 |
-
ax1.set_xlabel('Epoch', fontsize=12)
|
| 153 |
-
ax1.set_ylabel('Validation Loss', fontsize=12)
|
| 154 |
-
ax1.set_title('Validation Loss Comparison', fontsize=13, fontweight='bold')
|
| 155 |
-
ax1.legend(loc='best', fontsize=10)
|
| 156 |
-
ax1.grid(alpha=0.3)
|
| 157 |
-
|
| 158 |
-
ax2.set_xlabel('Epoch', fontsize=12)
|
| 159 |
-
ax2.set_ylabel('Validation Accuracy (%)', fontsize=12)
|
| 160 |
-
ax2.set_title('Validation Accuracy Comparison', fontsize=13, fontweight='bold')
|
| 161 |
-
ax2.legend(loc='best', fontsize=10)
|
| 162 |
-
ax2.grid(alpha=0.3)
|
| 163 |
-
|
| 164 |
-
plt.tight_layout()
|
| 165 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
|
| 166 |
-
plt.close()
|
| 167 |
-
print(f"✓ Saved {filename}")
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def main():
|
| 171 |
-
print("="*60)
|
| 172 |
-
print("LEARNING CURVES VISUALIZATION")
|
| 173 |
-
print("="*60)
|
| 174 |
-
|
| 175 |
-
# Load CNN training history
|
| 176 |
-
history_path = os.path.join(config.MODELS_DIR, 'cnn_10class_history.pkl')
|
| 177 |
-
|
| 178 |
-
if os.path.exists(history_path):
|
| 179 |
-
print(f"\nLoading training history from {history_path}")
|
| 180 |
-
with open(history_path, 'rb') as f:
|
| 181 |
-
history = pickle.load(f)
|
| 182 |
-
|
| 183 |
-
print(f"Loaded {len(history['train_loss'])} epochs of training data")
|
| 184 |
-
|
| 185 |
-
# Plot learning curves
|
| 186 |
-
plot_learning_curves(
|
| 187 |
-
history,
|
| 188 |
-
'CNN Training Dynamics (MNIST 10-class)',
|
| 189 |
-
'fig_14_learning_curves.png'
|
| 190 |
-
)
|
| 191 |
-
|
| 192 |
-
# Analyze for overfitting
|
| 193 |
-
analysis = analyze_overfitting(history)
|
| 194 |
-
|
| 195 |
-
else:
|
| 196 |
-
print(f"\n⚠️ Training history not found at {history_path}")
|
| 197 |
-
print("Please run 'python src/train_models.py' first to generate training history")
|
| 198 |
-
return
|
| 199 |
-
|
| 200 |
-
# Check for Fashion-MNIST history
|
| 201 |
-
fashion_history_path = os.path.join(config.MODELS_DIR, 'cnn_fashion_history.pkl')
|
| 202 |
-
if os.path.exists(fashion_history_path):
|
| 203 |
-
print(f"\nLoading Fashion-MNIST training history...")
|
| 204 |
-
with open(fashion_history_path, 'rb') as f:
|
| 205 |
-
fashion_history = pickle.load(f)
|
| 206 |
-
|
| 207 |
-
plot_learning_curves(
|
| 208 |
-
fashion_history,
|
| 209 |
-
'CNN Training Dynamics (Fashion-MNIST)',
|
| 210 |
-
'fig_15_learning_curves_fashion.png'
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
# Comparative plot
|
| 214 |
-
print("\nGenerating comparative learning curves...")
|
| 215 |
-
plot_comparative_curves(
|
| 216 |
-
[history, fashion_history],
|
| 217 |
-
['MNIST', 'Fashion-MNIST'],
|
| 218 |
-
'fig_16_learning_curves_comparison.png'
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
print("\n" + "="*60)
|
| 222 |
-
print("✓ Learning curves visualization complete!")
|
| 223 |
-
print("="*60)
|
| 224 |
-
print(f"Results saved to: {config.RESULTS_DIR}")
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
if __name__ == "__main__":
|
| 228 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/12_roc_analysis.py
DELETED
|
@@ -1,291 +0,0 @@
|
|
| 1 |
-
# Exp 12 – ROC Curve Analysis for Hard Classification Pairs
|
| 2 |
-
# Particularly focused on the challenging 3 vs 8 classification
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import matplotlib.pyplot as plt
|
| 6 |
-
import torch
|
| 7 |
-
from torchvision import datasets, transforms
|
| 8 |
-
from sklearn.decomposition import TruncatedSVD
|
| 9 |
-
from sklearn.linear_model import LogisticRegression
|
| 10 |
-
from sklearn.metrics import roc_curve, auc, roc_auc_score
|
| 11 |
-
import pickle
|
| 12 |
-
import os
|
| 13 |
-
|
| 14 |
-
from src.hybrid_model import SimpleCNN
|
| 15 |
-
from src import config
|
| 16 |
-
|
| 17 |
-
# --- Configuration ---
|
| 18 |
-
BLUE_DEEP = "#5E81AC"
|
| 19 |
-
ORANGE = "#D08770"
|
| 20 |
-
GREEN = "#A3BE8C"
|
| 21 |
-
RED = "#BF616A"
|
| 22 |
-
|
| 23 |
-
SEED = 42
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def set_seeds(seed):
|
| 27 |
-
"""Set random seeds for reproducibility."""
|
| 28 |
-
np.random.seed(seed)
|
| 29 |
-
torch.manual_seed(seed)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def load_mnist_subset(digit_a=3, digit_b=8):
|
| 33 |
-
"""Load MNIST subset with only two specified digits."""
|
| 34 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 35 |
-
testset = datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
|
| 36 |
-
|
| 37 |
-
# Filter for specific digits
|
| 38 |
-
mask = (testset.targets == digit_a) | (testset.targets == digit_b)
|
| 39 |
-
X = testset.data[mask].numpy().astype(np.float32) / 255.0
|
| 40 |
-
y = testset.targets[mask].numpy()
|
| 41 |
-
|
| 42 |
-
# Binary labels: 0 for digit_a, 1 for digit_b
|
| 43 |
-
y_binary = (y == digit_b).astype(int)
|
| 44 |
-
|
| 45 |
-
print(f"Loaded {len(X)} samples: {np.sum(y_binary == 0)} digit-{digit_a}, {np.sum(y_binary == 1)} digit-{digit_b}")
|
| 46 |
-
|
| 47 |
-
return X, y_binary, digit_a, digit_b
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def get_svd_probabilities(X, svd_path=config.SVD_MODEL_PATH):
|
| 51 |
-
"""Get probability scores from SVD+LR model."""
|
| 52 |
-
with open(svd_path, 'rb') as f:
|
| 53 |
-
svd = pickle.load(f)
|
| 54 |
-
|
| 55 |
-
# Get mean from saved model
|
| 56 |
-
if hasattr(svd, '_train_mean'):
|
| 57 |
-
mean = svd._train_mean
|
| 58 |
-
else:
|
| 59 |
-
mean = np.zeros(784)
|
| 60 |
-
|
| 61 |
-
X_flat = X.reshape(-1, 784)
|
| 62 |
-
X_centered = X_flat - mean
|
| 63 |
-
X_svd = svd.transform(X_centered)
|
| 64 |
-
|
| 65 |
-
# Train a binary classifier on 3 vs 8
|
| 66 |
-
print("Training binary SVD+LR classifier for ROC analysis...")
|
| 67 |
-
X_train_full, y_train_full = load_mnist_binary_train()
|
| 68 |
-
X_train_centered = X_train_full - mean
|
| 69 |
-
X_train_svd = svd.transform(X_train_centered)
|
| 70 |
-
|
| 71 |
-
clf = LogisticRegression(random_state=SEED, max_iter=1000)
|
| 72 |
-
clf.fit(X_train_svd, y_train_full)
|
| 73 |
-
|
| 74 |
-
# Get probability scores (for positive class)
|
| 75 |
-
probs = clf.predict_proba(X_svd)[:, 1]
|
| 76 |
-
|
| 77 |
-
return probs
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def load_mnist_binary_train(digit_a=3, digit_b=8):
|
| 81 |
-
"""Load training data for binary classification."""
|
| 82 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 83 |
-
trainset = datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
|
| 84 |
-
|
| 85 |
-
mask = (trainset.targets == digit_a) | (trainset.targets == digit_b)
|
| 86 |
-
X = trainset.data[mask].numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 87 |
-
y = trainset.targets[mask].numpy()
|
| 88 |
-
y_binary = (y == digit_b).astype(int)
|
| 89 |
-
|
| 90 |
-
return X, y_binary
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def get_cnn_probabilities(X, cnn_path=config.CNN_MODEL_PATH, digit_a=3, digit_b=8):
|
| 94 |
-
"""Get probability scores from CNN model."""
|
| 95 |
-
device = torch.device('cpu')
|
| 96 |
-
cnn = SimpleCNN(num_classes=10)
|
| 97 |
-
cnn.load_state_dict(torch.load(cnn_path, map_location=device))
|
| 98 |
-
cnn.eval()
|
| 99 |
-
|
| 100 |
-
X_tensor = torch.tensor(X, dtype=torch.float32).view(-1, 1, 28, 28)
|
| 101 |
-
|
| 102 |
-
with torch.no_grad():
|
| 103 |
-
outputs = cnn(X_tensor)
|
| 104 |
-
probs = torch.softmax(outputs, dim=1).numpy()
|
| 105 |
-
|
| 106 |
-
# Extract probabilities for the two digits of interest
|
| 107 |
-
prob_a = probs[:, digit_a]
|
| 108 |
-
prob_b = probs[:, digit_b]
|
| 109 |
-
|
| 110 |
-
# Normalize to binary probability (probability of digit_b given only these two options)
|
| 111 |
-
binary_probs = prob_b / (prob_a + prob_b + 1e-10)
|
| 112 |
-
|
| 113 |
-
return binary_probs
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def plot_roc_curves(results, digit_a, digit_b, filename):
|
| 117 |
-
"""
|
| 118 |
-
Plot ROC curves for different models.
|
| 119 |
-
|
| 120 |
-
Args:
|
| 121 |
-
results: Dictionary with model names as keys and (fpr, tpr, auc) as values
|
| 122 |
-
digit_a, digit_b: The two digits being classified
|
| 123 |
-
filename: Output filename
|
| 124 |
-
"""
|
| 125 |
-
plt.figure(figsize=(10, 8))
|
| 126 |
-
|
| 127 |
-
colors = [BLUE_DEEP, ORANGE, GREEN, RED]
|
| 128 |
-
|
| 129 |
-
for i, (model_name, (fpr, tpr, roc_auc)) in enumerate(results.items()):
|
| 130 |
-
plt.plot(fpr, tpr, color=colors[i % len(colors)], linewidth=2.5,
|
| 131 |
-
label=f'{model_name} (AUC = {roc_auc:.4f})', marker='o', markersize=4, markevery=20)
|
| 132 |
-
|
| 133 |
-
# Random classifier baseline
|
| 134 |
-
plt.plot([0, 1], [0, 1], color='gray', linestyle='--', linewidth=2, label='Random Classifier (AUC = 0.5000)')
|
| 135 |
-
|
| 136 |
-
plt.xlim([0.0, 1.0])
|
| 137 |
-
plt.ylim([0.0, 1.05])
|
| 138 |
-
plt.xlabel('False Positive Rate', fontsize=13)
|
| 139 |
-
plt.ylabel('True Positive Rate', fontsize=13)
|
| 140 |
-
plt.title(f'ROC Curves: Digit {digit_a} vs {digit_b} Classification', fontsize=14, fontweight='bold')
|
| 141 |
-
plt.legend(loc='lower right', fontsize=11)
|
| 142 |
-
plt.grid(alpha=0.3)
|
| 143 |
-
|
| 144 |
-
# Highlight perfect classification region
|
| 145 |
-
plt.fill_between([0, 0, 0.1], [0.9, 1, 1], alpha=0.1, color='green', label='_nolegend_')
|
| 146 |
-
plt.text(0.02, 0.95, 'Ideal Region', fontsize=9, color='green', alpha=0.7)
|
| 147 |
-
|
| 148 |
-
plt.tight_layout()
|
| 149 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
|
| 150 |
-
plt.close()
|
| 151 |
-
print(f"✓ Saved {filename}")
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def plot_roc_zoom(results, digit_a, digit_b, filename):
|
| 155 |
-
"""Plot zoomed-in ROC curves focusing on high-sensitivity region."""
|
| 156 |
-
plt.figure(figsize=(10, 8))
|
| 157 |
-
|
| 158 |
-
colors = [BLUE_DEEP, ORANGE, GREEN, RED]
|
| 159 |
-
|
| 160 |
-
for i, (model_name, (fpr, tpr, roc_auc)) in enumerate(results.items()):
|
| 161 |
-
plt.plot(fpr, tpr, color=colors[i % len(colors)], linewidth=2.5,
|
| 162 |
-
label=f'{model_name} (AUC = {roc_auc:.4f})', marker='o', markersize=5, markevery=10)
|
| 163 |
-
|
| 164 |
-
plt.plot([0, 1], [0, 1], color='gray', linestyle='--', linewidth=2, alpha=0.5)
|
| 165 |
-
|
| 166 |
-
# Zoom to interesting region
|
| 167 |
-
plt.xlim([0.0, 0.2])
|
| 168 |
-
plt.ylim([0.8, 1.0])
|
| 169 |
-
plt.xlabel('False Positive Rate', fontsize=13)
|
| 170 |
-
plt.ylabel('True Positive Rate', fontsize=13)
|
| 171 |
-
plt.title(f'ROC Curves (Zoomed): Digit {digit_a} vs {digit_b}', fontsize=14, fontweight='bold')
|
| 172 |
-
plt.legend(loc='lower right', fontsize=11)
|
| 173 |
-
plt.grid(alpha=0.3)
|
| 174 |
-
|
| 175 |
-
plt.tight_layout()
|
| 176 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
|
| 177 |
-
plt.close()
|
| 178 |
-
print(f"✓ Saved {filename}")
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
def analyze_threshold_performance(y_true, y_probs, thresholds=[0.3, 0.5, 0.7, 0.9]):
|
| 182 |
-
"""Analyze model performance at different decision thresholds."""
|
| 183 |
-
print("\n" + "="*60)
|
| 184 |
-
print("THRESHOLD SENSITIVITY ANALYSIS")
|
| 185 |
-
print("="*60)
|
| 186 |
-
print(f"{'Threshold':<12} {'Accuracy':<12} {'TPR':<12} {'FPR':<12} {'Precision':<12}")
|
| 187 |
-
print("-"*60)
|
| 188 |
-
|
| 189 |
-
for threshold in thresholds:
|
| 190 |
-
y_pred = (y_probs >= threshold).astype(int)
|
| 191 |
-
|
| 192 |
-
# Calculate metrics
|
| 193 |
-
tp = np.sum((y_pred == 1) & (y_true == 1))
|
| 194 |
-
tn = np.sum((y_pred == 0) & (y_true == 0))
|
| 195 |
-
fp = np.sum((y_pred == 1) & (y_true == 0))
|
| 196 |
-
fn = np.sum((y_pred == 0) & (y_true == 1))
|
| 197 |
-
|
| 198 |
-
accuracy = (tp + tn) / len(y_true)
|
| 199 |
-
tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 200 |
-
fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
|
| 201 |
-
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
| 202 |
-
|
| 203 |
-
print(f"{threshold:<12.1f} {accuracy:<12.4f} {tpr:<12.4f} {fpr:<12.4f} {precision:<12.4f}")
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
def main():
|
| 207 |
-
set_seeds(SEED)
|
| 208 |
-
|
| 209 |
-
print("="*60)
|
| 210 |
-
print("ROC CURVE ANALYSIS: Digit 3 vs 8")
|
| 211 |
-
print("="*60)
|
| 212 |
-
|
| 213 |
-
# Load test data
|
| 214 |
-
print("\nLoading test data...")
|
| 215 |
-
X_test, y_test, digit_a, digit_b = load_mnist_subset(digit_a=3, digit_b=8)
|
| 216 |
-
|
| 217 |
-
results = {}
|
| 218 |
-
|
| 219 |
-
# SVD+LR model
|
| 220 |
-
print("\n" + "-"*60)
|
| 221 |
-
print("Evaluating SVD+LR model...")
|
| 222 |
-
print("-"*60)
|
| 223 |
-
try:
|
| 224 |
-
svd_probs = get_svd_probabilities(X_test)
|
| 225 |
-
fpr_svd, tpr_svd, _ = roc_curve(y_test, svd_probs)
|
| 226 |
-
auc_svd = auc(fpr_svd, tpr_svd)
|
| 227 |
-
results['SVD+LR'] = (fpr_svd, tpr_svd, auc_svd)
|
| 228 |
-
print(f"✓ SVD+LR AUC: {auc_svd:.4f}")
|
| 229 |
-
|
| 230 |
-
analyze_threshold_performance(y_test, svd_probs)
|
| 231 |
-
except Exception as e:
|
| 232 |
-
print(f"⚠️ Could not evaluate SVD model: {e}")
|
| 233 |
-
|
| 234 |
-
# CNN model
|
| 235 |
-
print("\n" + "-"*60)
|
| 236 |
-
print("Evaluating CNN model...")
|
| 237 |
-
print("-"*60)
|
| 238 |
-
try:
|
| 239 |
-
cnn_probs = get_cnn_probabilities(X_test, digit_a=digit_a, digit_b=digit_b)
|
| 240 |
-
fpr_cnn, tpr_cnn, _ = roc_curve(y_test, cnn_probs)
|
| 241 |
-
auc_cnn = auc(fpr_cnn, tpr_cnn)
|
| 242 |
-
results['CNN'] = (fpr_cnn, tpr_cnn, auc_cnn)
|
| 243 |
-
print(f"✓ CNN AUC: {auc_cnn:.4f}")
|
| 244 |
-
|
| 245 |
-
analyze_threshold_performance(y_test, cnn_probs)
|
| 246 |
-
except Exception as e:
|
| 247 |
-
print(f"⚠️ Could not evaluate CNN model: {e}")
|
| 248 |
-
|
| 249 |
-
# Plot results
|
| 250 |
-
if len(results) > 0:
|
| 251 |
-
print("\n" + "="*60)
|
| 252 |
-
print("Generating ROC visualizations...")
|
| 253 |
-
print("="*60)
|
| 254 |
-
|
| 255 |
-
plot_roc_curves(results, digit_a, digit_b, 'fig_17_roc_curves.png')
|
| 256 |
-
plot_roc_zoom(results, digit_a, digit_b, 'fig_18_roc_curves_zoom.png')
|
| 257 |
-
|
| 258 |
-
# Summary
|
| 259 |
-
print("\n" + "="*60)
|
| 260 |
-
print("SUMMARY")
|
| 261 |
-
print("="*60)
|
| 262 |
-
for model_name, (_, _, roc_auc) in results.items():
|
| 263 |
-
print(f"{model_name:15s}: AUC = {roc_auc:.4f}")
|
| 264 |
-
|
| 265 |
-
# Interpretation
|
| 266 |
-
print("\n" + "="*60)
|
| 267 |
-
print("INTERPRETATION")
|
| 268 |
-
print("="*60)
|
| 269 |
-
print("• AUC = 1.0: Perfect classifier")
|
| 270 |
-
print("• AUC = 0.9-1.0: Excellent")
|
| 271 |
-
print("• AUC = 0.8-0.9: Good")
|
| 272 |
-
print("• AUC = 0.7-0.8: Fair")
|
| 273 |
-
print("• AUC = 0.5: Random classifier")
|
| 274 |
-
|
| 275 |
-
if 'CNN' in results and 'SVD+LR' in results:
|
| 276 |
-
auc_diff = results['CNN'][2] - results['SVD+LR'][2]
|
| 277 |
-
print(f"\nCNN advantage over SVD: {auc_diff:.4f} AUC points")
|
| 278 |
-
if auc_diff > 0.05:
|
| 279 |
-
print("→ CNN shows substantially better discrimination ability")
|
| 280 |
-
elif auc_diff > 0.02:
|
| 281 |
-
print("→ CNN shows moderately better discrimination")
|
| 282 |
-
else:
|
| 283 |
-
print("→ Models show similar discrimination ability")
|
| 284 |
-
|
| 285 |
-
print("\n✓ ROC analysis complete!")
|
| 286 |
-
else:
|
| 287 |
-
print("\n⚠️ No models could be evaluated. Please train models first.")
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
if __name__ == "__main__":
|
| 291 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/13_per_class_metrics.py
DELETED
|
@@ -1,366 +0,0 @@
|
|
| 1 |
-
# Exp 13 – Per-Class Performance Metrics
|
| 2 |
-
# Generate detailed classification report with Precision, Recall, F1-Score for each digit
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import matplotlib.pyplot as plt
|
| 6 |
-
import pandas as pd
|
| 7 |
-
import seaborn as sns
|
| 8 |
-
import torch
|
| 9 |
-
from torchvision import datasets, transforms
|
| 10 |
-
from sklearn.decomposition import TruncatedSVD
|
| 11 |
-
from sklearn.linear_model import LogisticRegression
|
| 12 |
-
from sklearn.metrics import classification_report, precision_recall_fscore_support, confusion_matrix
|
| 13 |
-
import pickle
|
| 14 |
-
import os
|
| 15 |
-
|
| 16 |
-
from src.hybrid_model import SimpleCNN
|
| 17 |
-
from src import config
|
| 18 |
-
|
| 19 |
-
# --- Configuration ---
|
| 20 |
-
BLUE_DEEP = "#5E81AC"
|
| 21 |
-
ORANGE = "#D08770"
|
| 22 |
-
GREEN = "#A3BE8C"
|
| 23 |
-
RED = "#BF616A"
|
| 24 |
-
|
| 25 |
-
SEED = 42
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def set_seeds(seed):
|
| 29 |
-
"""Set random seeds for reproducibility."""
|
| 30 |
-
np.random.seed(seed)
|
| 31 |
-
torch.manual_seed(seed)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def load_mnist():
|
| 35 |
-
"""Load MNIST test data."""
|
| 36 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 37 |
-
testset = datasets.MNIST(root=config.MNIST_DIR, train=False, download=True, transform=transform)
|
| 38 |
-
|
| 39 |
-
X_test = testset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 40 |
-
y_test = testset.targets.numpy()
|
| 41 |
-
|
| 42 |
-
return X_test, y_test
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def load_mnist_train():
|
| 46 |
-
"""Load MNIST training data for SVD."""
|
| 47 |
-
transform = transforms.Compose([transforms.ToTensor()])
|
| 48 |
-
trainset = datasets.MNIST(root=config.MNIST_DIR, train=True, download=True, transform=transform)
|
| 49 |
-
|
| 50 |
-
X_train = trainset.data.numpy().reshape(-1, 784).astype(np.float32) / 255.0
|
| 51 |
-
y_train = trainset.targets.numpy()
|
| 52 |
-
|
| 53 |
-
return X_train, y_train
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def evaluate_svd(X_test, y_test, svd_path=config.SVD_MODEL_PATH):
|
| 57 |
-
"""Evaluate SVD+LR model."""
|
| 58 |
-
print("Evaluating SVD+LR model...")
|
| 59 |
-
|
| 60 |
-
# Load SVD model
|
| 61 |
-
with open(svd_path, 'rb') as f:
|
| 62 |
-
svd = pickle.load(f)
|
| 63 |
-
|
| 64 |
-
if hasattr(svd, '_train_mean'):
|
| 65 |
-
mean = svd._train_mean
|
| 66 |
-
else:
|
| 67 |
-
mean = np.zeros(784)
|
| 68 |
-
|
| 69 |
-
# Transform test data
|
| 70 |
-
X_test_centered = X_test - mean
|
| 71 |
-
X_test_svd = svd.transform(X_test_centered)
|
| 72 |
-
|
| 73 |
-
# Train classifier
|
| 74 |
-
print(" Training LogisticRegression classifier...")
|
| 75 |
-
X_train, y_train = load_mnist_train()
|
| 76 |
-
X_train_centered = X_train - mean
|
| 77 |
-
X_train_svd = svd.transform(X_train_centered)
|
| 78 |
-
|
| 79 |
-
clf = LogisticRegression(random_state=SEED, max_iter=1000)
|
| 80 |
-
clf.fit(X_train_svd, y_train)
|
| 81 |
-
|
| 82 |
-
# Predict
|
| 83 |
-
y_pred = clf.predict(X_test_svd)
|
| 84 |
-
|
| 85 |
-
return y_pred
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def evaluate_cnn(X_test, y_test, cnn_path=config.CNN_MODEL_PATH):
|
| 89 |
-
"""Evaluate CNN model."""
|
| 90 |
-
print("Evaluating CNN model...")
|
| 91 |
-
|
| 92 |
-
device = torch.device('cpu')
|
| 93 |
-
cnn = SimpleCNN(num_classes=10)
|
| 94 |
-
cnn.load_state_dict(torch.load(cnn_path, map_location=device))
|
| 95 |
-
cnn.eval()
|
| 96 |
-
|
| 97 |
-
X_tensor = torch.tensor(X_test, dtype=torch.float32).view(-1, 1, 28, 28)
|
| 98 |
-
|
| 99 |
-
with torch.no_grad():
|
| 100 |
-
outputs = cnn(X_tensor)
|
| 101 |
-
y_pred = outputs.argmax(dim=1).numpy()
|
| 102 |
-
|
| 103 |
-
return y_pred
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
def create_metrics_table(y_true, y_pred, model_name):
|
| 107 |
-
"""Create per-class metrics table."""
|
| 108 |
-
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None)
|
| 109 |
-
|
| 110 |
-
# Create DataFrame
|
| 111 |
-
df = pd.DataFrame({
|
| 112 |
-
'Class': [f'Digit {i}' for i in range(10)],
|
| 113 |
-
'Precision': precision,
|
| 114 |
-
'Recall': recall,
|
| 115 |
-
'F1-Score': f1,
|
| 116 |
-
'Support': support
|
| 117 |
-
})
|
| 118 |
-
|
| 119 |
-
# Add overall metrics
|
| 120 |
-
precision_avg, recall_avg, f1_avg, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
|
| 121 |
-
|
| 122 |
-
overall = pd.DataFrame({
|
| 123 |
-
'Class': ['Overall (weighted)'],
|
| 124 |
-
'Precision': [precision_avg],
|
| 125 |
-
'Recall': [recall_avg],
|
| 126 |
-
'F1-Score': [f1_avg],
|
| 127 |
-
'Support': [len(y_true)]
|
| 128 |
-
})
|
| 129 |
-
|
| 130 |
-
df = pd.concat([df, overall], ignore_index=True)
|
| 131 |
-
|
| 132 |
-
return df
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def plot_metrics_comparison(df_svd, df_cnn, filename):
|
| 136 |
-
"""Plot side-by-side comparison of metrics."""
|
| 137 |
-
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
| 138 |
-
|
| 139 |
-
metrics = ['Precision', 'Recall', 'F1-Score']
|
| 140 |
-
colors = [BLUE_DEEP, ORANGE]
|
| 141 |
-
|
| 142 |
-
# Exclude overall row for bar chart
|
| 143 |
-
df_svd_plot = df_svd.iloc[:-1]
|
| 144 |
-
df_cnn_plot = df_cnn.iloc[:-1]
|
| 145 |
-
|
| 146 |
-
x = np.arange(10)
|
| 147 |
-
width = 0.35
|
| 148 |
-
|
| 149 |
-
for i, metric in enumerate(metrics):
|
| 150 |
-
ax = axes[i]
|
| 151 |
-
|
| 152 |
-
svd_values = df_svd_plot[metric].values
|
| 153 |
-
cnn_values = df_cnn_plot[metric].values
|
| 154 |
-
|
| 155 |
-
ax.bar(x - width/2, svd_values, width, label='SVD+LR', color=colors[0], alpha=0.8)
|
| 156 |
-
ax.bar(x + width/2, cnn_values, width, label='CNN', color=colors[1], alpha=0.8)
|
| 157 |
-
|
| 158 |
-
ax.set_xlabel('Digit Class', fontsize=12)
|
| 159 |
-
ax.set_ylabel(metric, fontsize=12)
|
| 160 |
-
ax.set_title(f'{metric} by Digit', fontsize=13, fontweight='bold')
|
| 161 |
-
ax.set_xticks(x)
|
| 162 |
-
ax.set_xticklabels([str(i) for i in range(10)])
|
| 163 |
-
ax.legend()
|
| 164 |
-
ax.grid(axis='y', alpha=0.3)
|
| 165 |
-
ax.set_ylim([0.7, 1.0])
|
| 166 |
-
|
| 167 |
-
plt.tight_layout()
|
| 168 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
|
| 169 |
-
plt.close()
|
| 170 |
-
print(f"✓ Saved {filename}")
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def plot_metrics_heatmap(df_svd, df_cnn, filename):
|
| 174 |
-
"""Create heatmap showing per-class metrics."""
|
| 175 |
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))
|
| 176 |
-
|
| 177 |
-
# Prepare data (exclude overall row)
|
| 178 |
-
svd_data = df_svd.iloc[:-1][['Precision', 'Recall', 'F1-Score']].T
|
| 179 |
-
cnn_data = df_cnn.iloc[:-1][['Precision', 'Recall', 'F1-Score']].T
|
| 180 |
-
|
| 181 |
-
svd_data.columns = [str(i) for i in range(10)]
|
| 182 |
-
cnn_data.columns = [str(i) for i in range(10)]
|
| 183 |
-
|
| 184 |
-
# SVD heatmap
|
| 185 |
-
sns.heatmap(svd_data, annot=True, fmt='.3f', cmap='Blues',
|
| 186 |
-
vmin=0.7, vmax=1.0, ax=ax1, cbar_kws={'label': 'Score'})
|
| 187 |
-
ax1.set_title('SVD+LR Per-Class Metrics', fontsize=13, fontweight='bold')
|
| 188 |
-
ax1.set_xlabel('Digit Class', fontsize=12)
|
| 189 |
-
ax1.set_ylabel('Metric', fontsize=12)
|
| 190 |
-
|
| 191 |
-
# CNN heatmap
|
| 192 |
-
sns.heatmap(cnn_data, annot=True, fmt='.3f', cmap='Oranges',
|
| 193 |
-
vmin=0.7, vmax=1.0, ax=ax2, cbar_kws={'label': 'Score'})
|
| 194 |
-
ax2.set_title('CNN Per-Class Metrics', fontsize=13, fontweight='bold')
|
| 195 |
-
ax2.set_xlabel('Digit Class', fontsize=12)
|
| 196 |
-
ax2.set_ylabel('Metric', fontsize=12)
|
| 197 |
-
|
| 198 |
-
plt.tight_layout()
|
| 199 |
-
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300, bbox_inches='tight')
|
| 200 |
-
plt.close()
|
| 201 |
-
print(f"✓ Saved {filename}")
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def identify_hard_pairs(y_true, y_pred_svd, y_pred_cnn):
|
| 205 |
-
"""Identify digit pairs that are frequently confused."""
|
| 206 |
-
print("\n" + "="*60)
|
| 207 |
-
print("HARD CLASSIFICATION PAIRS")
|
| 208 |
-
print("="*60)
|
| 209 |
-
|
| 210 |
-
cm_svd = confusion_matrix(y_true, y_pred_svd, normalize='true')
|
| 211 |
-
cm_cnn = confusion_matrix(y_true, y_pred_cnn, normalize='true')
|
| 212 |
-
|
| 213 |
-
# Find top confusions (excluding diagonal)
|
| 214 |
-
np.fill_diagonal(cm_svd, 0)
|
| 215 |
-
np.fill_diagonal(cm_cnn, 0)
|
| 216 |
-
|
| 217 |
-
print("\nTop 5 SVD+LR Confusions:")
|
| 218 |
-
svd_confusions = []
|
| 219 |
-
for i in range(10):
|
| 220 |
-
for j in range(10):
|
| 221 |
-
if i != j and cm_svd[i, j] > 0.01:
|
| 222 |
-
svd_confusions.append((i, j, cm_svd[i, j]))
|
| 223 |
-
|
| 224 |
-
svd_confusions.sort(key=lambda x: x[2], reverse=True)
|
| 225 |
-
for i, (true_class, pred_class, rate) in enumerate(svd_confusions[:5], 1):
|
| 226 |
-
print(f" {i}. {true_class} → {pred_class}: {rate*100:.2f}%")
|
| 227 |
-
|
| 228 |
-
print("\nTop 5 CNN Confusions:")
|
| 229 |
-
cnn_confusions = []
|
| 230 |
-
for i in range(10):
|
| 231 |
-
for j in range(10):
|
| 232 |
-
if i != j and cm_cnn[i, j] > 0.01:
|
| 233 |
-
cnn_confusions.append((i, j, cm_cnn[i, j]))
|
| 234 |
-
|
| 235 |
-
cnn_confusions.sort(key=lambda x: x[2], reverse=True)
|
| 236 |
-
for i, (true_class, pred_class, rate) in enumerate(cnn_confusions[:5], 1):
|
| 237 |
-
print(f" {i}. {true_class} → {pred_class}: {rate*100:.2f}%")
|
| 238 |
-
|
| 239 |
-
# Compare improvements
|
| 240 |
-
print("\n" + "="*60)
|
| 241 |
-
print("CNN IMPROVEMENTS OVER SVD+LR")
|
| 242 |
-
print("="*60)
|
| 243 |
-
|
| 244 |
-
improvements = []
|
| 245 |
-
for i in range(10):
|
| 246 |
-
for j in range(10):
|
| 247 |
-
if i != j:
|
| 248 |
-
improvement = cm_svd[i, j] - cm_cnn[i, j]
|
| 249 |
-
if improvement > 0.01: # More than 1% improvement
|
| 250 |
-
improvements.append((i, j, improvement))
|
| 251 |
-
|
| 252 |
-
improvements.sort(key=lambda x: x[2], reverse=True)
|
| 253 |
-
|
| 254 |
-
print("Top pairs where CNN reduced confusion:")
|
| 255 |
-
for i, (true_class, pred_class, improvement) in enumerate(improvements[:5], 1):
|
| 256 |
-
svd_rate = cm_svd[true_class, pred_class]
|
| 257 |
-
cnn_rate = cm_cnn[true_class, pred_class]
|
| 258 |
-
print(f" {i}. {true_class} → {pred_class}: {svd_rate*100:.2f}% → {cnn_rate*100:.2f}% "
|
| 259 |
-
f"(Δ = -{improvement*100:.2f} pp)")
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
def save_reports_to_csv(df_svd, df_cnn):
|
| 263 |
-
"""Save detailed reports to CSV files."""
|
| 264 |
-
svd_path = os.path.join(config.RESULTS_DIR, 'per_class_metrics_svd.csv')
|
| 265 |
-
cnn_path = os.path.join(config.RESULTS_DIR, 'per_class_metrics_cnn.csv')
|
| 266 |
-
|
| 267 |
-
df_svd.to_csv(svd_path, index=False, float_format='%.4f')
|
| 268 |
-
df_cnn.to_csv(cnn_path, index=False, float_format='%.4f')
|
| 269 |
-
|
| 270 |
-
print(f"\n✓ Saved detailed metrics to:")
|
| 271 |
-
print(f" - {svd_path}")
|
| 272 |
-
print(f" - {cnn_path}")
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
def main():
|
| 276 |
-
set_seeds(SEED)
|
| 277 |
-
|
| 278 |
-
print("="*60)
|
| 279 |
-
print("PER-CLASS PERFORMANCE METRICS ANALYSIS")
|
| 280 |
-
print("="*60)
|
| 281 |
-
|
| 282 |
-
# Load test data
|
| 283 |
-
print("\nLoading MNIST test data...")
|
| 284 |
-
X_test, y_test = load_mnist()
|
| 285 |
-
print(f"Loaded {len(X_test)} test samples")
|
| 286 |
-
|
| 287 |
-
# Evaluate models
|
| 288 |
-
try:
|
| 289 |
-
print("\n" + "-"*60)
|
| 290 |
-
y_pred_svd = evaluate_svd(X_test, y_test)
|
| 291 |
-
print("✓ SVD+LR evaluation complete")
|
| 292 |
-
except Exception as e:
|
| 293 |
-
print(f"⚠️ Could not evaluate SVD model: {e}")
|
| 294 |
-
return
|
| 295 |
-
|
| 296 |
-
try:
|
| 297 |
-
print("\n" + "-"*60)
|
| 298 |
-
y_pred_cnn = evaluate_cnn(X_test, y_test)
|
| 299 |
-
print("✓ CNN evaluation complete")
|
| 300 |
-
except Exception as e:
|
| 301 |
-
print(f"⚠️ Could not evaluate CNN model: {e}")
|
| 302 |
-
return
|
| 303 |
-
|
| 304 |
-
# Generate metrics tables
|
| 305 |
-
print("\n" + "="*60)
|
| 306 |
-
print("GENERATING METRICS TABLES")
|
| 307 |
-
print("="*60)
|
| 308 |
-
|
| 309 |
-
df_svd = create_metrics_table(y_test, y_pred_svd, 'SVD+LR')
|
| 310 |
-
df_cnn = create_metrics_table(y_test, y_pred_cnn, 'CNN')
|
| 311 |
-
|
| 312 |
-
# Display tables
|
| 313 |
-
print("\n" + "-"*60)
|
| 314 |
-
print("SVD+LR Per-Class Metrics")
|
| 315 |
-
print("-"*60)
|
| 316 |
-
print(df_svd.to_string(index=False, float_format='%.4f'))
|
| 317 |
-
|
| 318 |
-
print("\n" + "-"*60)
|
| 319 |
-
print("CNN Per-Class Metrics")
|
| 320 |
-
print("-"*60)
|
| 321 |
-
print(df_cnn.to_string(index=False, float_format='%.4f'))
|
| 322 |
-
|
| 323 |
-
# Identify hard pairs
|
| 324 |
-
identify_hard_pairs(y_test, y_pred_svd, y_pred_cnn)
|
| 325 |
-
|
| 326 |
-
# Generate visualizations
|
| 327 |
-
print("\n" + "="*60)
|
| 328 |
-
print("GENERATING VISUALIZATIONS")
|
| 329 |
-
print("="*60)
|
| 330 |
-
|
| 331 |
-
plot_metrics_comparison(df_svd, df_cnn, 'fig_19_per_class_metrics_comparison.png')
|
| 332 |
-
plot_metrics_heatmap(df_svd, df_cnn, 'fig_20_per_class_metrics_heatmap.png')
|
| 333 |
-
|
| 334 |
-
# Save to CSV
|
| 335 |
-
save_reports_to_csv(df_svd, df_cnn)
|
| 336 |
-
|
| 337 |
-
# Summary statistics
|
| 338 |
-
print("\n" + "="*60)
|
| 339 |
-
print("SUMMARY STATISTICS")
|
| 340 |
-
print("="*60)
|
| 341 |
-
|
| 342 |
-
svd_overall = df_svd.iloc[-1]
|
| 343 |
-
cnn_overall = df_cnn.iloc[-1]
|
| 344 |
-
|
| 345 |
-
print(f"\nSVD+LR Overall:")
|
| 346 |
-
print(f" Precision: {svd_overall['Precision']:.4f}")
|
| 347 |
-
print(f" Recall: {svd_overall['Recall']:.4f}")
|
| 348 |
-
print(f" F1-Score: {svd_overall['F1-Score']:.4f}")
|
| 349 |
-
|
| 350 |
-
print(f"\nCNN Overall:")
|
| 351 |
-
print(f" Precision: {cnn_overall['Precision']:.4f}")
|
| 352 |
-
print(f" Recall: {cnn_overall['Recall']:.4f}")
|
| 353 |
-
print(f" F1-Score: {cnn_overall['F1-Score']:.4f}")
|
| 354 |
-
|
| 355 |
-
print(f"\nCNN Improvement:")
|
| 356 |
-
print(f" Precision: +{(cnn_overall['Precision'] - svd_overall['Precision'])*100:.2f} pp")
|
| 357 |
-
print(f" Recall: +{(cnn_overall['Recall'] - svd_overall['Recall'])*100:.2f} pp")
|
| 358 |
-
print(f" F1-Score: +{(cnn_overall['F1-Score'] - svd_overall['F1-Score'])*100:.2f} pp")
|
| 359 |
-
|
| 360 |
-
print("\n" + "="*60)
|
| 361 |
-
print("✓ Per-class metrics analysis complete!")
|
| 362 |
-
print("="*60)
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
if __name__ == "__main__":
|
| 366 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/appendix_learning_curves.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Appendix A – Learning Curves
|
| 3 |
+
Refactored to use centralized viz utilities.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pickle
|
| 7 |
+
import os
|
| 8 |
+
from src import config, viz
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
experiments = [
|
| 12 |
+
('cnn_10class_history.pkl', 'MNIST 10-class CNN Training', 'fig_14_learning_curves.png'),
|
| 13 |
+
('cnn_fashion_history.pkl', 'Fashion-MNIST CNN Training', 'fig_15_learning_curves_fashion.png')
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
for f_name, label, out_name in experiments:
|
| 17 |
+
path = os.path.join(config.MODELS_DIR, f_name)
|
| 18 |
+
if os.path.exists(path):
|
| 19 |
+
with open(path, 'rb') as f:
|
| 20 |
+
history = pickle.load(f)
|
| 21 |
+
viz.plot_learning_curves(history, label, out_name)
|
| 22 |
+
else:
|
| 23 |
+
print(f"Skipping {f_name}: Not found at {path}.")
|
| 24 |
+
|
| 25 |
+
if __name__ == "__main__":
|
| 26 |
+
main()
|
experiments/appendix_per_class_metrics.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Appendix B – Per-Class Performance Metrics (MNIST)
|
| 3 |
+
Refactored to use centralized utility modules.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from src import utils, viz, exp_utils
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 12 |
+
print("Loading Models and Test Data...")
|
| 13 |
+
|
| 14 |
+
# Load Models (MNIST default)
|
| 15 |
+
svd_pipe, cnn = utils.load_models(dataset_name="mnist")
|
| 16 |
+
if svd_pipe is None or cnn is None:
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
+
X_test, y_test = utils.load_data_split(dataset_name="mnist", train=False)
|
| 20 |
+
X_test_flat = X_test.view(X_test.size(0), -1).numpy()
|
| 21 |
+
y_test_np = y_test.numpy()
|
| 22 |
+
|
| 23 |
+
# 1. Collect Predictions
|
| 24 |
+
print("Collecting Predictions...")
|
| 25 |
+
y_preds_dict = {}
|
| 26 |
+
|
| 27 |
+
# CNN Predictions
|
| 28 |
+
cnn.eval()
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
y_preds_dict['CNN'] = cnn(X_test.to(device)).argmax(dim=1).cpu().numpy()
|
| 31 |
+
|
| 32 |
+
# SVD+LR Predictions
|
| 33 |
+
y_preds_dict['SVD+LR'] = svd_pipe.predict(X_test_flat)
|
| 34 |
+
|
| 35 |
+
# 2. Print Metrics Report
|
| 36 |
+
from sklearn.metrics import recall_score, precision_score, f1_score
|
| 37 |
+
for name, y_pred in y_preds_dict.items():
|
| 38 |
+
print(f"\n--- {name} Report (Average Metrics) ---")
|
| 39 |
+
p = precision_score(y_test_np, y_pred, average='macro')
|
| 40 |
+
r = recall_score(y_test_np, y_pred, average='macro')
|
| 41 |
+
f = f1_score(y_test_np, y_pred, average='macro')
|
| 42 |
+
print(f"Macro Average: Precision={p:.3f}, Recall={r:.3f}, F1={f:.3f}")
|
| 43 |
+
|
| 44 |
+
# 3. Visualization: Per-Class F1 Comparison
|
| 45 |
+
viz.plot_per_class_comparison(
|
| 46 |
+
y_test_np,
|
| 47 |
+
y_preds_dict,
|
| 48 |
+
'fig_19_per_class_metrics_comparison.png'
|
| 49 |
+
)
|
| 50 |
+
print("Appendix B Completed.")
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
main()
|
experiments/run_robustness_test.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Robustness Test Script
|
| 3 |
+
Evaluates CNN, SVD, and Hybrid model performance under Gaussian noise.
|
| 4 |
+
Refactored to use centralized src utilities.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from src import config, utils, viz, exp_utils
|
| 11 |
+
from src.hybrid_model import HybridSVDCNN, SVDProjectionLayer
|
| 12 |
+
|
| 13 |
+
def run_experiment(args):
|
| 14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
print(f"\n--- Running Robustness Test: {args.dataset.upper()} ---")
|
| 16 |
+
|
| 17 |
+
# 1. Load Data and Models
|
| 18 |
+
X_test, y_test = utils.load_data_split(dataset_name=args.dataset, train=False)
|
| 19 |
+
_, cnn = utils.load_models(dataset_name=args.dataset)
|
| 20 |
+
|
| 21 |
+
if cnn is None:
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
# 2. Fit SVD Baseline and Build Hybrid Model
|
| 25 |
+
print("Fitting SVD Baseline...")
|
| 26 |
+
X_test_flat = X_test.view(X_test.size(0), -1).numpy()
|
| 27 |
+
svd_pipe = exp_utils.fit_svd_baseline(X_test_flat, y_test.numpy(), n_components=20)
|
| 28 |
+
|
| 29 |
+
svd = svd_pipe.named_steps['svd']
|
| 30 |
+
scaler = svd_pipe.named_steps['scaler']
|
| 31 |
+
# Hybrid model expects mean from scaler if available
|
| 32 |
+
svd_layer = SVDProjectionLayer(svd.components_, scaler.mean_)
|
| 33 |
+
hybrid = HybridSVDCNN(svd_layer, cnn).to(device)
|
| 34 |
+
|
| 35 |
+
# 3. Define Noise Levels
|
| 36 |
+
sigmas = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
| 37 |
+
results = {'CNN': [], 'SVD': [], 'Hybrid': []}
|
| 38 |
+
|
| 39 |
+
# 4. Evaluation Loop
|
| 40 |
+
for sigma in sigmas:
|
| 41 |
+
X_noisy = exp_utils.add_gaussian_noise(X_test, sigma)
|
| 42 |
+
|
| 43 |
+
results['CNN'].append(exp_utils.evaluate_classifier(cnn, X_noisy, y_test, device))
|
| 44 |
+
results['SVD'].append(exp_utils.evaluate_classifier(svd_pipe, X_noisy, y_test, is_pytorch=False))
|
| 45 |
+
results['Hybrid'].append(exp_utils.evaluate_classifier(hybrid, X_noisy, y_test, device))
|
| 46 |
+
|
| 47 |
+
print(f"σ={sigma:.1f} | CNN: {results['CNN'][-1]:.4f} | SVD: {results['SVD'][-1]:.4f} | Hybrid: {results['Hybrid'][-1]:.4f}")
|
| 48 |
+
|
| 49 |
+
# 5. Visualization
|
| 50 |
+
viz.plot_robustness_curves(
|
| 51 |
+
x_values=sigmas,
|
| 52 |
+
results_dict=results,
|
| 53 |
+
x_label='Gaussian Noise Level (σ)',
|
| 54 |
+
title=f'Robustness Analysis: {args.dataset.upper()}',
|
| 55 |
+
filename=f'fig_robustness_{args.dataset}.png'
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
parser = argparse.ArgumentParser(description="Unified Robustness Evaluation")
|
| 60 |
+
parser.add_argument("--dataset", choices=["mnist", "fashion"], default="mnist", help="Dataset to evaluate.")
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
run_experiment(args)
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
main()
|
src/__init__.py
ADDED
|
File without changes
|
src/config.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# --- Paths ---
|
| 4 |
+
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 5 |
+
DATA_DIR = os.path.join(BASE_DIR, "data")
|
| 6 |
+
MODELS_DIR = os.path.join(BASE_DIR, "models")
|
| 7 |
+
RESULTS_DIR = os.path.join(BASE_DIR, "docs", "research_results")
|
| 8 |
+
|
| 9 |
+
for d in [DATA_DIR, MODELS_DIR, RESULTS_DIR]:
|
| 10 |
+
os.makedirs(d, exist_ok=True)
|
| 11 |
+
|
| 12 |
+
SVD_MODEL_PATH = os.path.join(MODELS_DIR, "svd_10class.pkl")
|
| 13 |
+
CNN_MODEL_PATH = os.path.join(MODELS_DIR, "cnn_10class.pth")
|
| 14 |
+
FASHION_SVD_PATH = os.path.join(MODELS_DIR, "svd_fashion.pkl") # Placeholder if not exists
|
| 15 |
+
FASHION_CNN_PATH = os.path.join(MODELS_DIR, "cnn_fashion.pth")
|
| 16 |
+
APP_CACHE_PATH = os.path.join(DATA_DIR, "app_cache.npz")
|
| 17 |
+
|
src/exp_utils.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
from sklearn.metrics import accuracy_score
|
| 5 |
+
from sklearn.decomposition import TruncatedSVD
|
| 6 |
+
from sklearn.linear_model import LogisticRegression
|
| 7 |
+
from sklearn.pipeline import Pipeline
|
| 8 |
+
from sklearn.preprocessing import StandardScaler
|
| 9 |
+
|
| 10 |
+
def fit_svd_baseline(X_train, y_train, n_components=20):
|
| 11 |
+
"""Fits a linear baseline (SVD + Logistic Regression) on the fly."""
|
| 12 |
+
pipeline = Pipeline([
|
| 13 |
+
('scaler', StandardScaler()),
|
| 14 |
+
('svd', TruncatedSVD(n_components=n_components, random_state=42)),
|
| 15 |
+
('logistic', LogisticRegression(max_iter=1000))
|
| 16 |
+
])
|
| 17 |
+
pipeline.fit(X_train, y_train)
|
| 18 |
+
return pipeline
|
| 19 |
+
|
| 20 |
+
def add_gaussian_noise(X, sigma):
|
| 21 |
+
"""
|
| 22 |
+
Uniform noise addition for both torch Tensors and numpy arrays.
|
| 23 |
+
Returns the same type as input.
|
| 24 |
+
"""
|
| 25 |
+
if sigma <= 0: return X
|
| 26 |
+
if torch.is_tensor(X):
|
| 27 |
+
noise = torch.randn_like(X) * sigma
|
| 28 |
+
return torch.clamp(X + noise, 0, 1)
|
| 29 |
+
else:
|
| 30 |
+
noise = np.random.randn(*X.shape) * sigma
|
| 31 |
+
return np.clip(X + noise, 0, 1)
|
| 32 |
+
|
| 33 |
+
def add_blur(X, kernel_size):
|
| 34 |
+
"""Unified blur for torch Tensors (4D: B, C, H, W)."""
|
| 35 |
+
if kernel_size <= 1:
|
| 36 |
+
return X
|
| 37 |
+
sigma = 0.1 + 0.3 * (kernel_size // 2)
|
| 38 |
+
blur_fn = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma))
|
| 39 |
+
return blur_fn(X)
|
| 40 |
+
|
| 41 |
+
def evaluate_classifier(model, X, y, device="cpu", is_pytorch=True):
|
| 42 |
+
"""
|
| 43 |
+
Unified evaluation function.
|
| 44 |
+
Handles PyTorch models (CNN, Hybrid) and Sklearn pipelines (SVD+LR).
|
| 45 |
+
"""
|
| 46 |
+
if is_pytorch:
|
| 47 |
+
model.eval()
|
| 48 |
+
model.to(device)
|
| 49 |
+
# Ensure X is 4D for CNN (B, 1, 28, 28)
|
| 50 |
+
if len(X.shape) == 2:
|
| 51 |
+
X_t = torch.as_tensor(X.reshape(-1, 1, 28, 28), dtype=torch.float32).to(device)
|
| 52 |
+
else:
|
| 53 |
+
X_t = torch.as_tensor(X, dtype=torch.float32).to(device)
|
| 54 |
+
|
| 55 |
+
y_t = torch.as_tensor(y, dtype=torch.long).to(device)
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
logits = model(X_t)
|
| 59 |
+
preds = torch.argmax(logits, dim=1).cpu().numpy()
|
| 60 |
+
return accuracy_score(y, preds)
|
| 61 |
+
else:
|
| 62 |
+
# Sklearn pipeline - Ensure X is flattened 2D numpy
|
| 63 |
+
if torch.is_tensor(X):
|
| 64 |
+
X_np = X.view(X.size(0), -1).cpu().numpy()
|
| 65 |
+
else:
|
| 66 |
+
X_np = X.reshape(X.shape[0], -1)
|
| 67 |
+
preds = model.predict(X_np)
|
| 68 |
+
return accuracy_score(y, preds)
|
src/hybrid_model.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hybrid SVD-CNN Model
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sklearn.decomposition import TruncatedSVD
|
| 7 |
+
|
| 8 |
+
class SimpleCNN(nn.Module):
|
| 9 |
+
def __init__(self, num_classes=10):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.features = nn.Sequential(
|
| 12 |
+
nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
|
| 13 |
+
nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
|
| 14 |
+
)
|
| 15 |
+
self.classifier = nn.Sequential(
|
| 16 |
+
nn.Linear(32 * 7 * 7, 128), nn.ReLU(),
|
| 17 |
+
nn.Linear(128, num_classes)
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return self.classifier(self.features(x).view(x.size(0), -1))
|
| 22 |
+
|
| 23 |
+
class SVDProjectionLayer(nn.Module):
|
| 24 |
+
def __init__(self, V_k, mean=None):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.register_buffer('V_k', torch.tensor(V_k, dtype=torch.float32))
|
| 27 |
+
self.register_buffer('mean', torch.tensor(mean, dtype=torch.float32) if mean is not None else torch.zeros(V_k.shape[1]))
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
b = x.size(0)
|
| 31 |
+
x_rec = (x.view(b, -1) - self.mean) @ self.V_k.T @ self.V_k + self.mean
|
| 32 |
+
return torch.clamp(x_rec, 0, 1).view(b, 1, 28, 28)
|
| 33 |
+
|
| 34 |
+
class HybridSVDCNN(nn.Module):
|
| 35 |
+
def __init__(self, svd_layer, cnn):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.svd_layer, self.cnn = svd_layer, cnn
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
return self.cnn(self.svd_layer(x))
|
| 41 |
+
|
| 42 |
+
def create_svd_layer(X_train, n_components=20):
|
| 43 |
+
mean = np.mean(X_train, axis=0)
|
| 44 |
+
svd = TruncatedSVD(n_components=n_components, random_state=42).fit(X_train - mean)
|
| 45 |
+
return SVDProjectionLayer(svd.components_, mean)
|
src/train_fashion.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from src.hybrid_model import SimpleCNN
|
| 5 |
+
from src import config
|
| 6 |
+
from src.train_models import train_cnn, set_seed
|
| 7 |
+
|
| 8 |
+
if __name__ == "__main__":
|
| 9 |
+
set_seed()
|
| 10 |
+
print("Loading Fashion-MNIST for training...")
|
| 11 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
| 12 |
+
train_dataset = torchvision.datasets.FashionMNIST(root=config.DATA_DIR, train=True, download=True, transform=transform)
|
| 13 |
+
|
| 14 |
+
# Extract data to tensors for train_cnn
|
| 15 |
+
X = train_dataset.data.float() / 255.0
|
| 16 |
+
y = train_dataset.targets
|
| 17 |
+
|
| 18 |
+
# Temporarily override CNN_MODEL_PATH for fashion
|
| 19 |
+
original_path = config.CNN_MODEL_PATH
|
| 20 |
+
config.CNN_MODEL_PATH = config.CNN_FASHION_MODEL_PATH
|
| 21 |
+
|
| 22 |
+
print(f"Retraining Fashion-MNIST model to {config.CNN_FASHION_MODEL_PATH}...")
|
| 23 |
+
train_cnn(X, y)
|
| 24 |
+
|
| 25 |
+
# Restore (optional but good practice)
|
| 26 |
+
config.CNN_MODEL_PATH = original_path
|
| 27 |
+
print("Fashion-MNIST training completed.")
|
src/train_models.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 5 |
+
from sklearn.decomposition import TruncatedSVD
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
import pickle
|
| 8 |
+
import os
|
| 9 |
+
import numpy as np
|
| 10 |
+
import random
|
| 11 |
+
from src.hybrid_model import SimpleCNN
|
| 12 |
+
from src.utils import load_data_split
|
| 13 |
+
from src import config
|
| 14 |
+
|
| 15 |
+
def set_seed(seed=42):
|
| 16 |
+
random.seed(seed); np.random.seed(seed)
|
| 17 |
+
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
|
| 18 |
+
torch.backends.cudnn.deterministic = True
|
| 19 |
+
|
| 20 |
+
def train_svd(X_flat, n_components=20):
|
| 21 |
+
print(f"Training SVD (k={n_components})...")
|
| 22 |
+
X_np = X_flat.numpy()
|
| 23 |
+
mean = X_np.mean(axis=0)
|
| 24 |
+
svd = TruncatedSVD(n_components=n_components, random_state=42).fit(X_np - mean)
|
| 25 |
+
svd._train_mean = mean
|
| 26 |
+
with open(config.SVD_MODEL_PATH, "wb") as f: pickle.dump(svd, f)
|
| 27 |
+
return svd
|
| 28 |
+
|
| 29 |
+
def train_cnn(X_flat, y, batch_size=64, epochs=5):
|
| 30 |
+
X_train, X_val, y_train, y_val = train_test_split(X_flat.numpy(), y.numpy(), test_size=0.2, random_state=42, stratify=y.numpy())
|
| 31 |
+
|
| 32 |
+
def to_loader(X, y, shuffle=True):
|
| 33 |
+
return DataLoader(TensorDataset(torch.tensor(X).view(-1, 1, 28, 28), torch.tensor(y, dtype=torch.long)), batch_size=batch_size, shuffle=shuffle)
|
| 34 |
+
|
| 35 |
+
train_loader, val_loader = to_loader(X_train, y_train), to_loader(X_val, y_val, False)
|
| 36 |
+
model = SimpleCNN().to("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
+
opt = optim.Adam(model.parameters(), lr=0.001)
|
| 38 |
+
crit = nn.CrossEntropyLoss()
|
| 39 |
+
|
| 40 |
+
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
|
| 41 |
+
best_acc, best_state = 0, None
|
| 42 |
+
|
| 43 |
+
for epoch in range(epochs):
|
| 44 |
+
model.train()
|
| 45 |
+
t_loss, t_corr = 0, 0
|
| 46 |
+
for x, labels in train_loader:
|
| 47 |
+
x, labels = x.to(next(model.parameters()).device), labels.to(next(model.parameters()).device)
|
| 48 |
+
opt.zero_grad(); out = model(x); loss = crit(out, labels); loss.backward(); opt.step()
|
| 49 |
+
t_loss += loss.item(); t_corr += (out.argmax(1) == labels).sum().item()
|
| 50 |
+
|
| 51 |
+
model.eval(); v_loss, v_corr = 0, 0
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
for x, labels in val_loader:
|
| 54 |
+
x, labels = x.to(next(model.parameters()).device), labels.to(next(model.parameters()).device)
|
| 55 |
+
out = model(x); v_loss += crit(out, labels).item(); v_corr += (out.argmax(1) == labels).sum().item()
|
| 56 |
+
|
| 57 |
+
history['train_acc'].append(100 * t_corr / len(X_train)); history['val_acc'].append(100 * v_corr / len(X_val))
|
| 58 |
+
print(f"Epoch {epoch+1}: Train Acc {history['train_acc'][-1]:.2f}%, Val Acc {history['val_acc'][-1]:.2f}%")
|
| 59 |
+
|
| 60 |
+
if history['val_acc'][-1] > best_acc:
|
| 61 |
+
best_acc, best_state = history['val_acc'][-1], model.state_dict().copy()
|
| 62 |
+
|
| 63 |
+
model.load_state_dict(best_state)
|
| 64 |
+
torch.save(model.cpu().state_dict(), config.CNN_MODEL_PATH)
|
| 65 |
+
with open(config.CNN_MODEL_PATH.replace('.pth', '_history.pkl'), 'wb') as f: pickle.dump(history, f)
|
| 66 |
+
return model, history
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
set_seed()
|
| 70 |
+
X, y = load_data_split()
|
| 71 |
+
train_svd(X.view(-1, 784))
|
| 72 |
+
train_cnn(X.view(-1, 784), y)
|
src/utils.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision.transforms as T
|
| 3 |
+
import torchvision
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
import ssl
|
| 8 |
+
from src.hybrid_model import SimpleCNN
|
| 9 |
+
from src import config
|
| 10 |
+
import cv2
|
| 11 |
+
|
| 12 |
+
def load_data_split(dataset_name="mnist", train=True, digits=None, flatten=False):
|
| 13 |
+
"""
|
| 14 |
+
Unified entry point for data loading.
|
| 15 |
+
Supports: MNIST, Fashion-MNIST, and custom digit filtering (e.g., [3, 8]).
|
| 16 |
+
"""
|
| 17 |
+
# Bypass SSL verification issues for dataset downloads
|
| 18 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 19 |
+
|
| 20 |
+
transform = T.Compose([T.ToTensor()])
|
| 21 |
+
|
| 22 |
+
if dataset_name.lower() == "mnist":
|
| 23 |
+
dataset = torchvision.datasets.MNIST(config.DATA_DIR, train=train, download=True, transform=transform)
|
| 24 |
+
elif dataset_name.lower() == "fashion":
|
| 25 |
+
dataset = torchvision.datasets.FashionMNIST(config.DATA_DIR, train=train, download=True, transform=transform)
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError(f"Unknown dataset: {dataset_name}")
|
| 28 |
+
|
| 29 |
+
X = dataset.data.float() / 255.0
|
| 30 |
+
y = dataset.targets
|
| 31 |
+
|
| 32 |
+
# Filter for specific digits if requested (e.g., [3, 8] for binary analysis)
|
| 33 |
+
if digits is not None:
|
| 34 |
+
mask = torch.zeros(len(y), dtype=torch.bool)
|
| 35 |
+
for d in digits:
|
| 36 |
+
mask |= (y == d)
|
| 37 |
+
X = X[mask]
|
| 38 |
+
y = y[mask]
|
| 39 |
+
|
| 40 |
+
# Remap labels to 0, 1... for binary tasks
|
| 41 |
+
if len(digits) == 2:
|
| 42 |
+
y = torch.where(y == digits[0], torch.tensor(0), torch.tensor(1))
|
| 43 |
+
|
| 44 |
+
# Add channel dimension if not flattened (B, 1, 28, 28)
|
| 45 |
+
if not flatten:
|
| 46 |
+
X = X.unsqueeze(1)
|
| 47 |
+
else:
|
| 48 |
+
X = X.view(X.size(0), -1)
|
| 49 |
+
|
| 50 |
+
return X, y
|
| 51 |
+
|
| 52 |
+
def load_models(dataset_name="mnist"):
|
| 53 |
+
"""
|
| 54 |
+
Loads pre-trained SVD transformer and CNN model for a specific dataset.
|
| 55 |
+
Returns (svd, cnn). Either can be None if the file is missing.
|
| 56 |
+
"""
|
| 57 |
+
svd_path = config.SVD_MODEL_PATH if dataset_name == "mnist" else config.FASHION_SVD_PATH
|
| 58 |
+
cnn_path = config.CNN_MODEL_PATH if dataset_name == "mnist" else config.FASHION_CNN_PATH
|
| 59 |
+
|
| 60 |
+
svd, cnn = None, None
|
| 61 |
+
|
| 62 |
+
if os.path.exists(svd_path):
|
| 63 |
+
with open(svd_path, "rb") as f:
|
| 64 |
+
svd = pickle.load(f)
|
| 65 |
+
else:
|
| 66 |
+
print(f"Note: SVD model for {dataset_name} not found at {svd_path}")
|
| 67 |
+
|
| 68 |
+
if os.path.exists(cnn_path):
|
| 69 |
+
cnn = SimpleCNN()
|
| 70 |
+
cnn.load_state_dict(torch.load(cnn_path, map_location="cpu"))
|
| 71 |
+
cnn.eval()
|
| 72 |
+
else:
|
| 73 |
+
print(f"Note: CNN model for {dataset_name} not found at {cnn_path}")
|
| 74 |
+
|
| 75 |
+
return svd, cnn
|
| 76 |
+
|
| 77 |
+
# --- Backward Compatibility Aliases ---
|
| 78 |
+
load_data = load_data_split
|
| 79 |
+
|
| 80 |
+
def preprocess_digit(img):
|
| 81 |
+
"""
|
| 82 |
+
Original preprocessing logic used by the Streamlit app.
|
| 83 |
+
Crops, resizes (20x20), and pads to 28x28.
|
| 84 |
+
"""
|
| 85 |
+
if isinstance(img, torch.Tensor):
|
| 86 |
+
img = img.numpy().astype(np.uint8)
|
| 87 |
+
|
| 88 |
+
# 1. Threshold & Find Bounding Box
|
| 89 |
+
_, thresh = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
|
| 90 |
+
coords = cv2.findNonZero(thresh)
|
| 91 |
+
if coords is None:
|
| 92 |
+
return torch.zeros((28, 28))
|
| 93 |
+
x, y, w, h = cv2.boundingRect(coords)
|
| 94 |
+
img_crop = img[y:y+h, x:x+w]
|
| 95 |
+
|
| 96 |
+
# 2. Resize to fit 20px
|
| 97 |
+
if w > h:
|
| 98 |
+
new_w = 20
|
| 99 |
+
new_h = int(h * (20 / w))
|
| 100 |
+
else:
|
| 101 |
+
new_h = 20
|
| 102 |
+
new_w = int(w * (20 / h))
|
| 103 |
+
|
| 104 |
+
if new_w == 0 or new_h == 0:
|
| 105 |
+
return torch.zeros((28, 28))
|
| 106 |
+
|
| 107 |
+
img_resize = cv2.resize(img_crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 108 |
+
|
| 109 |
+
# 3. Center in 28x28
|
| 110 |
+
final_img = np.zeros((28, 28), dtype=np.uint8)
|
| 111 |
+
pad_y = (28 - new_h) // 2
|
| 112 |
+
pad_x = (28 - new_w) // 2
|
| 113 |
+
final_img[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = img_resize
|
| 114 |
+
|
| 115 |
+
# 4. Normalize
|
| 116 |
+
return torch.tensor(final_img).float() / 255.0
|
src/viz.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 6 |
+
from src import config
|
| 7 |
+
|
| 8 |
+
# --- Nord Palette Colors ---
|
| 9 |
+
COLOR_SVD = "#5E81AC" # Nord 10 (Blue)
|
| 10 |
+
COLOR_CNN = "#BF616A" # Nord 11 (Red)
|
| 11 |
+
COLOR_HYBRID = "#A3BE8C" # Nord 14 (Green)
|
| 12 |
+
COLOR_TEXT = "#2E3440" # Nord 0 (Dark)
|
| 13 |
+
COLOR_GRID = "#D8DEE9" # Nord 4
|
| 14 |
+
|
| 15 |
+
def setup_style():
|
| 16 |
+
"""Standardize matplotlib plots."""
|
| 17 |
+
plt.rcParams['font.family'] = 'sans-serif'
|
| 18 |
+
plt.rcParams['axes.edgecolor'] = COLOR_GRID
|
| 19 |
+
plt.rcParams['grid.alpha'] = 0.3
|
| 20 |
+
plt.rcParams['axes.labelcolor'] = COLOR_TEXT
|
| 21 |
+
|
| 22 |
+
def save_fig(filename, dpi=300):
|
| 23 |
+
"""Save plot to results directory."""
|
| 24 |
+
path = os.path.join(config.RESULTS_DIR, filename)
|
| 25 |
+
plt.tight_layout()
|
| 26 |
+
plt.savefig(path, dpi=dpi)
|
| 27 |
+
plt.close()
|
| 28 |
+
print(f"Figure saved to {path}")
|
| 29 |
+
|
| 30 |
+
def plot_robustness_curves(x_values, results_dict, x_label, title, filename):
|
| 31 |
+
"""Standardized robustness curve plotter."""
|
| 32 |
+
setup_style()
|
| 33 |
+
plt.figure(figsize=(10, 6))
|
| 34 |
+
colors = {'CNN': COLOR_CNN, 'SVD': COLOR_SVD, 'Hybrid': COLOR_HYBRID}
|
| 35 |
+
|
| 36 |
+
for label, accs in results_dict.items():
|
| 37 |
+
plt.plot(x_values, accs, label=label, marker='o',
|
| 38 |
+
color=colors.get(label, '#4C566A'), linewidth=2)
|
| 39 |
+
|
| 40 |
+
plt.title(title, fontsize=14, fontweight='bold', pad=15)
|
| 41 |
+
plt.xlabel(x_label, fontsize=12)
|
| 42 |
+
plt.ylabel('Accuracy', fontsize=12)
|
| 43 |
+
plt.legend(frameon=True, facecolor='white', framealpha=0.8)
|
| 44 |
+
plt.grid(True)
|
| 45 |
+
save_fig(filename)
|
| 46 |
+
|
| 47 |
+
def plot_confusion_matrix(y_true, y_pred, labels, filename, title, color_end=COLOR_SVD):
|
| 48 |
+
"""Normalized confusion matrix with Nord-consistent coloring."""
|
| 49 |
+
from sklearn.metrics import confusion_matrix
|
| 50 |
+
setup_style()
|
| 51 |
+
cm = confusion_matrix(y_true, y_pred, normalize='true')
|
| 52 |
+
plt.figure(figsize=(10, 8))
|
| 53 |
+
|
| 54 |
+
# Custom cmap from Light Gray to Nord Color
|
| 55 |
+
cmap = LinearSegmentedColormap.from_list("NordCustom", ["#ECEFF4", color_end])
|
| 56 |
+
|
| 57 |
+
sns.heatmap(cm, annot=True, fmt='.1%', cmap=cmap, xticklabels=labels, yticklabels=labels)
|
| 58 |
+
plt.title(title, fontsize=14, fontweight='bold', pad=15)
|
| 59 |
+
plt.xlabel('Predicted', fontsize=12)
|
| 60 |
+
plt.ylabel('True', fontsize=12)
|
| 61 |
+
save_fig(filename)
|
| 62 |
+
|
| 63 |
+
def plot_singular_spectrum(singular_values, cumulative_variance, filename):
|
| 64 |
+
"""Visualizes singular values and explained variance."""
|
| 65 |
+
setup_style()
|
| 66 |
+
fig, ax1 = plt.subplots(figsize=(10, 6))
|
| 67 |
+
|
| 68 |
+
n = len(singular_values)
|
| 69 |
+
ax1.semilogy(range(1, n+1), singular_values, color=COLOR_SVD, label='Singular Values', linewidth=2)
|
| 70 |
+
ax1.set_xlabel('Principal Component (k)', fontsize=12)
|
| 71 |
+
ax1.set_ylabel('Singular Value (Log)', color=COLOR_SVD, fontsize=12)
|
| 72 |
+
ax1.tick_params(axis='y', labelcolor=COLOR_SVD)
|
| 73 |
+
|
| 74 |
+
ax2 = ax1.twinx()
|
| 75 |
+
ax2.plot(range(1, n+1), cumulative_variance, color=COLOR_CNN, linestyle='--', label='Cum. Var', linewidth=2)
|
| 76 |
+
ax2.set_ylabel('Cumulative Explained Variance', color=COLOR_CNN, fontsize=12)
|
| 77 |
+
ax2.tick_params(axis='y', labelcolor=COLOR_CNN)
|
| 78 |
+
ax2.set_ylim(0, 1.05)
|
| 79 |
+
|
| 80 |
+
plt.title('Singular Value Spectrum & Explained Variance', fontsize=14, fontweight='bold', pad=15)
|
| 81 |
+
fig.legend(loc="upper right", bbox_to_anchor=(1,1), bbox_transform=ax1.transAxes)
|
| 82 |
+
save_fig(filename)
|
| 83 |
+
|
| 84 |
+
def plot_interpolation_dynamics(alphas, probs_8, rec_errors, filename):
|
| 85 |
+
"""Visualizes the CNN response vs SVD reconstruction error during interpolation."""
|
| 86 |
+
setup_style()
|
| 87 |
+
plt.figure(figsize=(10, 6))
|
| 88 |
+
|
| 89 |
+
plt.plot(alphas, probs_8, color=COLOR_CNN, label='CNN Prob(8) [Topology]', marker='o', linewidth=2)
|
| 90 |
+
plt.plot(alphas, rec_errors, color=COLOR_SVD, label='SVD Rec Error [Global Variance]', marker='s', linewidth=2)
|
| 91 |
+
|
| 92 |
+
plt.axvline(x=0.5, color='#4C566A', linestyle='--', alpha=0.5, label='Ambiguity Mid-point')
|
| 93 |
+
plt.title('Mechanistic Dynamics: Interpolation vs. SVD Error', fontsize=14, fontweight='bold', pad=15)
|
| 94 |
+
plt.xlabel('Alpha (0=Digit 3, 1=Digit 8)', fontsize=12)
|
| 95 |
+
plt.ylabel('Metric Value', fontsize=12)
|
| 96 |
+
plt.legend()
|
| 97 |
+
plt.grid(True)
|
| 98 |
+
save_fig(filename)
|
| 99 |
+
|
| 100 |
+
def plot_manifold_comparison(X_svd, X_umap, y, acc_svd, acc_raw, filename):
|
| 101 |
+
"""Side-by-side comparison of SVD (linear) vs UMAP (non-linear) projections."""
|
| 102 |
+
setup_style()
|
| 103 |
+
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
|
| 104 |
+
|
| 105 |
+
colors = [COLOR_SVD, COLOR_CNN] # 3 vs 8
|
| 106 |
+
labels = ['Digit 3', 'Digit 8']
|
| 107 |
+
|
| 108 |
+
for i in range(2):
|
| 109 |
+
# SVD Plane
|
| 110 |
+
axes[0].scatter(X_svd[y==i, 0], X_svd[y==i, 1], label=labels[i], alpha=0.5, s=15, color=colors[i])
|
| 111 |
+
# UMAP Manifold
|
| 112 |
+
axes[1].scatter(X_umap[y==i, 0], X_umap[y==i, 1], label=labels[i], alpha=0.5, s=15, color=colors[i])
|
| 113 |
+
|
| 114 |
+
axes[0].set_title(f"SVD Projection (2D Subspace)\nk-NN Accuracy: {acc_svd:.2%}", fontsize=12)
|
| 115 |
+
axes[1].set_title(f"UMAP Manifold (Non-linear)\nRaw k-NN Accuracy: {acc_raw:.2%}", fontsize=12)
|
| 116 |
+
|
| 117 |
+
for ax in axes:
|
| 118 |
+
ax.legend()
|
| 119 |
+
ax.set_xticks([])
|
| 120 |
+
ax.set_yticks([])
|
| 121 |
+
|
| 122 |
+
plt.suptitle("Manifold Collapse: Linear SVD Overlap vs. Non-linear Topological Separation",
|
| 123 |
+
fontsize=14, fontweight='bold', y=1.02)
|
| 124 |
+
save_fig(filename)
|
| 125 |
+
|
| 126 |
+
def plot_learning_curves(history, title, filename):
|
| 127 |
+
"""Standardized plotter for training history (loss and accuracy)."""
|
| 128 |
+
setup_style()
|
| 129 |
+
epochs = range(1, len(history['train_loss']) + 1)
|
| 130 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
| 131 |
+
|
| 132 |
+
# Nord palette for curves
|
| 133 |
+
COLOR_TRAIN = COLOR_SVD
|
| 134 |
+
COLOR_VAL = "#D08770" # Nord 12 (Orange)
|
| 135 |
+
|
| 136 |
+
# Loss Plot
|
| 137 |
+
ax1.plot(epochs, history['train_loss'], label='Train', color=COLOR_TRAIN, marker='o', markersize=4, linewidth=1.5)
|
| 138 |
+
ax1.plot(epochs, history['val_loss'], label='Val', color=COLOR_VAL, marker='s', markersize=4, linewidth=1.5)
|
| 139 |
+
ax1.set_title('Loss Dynamics', fontsize=12, fontweight='bold')
|
| 140 |
+
ax1.set_xlabel('Epoch')
|
| 141 |
+
ax1.set_ylabel('Loss')
|
| 142 |
+
ax1.legend()
|
| 143 |
+
ax1.grid(True)
|
| 144 |
+
|
| 145 |
+
# Accuracy Plot
|
| 146 |
+
ax2.plot(epochs, history['train_acc'], label='Train', color=COLOR_TRAIN, marker='o', markersize=4, linewidth=1.5)
|
| 147 |
+
ax2.plot(epochs, history['val_acc'], label='Val', color=COLOR_VAL, marker='s', markersize=4, linewidth=1.5)
|
| 148 |
+
ax2.set_title('Accuracy Dynamics', fontsize=12, fontweight='bold')
|
| 149 |
+
ax2.set_xlabel('Epoch')
|
| 150 |
+
ax2.set_ylabel('Accuracy')
|
| 151 |
+
ax2.legend()
|
| 152 |
+
ax2.grid(True)
|
| 153 |
+
|
| 154 |
+
plt.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
|
| 155 |
+
save_fig(filename)
|
| 156 |
+
|
| 157 |
+
def plot_per_class_comparison(y_test, y_preds_dict, filename):
|
| 158 |
+
"""Grouped bar chart comparing F1-scores per class for multiple models."""
|
| 159 |
+
from sklearn.metrics import f1_score
|
| 160 |
+
setup_style()
|
| 161 |
+
plt.figure(figsize=(10, 6))
|
| 162 |
+
|
| 163 |
+
x = np.arange(10)
|
| 164 |
+
width = 0.8 / len(y_preds_dict)
|
| 165 |
+
|
| 166 |
+
colors = {
|
| 167 |
+
'SVD+LR': COLOR_SVD,
|
| 168 |
+
'CNN': COLOR_CNN,
|
| 169 |
+
'Hybrid': COLOR_HYBRID
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
for i, (label, y_pred) in enumerate(y_preds_dict.items()):
|
| 173 |
+
f1s = f1_score(y_test, y_pred, average=None)
|
| 174 |
+
plt.bar(x + (i - len(y_preds_dict)/2 + 0.5) * width, f1s, width,
|
| 175 |
+
label=label, color=colors.get(label, '#4C566A'), alpha=0.8)
|
| 176 |
+
|
| 177 |
+
plt.xticks(x)
|
| 178 |
+
plt.xlabel('Digit Class', fontsize=12)
|
| 179 |
+
plt.ylabel('F1-Score', fontsize=12)
|
| 180 |
+
plt.title('Per-Class Performance Comparison (F1-Score)', fontsize=14, fontweight='bold', pad=15)
|
| 181 |
+
plt.legend()
|
| 182 |
+
plt.grid(True, axis='y')
|
| 183 |
+
save_fig(filename)
|
| 184 |
+
|
| 185 |
+
def plot_multi_image_grid(images, titles, rows, cols, filename, suptitle=None):
|
| 186 |
+
"""Generic grid plotter for images (e.g., eigen-digits)."""
|
| 187 |
+
plt.figure(figsize=(cols * 2.5, rows * 2.5))
|
| 188 |
+
for i, (img, title) in enumerate(zip(images, titles)):
|
| 189 |
+
plt.subplot(rows, cols, i + 1)
|
| 190 |
+
plt.imshow(img, cmap='gray')
|
| 191 |
+
plt.title(title, fontsize=10)
|
| 192 |
+
plt.axis('off')
|
| 193 |
+
if suptitle:
|
| 194 |
+
plt.suptitle(suptitle, fontsize=14, fontweight='bold')
|
| 195 |
+
save_fig(filename)
|