zhuoranyang commited on
Commit
b753304
·
verified ·
1 Parent(s): 20509b6

Deploy app with precomputed results for p=15,23,29,31

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +41 -0
  2. .gitignore +18 -0
  3. README.md +200 -12
  4. hf_app/app.py +1375 -0
  5. hf_app/requirements.txt +9 -0
  6. precompute/README.md +337 -0
  7. precompute/__init__.py +0 -0
  8. precompute/generate_analytical.py +358 -0
  9. precompute/generate_plots.py +2192 -0
  10. precompute/grokking_stage_detector.py +55 -0
  11. precompute/neuron_selector.py +98 -0
  12. precompute/prime_config.py +135 -0
  13. precompute/run_all.sh +35 -0
  14. precompute/run_pipeline.sh +60 -0
  15. precompute/train_all.py +290 -0
  16. precomputed_results/p_015/p015_full_training_para_origin.png +0 -0
  17. precomputed_results/p_015/p015_lineplot_in.png +3 -0
  18. precomputed_results/p_015/p015_lineplot_out.png +3 -0
  19. precomputed_results/p_015/p015_logits_interactive.json +1 -0
  20. precomputed_results/p_015/p015_lottery_beta_contour.png +0 -0
  21. precomputed_results/p_015/p015_lottery_mech_magnitude.png +0 -0
  22. precomputed_results/p_015/p015_lottery_mech_phase.png +0 -0
  23. precomputed_results/p_015/p015_magnitude_distribution.png +0 -0
  24. precomputed_results/p_015/p015_metadata.json +82 -0
  25. precomputed_results/p_015/p015_neuron_spectra.json +1 -0
  26. precomputed_results/p_015/p015_output_logits.png +0 -0
  27. precomputed_results/p_015/p015_overview.json +1 -0
  28. precomputed_results/p_015/p015_overview_loss_ipr.png +0 -0
  29. precomputed_results/p_015/p015_overview_phase_scatter.png +0 -0
  30. precomputed_results/p_015/p015_phase_align_approx1.png +3 -0
  31. precomputed_results/p_015/p015_phase_align_approx2.png +3 -0
  32. precomputed_results/p_015/p015_phase_align_quad.png +0 -0
  33. precomputed_results/p_015/p015_phase_align_relu.png +0 -0
  34. precomputed_results/p_015/p015_phase_distribution.png +0 -0
  35. precomputed_results/p_015/p015_phase_relationship.png +0 -0
  36. precomputed_results/p_015/p015_single_freq_quad.png +3 -0
  37. precomputed_results/p_015/p015_single_freq_relu.png +3 -0
  38. precomputed_results/p_015/p015_training_log.json +0 -0
  39. precomputed_results/p_023/p023_full_training_para_origin.png +3 -0
  40. precomputed_results/p_023/p023_grokk_abs_phase_diff.png +0 -0
  41. precomputed_results/p_023/p023_grokk_acc.json +1 -0
  42. precomputed_results/p_023/p023_grokk_acc.png +0 -0
  43. precomputed_results/p_023/p023_grokk_avg_ipr.png +0 -0
  44. precomputed_results/p_023/p023_grokk_decoded_weights_dynamic.png +3 -0
  45. precomputed_results/p_023/p023_grokk_epoch_data.json +1 -0
  46. precomputed_results/p_023/p023_grokk_loss.json +0 -0
  47. precomputed_results/p_023/p023_grokk_loss.png +0 -0
  48. precomputed_results/p_023/p023_grokk_memorization_accuracy.png +0 -0
  49. precomputed_results/p_023/p023_grokk_memorization_common_to_rare.png +3 -0
  50. precomputed_results/p_023/p023_lineplot_in.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,44 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ precomputed_results/p_015/p015_lineplot_in.png filter=lfs diff=lfs merge=lfs -text
37
+ precomputed_results/p_015/p015_lineplot_out.png filter=lfs diff=lfs merge=lfs -text
38
+ precomputed_results/p_015/p015_phase_align_approx1.png filter=lfs diff=lfs merge=lfs -text
39
+ precomputed_results/p_015/p015_phase_align_approx2.png filter=lfs diff=lfs merge=lfs -text
40
+ precomputed_results/p_015/p015_single_freq_quad.png filter=lfs diff=lfs merge=lfs -text
41
+ precomputed_results/p_015/p015_single_freq_relu.png filter=lfs diff=lfs merge=lfs -text
42
+ precomputed_results/p_023/p023_full_training_para_origin.png filter=lfs diff=lfs merge=lfs -text
43
+ precomputed_results/p_023/p023_grokk_decoded_weights_dynamic.png filter=lfs diff=lfs merge=lfs -text
44
+ precomputed_results/p_023/p023_grokk_memorization_common_to_rare.png filter=lfs diff=lfs merge=lfs -text
45
+ precomputed_results/p_023/p023_lineplot_in.png filter=lfs diff=lfs merge=lfs -text
46
+ precomputed_results/p_023/p023_lineplot_out.png filter=lfs diff=lfs merge=lfs -text
47
+ precomputed_results/p_023/p023_output_logits.png filter=lfs diff=lfs merge=lfs -text
48
+ precomputed_results/p_023/p023_overview_loss_ipr.png filter=lfs diff=lfs merge=lfs -text
49
+ precomputed_results/p_023/p023_phase_align_approx1.png filter=lfs diff=lfs merge=lfs -text
50
+ precomputed_results/p_023/p023_phase_align_approx2.png filter=lfs diff=lfs merge=lfs -text
51
+ precomputed_results/p_023/p023_single_freq_quad.png filter=lfs diff=lfs merge=lfs -text
52
+ precomputed_results/p_023/p023_single_freq_relu.png filter=lfs diff=lfs merge=lfs -text
53
+ precomputed_results/p_029/p029_full_training_para_origin.png filter=lfs diff=lfs merge=lfs -text
54
+ precomputed_results/p_029/p029_grokk_decoded_weights_dynamic.png filter=lfs diff=lfs merge=lfs -text
55
+ precomputed_results/p_029/p029_grokk_memorization_accuracy.png filter=lfs diff=lfs merge=lfs -text
56
+ precomputed_results/p_029/p029_grokk_memorization_common_to_rare.png filter=lfs diff=lfs merge=lfs -text
57
+ precomputed_results/p_029/p029_lineplot_in.png filter=lfs diff=lfs merge=lfs -text
58
+ precomputed_results/p_029/p029_lineplot_out.png filter=lfs diff=lfs merge=lfs -text
59
+ precomputed_results/p_029/p029_output_logits.png filter=lfs diff=lfs merge=lfs -text
60
+ precomputed_results/p_029/p029_overview_loss_ipr.png filter=lfs diff=lfs merge=lfs -text
61
+ precomputed_results/p_029/p029_phase_align_approx1.png filter=lfs diff=lfs merge=lfs -text
62
+ precomputed_results/p_029/p029_phase_align_approx2.png filter=lfs diff=lfs merge=lfs -text
63
+ precomputed_results/p_029/p029_single_freq_quad.png filter=lfs diff=lfs merge=lfs -text
64
+ precomputed_results/p_029/p029_single_freq_relu.png filter=lfs diff=lfs merge=lfs -text
65
+ precomputed_results/p_031/p031_full_training_para_origin.png filter=lfs diff=lfs merge=lfs -text
66
+ precomputed_results/p_031/p031_grokk_decoded_weights_dynamic.png filter=lfs diff=lfs merge=lfs -text
67
+ precomputed_results/p_031/p031_grokk_memorization_accuracy.png filter=lfs diff=lfs merge=lfs -text
68
+ precomputed_results/p_031/p031_grokk_memorization_common_to_rare.png filter=lfs diff=lfs merge=lfs -text
69
+ precomputed_results/p_031/p031_lineplot_in.png filter=lfs diff=lfs merge=lfs -text
70
+ precomputed_results/p_031/p031_lineplot_out.png filter=lfs diff=lfs merge=lfs -text
71
+ precomputed_results/p_031/p031_output_logits.png filter=lfs diff=lfs merge=lfs -text
72
+ precomputed_results/p_031/p031_overview_loss_ipr.png filter=lfs diff=lfs merge=lfs -text
73
+ precomputed_results/p_031/p031_phase_align_approx1.png filter=lfs diff=lfs merge=lfs -text
74
+ precomputed_results/p_031/p031_phase_align_approx2.png filter=lfs diff=lfs merge=lfs -text
75
+ precomputed_results/p_031/p031_single_freq_quad.png filter=lfs diff=lfs merge=lfs -text
76
+ precomputed_results/p_031/p031_single_freq_relu.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ src/wandb/
2
+ notebooks/simulate_dynamics.ipynb
3
+
4
+ # Claude AI files
5
+ .claude/
6
+
7
+ # Python cache files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # Model checkpoints (too large for git; regenerate with precompute/run_pipeline.sh)
13
+ trained_models/
14
+ saved_models/
15
+
16
+ # OS files
17
+ .DS_Store
18
+ tmp/
README.md CHANGED
@@ -1,14 +1,202 @@
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Modular Addition Feature Learning
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 6.6.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Interactive Demo of Paper on Modular Addition
12
- ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # On the Mechanism and Dynamics of Modular Addition
2
+
3
+ ### Fourier Features, Lottery Ticket, and Grokking
4
+
5
+ **Jianliang He, Leda Wang, Siyu Chen, Zhuoran Yang**
6
+ *Department of Statistics and Data Science, Yale University*
7
+
8
+ [[arXiv (coming soon)](#)] [[Blog (coming soon)](#)] [[Interactive Demo](https://huggingface.co/spaces/y-agent/modular-addition-feature-learning)]
9
+
10
  ---
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ ## Overview
13
+
14
+ This repository provides the code for studying how a two-layer neural network learns modular arithmetic $f(x,y) = (x+y) \bmod p$. We analyze three phenomena:
15
+
16
+ 1. **Fourier Feature Learning** — Each neuron independently discovers a cosine wave at a single frequency, collectively implementing a discrete Fourier transform that the network was never taught.
17
+ 2. **Lottery Ticket Dynamics** — Random initialization determines which frequency each neuron will specialize in: the frequency with the best initial phase alignment wins a winner-take-all competition.
18
+ 3. **Grokking** — Under partial data with weight decay, the network first memorizes, then suddenly generalizes through a three-stage process: memorization → sparsification → cleanup.
19
+
20
+ ## Interactive Demo
21
+
22
+ An interactive Gradio app visualizes all results with math explanations and interactive Plotly charts:
23
+
24
+ - **9 analysis tabs** covering mechanism, dynamics, grokking, and analytical simulations
25
+ - **Interactive features**: neuron frequency inspector, logit explorer, grokking epoch slider
26
+ - **On-demand training**: generate results for any odd $p \geq 3$ directly from the app
27
+ - **Pre-computed examples** included for $p = 15, 23, 29, 31$
28
+
29
+ ### Launch Locally
30
+
31
+ ```bash
32
+ pip install -r requirements.txt
33
+ python hf_app/app.py
34
+ # Opens at http://localhost:7860
35
+ ```
36
+
37
+ ### Deploy to Hugging Face Spaces
38
+
39
+ 1. Create a new Space at [huggingface.co/new-space](https://huggingface.co/new-space) (SDK: Gradio)
40
+ 2. Push the repo:
41
+ ```bash
42
+ git remote add hf https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
43
+ git push hf main
44
+ ```
45
+ 3. The app reads from `precomputed_results/` — the included examples (p=15, 23, 29, 31) work out of the box
46
+ 4. Users can generate results for additional $p$ values on-demand via the "Generate" button. New results are auto-committed back to the Space repo so they persist.
47
+
48
+ > **Tip:** For GPU-accelerated on-demand training, select a GPU runtime in your Space settings.
49
+
50
+ ## Pre-computation Pipeline
51
+
52
+ The `precompute/` directory trains 5 model configurations per modulus and generates all plots + interactive JSON data. See [`precompute/README.md`](precompute/README.md) for full documentation.
53
+
54
+ ### Quick Start
55
+
56
+ ```bash
57
+ # Full pipeline for a single modulus (train → plots → analytical → verify)
58
+ bash precompute/run_pipeline.sh 23
59
+
60
+ # With custom d_mlp
61
+ bash precompute/run_pipeline.sh 23 --d_mlp 128
62
+
63
+ # Delete checkpoints after generating plots (saves disk space)
64
+ CLEANUP=1 bash precompute/run_pipeline.sh 23
65
+
66
+ # Batch: all odd p in [3, 99]
67
+ bash precompute/run_all.sh
68
+
69
+ # Or up to p=199
70
+ MAX_P=199 bash precompute/run_all.sh
71
+ ```
72
+
73
+ ### Manual Steps
74
+
75
+ ```bash
76
+ # Step 1: Train all 5 configurations
77
+ python precompute/train_all.py --p 23 --output ./trained_models --resume
78
+
79
+ # Step 2: Generate model-based plots (21 PNGs + 7 JSONs)
80
+ python precompute/generate_plots.py --p 23 --input ./trained_models --output ./precomputed_results
81
+
82
+ # Step 3: Generate analytical simulation plots (2 PNGs, no model needed)
83
+ python precompute/generate_analytical.py --p 23 --output ./precomputed_results
84
+ ```
85
+
86
+ ### Output
87
+
88
+ Each modulus produces ~33 files in `precomputed_results/p_XXX/`:
89
+
90
+ | Category | Files | Description |
91
+ |----------|-------|-------------|
92
+ | Overview (Tab 1) | 2 PNGs + 1 JSON | Loss, IPR, phase scatter |
93
+ | Fourier Weights (Tab 2) | 3 PNGs + 1 JSON | DFT heatmaps, cosine fits, neuron spectra |
94
+ | Phase Analysis (Tab 3) | 3 PNGs | Phase distribution, alignment, magnitudes |
95
+ | Output Logits (Tab 4) | 1 PNG + 1 JSON | Logit heatmap, interactive explorer |
96
+ | Lottery Mechanism (Tab 5) | 3 PNGs | Magnitude race, phase convergence, contour |
97
+ | Grokking (Tab 6) | 5 PNGs + 3 JSONs | Loss/acc curves, memorization, weight evolution |
98
+ | Gradient Dynamics (Tab 7) | 4 PNGs | Phase alignment + DFT for Quad and ReLU |
99
+ | Decoupled Simulation (Tab 8) | 2 PNGs | Analytical ODE integration |
100
+ | Metadata | 2 JSONs | Config + training log |
101
+
102
+ > **Note:** Grokking results (Tab 6) require $p \geq 19$. Smaller values of $p$ have too few data points for a meaningful train/test split.
103
+
104
+ ## The 5 Training Configurations
105
+
106
+ | Config | Activation | Optimizer | LR | Weight Decay | Data | Epochs | Used In |
107
+ |--------|-----------|-----------|-----|-------------|------|--------|---------|
108
+ | `standard` | ReLU | AdamW | 5e-5 | 0 | 100% | 5,000 | Tabs 1–4 |
109
+ | `grokking` | ReLU | AdamW | 1e-4 | 2.0 | 75% | 50,000 | Tabs 1, 6 |
110
+ | `quad_random` | Quad | AdamW | 5e-5 | 0 | 100% | 5,000 | Tab 5 |
111
+ | `quad_single_freq` | Quad | SGD | 0.1 | 0 | 100% | 5,000 | Tab 7 |
112
+ | `relu_single_freq` | ReLU | SGD | 0.01 | 0 | 100% | 5,000 | Tab 7 |
113
+
114
+ ## Running a Single Experiment
115
+
116
+ For custom experiments outside the pre-computation pipeline:
117
+
118
+ ```bash
119
+ cd src
120
+
121
+ # Train with default config (p=97, d_mlp=1024, ReLU, 5000 epochs)
122
+ python module_nn.py
123
+
124
+ # Train with specific parameters
125
+ python module_nn.py --p 23 --d_mlp 512 --num_epochs 5000 --lr 5e-5
126
+
127
+ # Dry run: see config without training
128
+ python module_nn.py --dry_run --p 23 --d_mlp 512
129
+ ```
130
+
131
+ ## Notebooks
132
+
133
+ Interactive analysis notebooks in `notebooks/`:
134
+
135
+ | Notebook | Description |
136
+ |----------|-------------|
137
+ | `empirical_insight_standard.ipynb` | Fourier weight analysis, phase distributions, output logits |
138
+ | `empirical_insight_grokk.ipynb` | Grokking stages, weight dynamics, IPR evolution |
139
+ | `lottery_mechanism.ipynb` | Neuron specialization, frequency magnitude/phase tracking |
140
+ | `interprete_gd_dynamics.ipynb` | Phase alignment under single-frequency initialization |
141
+ | `decouple_dynamics_simulation.ipynb` | Analytical gradient flow simulation |
142
+
143
+ ## Setup
144
+
145
+ ### Requirements
146
+
147
+ - Python 3.8+
148
+ - PyTorch 2.0+
149
+ - CUDA-capable GPU (recommended for $p > 50$; CPU works for small $p$)
150
+
151
+ ### Installation
152
+
153
+ ```bash
154
+ git clone https://github.com/Y-Agent/modular-addition-feature-learning.git
155
+ cd modular-addition-feature-learning
156
+ pip install -r requirements.txt
157
+ ```
158
+
159
+ ## Project Structure
160
+
161
+ ```
162
+ modular-addition-feature-learning/
163
+ ├── src/ # Core source code
164
+ │ ├── module_nn.py # Training script with CLI
165
+ │ ├── nnTrainer.py # Training loop and optimization
166
+ │ ├── model_base.py # Neural network architecture (EmbedMLP)
167
+ │ ├── mechanism_base.py # Fourier analysis and decomposition
168
+ │ ├── utils.py # Configuration and helpers
169
+ │ └── configs.yaml # Default hyperparameters
170
+ ├── precompute/ # Batch training and plot generation
171
+ │ ├── run_pipeline.sh # Full pipeline for one modulus
172
+ │ ├── run_all.sh # Batch pipeline for all odd p
173
+ │ ├── train_all.py # Train 5 configurations
174
+ │ ├── generate_plots.py # Generate model-based plots + JSONs
175
+ │ ├── generate_analytical.py # Analytical ODE simulation plots
176
+ │ └── prime_config.py # Configurations and sizing formula
177
+ ├── hf_app/ # Gradio web application
178
+ │ └── app.py # Interactive visualization app
179
+ ├── precomputed_results/ # Pre-computed plots and data
180
+ │ ├── p_015/ # Results for p=15
181
+ │ ├── p_023/ # Results for p=23
182
+ │ ├── p_029/ # Results for p=29
183
+ │ └── p_031/ # Results for p=31
184
+ ├── notebooks/ # Analysis and visualization notebooks
185
+ ├── requirements.txt # Python dependencies
186
+ └── README.md
187
+ ```
188
+
189
+ ## Citation
190
+
191
+ ```bibtex
192
+ @article{he2025modular,
193
+ title={On the Mechanism and Dynamics of Modular Addition: Fourier Features, Lottery Ticket, and Grokking},
194
+ author={He, Jianliang and Wang, Leda and Chen, Siyu and Yang, Zhuoran},
195
+ journal={arXiv preprint arXiv:XXXX.XXXXX},
196
+ year={2025}
197
+ }
198
+ ```
199
+
200
+ ## License
201
+
202
+ [MIT License](LICENSE)
hf_app/app.py ADDED
@@ -0,0 +1,1375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio app for Modular Addition Feature Learning visualization.
4
+ Serves pre-computed results for odd moduli p in [3, 199].
5
+
6
+ All results are pre-computed as PNG images and JSON data files.
7
+ No GPU needed at serving time.
8
+
9
+ Tab structure:
10
+ Core Interpretability:
11
+ 1. Training Overview -- loss + IPR sparsity
12
+ 2. Fourier Weights -- decoded W_in/W_out heatmaps + line plots + neuron inspector
13
+ 3. Phase Analysis -- phase distribution, 2phi vs psi, magnitudes
14
+ 4. Output Logits -- predicted logit heatmap + interactive logit explorer
15
+ 5. Lottery Mechanism -- neuron specialization, magnitude/phase, contour
16
+ Grokking:
17
+ 6. Grokking -- loss/acc, phase alignment, IPR, memorization, epoch slider
18
+ Theory:
19
+ 7. Gradient Dynamics -- phase alignment for Quad & ReLU single-freq init
20
+ 8. Decoupled Simulation -- analytical gradient flow (no model needed)
21
+ Diagnostics:
22
+ 9. Training Log -- per-run hyperparameters and epoch-by-epoch metrics
23
+ """
24
+ import gradio as gr
25
+ import json
26
+ import logging
27
+ import os
28
+ import shutil
29
+ import subprocess
30
+ import sys
31
+
32
+ import numpy as np
33
+
34
+ logger = logging.getLogger(__name__)
35
+ # Force pandas to be fully imported before plotly lazily imports it
36
+ # (avoids "partially initialized module 'pandas'" in threaded callbacks)
37
+ import pandas # noqa: F401
38
+ import plotly.graph_objects as go
39
+
40
+ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
41
+ RESULTS_DIR = os.path.join(PROJECT_ROOT, "precomputed_results")
42
+ TRAINED_MODELS_DIR = os.path.join(PROJECT_ROOT, "trained_models")
43
+
44
+ # Max p for on-demand training (d_mlp grows as O(p^2), memory limit)
45
+ MAX_P_ON_DEMAND = 97
46
+
47
+ COLORS = ['#0D2758', '#60656F', '#DEA54B', '#A32015', '#347186']
48
+ STAGE_COLORS = ['rgba(212,175,55,0.15)', 'rgba(139,115,85,0.15)', 'rgba(192,192,192,0.15)']
49
+
50
+ # KaTeX delimiters for Gradio Markdown
51
+ LATEX_DELIMITERS = [
52
+ {"left": "$$", "right": "$$", "display": True},
53
+ {"left": "$", "right": "$", "display": False},
54
+ ]
55
+
56
+ # Custom CSS for Palatino font and styling
57
+ CUSTOM_CSS = r"""
58
+ @import url('https://fonts.googleapis.com/css2?family=Libre+Baskerville:ital,wght@0,400;0,700;1,400&display=swap');
59
+
60
+ * {
61
+ font-family: "Palatino Linotype", "Book Antiqua", Palatino, "Libre Baskerville", Georgia, serif !important;
62
+ }
63
+ code, pre, .code, .monospace {
64
+ font-family: "Menlo", "Consolas", "Monaco", monospace !important;
65
+ }
66
+ .katex, .katex * {
67
+ font-family: KaTeX_Main, "Times New Roman", serif !important;
68
+ }
69
+ h1 {
70
+ font-family: "Palatino Linotype", "Book Antiqua", Palatino, "Libre Baskerville", Georgia, serif !important;
71
+ text-align: center !important;
72
+ margin-bottom: 0.1em !important;
73
+ }
74
+ h3 {
75
+ font-family: "Palatino Linotype", "Book Antiqua", Palatino, "Libre Baskerville", Georgia, serif !important;
76
+ text-align: center !important;
77
+ color: var(--neutral-500) !important;
78
+ font-weight: normal !important;
79
+ margin-top: 0 !important;
80
+ }
81
+ h2, h4 {
82
+ font-family: "Palatino Linotype", "Book Antiqua", Palatino, "Libre Baskerville", Georgia, serif !important;
83
+ }
84
+ blockquote {
85
+ border-left: 3px solid var(--color-accent) !important;
86
+ background-color: var(--block-background-fill) !important;
87
+ padding: 0.5em 1em !important;
88
+ margin: 0.5em 0 !important;
89
+ }
90
+ """
91
+
92
+ # ---------------------------------------------------------------------------
93
+ # Math explanation text for each tab (following the paper precisely)
94
+ # ---------------------------------------------------------------------------
95
+
96
+ MATH_TAB1 = r"""
97
+ ### Overview
98
+
99
+ We study how a two-layer neural network learns to compute modular addition $f(x,y) = (x+y) \bmod p$. The network has $M$ hidden neurons. Each input integer $x$ is represented as a one-hot vector, and the network produces a score for each of the $p$ possible answers. During training, the network learns two weight vectors per neuron: an **input weight** $\theta_m$ and an **output weight** $\xi_m$, both vectors of length $p$.
100
+
101
+ #### Two Training Setups
102
+
103
+ 1. **Full-data (Tabs 1--5, 7).** Train on all $p^2$ input pairs with no held-out data and no regularization. This produces clean features ideal for studying what the network learns and how.
104
+
105
+ 2. **Grokking (Tab 6).** Train on only 75% of input pairs with weight decay $\lambda = 2.0$ (a penalty that shrinks weights over time). These two ingredients -- incomplete data + weight decay -- cause the network to first memorize, then suddenly generalize, a phenomenon called **grokking**.
106
+
107
+ #### What the Network Learns
108
+
109
+ Each neuron's weight vectors turn into **cosine waves** at a single frequency -- the network independently rediscovers the Discrete Fourier Transform. The neurons collectively cover all frequencies with balanced strengths, enabling them to "vote" together and identify the correct answer $(x+y) \bmod p$.
110
+
111
+ #### How It Learns (Dynamics)
112
+
113
+ Frequencies **compete** within each neuron during training. The frequency whose input and output phases happen to start best-aligned grows fastest -- a **lottery ticket mechanism** where the random initialization determines the outcome before training begins.
114
+
115
+ #### Grokking (Three Stages)
116
+
117
+ When trained on partial data with weight decay: **(I) Memorization** -- the network fits the training data using noisy, multi-frequency features. **(II) Generalization** -- weight decay prunes away the noise, leaving clean single-frequency features; test accuracy jumps. **(III) Cleanup** -- weight decay slowly polishes the solution.
118
+
119
+ #### Progress Measures on These Plots
120
+
121
+ - **Loss**: Cross-entropy loss (lower = better predictions). We show both training loss and test loss.
122
+
123
+ - **IPR (Inverse Participation Ratio)**: Measures how concentrated a neuron's energy is across frequencies. We decompose each neuron's weights into Fourier components, measure the strength $A_k$ at each frequency $k$, and compute:
124
+
125
+ $$\text{IPR} = \frac{\sum_k A_k^4}{\left(\sum_k A_k^2\right)^2}.$$
126
+
127
+ When a neuron uses only **one frequency**, IPR $= 1$ (fully specialized). When energy is spread across **many frequencies**, IPR is close to $0$. Watching IPR rise toward 1 during training shows the network specializing.
128
+
129
+ - **Phase scatter**: Each neuron has an input phase $\phi_m$ and output phase $\psi_m$. The theory predicts the output phase equals twice the input phase ($\psi_m = 2\phi_m$). The scatter plot checks this: all points should fall on the diagonal.
130
+ """
131
+
132
+ MATH_TAB2 = r"""
133
+ ### Every Neuron is a Cosine Wave
134
+ > **Setup:** ReLU activation, full data, no weight decay.
135
+
136
+ After training, each neuron's weight vectors become clean **cosine waves** at a single frequency. Concretely, the input weight of neuron $m$ looks like:
137
+
138
+ $$\underbrace{\theta_m[j]}_{\text{input weight at position } j} = \underbrace{\alpha_m}_{\text{input magnitude}} \cdot \cos\!\left(\underbrace{\frac{2\pi k}{p}}_{\text{frequency}} \cdot j + \underbrace{\phi_m}_{\text{input phase}}\right),$$
139
+
140
+ and the output weight has the same form with its own magnitude $\beta_m$ (output magnitude) and phase $\psi_m$ (output phase). Each neuron picks **one frequency** $k$ out of the $(p{-}1)/2$ possible frequencies. No one told the network about Fourier analysis -- it rediscovered this representation on its own through training.
141
+
142
+ **Heatmap**: Each row is a neuron, each column is a Fourier component (cosine and sine at each frequency). If a row has only one bright cell, that neuron is using a single frequency -- and that's exactly what we see.
143
+
144
+ **Line Plots**: The dots are the actual learned weights; the dashed curves are best-fit cosines. The near-perfect fits confirm each neuron is well-described by a single cosine at a single frequency.
145
+
146
+ **Neuron Inspector**: Select a neuron from the dropdown to see how its energy is distributed across all frequencies (for both input and output weights).
147
+ """
148
+
149
+ MATH_TAB3 = r"""
150
+ ### Phase Alignment and Collective Diversification
151
+ > **Setup:** ReLU activation, full data, no weight decay.
152
+
153
+ #### The Input and Output Phases Lock Together
154
+
155
+ Each neuron has an input phase $\phi_m$ and an output phase $\psi_m$ (the "shift" of each cosine wave). These are not independent -- training drives them into a precise relationship:
156
+
157
+ $$\underbrace{\psi_m}_{\text{output phase}} = 2 \times \underbrace{\phi_m}_{\text{input phase}}.$$
158
+
159
+ **Why "doubled"?** The activation function squares (or, for ReLU, roughly squares) the sum of two cosines. Squaring a cosine at phase $\phi$ naturally produces terms at phase $2\phi$. The output layer learns to match this by setting its own phase to $2\phi$, so the two layers work together coherently.
160
+
161
+ The **scatter plot** checks this: we plot $2\phi_m$ (horizontal) vs. $\psi_m$ (vertical) for every neuron. If the relationship holds, all points land on the diagonal. This relationship is not built into the architecture -- it **emerges from training** (see Tab 7 for why).
162
+
163
+ #### Neurons Organize Themselves into a Balanced Ensemble
164
+
165
+ The neurons don't just specialize to single frequencies -- they also organize *collectively*:
166
+
167
+ 1. **Frequency balance:** Every frequency gets roughly the same number of neurons.
168
+ 2. **Phase spread:** Within each frequency group, the phases are spread uniformly around the circle. This is what enables **noise cancellation** -- the random noise from individual neurons averages out when their phases are evenly spaced.
169
+ 3. **Magnitude balance:** All neurons contribute roughly equally to the output (no single neuron dominates).
170
+
171
+ The **polar plot** shows phases at multiples ($1\times, 2\times, 3\times, 4\times$) on concentric rings -- uniform spread confirms the cancellation condition. The **violin plots** show the distribution of input magnitudes ($\alpha$) and output magnitudes ($\beta$) -- tight concentration confirms magnitude balance.
172
+ """
173
+
174
+ MATH_TAB4 = r"""
175
+ ### The Mechanism: Majority Voting in Fourier Space
176
+ > **Setup:** ReLU activation, full data, no weight decay.
177
+
178
+ #### How Neurons Vote for the Correct Answer
179
+
180
+ Each neuron produces a score for every possible output $j \in \{0, 1, \ldots, p{-}1\}$. Thanks to the phase alignment ($\psi = 2\phi$, see Tab 3), each neuron's score has a **signal** component that peaks at the correct answer $j = (x+y) \bmod p$, plus **noise** that depends on that neuron's particular phase.
181
+
182
+ When we sum over many neurons within a frequency group, the signal adds up (every neuron agrees on the right answer) while the noise cancels out (different neurons have different phases, and the phase spread from Tab 3 ensures the noise averages to zero). This is **majority voting** -- each neuron casts a noisy vote, but the consensus is correct.
183
+
184
+ #### The "Flawed Indicator"
185
+
186
+ After summing over all neurons and all frequency groups, the network's output simplifies to:
187
+
188
+ $$\text{score for answer } j \;\propto\; \underbrace{\frac{p}{2} \cdot \mathbf{1}[j = (x{+}y) \bmod p]}_{\text{correct answer (strongest)}} \;+\; \underbrace{\frac{p}{4} \cdot \bigl(\mathbf{1}[j = 2x \bmod p] + \mathbf{1}[j = 2y \bmod p]\bigr)}_{\text{two "ghost" peaks (half strength)}}.$$
189
+
190
+ The correct answer gets score $p/2$, but two **spurious ghost peaks** appear at $2x \bmod p$ and $2y \bmod p$ with score $p/4$. The correct answer always wins because $p/2 > p/4$, so the network always predicts correctly despite the ghosts.
191
+
192
+ **Heatmap**: The network's output scores for all inputs with $x = 0$. The bright diagonal is the correct answer. The faint lines are the ghost peaks.
193
+
194
+ **Logit Explorer**: Pick an input pair $(x, y)$ to see the full score distribution. The correct answer (highlighted) should be the tallest bar.
195
+ """
196
+
197
+ MATH_TAB5 = r"""
198
+ ### The Lottery Ticket: How Each Neuron Picks Its Frequency
199
+ > **Setup:** Quadratic activation ($\sigma(x) = x^2$), full data, random initialization.
200
+
201
+ #### The Competition
202
+
203
+ At the start of training, every neuron has a tiny bit of energy at **every** frequency -- nothing is specialized yet. But the input and output phases at each frequency start at random values, so some frequencies happen to be better aligned (input phase and output phase closer to the $\psi = 2\phi$ relationship) than others.
204
+
205
+ The key insight: **a frequency grows faster when its phases are better aligned.** The growth rate of a frequency's magnitude depends on how close it is to alignment:
206
+
207
+ $$\text{growth rate} \;\propto\; \cos(\underbrace{2\phi - \psi}_{\text{phase misalignment }\mathcal{D}}).$$
208
+
209
+ When the misalignment $\mathcal{D}$ is small (phases nearly aligned), $\cos(\mathcal{D}) \approx 1$ and the frequency grows quickly. When $\mathcal{D}$ is large, growth stalls.
210
+
211
+ #### Winner Takes All
212
+
213
+ This creates a **positive feedback loop**: the best-aligned frequency grows a little, which helps it align even better, which makes it grow even faster. The gap compounds exponentially until one frequency completely dominates -- **the winner takes all.**
214
+
215
+ The winning frequency is simply the one that started closest to alignment:
216
+
217
+ $$\text{winning frequency} = \text{the } k \text{ with smallest initial misalignment } |\mathcal{D}_m^k|.$$
218
+
219
+ This is a **lottery ticket**: the outcome is determined by the random initialization before training even begins. Since each neuron draws independent random phases, different neurons pick different winning frequencies, naturally producing the balanced frequency coverage seen in Tab 3.
220
+
221
+ **Phase plot:** Shows how the misalignment $\mathcal{D}$ evolves over training for each frequency within one neuron. The winner (red) converges to zero first; the others barely move.
222
+
223
+ **Magnitude plot:** Shows how the output magnitude $\beta$ (strength of each frequency) evolves. All start equal. Once the winner aligns, it grows explosively while the others stay frozen.
224
+
225
+ **Contour plot:** Final magnitude as a function of (initial magnitude, initial misalignment). Largest values appear at small misalignment -- confirming that alignment determines the winner.
226
+ """
227
+
228
+ MATH_TAB6 = r"""
229
+ ### Grokking: From Memorization to Generalization
230
+ > **Setup:** ReLU activation, 75% training fraction, weight decay $\lambda = 2.0$.
231
+
232
+ Under the train-test split setup, the network quickly memorizes the training set but takes much longer to generalize. Our analysis reveals grokking is a **three-stage process**, each driven by a different balance of forces.
233
+
234
+ **Stage I -- Memorization (loss gradient dominates).** The loss gradient dominates and the network rapidly memorizes training data. Training accuracy reaches 100% while test accuracy reaches only ~70%. The ~70% figure (not ~50%) arises because the architecture is symmetric in $x$ and $y$: since $\theta_m[x] + \theta_m[y]$ is invariant under swapping $(x,y) \leftrightarrow (y,x)$, memorizing $(x,y)$ automatically gives the correct answer for $(y,x)$. The lottery mechanism runs on incomplete data, producing a "noisy" multi-frequency representation. We also observe a **common-to-rare ordering**: the network first memorizes symmetric pairs (both $(i,j)$ and $(j,i)$ in training) while actively *suppressing* rare pairs, before eventually memorizing them too.
235
+
236
+ **Stage II -- Fast Generalization (loss + weight decay).** Weight decay penalizes all magnitudes equally, but the dominant frequency has much larger magnitude and can "afford" the penalty. Non-feature frequencies are driven to zero -- a **sparsification** effect visible as the sharp IPR increase. This transforms the noisy memorization solution into clean single-frequency-per-neuron features. Test accuracy jumps steeply.
237
+
238
+ **Stage III -- Slow Cleanup (weight decay dominates).** The loss gradient becomes negligible (both losses $\approx 0$). Weight decay alone slowly shrinks norms at rate $\partial_t \|w\| = -\lambda \|w\|$. The feature frequencies are already identified; this stage fine-tunes magnitudes. The network transitions from a lookup table to a generalizing algorithm implementing the indicator function from the mechanism (Tab 4).
239
+
240
+ **Four progress measures**: (a) Loss -- train drops in Stage I, test drops in Stage II. (b) Accuracy -- train reaches 100% early, test jumps in Stage II. (c) Phase alignment -- $|\sin(\mathcal{D}_m^\star)|$ decreases throughout. (d) IPR + parameter norms -- IPR increases sharply in Stage II, norms shrink in Stage III.
241
+
242
+ **Epoch Slider**: Use the slider below to see how the accuracy grid evolves across the three stages.
243
+ """
244
+
245
+ MATH_TAB7 = r"""
246
+ ### Training Dynamics: Phase Alignment and Single-Frequency Preservation
247
+ > **Setup:** Quadratic and ReLU activations, full data, single-frequency initialization, SGD.
248
+
249
+ #### The Four-Variable ODE
250
+
251
+ Under small initialization ($\kappa_{\mathrm{init}} \ll 1$), the dynamics decouple: each neuron evolves independently, and within each neuron, different Fourier modes evolve independently (because $\sum_{x \in \mathbb{Z}_p} \cos(\omega_k x) \cos(\omega_\tau x) = \frac{p}{2}\delta_{k,\tau}$). The full dynamics reduce to independent four-variable ODEs per (neuron, frequency):
252
+
253
+ $$\partial_t \alpha \approx 2p \cdot \alpha \cdot \beta \cdot \cos(\mathcal{D}), \qquad \partial_t \beta \approx p \cdot \alpha^2 \cdot \cos(\mathcal{D}),$$
254
+ $$\partial_t \phi \approx 2p \cdot \beta \cdot \sin(\mathcal{D}), \qquad \partial_t \psi \approx -p \cdot \frac{\alpha^2}{\beta} \cdot \sin(\mathcal{D}),$$
255
+
256
+ where $\mathcal{D} = (2\phi - \psi) \bmod 2\pi$ is the **phase misalignment**. This system has a clear physical interpretation: **magnitudes grow when phases are aligned** ($\cos(\mathcal{D}) \approx 1$), and **phases rotate toward alignment** ($\sin(\mathcal{D}) \to 0$). The dynamics self-coordinate: phases align first (while magnitudes are small), then magnitudes explode.
257
+
258
+ #### Phase Alignment Theorem
259
+
260
+ $\mathcal{D}(t) \to 0$ from any initial condition except the measure-zero unstable point $\mathcal{D} = \pi$. The dynamics on the circle behave like an **overdamped pendulum**: $\mathcal{D} = 0$ is a stable attractor, $\mathcal{D} = \pi$ is an unstable repeller. This is not a coincidence or a property of initialization -- it is an **inevitable consequence of the training dynamics**. It explains Observation 2 ($\psi = 2\phi$).
261
+
262
+ #### Single-Frequency Preservation Theorem
263
+
264
+ Under the decoupled flow, if a neuron starts at a single frequency, it remains there for all time. The Fourier orthogonality on $\mathbb{Z}_p$ prevents energy from leaking between modes.
265
+
266
+ **Quadratic** (left panels): Theory matches experiment almost exactly. The DFT heatmap shows the dominant frequency growing while all others stay at zero.
267
+
268
+ **ReLU** (right panels): Same qualitative behavior with minor quantitative differences. Small energy "leaks" to harmonic multiples ($3k^\star, 5k^\star, \ldots$ for input; $2k^\star, 3k^\star, \ldots$ for output). The leakage decays as $O(r^{-2})$ where $r$ is the harmonic order (third harmonic has $1/9$ the strength, fifth has $1/25$), keeping the dominant frequency overwhelmingly dominant.
269
+ """
270
+
271
+ MATH_TAB9 = r"""
272
+ ### Training Log
273
+
274
+ This tab shows the training logs for each of the 5 configurations run for the selected modulo $p$. Select a run from the dropdown to view its hyperparameters and per-epoch training metrics.
275
+
276
+ The 5 training runs are:
277
+ - **standard**: ReLU, full data, no weight decay -- produces the clean Fourier features analyzed in Tabs 1--5
278
+ - **grokking**: ReLU, 75% data, weight decay $\lambda = 2.0$ -- demonstrates the memorization $\to$ generalization transition (Tab 6)
279
+ - **quad_random**: Quadratic activation, full data, random init -- the lottery ticket mechanism (Tab 5)
280
+ - **quad_single_freq**: Quadratic activation, single-frequency init, SGD -- verifies single-frequency preservation (Tab 7)
281
+ - **relu_single_freq**: ReLU, single-frequency init, SGD -- ReLU variant of the dynamics (Tab 7)
282
+ """
283
+
284
+ MATH_TAB8 = r"""
285
+ ### Decoupled Gradient Flow Simulation
286
+ > **Setup:** Analytical ODE integration (no neural network training).
287
+
288
+ This tab shows a pure mathematical simulation of the multi-frequency gradient flow, **without training any neural network**. We numerically integrate the four-variable ODEs for all frequency modes simultaneously within a single neuron:
289
+
290
+ $$\partial_t \alpha_k \approx 2p \cdot \alpha_k \cdot \beta_k \cdot \cos(\mathcal{D}_k), \qquad \partial_t \beta_k \approx p \cdot \alpha_k^2 \cdot \cos(\mathcal{D}_k),$$
291
+ $$\partial_t \phi_k \approx 2p \cdot \beta_k \cdot \sin(\mathcal{D}_k), \qquad \partial_t \psi_k \approx -p \cdot \frac{\alpha_k^2}{\beta_k} \cdot \sin(\mathcal{D}_k),$$
292
+
293
+ for each frequency $k = 1, \ldots, (p{-}1)/2$, with random initial conditions.
294
+
295
+ The simulation confirms the theoretical predictions from Tab 7:
296
+
297
+ - **Phase alignment:** Phase misalignments $\mathcal{D}_k = (2\phi_k - \psi_k) \bmod 2\pi$ converge to $0$ for most frequencies, or linger near $\pi$ (the unstable repeller) before eventually escaping.
298
+ - **Magnitude competition:** Magnitudes grow explosively for the frequency where $\mathcal{D}_k \approx 0$ first, while others remain near their initial level.
299
+ - **Lottery outcome:** The winning frequency (smallest initial $\mathcal{D}_k$) dominates all others, reproducing the full lottery ticket mechanism without any neural network -- just ODEs.
300
+
301
+ Two cases are shown with different initial conditions to illustrate that the mechanism is robust: regardless of the random starting point, the frequency with the best initial phase alignment always wins.
302
+ """
303
+
304
+
305
+ # ---------------------------------------------------------------------------
306
+ # Data loading helpers
307
+ # ---------------------------------------------------------------------------
308
+
309
+ MIN_P = 3 # p=2 has 0 non-DC Fourier frequencies; analysis is degenerate
310
+
311
+
312
+ def get_available_moduli():
313
+ """Discover which p values have pre-computed results (odd p >= 3)."""
314
+ moduli = []
315
+ if os.path.exists(RESULTS_DIR):
316
+ for d in sorted(os.listdir(RESULTS_DIR)):
317
+ if d.startswith("p_"):
318
+ try:
319
+ p = int(d.split("_")[1])
320
+ if p >= MIN_P:
321
+ moduli.append(p)
322
+ except ValueError:
323
+ pass
324
+ return moduli
325
+
326
+
327
+ def _prime_dir(p):
328
+ return os.path.join(RESULTS_DIR, f"p_{p:03d}")
329
+
330
+
331
+ def load_json_file(p, filename):
332
+ """Load a JSON file from the prime's directory."""
333
+ path = os.path.join(_prime_dir(p), f"p{p:03d}_{filename}")
334
+ if os.path.exists(path):
335
+ with open(path) as f:
336
+ return json.load(f)
337
+ return None
338
+
339
+
340
+ def safe_img(p, filename):
341
+ """Return image path or None (Gradio handles None gracefully)."""
342
+ path = os.path.join(_prime_dir(p), f"p{p:03d}_{filename}")
343
+ return path if os.path.exists(path) else None
344
+
345
+
346
+ # ---------------------------------------------------------------------------
347
+ # Interactive Plotly chart builders
348
+ # ---------------------------------------------------------------------------
349
+
350
+ def _to_np(v):
351
+ """Convert a list/value to a numpy array (bypasses plotly's pandas check)."""
352
+ if v is None:
353
+ return None
354
+ return np.asarray(v)
355
+
356
+
357
+ def make_loss_chart(data, title="Training Loss"):
358
+ """Build an interactive Plotly loss chart from JSON data."""
359
+ if data is None:
360
+ return None
361
+ fig = go.Figure()
362
+ n = len(data.get('train_losses', []))
363
+ epochs = np.arange(n)
364
+
365
+ fig.add_trace(go.Scatter(
366
+ x=epochs, y=_to_np(data['train_losses']),
367
+ name='Train Loss', line=dict(color=COLORS[0]),
368
+ ))
369
+ if 'test_losses' in data:
370
+ fig.add_trace(go.Scatter(
371
+ x=epochs, y=_to_np(data['test_losses']),
372
+ name='Test Loss', line=dict(color=COLORS[3]),
373
+ ))
374
+
375
+ s1 = data.get('stage1_end')
376
+ s2 = data.get('stage2_end')
377
+ if s1 is not None:
378
+ fig.add_vrect(x0=0, x1=s1, fillcolor=STAGE_COLORS[0],
379
+ line_width=0, annotation_text="Memorization",
380
+ annotation_position="top left")
381
+ if s1 is not None and s2 is not None:
382
+ fig.add_vrect(x0=s1, x1=s2, fillcolor=STAGE_COLORS[1],
383
+ line_width=0, annotation_text="Transition",
384
+ annotation_position="top left")
385
+ if s2 is not None:
386
+ fig.add_vrect(x0=s2, x1=n, fillcolor=STAGE_COLORS[2],
387
+ line_width=0, annotation_text="Generalization",
388
+ annotation_position="top left")
389
+
390
+ fig.update_layout(
391
+ title=title, xaxis_title='Epoch', yaxis_title='Loss',
392
+ template='plotly_white', height=400,
393
+ legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
394
+ )
395
+ return fig
396
+
397
+
398
+ def make_acc_chart(data, title="Training Accuracy"):
399
+ """Build an interactive Plotly accuracy chart."""
400
+ if data is None:
401
+ return None
402
+ fig = go.Figure()
403
+ epochs = _to_np(data.get('epochs', list(range(len(data.get('train_accs', []))))))
404
+
405
+ fig.add_trace(go.Scatter(
406
+ x=epochs, y=_to_np(data['train_accs']),
407
+ name='Train Acc', line=dict(color=COLORS[0]),
408
+ ))
409
+ if 'test_accs' in data:
410
+ fig.add_trace(go.Scatter(
411
+ x=epochs, y=_to_np(data['test_accs']),
412
+ name='Test Acc', line=dict(color=COLORS[3]),
413
+ ))
414
+
415
+ s1 = data.get('stage1_end')
416
+ s2 = data.get('stage2_end')
417
+ if s1 is not None:
418
+ fig.add_vrect(x0=0, x1=s1, fillcolor=STAGE_COLORS[0], line_width=0)
419
+ if s1 is not None and s2 is not None:
420
+ fig.add_vrect(x0=s1, x1=s2, fillcolor=STAGE_COLORS[1], line_width=0)
421
+ if s2 is not None:
422
+ n = int(epochs.max()) if len(epochs) > 0 else len(data.get('train_accs', []))
423
+ fig.add_vrect(x0=s2, x1=n, fillcolor=STAGE_COLORS[2], line_width=0)
424
+
425
+ fig.update_layout(
426
+ title=title, xaxis_title='Epoch', yaxis_title='Accuracy',
427
+ template='plotly_white', height=400,
428
+ legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
429
+ )
430
+ return fig
431
+
432
+
433
+
434
+ def make_neuron_spectrum_chart(data, neuron_key):
435
+ """Build a Plotly bar chart for a single neuron's Fourier spectrum."""
436
+ if data is None or neuron_key not in data.get('neurons', {}):
437
+ return None
438
+ neuron = data['neurons'][neuron_key]
439
+ names = data.get('fourier_basis_names', [])
440
+ mags_in = _to_np(neuron['fourier_magnitudes_in'])
441
+ mags_out = _to_np(neuron['fourier_magnitudes_out'])
442
+ dom_freq = neuron.get('dominant_freq', '?')
443
+
444
+ fig = go.Figure()
445
+ fig.add_trace(go.Bar(
446
+ x=names, y=mags_in, name='W_in magnitude',
447
+ marker_color=COLORS[0], opacity=0.8,
448
+ ))
449
+ fig.add_trace(go.Bar(
450
+ x=names, y=mags_out, name='W_out magnitude',
451
+ marker_color=COLORS[3], opacity=0.8,
452
+ ))
453
+ fig.update_layout(
454
+ title=f"Neuron {neuron_key} (dominant freq={dom_freq})",
455
+ xaxis_title='Fourier Component',
456
+ yaxis_title='Magnitude',
457
+ barmode='group',
458
+ template='plotly_white', height=350,
459
+ )
460
+ return fig
461
+
462
+
463
+ def make_logit_bar_chart(data, pair_index):
464
+ """Build a Plotly bar chart of logits for a specific (a,b) pair."""
465
+ if data is None:
466
+ return None
467
+ pairs = data.get('pairs', [])
468
+ logits_all = data.get('logits', [])
469
+ correct = data.get('correct_answers', [])
470
+ classes = data.get('output_classes', [])
471
+
472
+ if pair_index >= len(pairs):
473
+ return None
474
+
475
+ a, b = pairs[pair_index]
476
+ logits = _to_np(logits_all[pair_index])
477
+ correct_ans = correct[pair_index]
478
+
479
+ bar_colors = [COLORS[3] if c == correct_ans else COLORS[0] for c in classes]
480
+
481
+ fig = go.Figure()
482
+ fig.add_trace(go.Bar(
483
+ x=[str(c) for c in classes], y=logits,
484
+ marker_color=bar_colors,
485
+ hovertemplate='Class %{x}: %{y:.3f}<extra></extra>',
486
+ ))
487
+ fig.update_layout(
488
+ title=f"Logits for ({a}, {b}) -- correct = {correct_ans}",
489
+ xaxis_title='Output Class',
490
+ yaxis_title='Logit Value',
491
+ template='plotly_white', height=350,
492
+ )
493
+ return fig
494
+
495
+
496
+ def make_grokk_heatmap(data, epoch_index):
497
+ """Build a Plotly heatmap of accuracy grid at a grokking checkpoint."""
498
+ if data is None:
499
+ return None
500
+ epochs = data.get('epochs', [])
501
+ grids = data.get('grids', [])
502
+ if epoch_index >= len(grids):
503
+ return None
504
+
505
+ grid = _to_np(grids[epoch_index])
506
+ ep = epochs[epoch_index]
507
+
508
+ fig = go.Figure(data=go.Heatmap(
509
+ z=grid,
510
+ colorscale=[[0, 'white'], [1, COLORS[0]]],
511
+ zmin=0, zmax=1,
512
+ hovertemplate='a=%{y}, b=%{x}: %{z:.0f}<extra></extra>',
513
+ ))
514
+ fig.update_layout(
515
+ title=f"Accuracy Grid at Epoch {ep}",
516
+ xaxis_title='Second Input (b)',
517
+ yaxis_title='First Input (a)',
518
+ template='plotly_white', height=450,
519
+ yaxis=dict(autorange='reversed'),
520
+ )
521
+ return fig
522
+
523
+
524
+ # ---------------------------------------------------------------------------
525
+ # Tab update functions
526
+ # ---------------------------------------------------------------------------
527
+
528
+ def update_tab1(p):
529
+ """Overview: standard + grokking loss/IPR, phase scatter."""
530
+ img_overview = safe_img(p, "overview_loss_ipr.png")
531
+ img_phase = safe_img(p, "overview_phase_scatter.png")
532
+ # Also build interactive charts from overview.json
533
+ data = load_json_file(p, "overview.json")
534
+ std_loss_chart = None
535
+ grokk_loss_chart = None
536
+ std_ipr_chart = None
537
+ grokk_ipr_chart = None
538
+
539
+ if data:
540
+ # Standard loss chart
541
+ std_ep = data.get('std_epochs', [])
542
+ std_tl = data.get('std_train_loss', [])
543
+ if std_tl:
544
+ fig = go.Figure()
545
+ fig.add_trace(go.Scatter(
546
+ x=_to_np(std_ep[:len(std_tl)]), y=_to_np(std_tl),
547
+ name='Train Loss', line=dict(color=COLORS[0]),
548
+ ))
549
+ fig.update_layout(
550
+ title='Standard: Training Loss (ReLU, full data)',
551
+ xaxis_title='Step', yaxis_title='Loss',
552
+ template='plotly_white', height=350,
553
+ )
554
+ std_loss_chart = fig
555
+
556
+ # Standard IPR chart
557
+ std_ipr = data.get('std_ipr', [])
558
+ if std_ipr:
559
+ fig = go.Figure()
560
+ fig.add_trace(go.Scatter(
561
+ x=_to_np(std_ep[:len(std_ipr)]), y=_to_np(std_ipr),
562
+ name='Avg IPR', line=dict(color=COLORS[3]),
563
+ ))
564
+ fig.update_layout(
565
+ title='Standard: IPR (Fourier Sparsity)',
566
+ xaxis_title='Step', yaxis_title='IPR',
567
+ yaxis=dict(range=[0, 1.05]),
568
+ template='plotly_white', height=350,
569
+ )
570
+ std_ipr_chart = fig
571
+
572
+ # Grokking loss chart
573
+ grokk_ep = data.get('grokk_epochs', [])
574
+ grokk_tl = data.get('grokk_train_loss', [])
575
+ grokk_tel = data.get('grokk_test_loss', [])
576
+ if grokk_tl or grokk_tel:
577
+ fig = go.Figure()
578
+ if grokk_tl:
579
+ fig.add_trace(go.Scatter(
580
+ x=_to_np(grokk_ep[:len(grokk_tl)]), y=_to_np(grokk_tl),
581
+ name='Train Loss', line=dict(color=COLORS[0]),
582
+ ))
583
+ if grokk_tel:
584
+ fig.add_trace(go.Scatter(
585
+ x=_to_np(grokk_ep[:len(grokk_tel)]), y=_to_np(grokk_tel),
586
+ name='Test Loss', line=dict(color=COLORS[3]),
587
+ ))
588
+ fig.update_layout(
589
+ title='Grokking: Loss (ReLU, 75% data, WD)',
590
+ xaxis_title='Step', yaxis_title='Loss',
591
+ template='plotly_white', height=350,
592
+ )
593
+ grokk_loss_chart = fig
594
+
595
+ # Grokking IPR chart
596
+ grokk_ipr = data.get('grokk_ipr', [])
597
+ if grokk_ipr:
598
+ fig = go.Figure()
599
+ fig.add_trace(go.Scatter(
600
+ x=_to_np(grokk_ep[:len(grokk_ipr)]), y=_to_np(grokk_ipr),
601
+ name='Avg IPR', line=dict(color=COLORS[3]),
602
+ ))
603
+ fig.update_layout(
604
+ title='Grokking: IPR (weight decay drives sparsification)',
605
+ xaxis_title='Step', yaxis_title='IPR',
606
+ yaxis=dict(range=[0, 1.05]),
607
+ template='plotly_white', height=350,
608
+ )
609
+ grokk_ipr_chart = fig
610
+
611
+ return (img_overview, std_loss_chart, grokk_loss_chart,
612
+ std_ipr_chart, grokk_ipr_chart, img_phase)
613
+
614
+
615
+ def update_tab2(p):
616
+ """Fourier Weights: heatmap + line plots."""
617
+ return (
618
+ safe_img(p, "full_training_para_origin.png"),
619
+ safe_img(p, "lineplot_in.png"),
620
+ safe_img(p, "lineplot_out.png"),
621
+ )
622
+
623
+
624
+ def update_tab3(p):
625
+ """Phase Analysis: distribution, relationship, magnitude."""
626
+ return (
627
+ safe_img(p, "phase_distribution.png"),
628
+ safe_img(p, "phase_relationship.png"),
629
+ safe_img(p, "magnitude_distribution.png"),
630
+ )
631
+
632
+
633
+ def update_tab4(p):
634
+ """Output Logits."""
635
+ return safe_img(p, "output_logits.png")
636
+
637
+
638
+ def update_tab5(p):
639
+ """Lottery Mechanism: magnitude, phase, contour."""
640
+ return (
641
+ safe_img(p, "lottery_mech_magnitude.png"),
642
+ safe_img(p, "lottery_mech_phase.png"),
643
+ safe_img(p, "lottery_beta_contour.png"),
644
+ )
645
+
646
+
647
+ def update_tab6(p):
648
+ """Grokking: loss/acc charts + analysis images."""
649
+ loss_data = load_json_file(p, "grokk_loss.json")
650
+ acc_data = load_json_file(p, "grokk_acc.json")
651
+ loss_chart = make_loss_chart(loss_data, title="Grokking: Loss")
652
+ acc_chart = make_acc_chart(acc_data, title="Grokking: Accuracy")
653
+ return (
654
+ loss_chart,
655
+ acc_chart,
656
+ safe_img(p, "grokk_abs_phase_diff.png"),
657
+ safe_img(p, "grokk_avg_ipr.png"),
658
+ safe_img(p, "grokk_memorization_accuracy.png"),
659
+ safe_img(p, "grokk_memorization_common_to_rare.png"),
660
+ safe_img(p, "grokk_decoded_weights_dynamic.png"),
661
+ )
662
+
663
+
664
+ def update_tab7(p):
665
+ """Gradient Dynamics: Quad and ReLU single-freq."""
666
+ return (
667
+ safe_img(p, "phase_align_quad.png"),
668
+ safe_img(p, "single_freq_quad.png"),
669
+ safe_img(p, "phase_align_relu.png"),
670
+ safe_img(p, "single_freq_relu.png"),
671
+ )
672
+
673
+
674
+ def update_tab8(p):
675
+ """Decoupled Simulation: 2 analytical gradient flow plots."""
676
+ return (
677
+ safe_img(p, "phase_align_approx1.png"),
678
+ safe_img(p, "phase_align_approx2.png"),
679
+ )
680
+
681
+
682
+ def update_tab9(p):
683
+ """Training Log: return available run names and initial log."""
684
+ data = load_json_file(p, "training_log.json")
685
+ if data is None:
686
+ return [], None, "", ""
687
+ run_names = list(data.keys())
688
+ # Show first run by default
689
+ first_run = run_names[0] if run_names else None
690
+ if first_run:
691
+ run_data = data[first_run]
692
+ config = run_data.get('config', {})
693
+ config_text = _format_config_md(first_run, config)
694
+ log_text = run_data.get('log_text', 'No log available.')
695
+ else:
696
+ config_text = ""
697
+ log_text = ""
698
+ return run_names, first_run, config_text, log_text
699
+
700
+
701
+ def _format_config_md(run_name, config):
702
+ """Format a run config as a Markdown summary."""
703
+ lines = [f"**Run: {run_name}**\n"]
704
+ key_labels = {
705
+ 'prime': 'Modulo (p)', 'd_mlp': 'd_mlp',
706
+ 'act_type': 'Activation', 'init_type': 'Init Type',
707
+ 'init_scale': 'Init Scale', 'optimizer': 'Optimizer',
708
+ 'lr': 'Learning Rate', 'weight_decay': 'Weight Decay',
709
+ 'frac_train': 'Frac Train', 'num_epochs': 'Num Epochs',
710
+ 'seed': 'Seed',
711
+ }
712
+ for key, label in key_labels.items():
713
+ val = config.get(key, 'N/A')
714
+ lines.append(f"- **{label}**: `{val}`")
715
+ return "\n".join(lines)
716
+
717
+
718
+ def update_info(p):
719
+ meta = load_json_file(p, "metadata.json")
720
+ if not meta:
721
+ return f"**p = {p}** | No metadata available"
722
+ d_mlp = meta.get('d_mlp', '?')
723
+ parts = [f"**p = {p}**", f"d_mlp = {d_mlp}"]
724
+ std_metrics = meta.get('final_metrics', {}).get('standard', {})
725
+ if 'train_acc' in std_metrics:
726
+ parts.append(f"Train Acc = {std_metrics['train_acc']:.4f}")
727
+ if 'test_acc' in std_metrics:
728
+ parts.append(f"Test Acc = {std_metrics['test_acc']:.4f}")
729
+ if 'train_loss' in std_metrics:
730
+ parts.append(f"Train Loss = {std_metrics['train_loss']:.6f}")
731
+ return " | ".join(parts)
732
+
733
+
734
+ # ---------------------------------------------------------------------------
735
+ # Interactive callback helpers
736
+ # ---------------------------------------------------------------------------
737
+
738
+ def _get_neuron_choices(p):
739
+ """Return list of neuron keys from neuron_spectra.json."""
740
+ data = load_json_file(p, "neuron_spectra.json")
741
+ if data is None:
742
+ return []
743
+ return list(data.get('neurons', {}).keys())
744
+
745
+
746
+ def _get_pair_choices(p):
747
+ """Return list of (a,b) pair labels from logits_interactive.json."""
748
+ data = load_json_file(p, "logits_interactive.json")
749
+ if data is None:
750
+ return []
751
+ pairs = data.get('pairs', [])
752
+ return [f"({a}, {b})" for a, b in pairs]
753
+
754
+
755
+ def _get_grokk_epochs(p):
756
+ """Return list of epoch values from grokk_epoch_data.json."""
757
+ data = load_json_file(p, "grokk_epoch_data.json")
758
+ if data is None:
759
+ return []
760
+ return data.get('epochs', [])
761
+
762
+
763
+ # ---------------------------------------------------------------------------
764
+ # Markdown helper -- ensures latex_delimiters are set
765
+ # ---------------------------------------------------------------------------
766
+
767
+ def _md(text, **kwargs):
768
+ """Create a gr.Markdown with KaTeX delimiters enabled."""
769
+ return gr.Markdown(text, latex_delimiters=LATEX_DELIMITERS, **kwargs)
770
+
771
+
772
+ # ---------------------------------------------------------------------------
773
+ # Main app
774
+ # ---------------------------------------------------------------------------
775
+
776
+ def on_p_change(p_str):
777
+ """Called when the prime dropdown changes. Returns all outputs."""
778
+ p = int(p_str)
779
+
780
+ info = update_info(p)
781
+
782
+ # Overview
783
+ (t1_img_overview, t1_std_loss, t1_grokk_loss,
784
+ t1_std_ipr, t1_grokk_ipr, t1_phase_scatter) = update_tab1(p)
785
+ # Core Interpretability
786
+ t2_heatmap, t2_line_in, t2_line_out = update_tab2(p)
787
+ t3_phase_dist, t3_phase_rel, t3_magnitude = update_tab3(p)
788
+ t4_logits = update_tab4(p)
789
+ t5_mag, t5_phase, t5_contour = update_tab5(p)
790
+ # Grokking
791
+ (t6_loss, t6_acc, t6_phase_diff, t6_ipr,
792
+ t6_memo, t6_memo_rare, t6_decoded) = update_tab6(p)
793
+ # Theory
794
+ t7_pa_quad, t7_sf_quad, t7_pa_relu, t7_sf_relu = update_tab7(p)
795
+ t8_approx1, t8_approx2 = update_tab8(p)
796
+
797
+ # Training Log
798
+ t9_run_names, t9_default_run, t9_config_text, t9_log = update_tab9(p)
799
+ t9_run_dd_update = gr.update(
800
+ choices=t9_run_names,
801
+ value=t9_default_run,
802
+ )
803
+
804
+ # Interactive widget updates
805
+ neuron_choices = _get_neuron_choices(p)
806
+ neuron_dd_update = gr.update(
807
+ choices=neuron_choices,
808
+ value=neuron_choices[0] if neuron_choices else None,
809
+ )
810
+ neuron_spectra_data = load_json_file(p, "neuron_spectra.json")
811
+ neuron_chart = make_neuron_spectrum_chart(
812
+ neuron_spectra_data, neuron_choices[0]
813
+ ) if neuron_choices else None
814
+
815
+ pair_choices = _get_pair_choices(p)
816
+ pair_dd_update = gr.update(
817
+ choices=pair_choices,
818
+ value=pair_choices[0] if pair_choices else None,
819
+ )
820
+ logit_data = load_json_file(p, "logits_interactive.json")
821
+ logit_chart = make_logit_bar_chart(logit_data, 0) if pair_choices else None
822
+
823
+ grokk_epochs = _get_grokk_epochs(p)
824
+ if grokk_epochs:
825
+ slider_update = gr.update(
826
+ minimum=0, maximum=len(grokk_epochs) - 1, value=0, step=1,
827
+ visible=True,
828
+ )
829
+ else:
830
+ slider_update = gr.update(minimum=0, maximum=0, value=0, visible=False)
831
+ grokk_slider_data = load_json_file(p, "grokk_epoch_data.json")
832
+ grokk_heatmap = make_grokk_heatmap(grokk_slider_data, 0) if grokk_epochs else None
833
+ epoch_label = f"Epoch: {grokk_epochs[0]}" if grokk_epochs else ""
834
+
835
+ return [
836
+ info,
837
+ # Tab 1: Overview
838
+ t1_img_overview, t1_std_loss, t1_grokk_loss,
839
+ t1_std_ipr, t1_grokk_ipr, t1_phase_scatter,
840
+ # Tab 2: Fourier Weights
841
+ t2_heatmap, t2_line_in, t2_line_out,
842
+ neuron_dd_update, neuron_chart,
843
+ # Tab 3: Phase Analysis
844
+ t3_phase_dist, t3_phase_rel, t3_magnitude,
845
+ # Tab 4: Output Logits
846
+ t4_logits,
847
+ pair_dd_update, logit_chart,
848
+ # Tab 5: Lottery Mechanism
849
+ t5_mag, t5_phase, t5_contour,
850
+ # Tab 6: Grokking
851
+ t6_loss, t6_acc, t6_phase_diff, t6_ipr,
852
+ t6_memo, t6_memo_rare, t6_decoded,
853
+ slider_update, grokk_heatmap, epoch_label,
854
+ # Tab 7: Gradient Dynamics
855
+ t7_pa_quad, t7_sf_quad, t7_pa_relu, t7_sf_relu,
856
+ # Tab 8: Decoupled Simulation
857
+ t8_approx1, t8_approx2,
858
+ # Tab 9: Training Log
859
+ t9_run_dd_update, t9_config_text, t9_log,
860
+ ]
861
+
862
+
863
+ def _commit_results_to_repo(p):
864
+ """Try to commit new precomputed results back to the HF Space repo.
865
+
866
+ On HF Spaces, the repo is writable via the huggingface_hub API.
867
+ This allows results to accumulate as users generate them.
868
+ Returns (success, message).
869
+ """
870
+ try:
871
+ from huggingface_hub import HfApi
872
+ except ImportError:
873
+ return False, "huggingface_hub not installed"
874
+
875
+ space_id = os.environ.get("SPACE_ID") # e.g. "username/space-name"
876
+ if not space_id:
877
+ return False, "Not running on HF Spaces (SPACE_ID not set)"
878
+
879
+ result_dir = os.path.join(RESULTS_DIR, f"p_{p:03d}")
880
+ if not os.path.isdir(result_dir):
881
+ return False, "No results directory found"
882
+
883
+ try:
884
+ api = HfApi()
885
+ api.upload_folder(
886
+ folder_path=result_dir,
887
+ path_in_repo=f"precomputed_results/p_{p:03d}",
888
+ repo_id=space_id,
889
+ repo_type="space",
890
+ commit_message=f"Add precomputed results for p={p}",
891
+ )
892
+ return True, f"Committed results for p={p} to {space_id}"
893
+ except Exception as e:
894
+ logger.warning(f"Failed to commit results for p={p}: {e}")
895
+ return False, str(e)
896
+
897
+
898
+ def _run_step_streaming(cmd, env, label):
899
+ """Run a subprocess, yielding (line, error_flag) for each output line."""
900
+ proc = subprocess.Popen(
901
+ cmd, cwd=PROJECT_ROOT, env=env,
902
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
903
+ text=True, bufsize=1,
904
+ )
905
+ for line in proc.stdout:
906
+ yield line.rstrip("\n"), False
907
+ proc.wait()
908
+ if proc.returncode != 0:
909
+ yield f"[ERROR] {label} failed (exit code {proc.returncode})", True
910
+
911
+
912
+ def run_pipeline_for_p_streaming(p):
913
+ """Generator: run full pipeline for p, yielding log lines.
914
+
915
+ Yields (log_line: str, is_error: bool, is_done: bool).
916
+ Deletes model checkpoints after plot generation to save space.
917
+ """
918
+ if p < 3 or p % 2 == 0:
919
+ yield f"Error: p must be an odd number >= 3, got {p}", True, True
920
+ return
921
+ if p > MAX_P_ON_DEMAND:
922
+ yield f"Error: p={p} exceeds on-demand limit of {MAX_P_ON_DEMAND}", True, True
923
+ return
924
+
925
+ result_dir = os.path.join(RESULTS_DIR, f"p_{p:03d}")
926
+ if os.path.isdir(result_dir) and len(os.listdir(result_dir)) > 5:
927
+ yield f"Results for p={p} already exist ({len(os.listdir(result_dir))} files)", False, True
928
+ return
929
+
930
+ env = os.environ.copy()
931
+ env["PYTHONPATH"] = PROJECT_ROOT + ":" + env.get("PYTHONPATH", "")
932
+
933
+ steps = [
934
+ ("Step 1/3: Training 5 configurations", [
935
+ sys.executable, "precompute/train_all.py",
936
+ "--p", str(p), "--output", TRAINED_MODELS_DIR, "--resume",
937
+ ]),
938
+ ("Step 2/3: Generating model-based plots", [
939
+ sys.executable, "precompute/generate_plots.py",
940
+ "--p", str(p), "--input", TRAINED_MODELS_DIR,
941
+ "--output", RESULTS_DIR,
942
+ ]),
943
+ ("Step 3/3: Generating analytical plots", [
944
+ sys.executable, "precompute/generate_analytical.py",
945
+ "--p", str(p), "--output", RESULTS_DIR,
946
+ ]),
947
+ ]
948
+
949
+ for label, cmd in steps:
950
+ yield f"\n{'='*60}", False, False
951
+ yield f" {label} (p={p})", False, False
952
+ yield f"{'='*60}", False, False
953
+ for line, is_err in _run_step_streaming(cmd, env, label):
954
+ if is_err:
955
+ yield line, True, True
956
+ return
957
+ yield line, False, False
958
+
959
+ # Cleanup checkpoints
960
+ model_dir = os.path.join(TRAINED_MODELS_DIR, f"p_{p:03d}")
961
+ if os.path.isdir(model_dir):
962
+ shutil.rmtree(model_dir)
963
+ yield f"Cleaned up checkpoints: {model_dir}", False, False
964
+
965
+ n_files = len(os.listdir(result_dir)) if os.path.isdir(result_dir) else 0
966
+
967
+ # Try to commit results back to the HF repo
968
+ ok_commit, commit_msg = _commit_results_to_repo(p)
969
+ if ok_commit:
970
+ yield f"Results saved to HF repo.", False, False
971
+
972
+ yield f"\nDone! Generated {n_files} files for p={p}.", False, True
973
+
974
+
975
+ def create_app():
976
+ moduli = get_available_moduli()
977
+ p_choices = [str(p) for p in moduli]
978
+ default_p = p_choices[0] if p_choices else None
979
+
980
+ with gr.Blocks(
981
+ title="Modular Addition Feature Learning",
982
+ ) as app:
983
+ _md(
984
+ r"# On the Mechanism and Dynamics of Modular Addition" "\n"
985
+ r"### Fourier Features, Lottery Ticket, and Grokking" "\n\n"
986
+ r"**Jianliang He, Leda Wang, Siyu Chen, Zhuoran Yang**" "\n"
987
+ r"*Department of Statistics and Data Science, Yale University*" "\n\n"
988
+ r'<a href="#">[arXiv]</a> &nbsp; '
989
+ r'<a href="#">[Blog]</a> &nbsp; '
990
+ r'<a href="https://github.com/Y-Agent/modular-addition-feature-learning">[Code]</a>' "\n\n"
991
+ r"---" "\n\n"
992
+ r"This interactive explorer visualizes how a two-layer neural network "
993
+ r"learns modular arithmetic $f(x,y) = (x + y) \bmod p$ through "
994
+ r"Fourier feature learning, lottery ticket dynamics, and the grokking "
995
+ r"phenomenon. Select a modulo $p$ (any odd number $\geq 3$) below to view pre-computed results." "\n\n"
996
+ r"> **Note:** Grokking experiments (Tab 6) require $p \geq 19$ to have enough data for a meaningful train/test split. "
997
+ r"For $p < 19$, grokking plots will not be generated."
998
+ )
999
+
1000
+ # Hidden state for current modulo
1001
+ current_p = gr.State(value=int(default_p) if default_p else 3)
1002
+
1003
+ with gr.Row():
1004
+ p_dropdown = gr.Dropdown(
1005
+ choices=p_choices,
1006
+ value=default_p,
1007
+ label="Select Modulo (p)",
1008
+ interactive=True,
1009
+ scale=2,
1010
+ )
1011
+ info_md = _md(
1012
+ update_info(int(default_p)) if default_p else ""
1013
+ )
1014
+
1015
+ with gr.Accordion("Generate results for a new p", open=False):
1016
+ _md(
1017
+ f"Enter any odd number $p \\geq 3$ (max {MAX_P_ON_DEMAND} "
1018
+ f"for on-demand training). This will train 5 model "
1019
+ f"configurations and generate all plots. Training logs "
1020
+ f"are streamed below in real time."
1021
+ )
1022
+ with gr.Row():
1023
+ new_p_input = gr.Number(
1024
+ value=None, label="New p (odd, ≥ 3)",
1025
+ precision=0, scale=1,
1026
+ )
1027
+ generate_btn = gr.Button(
1028
+ "Generate", variant="primary", scale=1,
1029
+ )
1030
+ generate_status = _md("")
1031
+ generate_log = gr.Code(
1032
+ value="", language=None, label="Pipeline Log",
1033
+ lines=15, interactive=False, visible=False,
1034
+ )
1035
+
1036
+ # ----- Tabs -----
1037
+ with gr.Tabs():
1038
+
1039
+ # === Core Interpretability ===
1040
+
1041
+ # Tab 1: Overview
1042
+ with gr.Tab("1. Overview"):
1043
+ _md(MATH_TAB1)
1044
+ t1_img_overview = gr.Image(
1045
+ label="Loss & IPR Overview (Static)", type="filepath"
1046
+ )
1047
+ with gr.Row():
1048
+ t1_std_loss = gr.Plot(label="Standard: Loss")
1049
+ t1_grokk_loss = gr.Plot(label="Grokking: Loss")
1050
+ with gr.Row():
1051
+ t1_std_ipr = gr.Plot(label="Standard: IPR")
1052
+ t1_grokk_ipr = gr.Plot(label="Grokking: IPR")
1053
+ t1_phase_scatter = gr.Image(
1054
+ label="Phase Alignment: \u03c8 = 2\u03c6", type="filepath"
1055
+ )
1056
+
1057
+ # Tab 2: Fourier Weights
1058
+ with gr.Tab("2. Fourier Weights"):
1059
+ _md(MATH_TAB2)
1060
+ t2_heatmap = gr.Image(label="Decoded W_in / W_out Heatmap", type="filepath")
1061
+ with gr.Row():
1062
+ t2_line_in = gr.Image(label="First-Layer Line Plots (with cosine fit)", type="filepath")
1063
+ t2_line_out = gr.Image(label="Second-Layer Line Plots (with cosine fit)", type="filepath")
1064
+ _md("#### Neuron Frequency Inspector")
1065
+ t2_neuron_dd = gr.Dropdown(
1066
+ choices=[], value=None,
1067
+ label="Select Neuron", interactive=True,
1068
+ )
1069
+ t2_neuron_chart = gr.Plot(label="Neuron Fourier Spectrum")
1070
+
1071
+ # Tab 3: Phase Analysis
1072
+ with gr.Tab("3. Phase Analysis"):
1073
+ _md(MATH_TAB3)
1074
+ with gr.Row():
1075
+ t3_phase_dist = gr.Image(label="Phase Distribution", type="filepath")
1076
+ t3_phase_rel = gr.Image(
1077
+ label="Phase Relationship (2\u03c6 vs \u03c8)", type="filepath"
1078
+ )
1079
+ t3_magnitude = gr.Image(label="Magnitude Distribution", type="filepath")
1080
+
1081
+ # Tab 4: Output Logits
1082
+ with gr.Tab("4. Output Logits"):
1083
+ _md(MATH_TAB4)
1084
+ t4_logits = gr.Image(label="Output Logits Heatmap", type="filepath")
1085
+ _md("#### Logit Explorer")
1086
+ t4_pair_dd = gr.Dropdown(
1087
+ choices=[], value=None,
1088
+ label="Select Input Pair (a, b)", interactive=True,
1089
+ )
1090
+ t4_logit_chart = gr.Plot(label="Logit Distribution")
1091
+
1092
+ # Tab 5: Lottery Mechanism
1093
+ with gr.Tab("5. Lottery Mechanism"):
1094
+ _md(MATH_TAB5)
1095
+ _md(r"""**Magnitude plot** below: Each curve tracks one frequency's output magnitude $\beta_k$ within a single neuron over training. All frequencies start with equal magnitude (from random initialization). The winning frequency (best initial phase alignment) grows explosively while others remain frozen.""")
1096
+ t5_mag = gr.Image(label="Frequency Magnitude Evolution", type="filepath")
1097
+ _md(r"""**Phase plot** below: Each curve shows the phase misalignment $\mathcal{D}_k = 2\phi_k - \psi_k$ for one frequency within the same neuron. The winning frequency (colored) converges to $\mathcal{D} = 0$ (perfect alignment) first; other frequencies barely change because their magnitudes remain small.""")
1098
+ t5_phase = gr.Image(label="Phase Misalignment Convergence", type="filepath")
1099
+ _md(r"""**Contour plot** below: Final output magnitude as a function of initial magnitude and initial phase misalignment, across all neurons. The largest final magnitudes (brightest regions) appear at small initial misalignment $|\mathcal{D}|$, confirming that initial phase alignment -- not initial magnitude -- determines which frequency wins.""")
1100
+ t5_contour = gr.Image(label="Final Magnitude Contour", type="filepath")
1101
+
1102
+ # === Grokking ===
1103
+
1104
+ # Tab 6: Grokking
1105
+ with gr.Tab("6. Grokking"):
1106
+ _md(MATH_TAB6)
1107
+
1108
+ _md(r"""#### (a) Loss and (b) Accuracy
1109
+
1110
+ **(a) Loss:** Training loss (blue) drops rapidly in Stage I as the network memorizes training data. Test loss (red) stays high until Stage II, when weight decay forces the network to find a generalizing solution, causing test loss to plummet. The three colored bands mark the three stages.
1111
+
1112
+ **(b) Accuracy:** Training accuracy reaches 100% early (Stage I). Test accuracy stays at ~70% during memorization (not 50% -- the built-in symmetry $f(a,b) = f(b,a)$ gives "free" correct answers for the swapped pair). Test accuracy jumps sharply in Stage II when the network transitions from memorization to Fourier features.""")
1113
+ with gr.Row():
1114
+ t6_loss = gr.Plot(label="Grokking Loss (Interactive)")
1115
+ t6_acc = gr.Plot(label="Grokking Accuracy (Interactive)")
1116
+
1117
+ _md(r"""#### (c) Phase Alignment and (d) IPR & Norms
1118
+
1119
+ **(c) Phase alignment:** Average $|\sin(\mathcal{D}_m^\star)|$ over all neurons, where $\mathcal{D}_m^\star = 2\phi_m^\star - \psi_m^\star$ is the phase misalignment at each neuron's dominant frequency. This measures how far the network is from the ideal relationship $\psi = 2\phi$. It decreases throughout training as phases align, with the steepest drop during Stage II.
1120
+
1121
+ **(d) IPR and parameter norms:** IPR (Fourier sparsity) increases sharply in Stage II -- this is the "aha" moment where multi-frequency noise collapses into clean single-frequency features. Parameter norms shrink steadily in Stage III as weight decay slowly polishes the solution.""")
1122
+ with gr.Row():
1123
+ t6_phase_diff = gr.Image(
1124
+ label="Phase Difference |sin(D*)|", type="filepath"
1125
+ )
1126
+ t6_ipr = gr.Image(label="IPR & Parameter Norms", type="filepath")
1127
+
1128
+ _md(r"""#### (e) Memorization Accuracy Grid
1129
+
1130
+ Each cell $(i,j)$ in the grid shows whether the network correctly predicts $(i+j) \bmod p$ at a given training epoch. **White = correct, dark = incorrect.** Training pairs are marked with dots.
1131
+
1132
+ During Stage I, the network first memorizes **symmetric pairs** -- pairs where both $(i,j)$ and $(j,i)$ are in the training set (they appear on both sides of the diagonal). These are learned first because the architecture treats inputs symmetrically: $\theta_m[i] + \theta_m[j] = \theta_m[j] + \theta_m[i]$, so learning one automatically gives the other.
1133
+
1134
+ **Asymmetric pairs** (where only one of $(i,j)$ or $(j,i)$ is in training) are harder to memorize and are learned later. Some test pairs may even be *actively suppressed* (the network gets them wrong on purpose) before eventually being memorized.""")
1135
+ t6_memo = gr.Image(label="Memorization Accuracy", type="filepath")
1136
+
1137
+ _md(r"""#### (f) Common-to-Rare Ordering
1138
+
1139
+ This plot reorders the accuracy grid to reveal the **memorization sequence**. Instead of plotting by input value, it sorts pairs by how "common" they are in the training set:
1140
+
1141
+ - **Common pairs** (top-left): Both $(i,j)$ and $(j,i)$ in training set. These are memorized first.
1142
+ - **Rare pairs** (bottom-right): Only one ordering in training set. These are memorized last, and may be temporarily suppressed before being learned.
1143
+
1144
+ The plot shows a clear **top-left to bottom-right** progression, confirming that the network memorizes common pairs before rare ones.""")
1145
+ t6_memo_rare = gr.Image(label="Memorization: Common to Rare", type="filepath")
1146
+
1147
+ _md(r"""#### (g) Decoded Weights Across Stages
1148
+
1149
+ DFT heatmaps of the network's weights at key epochs through the three stages. Each row is a neuron; each column is a Fourier frequency component.
1150
+
1151
+ - **Stage I (memorization):** Weights are noisy with energy spread across many frequencies -- the network is using all available capacity to memorize.
1152
+ - **Stage II (generalization):** Weight decay kills the weak frequencies. Each neuron's energy concentrates into a single frequency -- clean Fourier features emerge.
1153
+ - **Stage III (cleanup):** Features are already clean; weight decay slowly shrinks overall magnitude without changing the structure.""")
1154
+ t6_decoded = gr.Image(label="Decoded Weights Across Stages", type="filepath")
1155
+
1156
+ _md(r"""#### Accuracy Grid Across Training (Interactive)
1157
+
1158
+ Use the slider to scrub through training epochs and watch the accuracy grid evolve. In Stage I, you'll see the symmetric pairs (along both diagonals) light up first, then asymmetric pairs fill in, and finally the entire grid becomes white in Stage II as the network generalizes.""")
1159
+ t6_slider = gr.Slider(
1160
+ minimum=0, maximum=0, value=0, step=1,
1161
+ label="Epoch Snapshot Index", interactive=True,
1162
+ visible=False,
1163
+ )
1164
+ t6_heatmap = gr.Plot(label="Accuracy Heatmap")
1165
+ t6_epoch_label = _md("")
1166
+
1167
+ # === Theory ===
1168
+
1169
+ # Tab 7: Gradient Dynamics
1170
+ with gr.Tab("7. Gradient Dynamics"):
1171
+ _md(MATH_TAB7)
1172
+ _md(r"""#### Quadratic Activation ($\sigma(x) = x^2$)
1173
+
1174
+ **Left -- Phase alignment:** Tracks the input phase $\phi_m^\star$, output phase $\psi_m^\star$, and doubled input phase $2\phi_m^\star$ of the dominant frequency in a single neuron over training. The theory predicts $\psi \to 2\phi$; here we see $\psi$ (red) and $2\phi$ (blue) converge and overlap, confirming phase alignment. The phases lock in early while magnitudes are still small.
1175
+
1176
+ **Right -- DFT heatmaps:** Decoded weights in Fourier space at key training steps. At step 0, the neuron starts with energy at a single frequency (by construction -- single-frequency initialization). At later steps, the dominant frequency grows while all other frequencies stay at zero. This confirms the **single-frequency preservation theorem**: Fourier orthogonality prevents energy leakage between modes.""")
1177
+ with gr.Row():
1178
+ t7_pa_quad = gr.Image(label="Phase Alignment (Quad)", type="filepath")
1179
+ t7_sf_quad = gr.Image(label="Decoded Weights (Quad)", type="filepath")
1180
+ _md(r"""#### ReLU Activation ($\sigma(x) = \max(0, x)$)
1181
+
1182
+ **Left -- Phase alignment:** Same as quadratic above, but with ReLU. The qualitative behavior is identical: $\psi$ converges to $2\phi$. Minor quantitative differences arise because ReLU is not exactly $x^2$.
1183
+
1184
+ **Right -- DFT heatmaps:** Unlike quadratic, ReLU leaks small amounts of energy to **harmonic multiples** of the dominant frequency ($3k^\star, 5k^\star, \ldots$ for input weights; $2k^\star, 3k^\star, \ldots$ for output weights). This leakage decays as $O(r^{-2})$ where $r$ is the harmonic order, so the dominant frequency remains overwhelmingly dominant. The faint "stripes" at harmonic positions are this leakage.""")
1185
+ with gr.Row():
1186
+ t7_pa_relu = gr.Image(label="Phase Alignment (ReLU)", type="filepath")
1187
+ t7_sf_relu = gr.Image(label="Decoded Weights (ReLU)", type="filepath")
1188
+
1189
+ # Tab 8: Decoupled Simulation
1190
+ with gr.Tab("8. Decoupled Simulation"):
1191
+ _md(MATH_TAB8)
1192
+ _md(r"""Each 3-panel figure below shows one simulation run. The gray curves are non-winning frequencies; the colored curves are the winning frequency $k^\star$.
1193
+
1194
+ **Top panel -- Phase alignment:** $\psi_{k^\star}$ (red) and $2\phi_{k^\star}$ (blue) converge toward each other, confirming that training drives phases into the $\psi = 2\phi$ relationship even in this pure ODE setting (no neural network).
1195
+
1196
+ **Middle panel -- Phase difference $D_{k^\star}$:** The misalignment $\mathcal{D}_{k^\star} = 2\phi_{k^\star} - \psi_{k^\star}$ converges toward $0$ (or $\pi/2$ transiently in Case 1). The dashed horizontal line marks $\pi/2$. Non-winning frequencies (gray) remain scattered because their magnitudes are too small to drive phase alignment.
1197
+
1198
+ **Bottom panel -- Magnitude evolution:** The winning frequency's magnitudes ($\alpha_{k^\star}$ and $\beta_{k^\star}$) grow explosively once phase alignment is achieved, while all other frequencies remain near their initial values. This is the lottery ticket effect in pure form.""")
1199
+ with gr.Row():
1200
+ t8_approx1 = gr.Image(
1201
+ label="Gradient Flow (Case 1: with annotations)", type="filepath"
1202
+ )
1203
+ t8_approx2 = gr.Image(label="Gradient Flow (Case 2)", type="filepath")
1204
+
1205
+ # Tab 9: Training Log
1206
+ with gr.Tab("9. Training Log"):
1207
+ _md(MATH_TAB9)
1208
+ t9_run_dd = gr.Dropdown(
1209
+ choices=[], value=None,
1210
+ label="Select Training Run", interactive=True,
1211
+ )
1212
+ t9_config_md = _md("")
1213
+ t9_log_text = gr.Code(
1214
+ value="", language=None, label="Training Log",
1215
+ lines=30, interactive=False,
1216
+ )
1217
+
1218
+ # All outputs for prime change
1219
+ all_outputs = [
1220
+ info_md,
1221
+ # Tab 1: Overview
1222
+ t1_img_overview, t1_std_loss, t1_grokk_loss,
1223
+ t1_std_ipr, t1_grokk_ipr, t1_phase_scatter,
1224
+ # Tab 2
1225
+ t2_heatmap, t2_line_in, t2_line_out,
1226
+ t2_neuron_dd, t2_neuron_chart,
1227
+ # Tab 3
1228
+ t3_phase_dist, t3_phase_rel, t3_magnitude,
1229
+ # Tab 4
1230
+ t4_logits,
1231
+ t4_pair_dd, t4_logit_chart,
1232
+ # Tab 5
1233
+ t5_mag, t5_phase, t5_contour,
1234
+ # Tab 6
1235
+ t6_loss, t6_acc, t6_phase_diff, t6_ipr,
1236
+ t6_memo, t6_memo_rare, t6_decoded,
1237
+ t6_slider, t6_heatmap, t6_epoch_label,
1238
+ # Tab 7
1239
+ t7_pa_quad, t7_sf_quad, t7_pa_relu, t7_sf_relu,
1240
+ # Tab 8
1241
+ t8_approx1, t8_approx2,
1242
+ # Tab 9
1243
+ t9_run_dd, t9_config_md, t9_log_text,
1244
+ ]
1245
+
1246
+ # --- Prime change handler ---
1247
+ def p_change_and_store(p_str):
1248
+ p = int(p_str)
1249
+ results = on_p_change(p_str)
1250
+ return [p] + results
1251
+
1252
+ p_dropdown.change(
1253
+ fn=p_change_and_store,
1254
+ inputs=[p_dropdown],
1255
+ outputs=[current_p] + all_outputs,
1256
+ )
1257
+
1258
+ # --- Neuron dropdown handler ---
1259
+ def on_neuron_change(neuron_key, p):
1260
+ data = load_json_file(p, "neuron_spectra.json")
1261
+ return make_neuron_spectrum_chart(data, neuron_key)
1262
+
1263
+ t2_neuron_dd.change(
1264
+ fn=on_neuron_change,
1265
+ inputs=[t2_neuron_dd, current_p],
1266
+ outputs=[t2_neuron_chart],
1267
+ )
1268
+
1269
+ # --- Logit pair dropdown handler ---
1270
+ def on_pair_change(pair_str, p):
1271
+ data = load_json_file(p, "logits_interactive.json")
1272
+ if data is None or not pair_str:
1273
+ return None
1274
+ pairs = data.get('pairs', [])
1275
+ pair_labels = [f"({a}, {b})" for a, b in pairs]
1276
+ if pair_str in pair_labels:
1277
+ idx = pair_labels.index(pair_str)
1278
+ else:
1279
+ idx = 0
1280
+ return make_logit_bar_chart(data, idx)
1281
+
1282
+ t4_pair_dd.change(
1283
+ fn=on_pair_change,
1284
+ inputs=[t4_pair_dd, current_p],
1285
+ outputs=[t4_logit_chart],
1286
+ )
1287
+
1288
+ # --- Grokking slider handler ---
1289
+ def on_grokk_slider(slider_val, p):
1290
+ data = load_json_file(p, "grokk_epoch_data.json")
1291
+ if data is None:
1292
+ return None, ""
1293
+ idx = int(slider_val)
1294
+ epochs = data.get('epochs', [])
1295
+ label = f"**Epoch: {epochs[idx]}**" if idx < len(epochs) else ""
1296
+ return make_grokk_heatmap(data, idx), label
1297
+
1298
+ t6_slider.change(
1299
+ fn=on_grokk_slider,
1300
+ inputs=[t6_slider, current_p],
1301
+ outputs=[t6_heatmap, t6_epoch_label],
1302
+ )
1303
+
1304
+ # --- Training log run dropdown handler ---
1305
+ def on_log_run_change(run_name, p):
1306
+ data = load_json_file(p, "training_log.json")
1307
+ if data is None or run_name not in data:
1308
+ return "", ""
1309
+ run_data = data[run_name]
1310
+ config = run_data.get('config', {})
1311
+ config_text = _format_config_md(run_name, config)
1312
+ log_text = run_data.get('log_text', 'No log available.')
1313
+ return config_text, log_text
1314
+
1315
+ t9_run_dd.change(
1316
+ fn=on_log_run_change,
1317
+ inputs=[t9_run_dd, current_p],
1318
+ outputs=[t9_config_md, t9_log_text],
1319
+ )
1320
+
1321
+ # --- On-demand training handler (streaming) ---
1322
+ def on_generate_click(new_p):
1323
+ if new_p is None:
1324
+ yield (
1325
+ gr.update(), gr.update(),
1326
+ "Enter a value for p.",
1327
+ gr.update(visible=False, value=""),
1328
+ )
1329
+ return
1330
+ p = int(new_p)
1331
+ log_lines = []
1332
+ yield (
1333
+ gr.update(), gr.update(),
1334
+ f"**Running pipeline for p={p}...**",
1335
+ gr.update(visible=True, value="Starting...\n"),
1336
+ )
1337
+ for line, is_err, is_done in run_pipeline_for_p_streaming(p):
1338
+ log_lines.append(line)
1339
+ # Keep last 200 lines to avoid memory bloat
1340
+ display = "\n".join(log_lines[-200:])
1341
+ if is_err:
1342
+ yield (
1343
+ gr.update(), gr.update(),
1344
+ f"**Error:** {line}",
1345
+ gr.update(value=display),
1346
+ )
1347
+ return
1348
+ if is_done:
1349
+ new_moduli = get_available_moduli()
1350
+ new_choices = [str(v) for v in new_moduli]
1351
+ yield (
1352
+ gr.update(choices=new_choices, value=str(p)),
1353
+ gr.update(),
1354
+ f"**Success:** {line}",
1355
+ gr.update(value=display),
1356
+ )
1357
+ return
1358
+ yield (
1359
+ gr.update(), gr.update(),
1360
+ f"**Running pipeline for p={p}...**",
1361
+ gr.update(value=display),
1362
+ )
1363
+
1364
+ generate_btn.click(
1365
+ fn=on_generate_click,
1366
+ inputs=[new_p_input],
1367
+ outputs=[p_dropdown, current_p, generate_status, generate_log],
1368
+ )
1369
+
1370
+ return app
1371
+
1372
+
1373
+ if __name__ == "__main__":
1374
+ app = create_app()
1375
+ app.launch(theme=gr.themes.Soft(), css=CUSTOM_CSS)
hf_app/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0
2
+ torch>=2.0.0
3
+ --extra-index-url https://download.pytorch.org/whl/cpu
4
+ numpy>=1.24
5
+ matplotlib>=3.7
6
+ Pillow>=9.0
7
+ plotly>=5.0
8
+ einops>=0.6
9
+ scipy>=1.10
precompute/README.md ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pre-computation Pipeline
2
+
3
+ Batch training and plot generation for all odd moduli $p$ in [3, 199]. Trains 5 model configurations per $p$ and generates publication-quality figures plus interactive JSON data files covering the paper's core results.
4
+
5
+ All commands are run from the **project root directory**.
6
+
7
+ ## Quick Start (Shell Script)
8
+
9
+ The easiest way to run the full pipeline for a single modulus:
10
+
11
+ ```bash
12
+ # Run the complete pipeline for p=23
13
+ bash precompute/run_pipeline.sh 23
14
+
15
+ # Or using an environment variable
16
+ P=23 bash precompute/run_pipeline.sh
17
+ ```
18
+
19
+ This runs training, plot generation, analytical simulation, and verification in sequence.
20
+
21
+ ## Complete Pipeline (Single Modulus, Manual Steps)
22
+
23
+ ```bash
24
+ # Step 1: Train all 5 model configurations
25
+ python precompute/train_all.py --p 23 --output ./trained_models
26
+
27
+ # Step 2: Generate model-based plots (21 PNGs + 6 JSONs + metadata)
28
+ python precompute/generate_plots.py --p 23 --input ./trained_models --output ./precomputed_results
29
+
30
+ # Step 3: Generate analytical simulation plots (2 PNGs, no model needed)
31
+ python precompute/generate_analytical.py --p 23 --output ./precomputed_results
32
+
33
+ # Step 4: Verify
34
+ ls precomputed_results/p_023/
35
+ ```
36
+
37
+ ## Complete Pipeline (All Odd p)
38
+
39
+ ```bash
40
+ # Train everything (225 runs total). Use --resume to skip completed runs.
41
+ python precompute/train_all.py --all --output ./trained_models --resume
42
+
43
+ # Generate all plots
44
+ python precompute/generate_plots.py --all --input ./trained_models --output ./precomputed_results
45
+
46
+ # Generate all analytical plots
47
+ python precompute/generate_analytical.py --all --output ./precomputed_results
48
+ ```
49
+
50
+ ---
51
+
52
+ ## The 5 Model Configurations
53
+
54
+ Each modulus is trained with 5 configurations that correspond to different sections of the paper:
55
+
56
+ ### 1. Standard Training (`standard`)
57
+
58
+ The baseline experiment for Parts I--II (Mechanism & Dynamics). Demonstrates Fourier feature learning: neurons decompose modular addition into sparse frequency components with phase alignment (ψ ≈ 2φ).
59
+
60
+ | Parameter | Value |
61
+ |-----------|-------|
62
+ | Activation | ReLU |
63
+ | Initialization | random |
64
+ | Optimizer | AdamW |
65
+ | Learning rate | 5e-5 |
66
+ | Weight decay | 0 |
67
+ | Train fraction | 1.0 (all p² pairs) |
68
+ | Epochs | 5,000 |
69
+ | Init scale | 0.1 |
70
+
71
+ **Used by:** Tab 1 (Overview), Tab 2 (Fourier Weights), Tab 3 (Phase Analysis), Tab 4 (Output Logits)
72
+
73
+ ### 2. Grokking (`grokking`)
74
+
75
+ Reproduces the grokking phenomenon (Part III). The model memorizes training data first, then suddenly generalizes. Requires partial training data + weight decay.
76
+
77
+ | Parameter | Value |
78
+ |-----------|-------|
79
+ | Activation | ReLU |
80
+ | Initialization | random |
81
+ | Optimizer | AdamW |
82
+ | Learning rate | 1e-4 |
83
+ | Weight decay | **2.0** |
84
+ | Train fraction | **0.75** |
85
+ | Epochs | **50,000** |
86
+ | Init scale | 0.1 |
87
+
88
+ **Used by:** Tab 1 (Overview, grokking curves), Tab 6 (Grokking)
89
+ **Note:** Only runs for p ≥ 19 (smaller $p$ have too few test points for meaningful grokking).
90
+
91
+ ### 3. Quadratic Activation (`quad_random`)
92
+
93
+ Uses σ(x) = x² activation. The quadratic nonlinearity directly implements the frequency factorization mechanism from the theory, enabling clean analysis of the lottery ticket mechanism.
94
+
95
+ | Parameter | Value |
96
+ |-----------|-------|
97
+ | Activation | **Quad** |
98
+ | Initialization | random |
99
+ | Optimizer | AdamW |
100
+ | Learning rate | 5e-5 |
101
+ | Weight decay | 0 |
102
+ | Train fraction | 1.0 |
103
+ | Epochs | 5,000 |
104
+ | Init scale | 0.1 |
105
+
106
+ **Used by:** Tab 5 (Lottery Mechanism)
107
+
108
+ ### 4. Single-Frequency Quad (`quad_single_freq`)
109
+
110
+ Initializes neurons at specific frequencies to study gradient dynamics under controlled conditions. Validates the phase alignment theorem and single-frequency preservation theorem.
111
+
112
+ | Parameter | Value |
113
+ |-----------|-------|
114
+ | Activation | **Quad** |
115
+ | Initialization | **single-freq** |
116
+ | Optimizer | **SGD** |
117
+ | Learning rate | **0.1** |
118
+ | Weight decay | 0 |
119
+ | Train fraction | 1.0 |
120
+ | Epochs | 5,000 |
121
+ | Init scale | **0.02** |
122
+
123
+ **Used by:** Tab 7 (Gradient Dynamics, quadratic panels)
124
+
125
+ ### 5. Single-Frequency ReLU (`relu_single_freq`)
126
+
127
+ Same as above but with ReLU activation. Shows that the theoretical results (proved for quadratic) hold approximately for ReLU with small harmonic leakage.
128
+
129
+ | Parameter | Value |
130
+ |-----------|-------|
131
+ | Activation | **ReLU** |
132
+ | Initialization | **single-freq** |
133
+ | Optimizer | **SGD** |
134
+ | Learning rate | **0.01** |
135
+ | Weight decay | 0 |
136
+ | Train fraction | 1.0 |
137
+ | Epochs | 5,000 |
138
+ | Init scale | **0.002** |
139
+
140
+ **Used by:** Tab 7 (Gradient Dynamics, ReLU panels)
141
+
142
+ ---
143
+
144
+ ## Neuron Sizing
145
+
146
+ The number of hidden neurons scales with $p$ to maintain the ratio from the baseline experiment ($p=23$, $d_\text{mlp}=512$):
147
+
148
+ ```
149
+ d_mlp = max(512, ceil(512/529 * p²))
150
+ ```
151
+
152
+ Examples: $p=3 \to 512$, $p=23 \to 512$, $p=53 \to 2720$, $p=97 \to 9108$, $p=199 \to 38329$.
153
+
154
+ ---
155
+
156
+ ## Blog Figure → Pipeline Output Mapping
157
+
158
+ The table below maps every figure in the blog post to the corresponding file generated by the pipeline. Each figure is reproduced for every $p$, allowing users to verify the paper's claims across different moduli.
159
+
160
+ ### Part I: Mechanism (Tabs 2--4, standard run)
161
+
162
+ | Blog Figure | Description | Pipeline Output | Tab | Verified? |
163
+ |-------------|-------------|----------------|-----|-----------|
164
+ | **Fig. 2** — Fourier sparsity of learned weights | DFT heatmap: each row is a neuron, each column is a Fourier mode (cos k, sin k). Sparse = one bright cell per row, confirming single-frequency specialization. | `pXXX_full_training_para_origin.png` | 2 | The heatmap applies `W_in @ fourier_basis.T` and `W_out.T @ fourier_basis.T` to show DFT coefficients. X-axis labels are Fourier mode names (Const, cos 1, sin 1, ...). Sparsity is visible as one dominant pair per neuron row. |
165
+ | **Fig. 3** — Cosine fits to individual neurons | Raw learned weight values (dots) vs. best-fit cosine (dashed) for 3 representative neurons. Left: input weights θ_m. Right: output weights ξ_m. | `pXXX_lineplot_in.png`, `pXXX_lineplot_out.png` | 2 | Projects raw weights into Fourier space, keeps top-2 components, projects back. The fit quality demonstrates that each neuron is well-described by a single cosine. |
166
+ | **Fig. 4** — Phase alignment ψ = 2φ | Scatter plot of (2φ_m mod 2π) vs (ψ_m mod 2π). All points lie on the diagonal y = x. | `pXXX_phase_relationship.png` | 3 | Computed via `compute_neuron()` for every neuron. The diagonal pattern is Observation 2 from the paper. |
167
+ | **Fig. 5** — Higher-order phase symmetry | Polar plot: phase angles ι·φ_m on concentric rings for ι = 1, 2, 3, 4. Uniform spread confirms the cancellation condition Σ exp(i·ι·φ_m) ≈ 0. | `pXXX_phase_distribution.png` | 3 | Shows phases for the most common frequency group. For large p with many neurons, the uniform spread is clearly visible. |
168
+ | **Fig. 6** — Magnitude homogeneity | Violin plots of α_m (input) and β_m (output) across all neurons. Tight concentration confirms magnitude homogeneity (Observation 3c). | `pXXX_magnitude_distribution.png` | 3 | Uses `compute_neuron()` to extract scale for every neuron. |
169
+ | **Fig. 7** — Output logits (flawed indicator) | Heatmap of f(x,y)[j] for x=0. Bright red diagonal at j=(x+y) mod p (correct answer, coefficient p/2). Faint pink at j=2x mod p and j=2y mod p (spurious peaks, coefficient p/4). | `pXXX_output_logits.png` | 4 | Forward pass through the trained model with **matching activation** (ReLU for standard run). Rectangles highlight the correct answer and spurious peak positions. |
170
+
171
+ ### Part II: Dynamics (Tabs 5, 7, 8)
172
+
173
+ | Blog Figure | Description | Pipeline Output | Tab | Verified? |
174
+ |-------------|-------------|----------------|-----|-----------|
175
+ | **Fig. 8** — Phase alignment dynamics | Phase trajectories (φ, ψ, 2φ) and magnitude growth (α, β) over training. Left: Quad activation. Right: ReLU. Shows ψ → 2φ convergence. | `pXXX_phase_align_quad.png`, `pXXX_phase_align_relu.png` | 7 | Tracks the neuron with largest final scale across all checkpoints. Shows phases converging and magnitudes growing. |
176
+ | **Fig. 9** — Lottery ticket race | Left: phase misalignment D_m^k(t) for all frequencies within one neuron. The winner (smallest initial D) converges first. Right: magnitude β_m^k(t). Winner grows explosively. | `pXXX_lottery_mech_phase.png`, `pXXX_lottery_mech_magnitude.png` | 5 | Tracks all frequency components of a single neuron via `decode_scales_phis()` across checkpoints from the `quad_random` run. The winning frequency (highlighted in red) has the smallest initial misalignment. |
177
+ | **Fig. 10** — Lottery outcome contour | Final magnitude β as a function of (initial magnitude, initial phase difference 2φ₀). Largest values at small D, symmetric about D = π. | `pXXX_lottery_beta_contour.png` | 5 | Simulates gradient flow on a 30×30 grid of initial conditions. Each point runs 100 steps of the analytical ODE. |
178
+ | **Fig. 11** — Single-frequency preservation (Quad) | DFT heatmap at multiple training timepoints. The initialized frequency retains all energy; no cross-frequency leakage. | `pXXX_single_freq_quad.png` | 7 | Shows DFT of weights at 3 timepoints (step 0, mid, final) for the `quad_single_freq` run. Each column is a Fourier mode; sparsity confirms preservation. |
179
+ | **Fig. 12a** — Single-frequency preservation (ReLU) | Same as Fig. 11 but with ReLU. Small harmonic leakage visible at 3k*, 5k* (input) and 2k*, 3k* (output), decaying as O(r⁻²). | `pXXX_single_freq_relu.png` | 7 | Shows DFT at 2 timepoints (step 0, final) for the `relu_single_freq` run. Dominant frequency overwhelms harmonics. |
180
+ | **Fig. 12b** — Phase alignment under ReLU | Phase and magnitude trajectories for ReLU single-frequency init. Same zero-attractor behavior as Quad. | `pXXX_phase_align_relu.png` | 7 | Same as Fig. 8 right panel. |
181
+ | — Decoupled ODE simulation | Pure ODE integration (no neural network) showing phase convergence and magnitude competition for all frequencies within one neuron. Two cases with different initial conditions. | `pXXX_phase_align_approx1.png`, `pXXX_phase_align_approx2.png` | 8 | Generated by `generate_analytical.py`, not `generate_plots.py`. Validates the theory without any training. |
182
+
183
+ ### Part III: Grokking (Tab 6)
184
+
185
+ | Blog Figure | Description | Pipeline Output | Tab | Verified? |
186
+ |-------------|-------------|----------------|-----|-----------|
187
+ | **Fig. 13a** — Grokking loss curves | Training and test loss over 50k epochs. Three stages: (I) train loss drops, (II) test loss drops, (III) both near zero. Stage boundaries marked. | `pXXX_grokk_loss.json` → interactive Plotly chart | 6 | From `training_curves.json`. Stage boundaries detected by `grokking_stage_detector.py`. Shaded regions distinguish the three stages. |
188
+ | **Fig. 13b** — Grokking accuracy curves | Training and test accuracy. Train → 100% in Stage I, test jumps in Stage II. | `pXXX_grokk_acc.json` → interactive Plotly chart | 6 | Computed by running forward pass on train/test data at each checkpoint. |
189
+ | **Fig. 13c** — Phase alignment progress | Average |sin(D_m*)| over training. Decreases throughout, steepest in Stage II. | `pXXX_grokk_abs_phase_diff.png` | 6 | Computed via `decode_weights()` + `compute_neuron()` at each grokking checkpoint. |
190
+ | **Fig. 13d** — IPR and parameter norm | Dual-axis: IPR (Fourier sparsity) increases sharply in Stage II; parameter norm shrinks in Stage III. | `pXXX_grokk_avg_ipr.png` | 6 | IPR uses the corrected per-frequency magnitude formula: A_k = sqrt(c_k² + s_k²), IPR = Σ A_k⁴ / (Σ A_k²)². Parameter norms from `training_curves.json`. |
191
+ | **Fig. 14** — Memorization accuracy heatmap | Three panels at end of Stage I: (1) training data distribution under symmetry, (2) accuracy grid, (3) softmax probability at ground truth. Red rectangles = true test pairs. | `pXXX_grokk_memorization_accuracy.png` | 6 | Forward pass at the checkpoint closest to stage1_end. The symmetric architecture guarantees ~70% test accuracy during memorization. |
192
+ | **Fig. 15** — Common-to-rare memorization | Four panels: training data distribution + accuracy at 3 timepoints during Stage I. Shows common pairs (both (i,j) and (j,i) in train) memorized before rare pairs (only one ordering). | `pXXX_grokk_memorization_common_to_rare.png` | 6 | Epochs selected at 0, stage1_end/2, stage1_end. Red rectangles mark asymmetric training pairs. |
193
+ | **Fig. 16** — Weight evolution during grokking | 2×3 DFT heatmap grid showing θ_m and ξ_m at Step 0 (random init), end of Stage I (noisy multi-frequency), and end of Stage II (clean single-frequency). | `pXXX_grokk_decoded_weights_dynamic.png` | 6 | DFT coefficients `W @ fourier_basis.T` at 3 key epochs. The transition from diffuse to sparse confirms the sparsification narrative. |
194
+
195
+ ### Overview (Tab 1)
196
+
197
+ | Blog Figure | Description | Pipeline Output | Tab | Verified? |
198
+ |-------------|-------------|----------------|-----|-----------|
199
+ | — Overview dashboard | 2×2 grid: standard loss + grokking loss (top), standard IPR + grokking IPR (bottom). Plus phase scatter from standard final checkpoint. | `pXXX_overview_loss_ipr.png`, `pXXX_overview_phase_scatter.png`, `pXXX_overview.json` | 1 | Combines data from standard and grokking runs. The phase scatter uses the same computation as Fig. 4. For $p < 19$, only the standard column is shown. |
200
+
201
+ ### Not Currently Generated (Blog-Only Figures)
202
+
203
+ | Blog Figure | Description | Status |
204
+ |-------------|-------------|--------|
205
+ | **Fig. 1** — Architecture illustration | Schematic of the two-layer network, DFT decomposition, and mechanism overview. | Static illustration, not dependent on p. Could be included as a fixed image in the app. |
206
+
207
+ ---
208
+
209
+ ## Interactive JSON Data Files
210
+
211
+ In addition to static PNG plots, the pipeline generates JSON files for interactive Plotly charts in the Gradio app:
212
+
213
+ | File | Content | Used In |
214
+ |------|---------|---------|
215
+ | `pXXX_overview.json` | Standard loss/IPR + grokking loss/IPR time series | Tab 1: Interactive loss and IPR charts |
216
+ | `pXXX_neuron_spectra.json` | Per-neuron Fourier magnitudes (W_in and W_out) for top-20 neurons, sorted by frequency | Tab 2: Neuron Inspector dropdown → bar chart of frequency decomposition |
217
+ | `pXXX_logits_interactive.json` | Output logits for p representative (a,b) pairs, plus correct answers | Tab 4: Logit Explorer dropdown → bar chart with correct answer highlighted |
218
+ | `pXXX_grokk_loss.json` | Full training/test loss curves + stage boundaries | Tab 6: Interactive loss chart with stage shading |
219
+ | `pXXX_grokk_acc.json` | Accuracy at each checkpoint epoch + stage boundaries | Tab 6: Interactive accuracy chart with stage shading |
220
+ | `pXXX_grokk_epoch_data.json` | p×p accuracy grids at ~10 evenly-spaced grokking epochs | Tab 6: Epoch Slider → heatmap animation across training |
221
+ | `pXXX_metadata.json` | Config for all 5 runs + final metrics (loss, accuracy) | Displayed in the app's info panel for the selected $p$ |
222
+
223
+ ---
224
+
225
+ ## Output Structure
226
+
227
+ All plots for a modulus are saved in a single flat directory. Each file is prefixed with `pXXX_` so the folder is self-contained and browsable:
228
+
229
+ ```
230
+ precomputed_results/p_023/
231
+ # Metadata
232
+ p023_metadata.json
233
+
234
+ # Tab 1: Overview (Blog: summary of standard + grokking)
235
+ p023_overview_loss_ipr.png # 2×2 grid: loss + IPR for both setups
236
+ p023_overview_phase_scatter.png # Phase alignment scatter (same as Fig. 4)
237
+ p023_overview.json # Interactive data
238
+
239
+ # Tab 2: Fourier Weights (Blog: Figures 2, 3)
240
+ p023_full_training_para_origin.png # DFT heatmap (Fig. 2)
241
+ p023_lineplot_in.png # Cosine fits, input layer (Fig. 3 left)
242
+ p023_lineplot_out.png # Cosine fits, output layer (Fig. 3 right)
243
+ p023_neuron_spectra.json # Interactive: neuron inspector
244
+
245
+ # Tab 3: Phase Analysis (Blog: Figures 4, 5, 6)
246
+ p023_phase_distribution.png # Polar phase plot (Fig. 5)
247
+ p023_phase_relationship.png # 2φ vs ψ scatter (Fig. 4)
248
+ p023_magnitude_distribution.png # Violin plots (Fig. 6)
249
+
250
+ # Tab 4: Output Logits (Blog: Figure 7)
251
+ p023_output_logits.png # Logit heatmap (Fig. 7)
252
+ p023_logits_interactive.json # Interactive: logit explorer
253
+
254
+ # Tab 5: Lottery Mechanism (Blog: Figures 9, 10)
255
+ p023_lottery_mech_magnitude.png # Magnitude race (Fig. 9 right)
256
+ p023_lottery_mech_phase.png # Phase misalignment race (Fig. 9 left)
257
+ p023_lottery_beta_contour.png # Contour plot (Fig. 10)
258
+
259
+ # Tab 6: Grokking (Blog: Figures 13, 14, 15, 16)
260
+ p023_grokk_loss.json # Interactive loss curves (Fig. 13a)
261
+ p023_grokk_acc.json # Interactive accuracy curves (Fig. 13b)
262
+ p023_grokk_abs_phase_diff.png # Phase alignment progress (Fig. 13c)
263
+ p023_grokk_avg_ipr.png # IPR + param norms (Fig. 13d)
264
+ p023_grokk_memorization_accuracy.png # 3-panel heatmap (Fig. 14)
265
+ p023_grokk_memorization_common_to_rare.png # 4-panel sequence (Fig. 15)
266
+ p023_grokk_decoded_weights_dynamic.png # DFT evolution (Fig. 16)
267
+ p023_grokk_epoch_data.json # Interactive: epoch slider
268
+
269
+ # Tab 7: Gradient Dynamics (Blog: Figures 8, 11, 12)
270
+ p023_phase_align_quad.png # Phase + magnitude, Quad (Fig. 8 left)
271
+ p023_single_freq_quad.png # DFT heatmap over time, Quad (Fig. 11)
272
+ p023_phase_align_relu.png # Phase + magnitude, ReLU (Fig. 8 right / 12b)
273
+ p023_single_freq_relu.png # DFT heatmap over time, ReLU (Fig. 12a)
274
+
275
+ # Tab 8: Decoupled Simulation (no blog figure number)
276
+ p023_phase_align_approx1.png # ODE simulation case 1
277
+ p023_phase_align_approx2.png # ODE simulation case 2
278
+ ```
279
+
280
+ **29 files per $p$:** 21 PNGs + 6 JSONs from trained models, 2 PNGs from analytical simulation.
281
+
282
+ ---
283
+
284
+ ## Correctness Verification
285
+
286
+ ### How each computation matches the paper
287
+
288
+ 1. **DFT Decomposition (Figs. 2, 11, 12a, 16):** We compute `W @ fourier_basis.T` where `fourier_basis` is the orthonormal DFT basis from `get_fourier_basis(p)`. The basis has rows: [Const, cos 1, sin 1, cos 2, sin 2, ..., cos K, sin K] with K = (p-1)/2 for odd $p$. Each row is L2-normalized. This matches the standard real DFT on Z_p.
289
+
290
+ 2. **Phase extraction (Figs. 4, 8, 9, 13c):** For frequency k, the DFT gives coefficients (c_k, s_k) at indices (2k-1, 2k). The magnitude is α = sqrt(2/p) · sqrt(c_k² + s_k²), and the phase is φ = arctan2(-s_k, c_k). This convention matches the paper's θ_m[j] = α cos(ω_k j + φ) representation.
291
+
292
+ 3. **IPR (Figs. 13d, Overview):** Uses the corrected per-frequency magnitude formula: A_k = sqrt(c_k² + s_k²) (combining cos/sin pairs), then IPR = Σ A_k⁴ / (Σ A_k²)². This gives IPR → 1 for perfect single-frequency neurons, matching the paper's definition.
293
+
294
+ 4. **Phase alignment (Fig. 4):** The doubled-phase relationship ψ_m = 2φ_m is verified by extracting φ from W_in and ψ from W_out using the same `compute_neuron()` function, then plotting (2φ mod 2π) vs (ψ mod 2π).
295
+
296
+ 5. **Output logits (Fig. 7):** Forward pass uses the **same activation function** as training (ReLU for standard run). The flawed indicator structure (main diagonal + two ghost diagonals) is visible because the standard run trains to 100% accuracy with clean Fourier features.
297
+
298
+ 6. **Lottery mechanism (Figs. 9, 10):** Uses the `quad_random` run (quadratic activation, random init) which matches the theoretical setting. `decode_scales_phis()` extracts per-frequency magnitudes and phases at each checkpoint. The winning frequency is the one with smallest initial |D| = |2φ - ψ|.
299
+
300
+ 7. **Grokking stages (Figs. 13--16):** `grokking_stage_detector.py` identifies stage boundaries from training curves. Stage I ends when train accuracy ≈ 1.0, Stage II ends when test accuracy ≈ 1.0. Memorization heatmaps use forward pass at the closest checkpoint to stage1_end.
301
+
302
+ 8. **Analytical simulation (Tab 8):** Numerically integrates the four-variable ODE system from Section 5.3 of the paper. No neural network is involved — this validates the theoretical dynamics directly.
303
+
304
+ ### Why results generalize across $p$
305
+
306
+ The paper's theory is stated for general odd $p$. Key properties that scale:
307
+
308
+ - **Fourier basis:** Always has (p-1)/2 non-DC frequencies for any odd $p$.
309
+ - **Phase alignment:** The ψ = 2φ relationship is a consequence of the gradient dynamics, independent of p.
310
+ - **Lottery mechanism:** Random initial misalignments are uniform on [0, 2π) for any p.
311
+ - **Grokking three stages:** The stage structure depends on the balance of loss gradient vs. weight decay, not on p specifically (though the stage durations and test accuracy during memorization may vary).
312
+ - **Network width:** d_mlp scales as O(p²) to maintain the neuron-to-frequency ratio, ensuring enough neurons per frequency for diversification.
313
+
314
+ ---
315
+
316
+ ## Scripts
317
+
318
+ | Script | Purpose |
319
+ |--------|---------|
320
+ | `run_pipeline.sh` | Runs the complete pipeline (train + plots + analytical + verify) for a single modulus. |
321
+ | `train_all.py` | Trains all 5 model configurations. Saves checkpoints + `training_curves.json`. |
322
+ | `generate_plots.py` | Loads trained models and generates all model-dependent plots (Tabs 1--7) plus interactive JSONs and metadata. |
323
+ | `generate_analytical.py` | Runs gradient flow simulations to generate theory plots (Tab 8). No model needed. |
324
+ | `prime_config.py` | Configuration: moduli list, d_mlp formula, training run parameters. |
325
+ | `neuron_selector.py` | Automated neuron selection for plots (replaces hardcoded indices from notebooks). |
326
+ | `grokking_stage_detector.py` | Detects memorization/transition/generalization stage boundaries from training curves. |
327
+
328
+ ---
329
+
330
+ ## Analytical Simulations (No Model Needed)
331
+
332
+ `generate_analytical.py` produces 2 plots per $p$ by simulating gradient flow on decoupled frequency dynamics. These validate the theoretical analysis without training any model.
333
+
334
+ - **Case 1**: Shows phase difference D* converging from initial conditions (φ₀=1.5, ψ₀=0.18)
335
+ - **Case 2**: Different initial conditions (φ₀=-0.72, ψ₀=-2.91) showing convergence from the other side
336
+
337
+ Both cases confirm the phase alignment theorem: D → 0 is the stable attractor, D → π is unstable.
precompute/__init__.py ADDED
File without changes
precompute/generate_analytical.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate "Decoupled Simulation" plots -- analytical gradient flow simulations
4
+ that don't require trained models.
5
+
6
+ Produces 2 plots per p, saved to {output_dir}/p_{p:03d}/:
7
+ 1. p{p:03d}_phase_align_approx1.png -- case 1: longer simulation with annotations
8
+ 2. p{p:03d}_phase_align_approx2.png -- case 2: shorter simulation
9
+
10
+ Usage:
11
+ python generate_analytical.py --all
12
+ python generate_analytical.py --p 23
13
+ python generate_analytical.py --p 23 --output ./my_output
14
+ """
15
+ import argparse
16
+ import os
17
+ import sys
18
+
19
+ import numpy as np
20
+ import torch
21
+ import matplotlib
22
+ matplotlib.use('Agg')
23
+ import matplotlib.pyplot as plt
24
+
25
+ # Add project root to path so we can import src modules
26
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
27
+ from mechanism_base import get_fourier_basis, normalize_to_pi
28
+ from prime_config import get_moduli, ANALYTICAL_CONFIGS
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Style constants
32
+ # ---------------------------------------------------------------------------
33
+ COLORS = ['#0D2758', '#60656F', '#DEA54B', '#A32015', '#347186']
34
+ DPI = 150
35
+
36
+
37
+ # ===========================================================================
38
+ # Decouple dynamics simulation
39
+ # ===========================================================================
40
+
41
+ def gradient_update(theta, xi, p, device):
42
+ """
43
+ Compute the sum of gradients over all frequency modes k.
44
+
45
+ For each frequency k from 1 to (p-1)//2, project theta and xi onto the
46
+ Fourier basis to obtain 2-coefficient vectors, then compute alpha, phi,
47
+ beta, psi and the corresponding gradient contributions.
48
+ """
49
+ fourier_basis, _ = get_fourier_basis(p, device)
50
+ fourier_basis = fourier_basis.to(theta.dtype)
51
+ theta_coeff = fourier_basis @ theta
52
+ xi_coeff = fourier_basis @ xi
53
+
54
+ total_grad_theta = torch.zeros_like(theta)
55
+ total_grad_xi = torch.zeros_like(xi)
56
+
57
+ j_values = torch.arange(p, device=device, dtype=theta.dtype)
58
+ factor = np.sqrt(2.0 / p)
59
+
60
+ for k in range(1, p // 2 + 1):
61
+ coeff_indices = [k * 2 - 1, k * 2]
62
+ neuron_coeff_theta = theta_coeff[coeff_indices]
63
+ neuron_coeff_xi = xi_coeff[coeff_indices]
64
+
65
+ alpha = factor * torch.norm(neuron_coeff_theta, dim=0)
66
+ phi = torch.arctan2(-neuron_coeff_theta[1], neuron_coeff_theta[0])
67
+
68
+ beta = factor * torch.norm(neuron_coeff_xi, dim=0)
69
+ psi = torch.arctan2(-neuron_coeff_xi[1], neuron_coeff_xi[0])
70
+
71
+ w_k = 2 * np.pi * k / p
72
+ grad_theta_k = 2 * p * alpha * beta * torch.cos(w_k * j_values + psi - phi)
73
+ grad_xi_k = p * alpha.pow(2) * torch.cos(w_k * j_values + 2 * phi)
74
+
75
+ total_grad_theta += grad_theta_k / p ** 2
76
+ total_grad_xi += grad_xi_k / p ** 2
77
+
78
+ return total_grad_theta, total_grad_xi
79
+
80
+
81
+ def simulate_gradient_flow(theta, xi, p, num_steps, learning_rate, device):
82
+ """Euler integration of the coupled gradient-flow ODEs."""
83
+ theta_history = [theta.clone()]
84
+ xi_history = [xi.clone()]
85
+
86
+ for _ in range(num_steps):
87
+ grad_theta, grad_xi = gradient_update(theta, xi, p, device)
88
+ theta = theta + learning_rate * grad_theta
89
+ xi = xi + learning_rate * grad_xi
90
+ theta_history.append(theta.clone())
91
+ xi_history.append(xi.clone())
92
+
93
+ return theta_history, xi_history
94
+
95
+
96
+ def analyze_history(theta_history, xi_history, p, fourier_basis):
97
+ """
98
+ Extract time series of alpha, phi, beta, psi, delta for every frequency k.
99
+ """
100
+ theta_hist_tensor = torch.stack(theta_history)
101
+ xi_hist_tensor = torch.stack(xi_history)
102
+
103
+ theta_coeffs_hist = fourier_basis @ theta_hist_tensor.T
104
+ xi_coeffs_hist = fourier_basis @ xi_hist_tensor.T
105
+
106
+ results = {
107
+ 'alphas': {}, 'phis': {}, 'betas': {}, 'psis': {}, 'deltas': {}
108
+ }
109
+ factor = np.sqrt(2.0 / p)
110
+
111
+ for k in range(1, p // 2 + 1):
112
+ idx = [k * 2 - 1, k * 2]
113
+ neuron_theta_hist = theta_coeffs_hist[idx, :]
114
+ neuron_xi_hist = xi_coeffs_hist[idx, :]
115
+
116
+ alphas_k = factor * torch.norm(neuron_theta_hist, dim=0)
117
+ phis_k = torch.atan2(-neuron_theta_hist[1, :], neuron_theta_hist[0, :])
118
+
119
+ betas_k = factor * torch.norm(neuron_xi_hist, dim=0)
120
+ psis_k = torch.atan2(-neuron_xi_hist[1, :], neuron_xi_hist[0, :])
121
+
122
+ deltas_k = normalize_to_pi(2 * phis_k - psis_k)
123
+
124
+ results['alphas'][k] = alphas_k.numpy()
125
+ results['phis'][k] = phis_k.numpy()
126
+ results['betas'][k] = betas_k.numpy()
127
+ results['psis'][k] = psis_k.numpy()
128
+ results['deltas'][k] = deltas_k.numpy()
129
+
130
+ return results
131
+
132
+
133
+ def _run_decouple_simulation(p, init_k, num_steps, lr, init_phi, init_psi,
134
+ amplitude, device):
135
+ """Initialize and run a single decouple-dynamics simulation."""
136
+ fourier_basis, _ = get_fourier_basis(p, device)
137
+ fourier_basis = fourier_basis.to(torch.float64)
138
+ w_k = 2 * np.pi * init_k / p
139
+
140
+ theta_init = amplitude * torch.tensor(
141
+ [np.cos(w_k * j + init_phi) for j in range(p)],
142
+ device=device, dtype=torch.float64,
143
+ )
144
+ xi_init = amplitude * torch.tensor(
145
+ [np.cos(w_k * j + init_psi) for j in range(p)],
146
+ device=device, dtype=torch.float64,
147
+ )
148
+
149
+ theta_history, xi_history = simulate_gradient_flow(
150
+ theta_init, xi_init, p, num_steps, lr, device,
151
+ )
152
+ results = analyze_history(theta_history, xi_history, p, fourier_basis)
153
+ return results
154
+
155
+
156
+ def _plot_decouple(results, p, num_steps, lr, init_k, save_path,
157
+ show_vline=True, vline_x=500):
158
+ """
159
+ Publication-quality 3-panel figure:
160
+ Top: psi_k* and 2*phi_k* vs time
161
+ Middle: D_k* (phase difference) vs time, horizontal line at pi/2
162
+ Bottom: alpha_k* and beta_k* vs time
163
+ """
164
+ plt.rcParams['mathtext.fontset'] = 'cm'
165
+
166
+ alphas = np.array(results['alphas'][init_k])
167
+ betas = np.array(results['betas'][init_k])
168
+ deltas = np.array(results['deltas'][init_k])
169
+ phis = np.array(results['phis'][init_k])
170
+ psis = np.array(results['psis'][init_k])
171
+
172
+ # Phase wrapping fix: normalize 2*phi to [-pi,pi], adjust psi to
173
+ # stay within pi of 2*phi, then unwrap the time series so there
174
+ # are no discontinuous jumps at +-pi boundaries.
175
+ def _fix_phase_pair(two_phi_raw, psi_raw):
176
+ two_phi = normalize_to_pi(two_phi_raw)
177
+ psi_fixed = normalize_to_pi(psi_raw).copy()
178
+ diff = psi_fixed - two_phi
179
+ psi_fixed[diff > np.pi] -= 2 * np.pi
180
+ psi_fixed[diff < -np.pi] += 2 * np.pi
181
+ return np.unwrap(two_phi), np.unwrap(psi_fixed)
182
+
183
+ phis2_plot, psis_plot = _fix_phase_pair(2 * phis, psis)
184
+
185
+ x = np.arange(num_steps + 1) * lr
186
+ vline_kwargs = dict(color='gray', linestyle='--', linewidth=1.5)
187
+
188
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(7, 9), sharex=True)
189
+ fig.suptitle(f'Decoupled Gradient Flow (p={p})', fontsize=20, y=1.01)
190
+
191
+ # --- Top: phase alignment ---
192
+ for k in range(1, (p - 1) // 2 + 1):
193
+ if k != init_k:
194
+ bg_2phi, bg_psi = _fix_phase_pair(
195
+ 2 * np.array(results['phis'][k]),
196
+ np.array(results['psis'][k]),
197
+ )
198
+ ax1.plot(x, bg_psi, lw=1.5, alpha=0.4, color='gray')
199
+ ax1.plot(x, bg_2phi, lw=1.5, alpha=0.4, color='gray',
200
+ linestyle='--')
201
+ ax1.plot(x, psis_plot, color=COLORS[3], linewidth=2.5,
202
+ label=r"$\psi_{k^\star}$")
203
+ ax1.plot(x, phis2_plot, linewidth=2.5, color=COLORS[0],
204
+ label=r"$2\phi_{k^\star}$")
205
+ if show_vline:
206
+ ax1.axvline(x=vline_x, **vline_kwargs)
207
+ ax1.set_title('Dynamics of Phase Alignment', fontsize=18)
208
+ ax1.set_ylabel('Phase (radians)', fontsize=14)
209
+ ax1.legend(fontsize=18)
210
+ ax1.grid(True)
211
+
212
+ # --- Middle: phase difference ---
213
+ for k in range(1, (p - 1) // 2 + 1):
214
+ if k != init_k:
215
+ ax2.plot(x, np.array(results['deltas'][k]),
216
+ lw=1.5, alpha=0.4, color='gray')
217
+ ax2.plot(x, deltas, color=COLORS[0], linewidth=2.5,
218
+ label=r"$D_{k^\star}$")
219
+ if show_vline:
220
+ ax2.axvline(x=vline_x, **vline_kwargs)
221
+ ax2.axhline(y=np.pi / 2, **vline_kwargs)
222
+ ax2.text(x=max(x) * 0.05, y=np.pi / 2 - 0.45,
223
+ s=r"$D^\star_{k^\star}=\pi/2$", fontsize=16, color='black')
224
+ ax2.set_title('Dynamics of Phase Difference', fontsize=18)
225
+ ax2.set_ylabel('Phase (radians)', fontsize=14)
226
+ ax2.legend(fontsize=18)
227
+ ax2.grid(True)
228
+
229
+ # --- Bottom: magnitude evolution ---
230
+ for k in range(1, (p - 1) // 2 + 1):
231
+ if k != init_k:
232
+ ax3.plot(x, np.array(results['alphas'][k]),
233
+ lw=1.5, alpha=0.4, color='gray')
234
+ ax3.plot(x, np.array(results['betas'][k]),
235
+ lw=1.5, alpha=0.4, color='gray', linestyle='--')
236
+ ax3.plot(x, alphas, linewidth=2.5, color=COLORS[0],
237
+ label=r"$\alpha_{k^\star}$")
238
+ ax3.plot(x, betas, linewidth=2.5, color=COLORS[3],
239
+ label=r"$\beta_{k^\star}$")
240
+ if show_vline:
241
+ ax3.axvline(x=vline_x, **vline_kwargs)
242
+ ax3.set_title('Magnitude Evolution', fontsize=18)
243
+ ax3.set_xlabel('Time', fontsize=18)
244
+ ax3.set_ylabel('Magnitude', fontsize=14)
245
+ ax3.legend(fontsize=18)
246
+ ax3.grid(True)
247
+
248
+ plt.tight_layout()
249
+ plt.savefig(save_path, dpi=DPI, bbox_inches='tight')
250
+ plt.close(fig)
251
+ print(f" Saved {save_path}")
252
+
253
+
254
+ def generate_decouple_dynamics(p, output_dir):
255
+ """Generate the two decouple-dynamics phase-alignment plots."""
256
+ max_freq = (p - 1) // 2
257
+ if max_freq < 1:
258
+ print(f" SKIP: p={p} has no non-DC frequencies for analytical simulation")
259
+ return
260
+
261
+ cfg = ANALYTICAL_CONFIGS["decouple_dynamics"]
262
+ device = torch.device("cpu")
263
+ init_k = min(cfg["init_k"], max_freq)
264
+ amplitude = cfg["amplitude"]
265
+
266
+ # Case 1: longer simulation with vline annotations
267
+ print(f" Running decouple dynamics case 1 (p={p}) ...")
268
+ results1 = _run_decouple_simulation(
269
+ p, init_k,
270
+ num_steps=cfg["num_steps_case1"],
271
+ lr=cfg["learning_rate_case1"],
272
+ init_phi=cfg["init_phi_case1"],
273
+ init_psi=cfg["init_psi_case1"],
274
+ amplitude=amplitude,
275
+ device=device,
276
+ )
277
+ _plot_decouple(
278
+ results1, p,
279
+ num_steps=cfg["num_steps_case1"],
280
+ lr=cfg["learning_rate_case1"],
281
+ init_k=init_k,
282
+ save_path=os.path.join(output_dir, f"p{p:03d}_phase_align_approx1.png"),
283
+ show_vline=True,
284
+ vline_x=cfg["num_steps_case1"] * cfg["learning_rate_case1"] * 0.36,
285
+ )
286
+
287
+ # Case 2: shorter simulation without vline annotations
288
+ print(f" Running decouple dynamics case 2 (p={p}) ...")
289
+ results2 = _run_decouple_simulation(
290
+ p, init_k,
291
+ num_steps=cfg["num_steps_case2"],
292
+ lr=cfg["learning_rate_case2"],
293
+ init_phi=cfg["init_phi_case2"],
294
+ init_psi=cfg["init_psi_case2"],
295
+ amplitude=amplitude,
296
+ device=device,
297
+ )
298
+ _plot_decouple(
299
+ results2, p,
300
+ num_steps=cfg["num_steps_case2"],
301
+ lr=cfg["learning_rate_case2"],
302
+ init_k=init_k,
303
+ save_path=os.path.join(output_dir, f"p{p:03d}_phase_align_approx2.png"),
304
+ show_vline=False,
305
+ )
306
+
307
+
308
+ # ===========================================================================
309
+ # Entry point
310
+ # ===========================================================================
311
+
312
+ def generate_all_for_prime(p, output_base):
313
+ """Generate the 2 decoupled simulation plots for a single prime."""
314
+ output_dir = os.path.join(output_base, f"p_{p:03d}")
315
+ os.makedirs(output_dir, exist_ok=True)
316
+
317
+ print(f"\n{'='*60}")
318
+ print(f"Generating decoupled simulation plots for p={p}")
319
+ print(f"Output: {output_dir}")
320
+ print(f"{'='*60}")
321
+
322
+ # Use float64 globally for numerical precision in simulations
323
+ prev_dtype = torch.get_default_dtype()
324
+ torch.set_default_dtype(torch.float64)
325
+
326
+ try:
327
+ generate_decouple_dynamics(p, output_dir)
328
+ finally:
329
+ torch.set_default_dtype(prev_dtype)
330
+
331
+ print(f"[DONE] p={p}: 2 plots written to {output_dir}")
332
+
333
+
334
+ def main():
335
+ parser = argparse.ArgumentParser(
336
+ description='Generate decoupled simulation plots (analytical, no model needed)'
337
+ )
338
+ parser.add_argument('--all', action='store_true',
339
+ help='Generate plots for all odd p in [3, 199]')
340
+ parser.add_argument('--p', type=int,
341
+ help='Generate plots for a specific p')
342
+ parser.add_argument('--output', type=str, default='./precomputed_results',
343
+ help='Base output directory (default: ./precomputed_results)')
344
+ args = parser.parse_args()
345
+
346
+ if not args.all and args.p is None:
347
+ parser.error("Specify --all or --p P")
348
+
349
+ moduli = [args.p] if args.p else get_moduli()
350
+
351
+ for p in moduli:
352
+ generate_all_for_prime(p, args.output)
353
+
354
+ print(f"\nAll done. Processed {len(moduli)} value(s) of p.")
355
+
356
+
357
+ if __name__ == "__main__":
358
+ main()
precompute/generate_plots.py ADDED
@@ -0,0 +1,2192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Main plot generation script for the HF app.
4
+ Creates all model-dependent plots (Tabs 1-7) from trained checkpoints.
5
+
6
+ Usage:
7
+ python generate_plots.py --all # Generate for all primes
8
+ python generate_plots.py --p 23 # Generate for a specific p
9
+ python generate_plots.py --p 23 --input ./trained_models --output ./hf_app/precomputed_results
10
+ """
11
+ import matplotlib
12
+ matplotlib.use('Agg')
13
+
14
+ import argparse
15
+ import json
16
+ import math
17
+ import os
18
+ import sys
19
+ import traceback
20
+
21
+ import matplotlib.colors as mcolors
22
+ import matplotlib.patches as patches
23
+ import matplotlib.pyplot as plt
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from matplotlib.colors import LinearSegmentedColormap
28
+ from matplotlib.ticker import FuncFormatter
29
+
30
+ # Add project root to path so we can import src modules
31
+ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
32
+ sys.path.insert(0, PROJECT_ROOT)
33
+ sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src'))
34
+ sys.path.insert(0, os.path.dirname(__file__))
35
+
36
+ from src.mechanism_base import (
37
+ get_fourier_basis,
38
+ decode_weights,
39
+ compute_neuron,
40
+ decode_scales_phis,
41
+ normalize_to_pi,
42
+ )
43
+ from src.model_base import EmbedMLP
44
+ from src.utils import cross_entropy_high_precision, acc_rate
45
+ from precompute.neuron_selector import (
46
+ select_top_neurons_by_frequency,
47
+ select_lineplot_neurons,
48
+ select_phase_frequency,
49
+ select_lottery_neuron,
50
+ )
51
+ from precompute.grokking_stage_detector import detect_grokking_stages
52
+ from precompute.prime_config import compute_d_mlp, TRAINING_RUNS, MIN_P_GROKKING
53
+
54
+ # ---------- Lightweight train/test data regeneration ----------
55
+
56
+ def _gen_train_test(p, frac_train=0.75, seed=42):
57
+ """
58
+ Regenerate train/test split deterministically without needing a Config object.
59
+ Mirrors the logic in utils.gen_train_test for the 'add' function.
60
+ Returns (train_data, test_data) where each is a tensor of shape (N, 2).
61
+ """
62
+ import random as _random
63
+ all_pairs = []
64
+ for i in range(p):
65
+ for j in range(p):
66
+ all_pairs.append((i, j))
67
+ data_tensor = torch.tensor(all_pairs, dtype=torch.long)
68
+ _random.seed(seed)
69
+ indices = torch.randperm(len(all_pairs))
70
+ data_tensor = data_tensor[indices]
71
+ if frac_train >= 1.0:
72
+ return data_tensor, data_tensor
73
+ div = int(frac_train * len(all_pairs))
74
+ return data_tensor[:div], data_tensor[div:]
75
+
76
+
77
+ # ---------- Style constants ----------
78
+ COLORS = ['#0D2758', '#60656F', '#DEA54B', '#A32015', '#347186']
79
+ CMAP_DIVERGING = LinearSegmentedColormap.from_list(
80
+ 'cividis_white_center', ['#0D2758', 'white', '#A32015'], N=256
81
+ )
82
+ CMAP_SEQUENTIAL = LinearSegmentedColormap.from_list(
83
+ 'cividis_white_seq', ['white', '#0D2758'], N=256
84
+ )
85
+ DPI = 150
86
+ plt.rcParams['mathtext.fontset'] = 'cm'
87
+
88
+
89
+ def _save_fig(fig, path):
90
+ """Save a figure and close it."""
91
+ fig.savefig(path, dpi=DPI, bbox_inches='tight', format='png')
92
+ plt.close(fig)
93
+
94
+
95
+ # ======================================================================
96
+ # Helpers for loading checkpoints
97
+ # ======================================================================
98
+
99
+ def _find_run_dir(base_dir):
100
+ """
101
+ Given a run type directory (e.g. trained_models/p_023/standard/),
102
+ find the actual checkpoint directory. It may be a timestamped
103
+ subdirectory, or the checkpoints may live directly in base_dir.
104
+ Returns the path that contains the .pth checkpoint files.
105
+ """
106
+ if not os.path.isdir(base_dir):
107
+ return None
108
+
109
+ # Check if .pth files live directly here
110
+ pth_files = [f for f in os.listdir(base_dir)
111
+ if f.endswith('.pth') and f not in ('train_data.pth', 'test_data.pth')]
112
+ if pth_files:
113
+ return base_dir
114
+
115
+ # Otherwise look for a single timestamped subdirectory
116
+ subdirs = [d for d in os.listdir(base_dir)
117
+ if os.path.isdir(os.path.join(base_dir, d))]
118
+ for sd in sorted(subdirs):
119
+ candidate = os.path.join(base_dir, sd)
120
+ files = os.listdir(candidate)
121
+ if any(f.endswith('.pth') for f in files):
122
+ return candidate
123
+ return None
124
+
125
+
126
+ def _load_checkpoints(run_dir, device='cpu'):
127
+ """
128
+ Load all numbered checkpoints from run_dir.
129
+ Returns dict {epoch_int: state_dict} sorted by epoch.
130
+ """
131
+ loaded = {}
132
+ exclude = {'final.pth', 'test_data.pth', 'train_data.pth'}
133
+ for fname in os.listdir(run_dir):
134
+ fpath = os.path.join(run_dir, fname)
135
+ if (os.path.isfile(fpath) and fname.endswith('.pth')
136
+ and fname not in exclude):
137
+ try:
138
+ epoch = int(os.path.splitext(fname)[0])
139
+ except ValueError:
140
+ continue
141
+ data = torch.load(fpath, weights_only=True, map_location=device)
142
+ if isinstance(data, dict) and 'model' in data:
143
+ loaded[epoch] = data['model']
144
+ else:
145
+ loaded[epoch] = data
146
+ return {k: loaded[k] for k in sorted(loaded)}
147
+
148
+
149
+ def _load_final(run_dir, device='cpu'):
150
+ """Load the final.pth model data dict."""
151
+ fpath = os.path.join(run_dir, 'final.pth')
152
+ if not os.path.exists(fpath):
153
+ # Fall back to largest epoch checkpoint
154
+ ckpts = _load_checkpoints(run_dir, device)
155
+ if ckpts:
156
+ max_epoch = max(ckpts.keys())
157
+ return {'model': ckpts[max_epoch]}
158
+ return None
159
+ return torch.load(fpath, weights_only=True, map_location=device)
160
+
161
+
162
+ def _load_training_curves(run_type_dir):
163
+ """Load training_curves.json from the run type directory."""
164
+ path = os.path.join(run_type_dir, 'training_curves.json')
165
+ if os.path.exists(path):
166
+ with open(path) as f:
167
+ return json.load(f)
168
+ # Fall back: check inside the checkpoint subdirectory
169
+ run_dir = _find_run_dir(run_type_dir)
170
+ if run_dir and run_dir != run_type_dir:
171
+ path = os.path.join(run_dir, 'training_curves.json')
172
+ if os.path.exists(path):
173
+ with open(path) as f:
174
+ return json.load(f)
175
+ # Fall back: try loading from final.pth
176
+ if run_dir:
177
+ final_path = os.path.join(run_dir, 'final.pth')
178
+ if os.path.exists(final_path):
179
+ data = torch.load(final_path, weights_only=True, map_location='cpu')
180
+ if isinstance(data, dict):
181
+ curves = {}
182
+ for key in ('train_losses', 'test_losses', 'train_accs', 'test_accs',
183
+ 'grad_norms', 'param_norms'):
184
+ if key in data:
185
+ val = data[key]
186
+ if isinstance(val, torch.Tensor):
187
+ val = val.cpu().tolist()
188
+ curves[key] = val
189
+ if curves:
190
+ return curves
191
+ return None
192
+
193
+
194
+ # ======================================================================
195
+ # PlotGenerator
196
+ # ======================================================================
197
+
198
+ class PlotGenerator:
199
+ """
200
+ Generates all model-dependent plots for a single prime p.
201
+
202
+ Parameters
203
+ ----------
204
+ p : int
205
+ The prime modulus.
206
+ input_dir : str
207
+ Path to trained_models/p_PPP/ containing run-type subdirectories.
208
+ output_dir : str
209
+ Path to hf_app/precomputed_results/p_PPP/ where plots are saved.
210
+ """
211
+
212
+ def __init__(self, p, input_dir, output_dir):
213
+ self.p = p
214
+ self.input_dir = input_dir
215
+ self.output_dir = output_dir
216
+ self.device = 'cpu'
217
+ self.d_vocab = p
218
+ self.d_model = p
219
+
220
+ os.makedirs(output_dir, exist_ok=True)
221
+
222
+ # Infer d_mlp from checkpoint weights; fall back to formula
223
+ self.d_mlp = self._infer_d_mlp() or compute_d_mlp(p)
224
+
225
+ # Fourier basis (mechanism_base version with device arg)
226
+ self.fourier_basis, self.fourier_basis_names = get_fourier_basis(p, self.device)
227
+
228
+ # All (a,b) pairs and labels
229
+ self.all_data = torch.tensor(
230
+ [(i, j) for i in range(p) for j in range(p)], dtype=torch.long
231
+ )
232
+ self.all_labels = torch.tensor(
233
+ [(i + j) % p for i in range(p) for j in range(p)], dtype=torch.long
234
+ )
235
+
236
+ def _infer_d_mlp(self):
237
+ """Infer d_mlp from the first available checkpoint's weight shape."""
238
+ for run_name in TRAINING_RUNS:
239
+ run_type_dir = os.path.join(self.input_dir, run_name)
240
+ run_dir = _find_run_dir(run_type_dir)
241
+ if run_dir is None:
242
+ continue
243
+ final = _load_final(run_dir, 'cpu')
244
+ if final and 'model' in final and 'mlp.W_in' in final['model']:
245
+ d_mlp = final['model']['mlp.W_in'].shape[0]
246
+ print(f" Inferred d_mlp={d_mlp} from {run_name} checkpoint")
247
+ return d_mlp
248
+ return None
249
+
250
+ # ------------------------------------------------------------------
251
+ # Path helpers
252
+ # ------------------------------------------------------------------
253
+
254
+ def _run_type_dir(self, run_name):
255
+ return os.path.join(self.input_dir, run_name)
256
+
257
+ def _run_dir(self, run_name):
258
+ return _find_run_dir(self._run_type_dir(run_name))
259
+
260
+ def _out(self, filename):
261
+ # Prefix every file with pXXX_ so folders are self-contained and browsable
262
+ return os.path.join(self.output_dir, f"p{self.p:03d}_{filename}")
263
+
264
+ # ------------------------------------------------------------------
265
+ # ------------------------------------------------------------------
266
+ # Shared IPR helper
267
+ # ------------------------------------------------------------------
268
+
269
+ def _compute_freq_ipr(self, W_dec):
270
+ """IPR over per-frequency magnitudes (combines cos+sin pairs).
271
+
272
+ IPR = sum_k A_k^4 / (sum_k A_k^2)^2, where A_k = sqrt(c_k^2 + s_k^2).
273
+ IPR → 1 means all energy at a single frequency.
274
+ Returns mean IPR across neurons.
275
+ """
276
+ K = (self.p - 1) // 2
277
+ A2 = torch.zeros(W_dec.shape[0], K)
278
+ for k in range(1, K + 1):
279
+ A2[:, k - 1] = W_dec[:, 2 * k - 1].pow(2) + W_dec[:, 2 * k].pow(2)
280
+ A4 = A2.pow(2)
281
+ denom = A2.sum(dim=1).pow(2)
282
+ valid = denom > 0
283
+ ipr = torch.zeros(W_dec.shape[0])
284
+ ipr[valid] = A4[valid].sum(dim=1) / denom[valid]
285
+ return ipr.mean()
286
+
287
+ def _ipr_at_checkpoint(self, model_sd):
288
+ """Compute average IPR (across both layers) for a single checkpoint."""
289
+ W_in_d, W_out_d, _ = decode_weights(model_sd, self.fourier_basis)
290
+ return ((self._compute_freq_ipr(W_in_d)
291
+ + self._compute_freq_ipr(W_out_d)) / 2).item()
292
+
293
+ # ------------------------------------------------------------------
294
+ # Tab 1: Overview (standard loss+IPR, grokking loss+IPR, phase plot)
295
+ # ------------------------------------------------------------------
296
+
297
+ def generate_tab1(self):
298
+ """Generate overview plots: standard + grokking loss/IPR, plus phase scatter."""
299
+ print(f" [Tab 1] Overview for p={self.p}")
300
+
301
+ # ---- Standard run: loss + IPR ----
302
+ std_dir = self._run_dir('standard')
303
+ std_epochs, std_loss, std_ipr = [], [], []
304
+ if std_dir is not None:
305
+ std_curves = _load_training_curves(self._run_type_dir('standard'))
306
+ std_ckpts = _load_checkpoints(std_dir, self.device)
307
+ if std_ckpts:
308
+ std_epochs = sorted(std_ckpts.keys())
309
+ std_ipr = [self._ipr_at_checkpoint(std_ckpts[ep]) for ep in std_epochs]
310
+ if std_curves and 'train_losses' in std_curves:
311
+ se = std_epochs[1] - std_epochs[0] if len(std_epochs) > 1 else 200
312
+ std_loss = std_curves['train_losses'][::se][:len(std_epochs)]
313
+
314
+ # ---- Grokking run: train/test loss + IPR ----
315
+ grokk_epochs, grokk_train_loss, grokk_test_loss, grokk_ipr = [], [], [], []
316
+ has_grokk = self.p >= MIN_P_GROKKING
317
+ if has_grokk:
318
+ grokk_dir = self._run_dir('grokking')
319
+ if grokk_dir is not None:
320
+ grokk_curves = _load_training_curves(self._run_type_dir('grokking'))
321
+ grokk_ckpts = _load_checkpoints(grokk_dir, self.device)
322
+ if grokk_ckpts:
323
+ grokk_epochs = sorted(grokk_ckpts.keys())
324
+ grokk_ipr = [self._ipr_at_checkpoint(grokk_ckpts[ep])
325
+ for ep in grokk_epochs]
326
+ if grokk_curves:
327
+ se = grokk_epochs[1] - grokk_epochs[0] if len(grokk_epochs) > 1 else 200
328
+ if 'train_losses' in grokk_curves:
329
+ grokk_train_loss = grokk_curves['train_losses'][::se][:len(grokk_epochs)]
330
+ if 'test_losses' in grokk_curves:
331
+ grokk_test_loss = grokk_curves['test_losses'][::se][:len(grokk_epochs)]
332
+
333
+ if not std_epochs and not grokk_epochs:
334
+ print(" SKIP: no checkpoints found for standard or grokking run")
335
+ return
336
+
337
+ # ---- Static plot: 2×2 grid (std loss, grokk loss, std IPR, grokk IPR) ----
338
+ n_cols = 2 if has_grokk and grokk_epochs else 1
339
+ fig, axes = plt.subplots(2, n_cols, figsize=(5 * n_cols, 7),
340
+ constrained_layout=True)
341
+ if n_cols == 1:
342
+ axes = axes.reshape(2, 1)
343
+
344
+ # Standard loss (top-left)
345
+ ax = axes[0, 0]
346
+ if std_loss:
347
+ ax.plot(std_epochs[:len(std_loss)], std_loss,
348
+ color=COLORS[0], linewidth=1.5, label="Train Loss")
349
+ ax.set_title('Standard (ReLU, full data)', fontsize=14)
350
+ ax.set_ylabel('Loss', fontsize=13)
351
+ ax.legend(fontsize=11)
352
+ ax.grid(True, alpha=0.4)
353
+
354
+ # Standard IPR (bottom-left)
355
+ ax = axes[1, 0]
356
+ if std_ipr:
357
+ ax.plot(std_epochs[:len(std_ipr)], std_ipr,
358
+ color=COLORS[3], linewidth=1.5, label="Avg. IPR")
359
+ ax.set_title('Standard IPR', fontsize=14)
360
+ ax.set_xlabel('Step', fontsize=13)
361
+ ax.set_ylabel('IPR', fontsize=13)
362
+ ax.set_ylim([0, 1.05])
363
+ ax.legend(fontsize=11)
364
+ ax.grid(True, alpha=0.4)
365
+
366
+ if n_cols == 2:
367
+ # Grokking loss (top-right)
368
+ ax = axes[0, 1]
369
+ gx = grokk_epochs
370
+ if grokk_train_loss:
371
+ ax.plot(gx[:len(grokk_train_loss)], grokk_train_loss,
372
+ color=COLORS[0], linewidth=1.5, label="Train Loss")
373
+ if grokk_test_loss:
374
+ ax.plot(gx[:len(grokk_test_loss)], grokk_test_loss,
375
+ color=COLORS[3], linewidth=1.5, label="Test Loss")
376
+ ax.set_title('Grokking (ReLU, 75% data, WD)', fontsize=14)
377
+ ax.legend(fontsize=11)
378
+ ax.grid(True, alpha=0.4)
379
+
380
+ # Grokking IPR (bottom-right)
381
+ ax = axes[1, 1]
382
+ if grokk_ipr:
383
+ ax.plot(gx[:len(grokk_ipr)], grokk_ipr,
384
+ color=COLORS[3], linewidth=1.5, label="Avg. IPR")
385
+ ax.set_title('Grokking IPR', fontsize=14)
386
+ ax.set_xlabel('Step', fontsize=13)
387
+ ax.set_ylim([0, 1.05])
388
+ ax.legend(fontsize=11)
389
+ ax.grid(True, alpha=0.4)
390
+
391
+ _save_fig(fig, self._out('overview_loss_ipr.png'))
392
+
393
+ # ---- Phase relationship scatter from standard final checkpoint ----
394
+ if std_ckpts:
395
+ final_ep = max(std_ckpts.keys())
396
+ model_sd = std_ckpts[final_ep]
397
+ W_in_d, W_out_d, mfl = decode_weights(model_sd, self.fourier_basis)
398
+ n_neurons = W_in_d.shape[0]
399
+ phis_2, psis = [], []
400
+ for neuron in range(n_neurons):
401
+ _, phi = compute_neuron(neuron, mfl, W_in_d)
402
+ _, psi = compute_neuron(neuron, mfl, W_out_d)
403
+ two_phi = normalize_to_pi(2 * phi)
404
+ psi_n = normalize_to_pi(psi)
405
+ # Fix ±π wrap: keep ψ within π of 2φ
406
+ if psi_n - two_phi > np.pi:
407
+ psi_n -= 2 * np.pi
408
+ elif psi_n - two_phi < -np.pi:
409
+ psi_n += 2 * np.pi
410
+ phis_2.append(two_phi)
411
+ psis.append(psi_n)
412
+
413
+ fig, ax = plt.subplots(figsize=(5, 5))
414
+ ax.plot([-np.pi, np.pi], [-np.pi, np.pi], 'r-',
415
+ linewidth=3, alpha=0.8,
416
+ label=r'$\psi_m = 2\phi_m$', zorder=1)
417
+ ax.scatter(phis_2, psis, s=12, alpha=0.6, color=COLORS[0], zorder=2)
418
+ ax.legend(fontsize=12, loc='upper left')
419
+ ax.set_xlabel(r'$2\phi_m$', fontsize=14)
420
+ ax.set_ylabel(r'$\psi_m$', fontsize=14)
421
+ ax.set_title(r'Phase Alignment: $\psi_m = 2\phi_m$', fontsize=14)
422
+ ax.set_xlim([-np.pi, np.pi])
423
+ ax.set_ylim([-np.pi, np.pi])
424
+ ax.set_aspect('equal')
425
+ ax.grid(True, alpha=0.3)
426
+ _save_fig(fig, self._out('overview_phase_scatter.png'))
427
+
428
+ # ---- JSON for interactive Plotly charts ----
429
+ payload = {
430
+ 'std_epochs': [int(e) for e in std_epochs],
431
+ 'std_ipr': std_ipr,
432
+ }
433
+ if std_loss:
434
+ payload['std_train_loss'] = [float(v) for v in std_loss]
435
+
436
+ if has_grokk and grokk_epochs:
437
+ payload['grokk_epochs'] = [int(e) for e in grokk_epochs]
438
+ payload['grokk_ipr'] = grokk_ipr
439
+ if grokk_train_loss:
440
+ payload['grokk_train_loss'] = [float(v) for v in grokk_train_loss]
441
+ if grokk_test_loss:
442
+ payload['grokk_test_loss'] = [float(v) for v in grokk_test_loss]
443
+
444
+ with open(self._out('overview.json'), 'w') as f:
445
+ json.dump(payload, f)
446
+
447
+ files = ['overview_loss_ipr.png', 'overview.json']
448
+ if std_ckpts:
449
+ files.append('overview_phase_scatter.png')
450
+ print(f" Saved {', '.join(files)}")
451
+
452
+ # ------------------------------------------------------------------
453
+ # Tab 2: Fourier Weights (heatmap + lineplots)
454
+ # ------------------------------------------------------------------
455
+
456
+ def generate_tab2(self):
457
+ """Generate full_training_para_origin.png, lineplot_in.png, lineplot_out.png."""
458
+ print(f" [Tab 2] Fourier Weights for p={self.p}")
459
+ run_dir = self._run_dir('standard')
460
+ if run_dir is None:
461
+ print(" SKIP: standard run directory not found")
462
+ return
463
+
464
+ final_data = _load_final(run_dir, self.device)
465
+ if final_data is None:
466
+ print(" SKIP: no final checkpoint")
467
+ return
468
+ model_load = final_data['model']
469
+
470
+ W_in_decode, W_out_decode, max_freq_ls = decode_weights(
471
+ model_load, self.fourier_basis
472
+ )
473
+ d_mlp = W_in_decode.shape[0]
474
+ num_neurons = min(20, d_mlp)
475
+
476
+ # Sort neurons by frequency
477
+ sorted_indices = select_top_neurons_by_frequency(
478
+ max_freq_ls, W_in_decode, n=num_neurons
479
+ )
480
+
481
+ freq_ls = np.array([max_freq_ls[i] for i in sorted_indices])
482
+
483
+ # DFT coefficients for heatmap (matches blog Figure 2)
484
+ W_in_dft = W_in_decode[sorted_indices, :]
485
+ W_out_dft = W_out_decode[sorted_indices, :]
486
+ # Raw weights for line plots (matches blog Figure 3)
487
+ W_in_raw = model_load['mlp.W_in'][sorted_indices, :]
488
+ W_out_raw = model_load['mlp.W_out'].T[sorted_indices, :]
489
+
490
+ # Sort within selected set by frequency
491
+ sort_order = np.argsort(freq_ls)
492
+ ranked_W_in_dft = W_in_dft[sort_order, :]
493
+ ranked_W_out_dft = W_out_dft[sort_order, :]
494
+ ranked_W_in_raw = W_in_raw[sort_order, :]
495
+ ranked_W_out_raw = W_out_raw[sort_order, :]
496
+
497
+ # ---- Heatmap plot (DFT coefficients, matching blog Figure 2) ----
498
+ fb_names = self.fourier_basis_names
499
+ n_modes = len(fb_names)
500
+ fig_w = max(8, n_modes * 0.4)
501
+ fig_h = max(8, num_neurons * 0.35 + 3)
502
+ fig, axes = plt.subplots(
503
+ 2, 1, figsize=(fig_w, fig_h), constrained_layout=True,
504
+ gridspec_kw={"hspace": 0.15}
505
+ )
506
+
507
+ # W_in DFT
508
+ ax_in = axes[0]
509
+ W_in_np = ranked_W_in_dft.detach().cpu().numpy()
510
+ abs_max_in = np.abs(W_in_np).max()
511
+ im_in = ax_in.imshow(
512
+ W_in_np,
513
+ cmap=CMAP_DIVERGING, vmin=-abs_max_in, vmax=abs_max_in,
514
+ aspect='auto'
515
+ )
516
+ ax_in.set_title(r'First-Layer $\theta_m$ after DFT', fontsize=18)
517
+ fig.colorbar(im_in, ax=ax_in, shrink=0.8)
518
+ y_locs = np.arange(num_neurons)
519
+ ax_in.set_yticks(y_locs)
520
+ ax_in.set_yticklabels(y_locs, fontsize=10)
521
+ ax_in.set_ylabel('Neuron #', fontsize=14)
522
+ x_locs = np.arange(n_modes)
523
+ ax_in.set_xticks(x_locs)
524
+ ax_in.set_xticklabels(fb_names, rotation=90, fontsize=10)
525
+
526
+ # W_out DFT
527
+ ax_out = axes[1]
528
+ W_out_np = ranked_W_out_dft.detach().cpu().numpy()
529
+ abs_max_out = np.abs(W_out_np).max()
530
+ im_out = ax_out.imshow(
531
+ W_out_np,
532
+ cmap=CMAP_DIVERGING, vmin=-abs_max_out, vmax=abs_max_out,
533
+ aspect='auto'
534
+ )
535
+ ax_out.set_title(r'Second-Layer $\xi_m$ after DFT', fontsize=18)
536
+ fig.colorbar(im_out, ax=ax_out, shrink=0.8)
537
+ ax_out.set_yticks(y_locs)
538
+ ax_out.set_yticklabels(y_locs, fontsize=10)
539
+ ax_out.set_ylabel('Neuron #', fontsize=14)
540
+ ax_out.set_xticks(x_locs)
541
+ ax_out.set_xticklabels(fb_names, rotation=90, fontsize=10)
542
+ ax_out.set_xlabel('Fourier Component', fontsize=14)
543
+
544
+ _save_fig(fig, self._out('full_training_para_origin.png'))
545
+
546
+ # ---- Line plots (raw weights + cosine fits, matching blog Figure 3) ----
547
+ lineplot_idx = select_lineplot_neurons(list(range(num_neurons)), n=3)
548
+ fb = self.fourier_basis
549
+ positions = np.arange(ranked_W_in_raw.shape[1])
550
+
551
+ for tag, weight_data, title_tex in [
552
+ ('lineplot_in', ranked_W_in_raw, r'First-Layer Parameters $\theta_m$'),
553
+ ('lineplot_out', ranked_W_out_raw, r'Second-Layer Parameters $\xi_m$'),
554
+ ]:
555
+ if hasattr(weight_data, 'detach'):
556
+ weight_np = weight_data.detach().cpu()
557
+ else:
558
+ weight_np = weight_data
559
+
560
+ top3 = weight_np[lineplot_idx]
561
+
562
+ lp_w = max(8, self.p * 0.35)
563
+ fig, axes_lp = plt.subplots(
564
+ nrows=3, ncols=1, figsize=(lp_w, 8),
565
+ constrained_layout=True,
566
+ gridspec_kw={'hspace': 0.08}
567
+ )
568
+
569
+ for i, ax in enumerate(axes_lp):
570
+ data = top3[i]
571
+ if isinstance(data, torch.Tensor):
572
+ data_t = data.float()
573
+ else:
574
+ data_t = torch.tensor(data, dtype=torch.float32)
575
+ # Project into Fourier space, keep top 2 components, project back
576
+ proj = data_t @ fb.T
577
+ abs_proj = torch.abs(proj)
578
+ _, top2_idx = torch.topk(abs_proj, 2)
579
+ mask = torch.zeros_like(proj)
580
+ mask[top2_idx] = proj[top2_idx]
581
+ data_est = mask @ fb
582
+ data_np = data_t.numpy()
583
+ data_est_np = data_est.numpy()
584
+
585
+ ax.plot(data_np, marker='o', markersize=5,
586
+ color=COLORS[0], linewidth=1.5, linestyle=':',
587
+ label="Actual")
588
+ ax.plot(data_est_np, marker='o', markersize=5,
589
+ color=COLORS[3], linewidth=1.5, linestyle=':',
590
+ alpha=0.7, label="Fitted")
591
+ ax.set_ylim(-0.9, 0.9)
592
+ ax.set_ylabel(f'Neuron #{i+1}', fontsize=14)
593
+ ax.set_xticks(positions)
594
+ ax.grid(True, which='major', axis='both',
595
+ linestyle='--', linewidth=0.5, alpha=0.6)
596
+ if i < len(axes_lp) - 1:
597
+ ax.set_xticklabels([])
598
+ ax.legend(fontsize=12, loc="upper right")
599
+
600
+ axes_lp[-1].set_xlabel('Input Dimension', fontsize=14)
601
+ axes_lp[-1].set_xticks(positions)
602
+ axes_lp[-1].set_xticklabels(
603
+ np.arange(ranked_W_in_raw.shape[1]), rotation=0, fontsize=10
604
+ )
605
+ axes_lp[0].set_title(title_tex, fontsize=18)
606
+
607
+ _save_fig(fig, self._out(f'{tag}.png'))
608
+
609
+ print(" Saved full_training_para_origin.png, lineplot_in.png, lineplot_out.png")
610
+
611
+ # ------------------------------------------------------------------
612
+ # Tab 3: Phase Analysis
613
+ # ------------------------------------------------------------------
614
+
615
+ def generate_tab3(self):
616
+ """Generate phase_distribution.png, phase_relationship.png, magnitude_distribution.png."""
617
+ print(f" [Tab 3] Phase Analysis for p={self.p}")
618
+ run_dir = self._run_dir('standard')
619
+ if run_dir is None:
620
+ print(" SKIP: standard run directory not found")
621
+ return
622
+
623
+ final_data = _load_final(run_dir, self.device)
624
+ if final_data is None:
625
+ print(" SKIP: no final checkpoint")
626
+ return
627
+ model_load = final_data['model']
628
+
629
+ W_in_decode, W_out_decode, max_freq_ls = decode_weights(
630
+ model_load, self.fourier_basis
631
+ )
632
+ d_mlp = W_in_decode.shape[0]
633
+
634
+ # Compute all neuron phases and magnitudes
635
+ coeff_in_scale_ls = []
636
+ coeff_out_scale_ls = []
637
+ coeff_phi_ls = []
638
+ coeff_psi_ls = []
639
+
640
+ for neuron in range(d_mlp):
641
+ s_in, phi_in = compute_neuron(neuron, max_freq_ls, W_in_decode)
642
+ s_out, phi_out = compute_neuron(neuron, max_freq_ls, W_out_decode)
643
+ coeff_in_scale_ls.append(s_in)
644
+ coeff_out_scale_ls.append(s_out)
645
+ coeff_phi_ls.append(phi_in)
646
+ coeff_psi_ls.append(phi_out)
647
+
648
+ coeff_phi_arr = np.array(coeff_phi_ls)
649
+ coeff_psi_arr = np.array(coeff_psi_ls)
650
+
651
+ # ---- Phase distribution on concentric circles ----
652
+ # Select the most common non-zero frequency for phase analysis
653
+ target_freq = select_phase_frequency(max_freq_ls, self.p)
654
+ freq_neurons = [i for i, f in enumerate(max_freq_ls) if f == target_freq]
655
+ phi_subset = np.array([coeff_phi_ls[i] for i in freq_neurons])
656
+
657
+ theta = np.linspace(0, 2 * np.pi, 300)
658
+ multipliers = [1, 2, 3, 4]
659
+ radii = [1.0, 0.88, 0.76, 0.64]
660
+
661
+ fig, ax = plt.subplots(figsize=(4, 4))
662
+ for m, r in zip(multipliers, radii):
663
+ x_c, y_c = r * np.cos(theta), r * np.sin(theta)
664
+ ax.plot(x_c, y_c, linewidth=0.8, color='gray', alpha=0.6)
665
+
666
+ x_pts = r * np.cos(m * phi_subset)
667
+ y_pts = r * np.sin(m * phi_subset)
668
+ label = fr'$\phi_m$' if m == 1 else fr'${m}\phi_m$'
669
+ ax.scatter(x_pts, y_pts, s=20, marker='o',
670
+ color=COLORS[m - 1], label=label)
671
+
672
+ ax.legend(
673
+ fontsize=15, loc='upper center', columnspacing=0.2,
674
+ handletextpad=0.1, bbox_to_anchor=(0.5, 1.15), ncol=4, frameon=False
675
+ )
676
+ ax.set_xlabel(r'$\cos(\phi_m)$', fontsize=19)
677
+ ax.set_ylabel(r'$\sin(\phi_m)$', fontsize=19)
678
+ ax.set_xticks([])
679
+ ax.set_yticks([])
680
+ for spine in ax.spines.values():
681
+ spine.set_visible(False)
682
+
683
+ _save_fig(fig, self._out('phase_distribution.png'))
684
+
685
+ # ---- Phase relationship: 2*phi vs psi ----
686
+ coeff_2phi_arr = np.array([normalize_to_pi(2 * phi) for phi in coeff_phi_arr])
687
+ coeff_psi_plot = coeff_psi_arr.copy()
688
+ # Fix ±π wrap: keep ψ within π of 2φ so boundary points stay on diagonal
689
+ diff = coeff_psi_plot - coeff_2phi_arr
690
+ coeff_psi_plot[diff > np.pi] -= 2 * np.pi
691
+ coeff_psi_plot[diff < -np.pi] += 2 * np.pi
692
+
693
+ fig, ax = plt.subplots(figsize=(5, 5))
694
+ ax.plot([-np.pi, np.pi], [-np.pi, np.pi], 'r-', linewidth=3, alpha=0.8,
695
+ label=r'$\psi_m = 2\phi_m$', zorder=1)
696
+ ax.scatter(
697
+ coeff_2phi_arr, coeff_psi_plot,
698
+ marker='.', color=COLORS[0], s=20, zorder=2
699
+ )
700
+ ax.legend(fontsize=12, loc='upper left')
701
+ ax.set_xlabel(r'$2\phi_m$', fontsize=14)
702
+ ax.set_ylabel(r'$\psi_m$', fontsize=14)
703
+ ax.set_title(r'Phase Alignment: $\psi_m = 2\phi_m$', fontsize=14)
704
+ ax.set_xlim(-np.pi * 1.1, np.pi * 1.1)
705
+ ax.set_ylim(-np.pi * 1.1, np.pi * 1.1)
706
+ ax.set_aspect('equal')
707
+ ax.grid(True, alpha=0.3)
708
+
709
+ _save_fig(fig, self._out('phase_relationship.png'))
710
+
711
+ # ---- Magnitude distribution (violin) ----
712
+ fig, ax = plt.subplots(figsize=(4, 4))
713
+ data_for_plot = [coeff_in_scale_ls, coeff_out_scale_ls]
714
+ positions = [1, 2]
715
+
716
+ parts = ax.violinplot(
717
+ data_for_plot, positions=positions, widths=0.6,
718
+ showmeans=True, showmedians=True, showextrema=True
719
+ )
720
+ for pc in parts['bodies']:
721
+ pc.set_facecolor(COLORS[0])
722
+ pc.set_alpha(0.7)
723
+ parts['cmedians'].set_color(COLORS[2])
724
+ parts['cmedians'].set_linewidth(2)
725
+ parts['cmeans'].set_color(COLORS[2])
726
+ parts['cmeans'].set_linewidth(2)
727
+ parts['cbars'].set_color(COLORS[0])
728
+ parts['cbars'].set_linewidth(1.5)
729
+ parts['cmaxes'].set_color(COLORS[0])
730
+ parts['cmins'].set_color(COLORS[0])
731
+
732
+ ax.set_xticks(positions)
733
+ ax.set_xticklabels(['First-Layer', 'Second-Layer'], fontsize=14)
734
+ ax.set_ylabel('Magnitude', fontsize=19)
735
+ ax.grid(True, alpha=0.3)
736
+ plt.tight_layout()
737
+
738
+ _save_fig(fig, self._out('magnitude_distribution.png'))
739
+
740
+ print(" Saved phase_distribution.png, phase_relationship.png, magnitude_distribution.png")
741
+
742
+ # ------------------------------------------------------------------
743
+ # Tab 4: Output Logits
744
+ # ------------------------------------------------------------------
745
+
746
+ def generate_tab4(self):
747
+ """Generate output_logits.png."""
748
+ print(f" [Tab 4] Output Logits for p={self.p}")
749
+ run_dir = self._run_dir('standard')
750
+ if run_dir is None:
751
+ print(" SKIP: standard run directory not found")
752
+ return
753
+
754
+ final_data = _load_final(run_dir, self.device)
755
+ if final_data is None:
756
+ print(" SKIP: no final checkpoint")
757
+ return
758
+ model_load = final_data['model']
759
+
760
+ p = self.p
761
+ act_type = TRAINING_RUNS['standard']['act_type']
762
+ model = EmbedMLP(
763
+ d_vocab=self.d_vocab,
764
+ d_model=self.d_model,
765
+ d_mlp=self.d_mlp,
766
+ act_type=act_type,
767
+ use_cache=False
768
+ )
769
+ model.to(self.device)
770
+ model.load_state_dict(model_load)
771
+ model.eval()
772
+
773
+ with torch.no_grad():
774
+ logits = model(self.all_data).squeeze(1)
775
+
776
+ logits_np = logits.cpu().numpy()
777
+
778
+ # Show first p pairs (first row of the input grid)
779
+ interval_start = 0
780
+ interval_end = p
781
+ logits_interval = logits_np[interval_start:interval_end]
782
+ selected_pairs = self.all_data[interval_start:interval_end]
783
+
784
+ fig, ax = plt.subplots(figsize=(7, 6))
785
+ abs_max = np.abs(logits_np).max() * 0.8
786
+ im = ax.imshow(
787
+ logits_interval.T, cmap=CMAP_DIVERGING, aspect='auto',
788
+ vmin=-abs_max, vmax=abs_max
789
+ )
790
+
791
+ # Highlight target positions with rectangles
792
+ for i, (x_val_t, y_val_t) in enumerate(selected_pairs):
793
+ x_val = x_val_t.item()
794
+ y_val = y_val_t.item()
795
+ target_2x = (2 * x_val) % p
796
+ target_2y = (2 * y_val) % p
797
+ target_sum = (x_val + y_val) % p
798
+
799
+ rect_2x = patches.Rectangle(
800
+ (i - 0.5, target_2x - 0.5), 1, 1,
801
+ linewidth=1.6, edgecolor='#0D2758', facecolor='none', alpha=0.9
802
+ )
803
+ ax.add_patch(rect_2x)
804
+ if target_2y != target_2x:
805
+ rect_2y = patches.Rectangle(
806
+ (i - 0.5, target_2y - 0.5), 1, 1,
807
+ linewidth=1.6, edgecolor='#0D2758', facecolor='none', alpha=0.9
808
+ )
809
+ ax.add_patch(rect_2y)
810
+ rect_sum = patches.Rectangle(
811
+ (i - 0.5, target_sum - 0.5), 1, 1,
812
+ linewidth=1.6, edgecolor='#0D2758', facecolor='none', alpha=0.9
813
+ )
814
+ ax.add_patch(rect_sum)
815
+
816
+ n_pairs = interval_end - interval_start
817
+ if n_pairs <= 50:
818
+ x_positions = np.arange(n_pairs)
819
+ x_labels = [f"({selected_pairs[i][0].item()},{selected_pairs[i][1].item()})"
820
+ for i in range(n_pairs)]
821
+ ax.set_xticks(x_positions)
822
+ ax.set_xticklabels(x_labels, rotation=90, ha='right', fontsize=14)
823
+ else:
824
+ n_labels = min(25, n_pairs)
825
+ step = n_pairs // n_labels
826
+ x_positions = np.arange(0, n_pairs, step)
827
+ x_labels = [f"({selected_pairs[i][0].item()},{selected_pairs[i][1].item()})"
828
+ for i in x_positions]
829
+ ax.set_xticks(x_positions)
830
+ ax.set_xticklabels(x_labels, rotation=90, ha='right', fontsize=14)
831
+
832
+ ax.set_yticks(np.arange(p))
833
+ ax.set_yticklabels(np.arange(p), fontsize=14)
834
+ ax.set_xlabel("Input Pair", fontsize=18)
835
+ ax.set_ylabel("Output", fontsize=18)
836
+ plt.colorbar(im, ax=ax)
837
+ ax.grid(True, alpha=0.2, linestyle=':', linewidth=0.5, axis='x')
838
+ plt.tight_layout()
839
+
840
+ _save_fig(fig, self._out('output_logits.png'))
841
+ print(" Saved output_logits.png")
842
+
843
+ # ------------------------------------------------------------------
844
+ # Tab 5: Grokking
845
+ # ------------------------------------------------------------------
846
+
847
+ def generate_tab5(self):
848
+ """Generate grokking-related plots."""
849
+ print(f" [Tab 5] Grokking for p={self.p}")
850
+ if self.p < MIN_P_GROKKING:
851
+ print(f" SKIP: p={self.p} < {MIN_P_GROKKING} (too few test points for grokking)")
852
+ return
853
+ run_dir = self._run_dir('grokking')
854
+ if run_dir is None:
855
+ print(" SKIP: grokking run directory not found")
856
+ return
857
+
858
+ curves = _load_training_curves(self._run_type_dir('grokking'))
859
+ checkpoints = _load_checkpoints(run_dir, self.device)
860
+
861
+ if not checkpoints:
862
+ print(" SKIP: no grokking checkpoints")
863
+ return
864
+
865
+ epochs = sorted(checkpoints.keys())
866
+ p = self.p
867
+ d_mlp = self.d_mlp
868
+ act_type = TRAINING_RUNS['grokking']['act_type']
869
+
870
+ # Load train/test data
871
+ train_data_path = os.path.join(run_dir, 'train_data.pth')
872
+ test_data_path = os.path.join(run_dir, 'test_data.pth')
873
+ train_data = None
874
+ test_data = None
875
+ train_labels = None
876
+ test_labels = None
877
+ if os.path.exists(train_data_path):
878
+ raw = torch.load(train_data_path, weights_only=False,
879
+ map_location=self.device)
880
+ # Handle both formats: plain tensor or (pairs, labels) tuple
881
+ if isinstance(raw, (tuple, list)):
882
+ train_data, train_labels = raw[0], raw[1]
883
+ else:
884
+ train_data = raw
885
+ if os.path.exists(test_data_path):
886
+ raw = torch.load(test_data_path, weights_only=False,
887
+ map_location=self.device)
888
+ if isinstance(raw, (tuple, list)):
889
+ test_data, test_labels = raw[0], raw[1]
890
+ else:
891
+ test_data = raw
892
+
893
+ # Fallback: regenerate data deterministically if files are missing
894
+ if train_data is None or test_data is None:
895
+ grokk_cfg = TRAINING_RUNS['grokking']
896
+ frac = grokk_cfg['frac_train']
897
+ seed = grokk_cfg['seed']
898
+ print(f" Regenerating train/test data (frac={frac}, seed={seed})")
899
+ train_data, test_data = _gen_train_test(p, frac_train=frac, seed=seed)
900
+
901
+ # Compute labels from pairs if not loaded directly
902
+ if train_labels is None and train_data is not None:
903
+ train_labels = torch.tensor(
904
+ [(train_data[i, 0].item() + train_data[i, 1].item()) % p
905
+ for i in range(train_data.shape[0])],
906
+ dtype=torch.long
907
+ )
908
+ if test_labels is None and test_data is not None:
909
+ test_labels = torch.tensor(
910
+ [(test_data[i, 0].item() + test_data[i, 1].item()) % p
911
+ for i in range(test_data.shape[0])],
912
+ dtype=torch.long
913
+ )
914
+
915
+ # Detect stage boundaries
916
+ train_losses = curves.get('train_losses', []) if curves else []
917
+ test_losses = curves.get('test_losses', []) if curves else []
918
+ train_accs_curve = curves.get('train_accs', None) if curves else None
919
+ test_accs_curve = curves.get('test_accs', None) if curves else None
920
+
921
+ stage1_end, stage2_end = detect_grokking_stages(
922
+ train_losses, test_losses, train_accs_curve, test_accs_curve
923
+ )
924
+ if stage1_end is None:
925
+ stage1_end = len(epochs) // 5
926
+ if stage2_end is None:
927
+ stage2_end = len(epochs) * 3 // 5
928
+
929
+ # ---- Loss JSON + static PNG ----
930
+ if train_losses:
931
+ loss_data = {
932
+ 'train_losses': train_losses,
933
+ 'test_losses': test_losses,
934
+ 'stage1_end': stage1_end,
935
+ 'stage2_end': stage2_end,
936
+ }
937
+ with open(self._out('grokk_loss.json'), 'w') as f:
938
+ json.dump(loss_data, f)
939
+
940
+ # Static loss PNG (matches blog Figure 13a)
941
+ max_step = min(len(train_losses), len(test_losses)) if test_losses else len(train_losses)
942
+ fig, ax = plt.subplots(figsize=(4, 4))
943
+ ax.plot(train_losses[:max_step], color='#0D2758', linewidth=2, label='Train')
944
+ if test_losses:
945
+ ax.plot(test_losses[:max_step], color='#A32015', linewidth=2, label='Test')
946
+ ax.axvspan(0, stage1_end, alpha=0.15, color='#D4AF37')
947
+ ax.axvspan(stage1_end, stage2_end, alpha=0.15, color='#8B7355')
948
+ ax.axvspan(stage2_end, max_step, alpha=0.15, color='#60656F')
949
+ ax.axvline(x=stage1_end, color='black', linestyle='--', linewidth=1)
950
+ ax.axvline(x=stage2_end, color='black', linestyle='--', linewidth=1)
951
+ ax.set_xlabel('Step', fontsize=16)
952
+ ax.set_ylabel('Loss', fontsize=16)
953
+ ax.legend(fontsize=16, loc='upper right')
954
+ ax.grid(True, linestyle='--', alpha=0.5)
955
+ plt.tight_layout()
956
+ _save_fig(fig, self._out('grokk_loss.png'))
957
+
958
+ # ---- Accuracy: compute from checkpoints if not in curves ----
959
+ train_accs = []
960
+ test_accs = []
961
+ if train_data is not None and test_data is not None:
962
+ for ep in epochs:
963
+ model = EmbedMLP(
964
+ d_vocab=self.d_vocab, d_model=self.d_model,
965
+ d_mlp=d_mlp, act_type=act_type, use_cache=False
966
+ ).to(self.device)
967
+ model.load_state_dict(checkpoints[ep])
968
+ model.eval()
969
+ with torch.no_grad():
970
+ tr_logits = model(train_data)
971
+ te_logits = model(test_data)
972
+ train_accs.append(acc_rate(tr_logits, train_labels))
973
+ test_accs.append(acc_rate(te_logits, test_labels))
974
+ elif train_accs_curve is not None:
975
+ # Use curves data, subsample to match checkpoint epochs
976
+ save_every = epochs[1] - epochs[0] if len(epochs) > 1 else 200
977
+ train_accs = train_accs_curve[::save_every][:len(epochs)]
978
+ test_accs = test_accs_curve[::save_every][:len(epochs)]
979
+
980
+ acc_data = {
981
+ 'epochs': epochs,
982
+ 'train_accs': train_accs,
983
+ 'test_accs': test_accs,
984
+ 'stage1_end': stage1_end,
985
+ 'stage2_end': stage2_end,
986
+ }
987
+ with open(self._out('grokk_acc.json'), 'w') as f:
988
+ json.dump(acc_data, f)
989
+
990
+ # Static accuracy PNG (matches blog Figure 13b)
991
+ if train_accs and test_accs:
992
+ fig, ax = plt.subplots(figsize=(4, 4))
993
+ ax.axvspan(0, stage1_end, alpha=0.15, color='#D4AF37')
994
+ ax.axvspan(stage1_end, stage2_end, alpha=0.15, color='#8B7355')
995
+ ax.axvspan(stage2_end, epochs[-1] if epochs else stage2_end,
996
+ alpha=0.15, color='#60656F')
997
+ ax.axvline(x=stage1_end, color='black', linestyle='--', linewidth=1)
998
+ ax.axvline(x=stage2_end, color='black', linestyle='--', linewidth=1)
999
+ ax.plot(epochs[:len(train_accs)], train_accs,
1000
+ label='Train', color='#0D2758', linewidth=2.5)
1001
+ ax.plot(epochs[:len(test_accs)], test_accs,
1002
+ label='Test', color='#A32015', linewidth=2.5)
1003
+ ax.set_xlabel('Step', fontsize=16)
1004
+ ax.set_ylabel('Accuracy', fontsize=16)
1005
+ ax.legend(fontsize=16, loc='lower right')
1006
+ ax.grid(True, linestyle='--', alpha=0.5)
1007
+ plt.tight_layout()
1008
+ _save_fig(fig, self._out('grokk_acc.png'))
1009
+
1010
+ # ---- Phase difference |sin(D*)| ----
1011
+ abs_phase_diff = []
1012
+ sparse_level = []
1013
+
1014
+ for ep in epochs:
1015
+ model_sd = checkpoints[ep]
1016
+ W_in_d, W_out_d, mfl = decode_weights(model_sd, self.fourier_basis)
1017
+
1018
+ sparse_level.append(self._ipr_at_checkpoint(model_sd))
1019
+
1020
+ phase_diffs = []
1021
+ for neuron in range(W_in_d.shape[0]):
1022
+ _, phi_in = compute_neuron(neuron, mfl, W_in_d)
1023
+ _, phi_out = compute_neuron(neuron, mfl, W_out_d)
1024
+ phase_diffs.append(normalize_to_pi(phi_out - 2 * phi_in))
1025
+ phase_diffs = np.array(phase_diffs)
1026
+ abs_phase_diff.append(np.mean(np.abs(np.sin(phase_diffs))))
1027
+
1028
+ # Limit to reasonable number of points for plotting
1029
+ n_plot = min(len(epochs), 100)
1030
+ x_phase = np.array(epochs[:n_plot])
1031
+
1032
+ fig, ax = plt.subplots(figsize=(4, 4))
1033
+ ax.axvspan(0, stage1_end, alpha=0.15, color='#D4AF37')
1034
+ ax.axvspan(stage1_end, min(stage2_end, x_phase[-1] if len(x_phase) else stage2_end),
1035
+ alpha=0.15, color='#8B7355')
1036
+ if len(x_phase):
1037
+ ax.axvspan(stage2_end, x_phase[-1], alpha=0.15, color='#60656F')
1038
+ ax.axvline(x=stage1_end, color='black', linestyle='--', linewidth=1)
1039
+ ax.axvline(x=stage2_end, color='black', linestyle='--', linewidth=1)
1040
+ ax.plot(x_phase, abs_phase_diff[:n_plot], marker='x', markersize=5,
1041
+ color='#986d56', label=r"Avg. $|\sin(D_m^\star)|$", linewidth=1.5)
1042
+ ax.set_xlabel('Step', fontsize=16)
1043
+ ax.set_ylabel('Average Value', fontsize=16)
1044
+ ax.set_ylim([0, 0.65])
1045
+ ax.legend(fontsize=16, loc="upper right")
1046
+ ax.grid(True, alpha=0.5, linestyle='--')
1047
+ plt.tight_layout()
1048
+ _save_fig(fig, self._out('grokk_abs_phase_diff.png'))
1049
+
1050
+ # ---- IPR + param norms (dual axis) ----
1051
+ x_all = np.array(epochs)
1052
+ param_norms = []
1053
+ if curves and 'param_norms' in curves:
1054
+ save_every = epochs[1] - epochs[0] if len(epochs) > 1 else 200
1055
+ param_norms = curves['param_norms'][::save_every][:len(epochs)]
1056
+
1057
+ fig, ax1 = plt.subplots(figsize=(4, 4))
1058
+ ax1.axvspan(0, stage1_end, alpha=0.15, color='#D4AF37')
1059
+ ax1.axvspan(stage1_end, min(stage2_end, x_all[-1] if len(x_all) else stage2_end),
1060
+ alpha=0.15, color='#8B7355')
1061
+ if len(x_all):
1062
+ ax1.axvspan(stage2_end, x_all[-1], alpha=0.15, color='#60656F')
1063
+ ax1.axvline(x=stage1_end, color='black', linestyle='--', linewidth=1)
1064
+ ax1.axvline(x=stage2_end, color='black', linestyle='--', linewidth=1)
1065
+
1066
+ line1 = ax1.plot(x_all, sparse_level, marker='x', markersize=3,
1067
+ color='#986d56', label=r"Avg. IPR", linewidth=1.5)
1068
+ ax1.set_xlabel('Step', fontsize=16)
1069
+ ax1.tick_params(axis='y')
1070
+ ax1.set_ylim([0, 1.05])
1071
+
1072
+ if param_norms:
1073
+ ax2 = ax1.twinx()
1074
+ line2 = ax2.plot(x_all[:len(param_norms)], param_norms,
1075
+ marker='o', markersize=3, color='#2E5266',
1076
+ label=r"Param. Norm", linewidth=1.5)
1077
+ ax2.tick_params(axis='y')
1078
+ lines = line1 + line2
1079
+ labels = [l.get_label() for l in lines]
1080
+ ax1.legend(lines, labels, fontsize=16, loc="lower right")
1081
+ else:
1082
+ ax1.legend(fontsize=16, loc="lower right")
1083
+
1084
+ ax1.grid(True, alpha=0.5, linestyle='--')
1085
+ plt.tight_layout()
1086
+ _save_fig(fig, self._out('grokk_avg_ipr.png'))
1087
+
1088
+ # ---- Memorization accuracy (3-panel) ----
1089
+ if train_data is not None:
1090
+ # Find a checkpoint near stage1_end
1091
+ closest_epoch = min(epochs, key=lambda e: abs(e - stage1_end))
1092
+ model_sd = checkpoints[closest_epoch]
1093
+
1094
+ model = EmbedMLP(
1095
+ d_vocab=self.d_vocab, d_model=self.d_model,
1096
+ d_mlp=d_mlp, act_type=act_type, use_cache=False
1097
+ ).to(self.device)
1098
+ model.load_state_dict(model_sd)
1099
+ model.eval()
1100
+
1101
+ with torch.no_grad():
1102
+ logits = model(self.all_data).squeeze(1)
1103
+
1104
+ train_set = set([(int(i), int(j)) for i, j in train_data])
1105
+ true_test_points = []
1106
+
1107
+ train_mask = torch.zeros(p, p)
1108
+ for i in range(p):
1109
+ for j in range(p):
1110
+ if (i, j) in train_set:
1111
+ train_mask[i, j] = 1.0
1112
+ elif (j, i) in train_set:
1113
+ train_mask[i, j] = 0.65
1114
+ else:
1115
+ train_mask[i, j] = 0.0
1116
+ true_test_points.append((i, j))
1117
+
1118
+ predicted = torch.argmax(logits, dim=1).view(p, p)
1119
+ gt_grid = self.all_labels.view(p, p)
1120
+ accuracy_mask = (predicted == gt_grid).float()
1121
+
1122
+ probs = torch.softmax(logits, dim=1)
1123
+ gt_probs = torch.zeros(p * p)
1124
+ for idx in range(p * p):
1125
+ i_val = self.all_data[idx, 0].item()
1126
+ j_val = self.all_data[idx, 1].item()
1127
+ correct = (i_val + j_val) % p
1128
+ gt_probs[idx] = probs[idx, correct]
1129
+ gt_probs_grid = gt_probs.view(p, p)
1130
+
1131
+ fig = plt.figure(figsize=(20, 6))
1132
+ gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 1.1], wspace=0.15)
1133
+
1134
+ ax1 = fig.add_subplot(gs[0])
1135
+ ax2 = fig.add_subplot(gs[1])
1136
+ ax3 = fig.add_subplot(gs[2])
1137
+
1138
+ # Train mask
1139
+ im1 = ax1.imshow(train_mask.numpy(), cmap=CMAP_SEQUENTIAL,
1140
+ vmin=0, vmax=1, aspect='equal')
1141
+ ax1.set_title('Training Data under Symmetry', fontsize=21)
1142
+ ax1.set_ylabel('First Input', fontsize=18)
1143
+ ax1.set_xlabel('Second Input', fontsize=18)
1144
+ locs = np.arange(p)
1145
+ ax1.set_xticks(locs)
1146
+ ax1.set_yticks(locs)
1147
+ ax1.set_xticklabels(locs, fontsize=11)
1148
+ ax1.set_yticklabels(locs, fontsize=11)
1149
+ for ti, tj in true_test_points:
1150
+ rect = plt.Rectangle((tj - 0.5, ti - 0.5), 1, 1,
1151
+ linewidth=2.5, edgecolor='red', facecolor='none')
1152
+ ax1.add_patch(rect)
1153
+
1154
+ # Accuracy mask
1155
+ im2 = ax2.imshow(accuracy_mask.numpy(), cmap=CMAP_SEQUENTIAL,
1156
+ vmin=0, vmax=1, aspect='equal')
1157
+ ax2.set_title('Accuracy before Grokking', fontsize=21)
1158
+ ax2.set_xlabel('Second Input', fontsize=18)
1159
+ ax2.set_xticks(locs)
1160
+ ax2.set_yticks(locs)
1161
+ ax2.set_xticklabels(locs, fontsize=11)
1162
+ ax2.set_yticklabels(locs, fontsize=11)
1163
+ for ti, tj in true_test_points:
1164
+ rect = plt.Rectangle((tj - 0.5, ti - 0.5), 1, 1,
1165
+ linewidth=2.5, edgecolor='red', facecolor='none')
1166
+ ax2.add_patch(rect)
1167
+
1168
+ # Softmax probability
1169
+ prob_max = gt_probs_grid.max().item()
1170
+ im3 = ax3.imshow(gt_probs_grid.detach().numpy(), cmap=CMAP_SEQUENTIAL,
1171
+ vmin=0, vmax=prob_max, aspect='equal')
1172
+ ax3.set_title('Softmax Weight at Ground-Truth', fontsize=21)
1173
+ ax3.set_xlabel('Second Input', fontsize=18)
1174
+ ax3.set_xticks(locs)
1175
+ ax3.set_yticks(locs)
1176
+ ax3.set_xticklabels(locs, fontsize=11)
1177
+ ax3.set_yticklabels(locs, fontsize=11)
1178
+ for ti, tj in true_test_points:
1179
+ rect = plt.Rectangle((tj - 0.5, ti - 0.5), 1, 1,
1180
+ linewidth=2.5, edgecolor='red', facecolor='none')
1181
+ ax3.add_patch(rect)
1182
+ cbar3 = fig.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04)
1183
+ cbar3.ax.tick_params(labelsize=12)
1184
+ plt.tight_layout()
1185
+ _save_fig(fig, self._out('grokk_memorization_accuracy.png'))
1186
+
1187
+ # ---- Memorization common-to-rare (4-panel) ----
1188
+ if train_data is not None:
1189
+ train_set = set([(int(i), int(j)) for i, j in train_data])
1190
+ asymmetric_train_points = []
1191
+ train_mask_dist = torch.zeros(p, p)
1192
+ for i in range(p):
1193
+ for j in range(p):
1194
+ if (i, j) in train_set and (j, i) in train_set:
1195
+ train_mask_dist[i, j] = 1.0
1196
+ elif (i, j) in train_set and (j, i) not in train_set:
1197
+ train_mask_dist[i, j] = 0.5
1198
+ asymmetric_train_points.append((i, j))
1199
+ else:
1200
+ train_mask_dist[i, j] = 0.0
1201
+
1202
+ # Pick 3 epochs: 0, ~stage1/2, ~stage1
1203
+ selected_epochs = [0]
1204
+ mid_epoch = min(epochs, key=lambda e: abs(e - stage1_end // 2))
1205
+ end_epoch = min(epochs, key=lambda e: abs(e - stage1_end))
1206
+ if mid_epoch not in selected_epochs:
1207
+ selected_epochs.append(mid_epoch)
1208
+ if end_epoch not in selected_epochs:
1209
+ selected_epochs.append(end_epoch)
1210
+ # Ensure we have exactly 3 + distribution = 4 panels
1211
+ while len(selected_epochs) < 3:
1212
+ selected_epochs.append(epochs[min(len(epochs) - 1, 2)])
1213
+
1214
+ fig = plt.figure(figsize=(26, 6))
1215
+ gs = fig.add_gridspec(
1216
+ 1, 4, width_ratios=[1, 1, 1, 1.1], wspace=0.15
1217
+ )
1218
+
1219
+ # Panel 1: training data distribution
1220
+ ax_d = fig.add_subplot(gs[0])
1221
+ ax_d.imshow(train_mask_dist.numpy(), cmap=CMAP_SEQUENTIAL,
1222
+ vmin=0, vmax=1, aspect='equal')
1223
+ ax_d.set_title('Training Data Distribution', fontsize=21)
1224
+ ax_d.set_ylabel('First Input', fontsize=18)
1225
+ ax_d.set_xlabel('Second Input', fontsize=18)
1226
+ locs = np.arange(p)
1227
+ ax_d.set_xticks(locs)
1228
+ ax_d.set_yticks(locs)
1229
+ ax_d.set_xticklabels(locs, fontsize=11)
1230
+ ax_d.set_yticklabels(locs, fontsize=11)
1231
+ for ti, tj in asymmetric_train_points:
1232
+ rect = plt.Rectangle((tj - 0.5, ti - 0.5), 1, 1,
1233
+ linewidth=2.0, edgecolor='red', facecolor='none')
1234
+ ax_d.add_patch(rect)
1235
+
1236
+ # Panels 2-4: accuracy at selected epochs
1237
+ for panel_idx, sel_ep in enumerate(selected_epochs):
1238
+ ax_p = fig.add_subplot(gs[panel_idx + 1])
1239
+ model_p = EmbedMLP(
1240
+ d_vocab=self.d_vocab, d_model=self.d_model,
1241
+ d_mlp=d_mlp, act_type=act_type, use_cache=False
1242
+ ).to(self.device)
1243
+ model_p.load_state_dict(checkpoints[sel_ep])
1244
+ model_p.eval()
1245
+ with torch.no_grad():
1246
+ logits_p = model_p(self.all_data).squeeze(1)
1247
+ pred_p = torch.argmax(logits_p, dim=1).view(p, p)
1248
+ acc_p = (pred_p == gt_grid).float()
1249
+
1250
+ ax_p.imshow(acc_p.numpy(), cmap=CMAP_SEQUENTIAL,
1251
+ vmin=0, vmax=1, aspect='equal')
1252
+ ax_p.set_title(f'Accuracy at Step {sel_ep}', fontsize=21)
1253
+ ax_p.set_xlabel('Second Input', fontsize=18)
1254
+ ax_p.set_xticks(locs)
1255
+ ax_p.set_yticks(locs)
1256
+ ax_p.set_xticklabels(locs, fontsize=11)
1257
+ ax_p.set_yticklabels(locs, fontsize=11)
1258
+ for ti, tj in asymmetric_train_points:
1259
+ rect = plt.Rectangle((tj - 0.5, ti - 0.5), 1, 1,
1260
+ linewidth=2.0, edgecolor='red', facecolor='none')
1261
+ ax_p.add_patch(rect)
1262
+
1263
+ plt.tight_layout()
1264
+ _save_fig(fig, self._out('grokk_memorization_common_to_rare.png'))
1265
+
1266
+ # ---- Decoded weights dynamic (3 timepoints) ----
1267
+ # Pick 3 representative epochs: 0, stage1, stage2
1268
+ key_epochs = [0]
1269
+ ep_s1 = min(epochs, key=lambda e: abs(e - stage1_end))
1270
+ ep_s2 = min(epochs, key=lambda e: abs(e - stage2_end))
1271
+ if ep_s1 not in key_epochs:
1272
+ key_epochs.append(ep_s1)
1273
+ if ep_s2 not in key_epochs:
1274
+ key_epochs.append(ep_s2)
1275
+ while len(key_epochs) < 3:
1276
+ key_epochs.append(epochs[-1])
1277
+
1278
+ num_components = min(20, d_mlp)
1279
+ n = len(key_epochs)
1280
+ fig, axes = plt.subplots(
1281
+ 2, n, figsize=(18, 3.3 * n),
1282
+ gridspec_kw={"hspace": 0.05}, constrained_layout=True
1283
+ )
1284
+ if n == 1:
1285
+ axes = axes.reshape(2, 1)
1286
+
1287
+ x_locs = np.arange(len(self.fourier_basis_names))
1288
+ y_locs = np.arange(num_components)
1289
+
1290
+ for col, key in enumerate(key_epochs):
1291
+ W_in = checkpoints[key]['mlp.W_in']
1292
+ W_out = checkpoints[key]['mlp.W_out']
1293
+
1294
+ data_in = (W_in @ self.fourier_basis.T)[:num_components]
1295
+ data_in_np = data_in.detach().cpu().numpy()
1296
+ abs_max_in = np.abs(data_in_np).max()
1297
+ ax_in = axes[0, col]
1298
+ im_in = ax_in.imshow(
1299
+ data_in_np, cmap=CMAP_DIVERGING,
1300
+ vmin=-abs_max_in, vmax=abs_max_in, aspect='auto'
1301
+ )
1302
+ ax_in.set_title(rf'Step {key}, $\theta_m$ after DFT', fontsize=18)
1303
+ ax_in.set_xticks(x_locs)
1304
+ ax_in.set_xticklabels(self.fourier_basis_names, rotation=90, fontsize=11)
1305
+ ax_in.set_yticks(y_locs)
1306
+ ax_in.set_yticklabels(y_locs)
1307
+ if col == 0:
1308
+ ax_in.set_ylabel('Neuron #', fontsize=16)
1309
+ fig.colorbar(im_in, ax=ax_in)
1310
+
1311
+ data_out = (W_out.T @ self.fourier_basis.T)[:num_components]
1312
+ data_out_np = data_out.detach().cpu().numpy()
1313
+ abs_max_out = np.abs(data_out_np).max() * 0.85
1314
+ ax_out = axes[1, col]
1315
+ im_out = ax_out.imshow(
1316
+ data_out_np, cmap=CMAP_DIVERGING,
1317
+ vmin=-abs_max_out, vmax=abs_max_out, aspect='auto'
1318
+ )
1319
+ ax_out.set_title(rf'Step {key}, $\xi_m$ after DFT', fontsize=18)
1320
+ ax_out.set_xticks(x_locs)
1321
+ ax_out.set_xticklabels(self.fourier_basis_names, rotation=90, fontsize=11)
1322
+ ax_out.set_yticks(y_locs)
1323
+ ax_out.set_yticklabels(y_locs)
1324
+ if col == 0:
1325
+ ax_out.set_ylabel('Neuron #', fontsize=16)
1326
+ fig.colorbar(im_out, ax=ax_out)
1327
+
1328
+ _save_fig(fig, self._out('grokk_decoded_weights_dynamic.png'))
1329
+
1330
+ print(" Saved grokk_loss.json, grokk_loss.png, grokk_acc.json, grokk_acc.png, "
1331
+ "grokk_abs_phase_diff.png, grokk_avg_ipr.png, "
1332
+ "grokk_memorization_accuracy.png, "
1333
+ "grokk_memorization_common_to_rare.png, grokk_decoded_weights_dynamic.png")
1334
+
1335
+ # ------------------------------------------------------------------
1336
+ # Tab 6: Lottery Mechanism
1337
+ # ------------------------------------------------------------------
1338
+
1339
+ def generate_tab6(self):
1340
+ """Generate lottery mechanism plots."""
1341
+ print(f" [Tab 6] Lottery Mechanism for p={self.p}")
1342
+ run_dir = self._run_dir('quad_random')
1343
+ if run_dir is None:
1344
+ print(" SKIP: quad_random run directory not found")
1345
+ return
1346
+
1347
+ checkpoints = _load_checkpoints(run_dir, self.device)
1348
+ if not checkpoints:
1349
+ print(" SKIP: no quad_random checkpoints")
1350
+ return
1351
+
1352
+ final_data = _load_final(run_dir, self.device)
1353
+ if final_data is None:
1354
+ print(" SKIP: no final quad_random checkpoint")
1355
+ return
1356
+ model_load_final = final_data['model']
1357
+
1358
+ # Select best neuron
1359
+ neuron_id = select_lottery_neuron(
1360
+ model_load_final, self.fourier_basis, decode_scales_phis
1361
+ )
1362
+
1363
+ epochs = sorted(checkpoints.keys())
1364
+ p = self.p
1365
+
1366
+ # Collect per-checkpoint scales and phase diffs for the selected neuron
1367
+ scales_list = []
1368
+ diff_list = []
1369
+ for ep in epochs:
1370
+ scales, phis, psis = decode_scales_phis(
1371
+ checkpoints[ep], self.fourier_basis
1372
+ )
1373
+ scales_list.append(scales[neuron_id])
1374
+ diff_list.append(normalize_to_pi(
1375
+ psis[neuron_id] - 2 * phis[neuron_id]
1376
+ ))
1377
+
1378
+ # Stack: [num_checkpoints, K+1], skip DC
1379
+ scales_all = torch.stack(scales_list, dim=0)[:, 1:]
1380
+ diff_all = torch.stack(diff_list, dim=0)[:, 1:]
1381
+
1382
+ # Determine which frequency this neuron specializes in
1383
+ _, _, max_freq_ls = decode_weights(model_load_final, self.fourier_basis)
1384
+ max_freq = max_freq_ls[neuron_id] - 1 # 0-indexed into scales_all
1385
+
1386
+ scales_np = scales_all.cpu().numpy()
1387
+ diff_np = diff_all.cpu().numpy()
1388
+ num_models, num_freqs = scales_np.shape
1389
+ n_plot = min(num_models, 160)
1390
+ scales_np = scales_np[:n_plot]
1391
+ diff_np = diff_np[:n_plot]
1392
+ x_idx = np.arange(n_plot)
1393
+
1394
+ # Color gradient for non-highlighted frequencies
1395
+ base_rgb = np.array(mcolors.to_rgb(COLORS[0]))
1396
+ gray_rgb = np.array(mcolors.to_rgb('white'))
1397
+ highlight_color = COLORS[3]
1398
+
1399
+ nonmax = [f for f in range(num_freqs) if f != max_freq]
1400
+ final_scales = scales_np[-1]
1401
+ sorted_nonmax = sorted(nonmax, key=lambda f: final_scales[f])
1402
+ M = len(sorted_nonmax)
1403
+
1404
+ # Compute save_every for x-axis formatter
1405
+ save_every = epochs[1] - epochs[0] if len(epochs) > 1 else 200
1406
+
1407
+ # ---- Magnitude plot ----
1408
+ fig, ax = plt.subplots(figsize=(4, 4))
1409
+ for idx, f in enumerate(sorted_nonmax):
1410
+ blend = idx / (M - 1) if M > 1 else 0.0
1411
+ col_rgb = (1 - blend - 0.05) * gray_rgb + (blend + 0.05) * base_rgb
1412
+ ax.plot(x_idx, scales_np[:, f], color=col_rgb, linestyle=':',
1413
+ marker='x', linewidth=3.5, markersize=1.5,
1414
+ label=f"Freq. {f + 1}")
1415
+
1416
+ ax.plot(x_idx, scales_np[:, max_freq], color=highlight_color,
1417
+ linestyle=':', marker='x', linewidth=3.5, markersize=1.5,
1418
+ label=f"Freq. {max_freq + 1}")
1419
+
1420
+ ax.xaxis.set_major_formatter(
1421
+ FuncFormatter(lambda val, pos: f"{int(val * save_every)}")
1422
+ )
1423
+ ax.legend(loc='upper left', bbox_to_anchor=(1.02, 1),
1424
+ borderaxespad=0.2, frameon=False, fontsize=13)
1425
+ ax.set_xlabel("Step", fontsize=16)
1426
+ ax.set_ylabel("Magnitude", fontsize=16)
1427
+ ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
1428
+ _save_fig(fig, self._out('lottery_mech_magnitude.png'))
1429
+
1430
+ # ---- Phase misalignment plot ----
1431
+ fig, ax = plt.subplots(figsize=(4, 4))
1432
+ ax.axhline(y=0, color='black', linewidth=1, linestyle='dotted')
1433
+ for idx, f in enumerate(sorted_nonmax):
1434
+ blend = idx / (M - 1) if M > 1 else 0.0
1435
+ col_rgb = (1 - blend - 0.05) * gray_rgb + (blend + 0.05) * base_rgb
1436
+ ax.plot(x_idx, diff_np[:, f], linestyle=':', marker='x',
1437
+ linewidth=3.5, markersize=1.5, color=col_rgb,
1438
+ label=f"Freq. {f}")
1439
+
1440
+ ax.plot(x_idx, diff_np[:, max_freq], linestyle=':', marker='x',
1441
+ linewidth=3.5, markersize=1.5, color=highlight_color,
1442
+ label=f"Freq. {max_freq}")
1443
+
1444
+ ax.xaxis.set_major_formatter(
1445
+ FuncFormatter(lambda val, pos: f"{int(val * save_every)}")
1446
+ )
1447
+ ax.set_xlabel("Step", fontsize=16)
1448
+ ax.set_ylabel("Misalignment", fontsize=16)
1449
+ ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
1450
+ _save_fig(fig, self._out('lottery_mech_phase.png'))
1451
+
1452
+ # ---- Beta contour: simulate gradient flow ----
1453
+ self._generate_lottery_contour()
1454
+
1455
+ print(" Saved lottery_mech_magnitude.png, lottery_mech_phase.png, "
1456
+ "lottery_beta_contour.png")
1457
+
1458
+ def _generate_lottery_contour(self):
1459
+ """Simulate gradient flow for a grid of (init_magnitude, init_phase_diff)."""
1460
+ p = self.p
1461
+ device = self.device
1462
+ init_k = 1
1463
+ init_psi = 0.0
1464
+ num_steps = 100
1465
+ learning_rate = 0.01
1466
+
1467
+ fourier_basis, _ = get_fourier_basis(p, device)
1468
+ fourier_basis = fourier_basis.to(torch.get_default_dtype())
1469
+
1470
+ initial_scales = np.linspace(0.01, 0.02, num=30)
1471
+ phi0_vals = np.linspace(0, np.pi, num=30)
1472
+
1473
+ results = []
1474
+ for scale in initial_scales:
1475
+ for phi0 in phi0_vals:
1476
+ w_k = 2 * np.pi * init_k / p
1477
+ theta = scale * torch.tensor(
1478
+ [np.cos(w_k * j + phi0) for j in range(p)],
1479
+ device=device
1480
+ )
1481
+ xi = scale * torch.tensor(
1482
+ [np.cos(w_k * j + init_psi) for j in range(p)],
1483
+ device=device
1484
+ )
1485
+
1486
+ # Run gradient flow simulation
1487
+ for _ in range(num_steps):
1488
+ theta, xi = self._gradient_flow_step(
1489
+ theta, xi, init_k, p, learning_rate, fourier_basis
1490
+ )
1491
+
1492
+ # Compute final beta
1493
+ coeffs_xi = fourier_basis.to(xi.dtype) @ xi
1494
+ idx = [init_k * 2 - 1, init_k * 2]
1495
+ xi_n = coeffs_xi[idx]
1496
+ beta_f = torch.norm(xi_n).item() * np.sqrt(2 / p)
1497
+
1498
+ results.append({
1499
+ "init_scale": scale,
1500
+ "init_diff": 2 * phi0,
1501
+ "beta_f": beta_f,
1502
+ })
1503
+
1504
+ # Pivot into grid
1505
+ n_scales = len(initial_scales)
1506
+ n_phis = len(phi0_vals)
1507
+ Z = np.zeros((n_phis, n_scales))
1508
+ for i, r in enumerate(results):
1509
+ row = i % n_phis
1510
+ col = i // n_phis
1511
+ Z[row, col] = r['beta_f']
1512
+
1513
+ X, Y = np.meshgrid(initial_scales, 2 * phi0_vals)
1514
+
1515
+ fig = plt.figure(figsize=(4.5, 4))
1516
+ cf = plt.contourf(X, Y, Z, levels=12, cmap=CMAP_DIVERGING, extend='both')
1517
+ plt.axhline(y=np.pi, color='white', linewidth=1, linestyle=':')
1518
+ plt.xlabel("Initial Magnitude", fontsize=16)
1519
+ plt.ylabel("Initial Phase Difference", fontsize=16)
1520
+ plt.title("Contour of Final Magnitude", fontsize=16)
1521
+ plt.colorbar(cf)
1522
+ plt.tight_layout()
1523
+ _save_fig(fig, self._out('lottery_beta_contour.png'))
1524
+
1525
+ @staticmethod
1526
+ def _gradient_flow_step(theta, xi, init_k, p, lr, fourier_basis):
1527
+ """One step of analytical gradient flow."""
1528
+ fb = fourier_basis.to(theta.dtype)
1529
+ theta_coeff = fb @ theta
1530
+ xi_coeff = fb @ xi
1531
+
1532
+ neuron_coeff_theta = theta_coeff[[init_k * 2 - 1, init_k * 2]]
1533
+ alpha = np.sqrt(2 / p) * torch.sqrt(
1534
+ torch.sum(neuron_coeff_theta.pow(2))
1535
+ ).item()
1536
+ phi = np.arctan2(
1537
+ -neuron_coeff_theta[1].item(), neuron_coeff_theta[0].item()
1538
+ )
1539
+
1540
+ neuron_coeff_xi = xi_coeff[[init_k * 2 - 1, init_k * 2]]
1541
+ beta = np.sqrt(2 / p) * torch.sqrt(
1542
+ torch.sum(neuron_coeff_xi.pow(2))
1543
+ ).item()
1544
+ psi = np.arctan2(
1545
+ -neuron_coeff_xi[1].item(), neuron_coeff_xi[0].item()
1546
+ )
1547
+
1548
+ w_k = 2 * np.pi * init_k / p
1549
+ grad_theta = torch.tensor(
1550
+ [2 * p * alpha * beta * np.cos(w_k * j + psi - phi)
1551
+ for j in range(p)],
1552
+ device=theta.device
1553
+ )
1554
+ grad_xi = torch.tensor(
1555
+ [p * alpha ** 2 * np.cos(w_k * j + 2 * phi)
1556
+ for j in range(p)],
1557
+ device=theta.device
1558
+ )
1559
+
1560
+ theta = theta + lr * grad_theta
1561
+ xi = xi + lr * grad_xi
1562
+ return theta, xi
1563
+
1564
+ # ------------------------------------------------------------------
1565
+ # Tab 7: Gradient Dynamics
1566
+ # ------------------------------------------------------------------
1567
+
1568
+ def generate_tab7(self):
1569
+ """Generate gradient dynamics plots for quad_single_freq and relu_single_freq."""
1570
+ print(f" [Tab 7] Gradient Dynamics for p={self.p}")
1571
+
1572
+ for run_name, act_name, prefix in [
1573
+ ('quad_single_freq', 'Quad', 'quad'),
1574
+ ('relu_single_freq', 'ReLU', 'relu'),
1575
+ ]:
1576
+ run_dir = self._run_dir(run_name)
1577
+ if run_dir is None:
1578
+ print(f" SKIP: {run_name} run directory not found")
1579
+ continue
1580
+
1581
+ checkpoints = _load_checkpoints(run_dir, self.device)
1582
+ if not checkpoints:
1583
+ print(f" SKIP: no {run_name} checkpoints")
1584
+ continue
1585
+
1586
+ epochs = sorted(checkpoints.keys())
1587
+ d_mlp = self.d_mlp
1588
+
1589
+ # Build all neuron records across epochs
1590
+ all_neuron_records = []
1591
+ for ep in epochs:
1592
+ model_sd = checkpoints[ep]
1593
+ W_in_d, W_out_d, mfl = decode_weights(model_sd, self.fourier_basis)
1594
+ for neuron in range(W_in_d.shape[0]):
1595
+ s_in, phi_in = compute_neuron(neuron, mfl, W_in_d)
1596
+ s_out, phi_out = compute_neuron(neuron, mfl, W_out_d)
1597
+ all_neuron_records.append({
1598
+ 'epoch': ep,
1599
+ 'neuron': neuron,
1600
+ 'scale_in': s_in,
1601
+ 'phi_in': phi_in,
1602
+ 'scale_out': s_out,
1603
+ 'phi_out': phi_out,
1604
+ })
1605
+
1606
+ # Select a neuron that shows clear phase alignment
1607
+ # Pick neuron with largest final scale
1608
+ final_records = [r for r in all_neuron_records if r['epoch'] == epochs[-1]]
1609
+ if not final_records:
1610
+ continue
1611
+ best_neuron = max(final_records, key=lambda r: r['scale_in'])['neuron']
1612
+
1613
+ # Extract trajectory for this neuron
1614
+ neuron_records = [r for r in all_neuron_records if r['neuron'] == best_neuron]
1615
+ # Remove last few points if noisy (as notebooks do)
1616
+ trim = max(0, len(neuron_records) - 4) if prefix == 'relu' else max(0, len(neuron_records) - 14)
1617
+ neuron_records = neuron_records[:trim] if trim > 0 else neuron_records
1618
+
1619
+ phi_in_raw = [r['phi_in'] for r in neuron_records]
1620
+ phi_out_raw = [r['phi_out'] for r in neuron_records]
1621
+ scale_in_list = [r['scale_in'] for r in neuron_records]
1622
+ scale_out_list = [r['scale_out'] for r in neuron_records]
1623
+
1624
+ # Phase wrapping fix: normalize 2*phi to [-pi, pi], then adjust
1625
+ # psi to stay within pi of 2*phi (same fix as Tab 3 scatter).
1626
+ phi2_in_list = [normalize_to_pi(2 * v) for v in phi_in_raw]
1627
+ phi_out_list = []
1628
+ for two_phi, psi in zip(phi2_in_list, phi_out_raw):
1629
+ psi_n = normalize_to_pi(psi)
1630
+ if psi_n - two_phi > np.pi:
1631
+ psi_n -= 2 * np.pi
1632
+ elif psi_n - two_phi < -np.pi:
1633
+ psi_n += 2 * np.pi
1634
+ phi_out_list.append(psi_n)
1635
+
1636
+ # Unwrap time series to remove remaining jumps at +-pi boundary
1637
+ phi_in_list = list(np.unwrap(phi_in_raw))
1638
+ phi2_in_list = list(np.unwrap(phi2_in_list))
1639
+ phi_out_list = list(np.unwrap(phi_out_list))
1640
+
1641
+ x = np.arange(len(phi_in_list)) * (epochs[1] - epochs[0] if len(epochs) > 1 else 200)
1642
+
1643
+ # ---- Phase alignment + magnitude plot ----
1644
+ fig_width = 8 if prefix == 'quad' else 5
1645
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(fig_width, 6), sharex=True)
1646
+
1647
+ ax1.plot(x, phi_in_list, marker='o', markersize=4,
1648
+ color=COLORS[1], label=r"$\phi_m^\star$")
1649
+ ax1.plot(x, phi_out_list, marker='x', markersize=4,
1650
+ color=COLORS[3], label=r"$\psi_m^\star$")
1651
+ ax1.plot(x, phi2_in_list, marker='^', markersize=4,
1652
+ color=COLORS[0], label=r"$2\phi_m^\star$")
1653
+ ax1.set_title('Phase Alignment of Neuron $m$', fontsize=16)
1654
+ ax1.legend(fontsize=18, loc="upper right")
1655
+ ax1.grid(True)
1656
+
1657
+ ax2.plot(x, scale_in_list, marker='o', markersize=4,
1658
+ color=COLORS[0], label=r"$\alpha_m^\star$")
1659
+ ax2.plot(x, scale_out_list, marker='x', markersize=4,
1660
+ color=COLORS[3], label=r"$\beta_m^\star$")
1661
+ ax2.set_title('Magnitude Growth of Neuron $m$', fontsize=16)
1662
+ ax2.set_xlabel('Step', fontsize=16)
1663
+ ax2.legend(fontsize=18, loc="upper left")
1664
+ ax2.grid(True)
1665
+
1666
+ plt.tight_layout()
1667
+ _save_fig(fig, self._out(f'phase_align_{prefix}.png'))
1668
+
1669
+ # ---- Decoded weights at timepoints ----
1670
+ if prefix == 'quad':
1671
+ keys = [0]
1672
+ mid = min(epochs, key=lambda e: abs(e - 1000))
1673
+ end = epochs[-1]
1674
+ if mid not in keys:
1675
+ keys.append(mid)
1676
+ if end not in keys:
1677
+ keys.append(end)
1678
+ else:
1679
+ keys = [0, epochs[-1]]
1680
+
1681
+ num_components = min(20, d_mlp)
1682
+ n = len(keys)
1683
+ fig, axes = plt.subplots(
1684
+ 2, n, figsize=(12 if n <= 2 else 18, 4 * n if n <= 2 else 3.3 * n),
1685
+ gridspec_kw={"hspace": 0.05}, constrained_layout=True
1686
+ )
1687
+ if n == 1:
1688
+ axes = axes.reshape(2, 1)
1689
+
1690
+ x_locs = np.arange(len(self.fourier_basis_names))
1691
+ y_locs = np.arange(num_components)
1692
+
1693
+ for col, key in enumerate(keys):
1694
+ if key not in checkpoints:
1695
+ key = min(checkpoints.keys(), key=lambda e: abs(e - key))
1696
+ W_in = checkpoints[key]['mlp.W_in']
1697
+ W_out = checkpoints[key]['mlp.W_out']
1698
+
1699
+ data_in = (W_in @ self.fourier_basis.T)[:num_components]
1700
+ data_in_np = data_in.detach().cpu().numpy()
1701
+ abs_max_in = np.abs(data_in_np).max()
1702
+ ax_in = axes[0, col]
1703
+ im_in = ax_in.imshow(
1704
+ data_in_np, cmap=CMAP_DIVERGING,
1705
+ vmin=-abs_max_in, vmax=abs_max_in, aspect='auto'
1706
+ )
1707
+ ax_in.set_title(rf'Step {key}, $\theta_m$ after DFT', fontsize=18)
1708
+ ax_in.set_xticks(x_locs)
1709
+ ax_in.set_xticklabels(self.fourier_basis_names, rotation=90, fontsize=11)
1710
+ ax_in.set_yticks(y_locs)
1711
+ ax_in.set_yticklabels(y_locs)
1712
+ if col == 0:
1713
+ ax_in.set_ylabel('Neuron #', fontsize=16)
1714
+ fig.colorbar(im_in, ax=ax_in)
1715
+
1716
+ data_out = (W_out.T @ self.fourier_basis.T)[:num_components]
1717
+ data_out_np = data_out.detach().cpu().numpy()
1718
+ abs_max_out = np.abs(data_out_np).max()
1719
+ ax_out = axes[1, col]
1720
+ im_out = ax_out.imshow(
1721
+ data_out_np, cmap=CMAP_DIVERGING,
1722
+ vmin=-abs_max_out, vmax=abs_max_out, aspect='auto'
1723
+ )
1724
+ ax_out.set_title(rf'Step {key}, $\xi_m$ after DFT', fontsize=18)
1725
+ ax_out.set_xticks(x_locs)
1726
+ ax_out.set_xticklabels(self.fourier_basis_names, rotation=90, fontsize=11)
1727
+ ax_out.set_yticks(y_locs)
1728
+ ax_out.set_yticklabels(y_locs)
1729
+ if col == 0:
1730
+ ax_out.set_ylabel('Neuron #', fontsize=16)
1731
+ fig.colorbar(im_out, ax=ax_out)
1732
+
1733
+ _save_fig(fig, self._out(f'single_freq_{prefix}.png'))
1734
+
1735
+ print(f" Saved phase_align_{prefix}.png, single_freq_{prefix}.png")
1736
+
1737
+ # ------------------------------------------------------------------
1738
+ # Metadata JSON
1739
+ # ------------------------------------------------------------------
1740
+
1741
+ def _save_metadata(self):
1742
+ """Save a metadata JSON summarizing config and final metrics."""
1743
+ print(f" [Meta] Saving metadata for p={self.p}")
1744
+ meta = {
1745
+ 'prime': self.p,
1746
+ 'd_mlp': self.d_mlp,
1747
+ 'training_runs': {},
1748
+ 'final_metrics': {},
1749
+ }
1750
+ for run_name, params in TRAINING_RUNS.items():
1751
+ meta['training_runs'][run_name] = {
1752
+ 'act_type': params['act_type'],
1753
+ 'lr': params['lr'],
1754
+ 'weight_decay': params['weight_decay'],
1755
+ 'num_epochs': params['num_epochs'],
1756
+ 'frac_train': params['frac_train'],
1757
+ 'init_type': params['init_type'],
1758
+ 'init_scale': params['init_scale'],
1759
+ 'optimizer': params['optimizer'],
1760
+ }
1761
+ curves = _load_training_curves(self._run_type_dir(run_name))
1762
+ if curves:
1763
+ metrics = {}
1764
+ if 'train_accs' in curves and curves['train_accs']:
1765
+ metrics['train_acc'] = curves['train_accs'][-1]
1766
+ if 'test_accs' in curves and curves['test_accs']:
1767
+ metrics['test_acc'] = curves['test_accs'][-1]
1768
+ if 'train_losses' in curves and curves['train_losses']:
1769
+ metrics['train_loss'] = curves['train_losses'][-1]
1770
+ if 'test_losses' in curves and curves['test_losses']:
1771
+ metrics['test_loss'] = curves['test_losses'][-1]
1772
+ if metrics:
1773
+ meta['final_metrics'][run_name] = metrics
1774
+
1775
+ with open(self._out('metadata.json'), 'w') as f:
1776
+ json.dump(meta, f, indent=2)
1777
+ print(" Saved metadata.json")
1778
+
1779
+ # ------------------------------------------------------------------
1780
+ # Interactive JSON precomputation
1781
+ # ------------------------------------------------------------------
1782
+
1783
+ def _precompute_neuron_spectra(self):
1784
+ """Precompute per-neuron Fourier magnitude spectra for top-20 neurons."""
1785
+ print(f" [Interactive] Neuron spectra for p={self.p}")
1786
+ run_dir = self._run_dir('standard')
1787
+ if run_dir is None:
1788
+ print(" SKIP: standard run directory not found")
1789
+ return
1790
+
1791
+ final_data = _load_final(run_dir, self.device)
1792
+ if final_data is None:
1793
+ print(" SKIP: no final checkpoint")
1794
+ return
1795
+ model_load = final_data['model']
1796
+
1797
+ W_in_decode, W_out_decode, max_freq_ls = decode_weights(
1798
+ model_load, self.fourier_basis
1799
+ )
1800
+ d_mlp = W_in_decode.shape[0]
1801
+ num_neurons = min(20, d_mlp)
1802
+
1803
+ sorted_indices = select_top_neurons_by_frequency(
1804
+ max_freq_ls, W_in_decode, n=num_neurons
1805
+ )
1806
+
1807
+ fb_names = self.fourier_basis_names
1808
+ spectra = {}
1809
+ for rank, neuron_idx in enumerate(sorted_indices):
1810
+ # Fourier magnitudes for W_in
1811
+ magnitudes_in = W_in_decode[neuron_idx].abs().cpu().tolist()
1812
+ magnitudes_out = W_out_decode[neuron_idx].abs().cpu().tolist()
1813
+ spectra[f"neuron_{rank}"] = {
1814
+ 'global_index': int(neuron_idx),
1815
+ 'dominant_freq': int(max_freq_ls[neuron_idx]),
1816
+ 'fourier_magnitudes_in': magnitudes_in,
1817
+ 'fourier_magnitudes_out': magnitudes_out,
1818
+ }
1819
+
1820
+ payload = {
1821
+ 'fourier_basis_names': fb_names,
1822
+ 'neurons': spectra,
1823
+ }
1824
+ with open(self._out('neuron_spectra.json'), 'w') as f:
1825
+ json.dump(payload, f)
1826
+ print(" Saved neuron_spectra.json")
1827
+
1828
+ def _precompute_logit_explorer(self):
1829
+ """Precompute logits for representative (a,b) pairs."""
1830
+ print(f" [Interactive] Logit explorer for p={self.p}")
1831
+ run_dir = self._run_dir('standard')
1832
+ if run_dir is None:
1833
+ print(" SKIP: standard run directory not found")
1834
+ return
1835
+
1836
+ final_data = _load_final(run_dir, self.device)
1837
+ if final_data is None:
1838
+ print(" SKIP: no final checkpoint")
1839
+ return
1840
+ model_load = final_data['model']
1841
+
1842
+ p = self.p
1843
+ act_type = TRAINING_RUNS['standard']['act_type']
1844
+ model = EmbedMLP(
1845
+ d_vocab=self.d_vocab, d_model=self.d_model,
1846
+ d_mlp=self.d_mlp, act_type=act_type, use_cache=False
1847
+ )
1848
+ model.to(self.device)
1849
+ model.load_state_dict(model_load)
1850
+ model.eval()
1851
+
1852
+ # Select p representative pairs: (0,0), (1,2), (3,5), ... spread across inputs
1853
+ pairs = []
1854
+ step = max(1, (p * p) // p)
1855
+ for idx in range(0, p * p, step):
1856
+ a = idx // p
1857
+ b = idx % p
1858
+ pairs.append((a, b))
1859
+ if len(pairs) >= p:
1860
+ break
1861
+
1862
+ pair_tensor = torch.tensor(pairs, dtype=torch.long, device=self.device)
1863
+ with torch.no_grad():
1864
+ logits = model(pair_tensor).squeeze(1) # [n_pairs, p]
1865
+
1866
+ payload = {
1867
+ 'pairs': pairs,
1868
+ 'correct_answers': [(a + b) % p for a, b in pairs],
1869
+ 'logits': logits.cpu().tolist(),
1870
+ 'output_classes': list(range(p)),
1871
+ }
1872
+ with open(self._out('logits_interactive.json'), 'w') as f:
1873
+ json.dump(payload, f)
1874
+ print(" Saved logits_interactive.json")
1875
+
1876
+ def _precompute_grokk_slider(self):
1877
+ """Precompute accuracy grids at ~10 grokking checkpoints for epoch slider."""
1878
+ print(f" [Interactive] Grokking epoch slider for p={self.p}")
1879
+ if self.p < MIN_P_GROKKING:
1880
+ print(f" SKIP: p={self.p} < {MIN_P_GROKKING}")
1881
+ return
1882
+ run_dir = self._run_dir('grokking')
1883
+ if run_dir is None:
1884
+ print(" SKIP: grokking run directory not found")
1885
+ return
1886
+
1887
+ checkpoints = _load_checkpoints(run_dir, self.device)
1888
+ if not checkpoints:
1889
+ print(" SKIP: no grokking checkpoints")
1890
+ return
1891
+
1892
+ epochs = sorted(checkpoints.keys())
1893
+ p = self.p
1894
+ d_mlp = self.d_mlp
1895
+ act_type = TRAINING_RUNS['grokking']['act_type']
1896
+ gt_grid = self.all_labels.view(p, p)
1897
+
1898
+ # Subsample ~10 epochs evenly
1899
+ n_snapshots = min(10, len(epochs))
1900
+ indices = np.linspace(0, len(epochs) - 1, n_snapshots, dtype=int)
1901
+ selected_epochs = [epochs[i] for i in indices]
1902
+
1903
+ epoch_data = []
1904
+ for ep in selected_epochs:
1905
+ model = EmbedMLP(
1906
+ d_vocab=self.d_vocab, d_model=self.d_model,
1907
+ d_mlp=d_mlp, act_type=act_type, use_cache=False
1908
+ ).to(self.device)
1909
+ model.load_state_dict(checkpoints[ep])
1910
+ model.eval()
1911
+ with torch.no_grad():
1912
+ logits = model(self.all_data).squeeze(1)
1913
+ predicted = torch.argmax(logits, dim=1).view(p, p)
1914
+ accuracy_grid = (predicted == gt_grid).float().cpu().tolist()
1915
+ epoch_data.append({
1916
+ 'epoch': int(ep),
1917
+ 'accuracy_grid': accuracy_grid,
1918
+ })
1919
+
1920
+ payload = {
1921
+ 'prime': p,
1922
+ 'epochs': [d['epoch'] for d in epoch_data],
1923
+ 'grids': [d['accuracy_grid'] for d in epoch_data],
1924
+ }
1925
+ with open(self._out('grokk_epoch_data.json'), 'w') as f:
1926
+ json.dump(payload, f)
1927
+ print(" Saved grokk_epoch_data.json")
1928
+
1929
+ # ------------------------------------------------------------------
1930
+ # Training Log consolidation
1931
+ # ------------------------------------------------------------------
1932
+
1933
+ def _save_training_log(self):
1934
+ """Consolidate training logs from all runs into a precomputed JSON.
1935
+
1936
+ For each run, includes:
1937
+ - config: hyperparameters
1938
+ - log_text: human-readable formatted log
1939
+ - table: subsampled per-epoch metrics for display
1940
+ """
1941
+ print(f" [Log] Saving training log for p={self.p}")
1942
+ all_runs = {}
1943
+
1944
+ for run_name, params in TRAINING_RUNS.items():
1945
+ run_type_dir = self._run_type_dir(run_name)
1946
+ curves = _load_training_curves(run_type_dir)
1947
+ if curves is None:
1948
+ continue
1949
+
1950
+ # Also check for a pre-saved training_log.txt
1951
+ log_text_path = os.path.join(run_type_dir, "training_log.txt")
1952
+ if os.path.exists(log_text_path):
1953
+ with open(log_text_path) as f:
1954
+ log_text = f.read()
1955
+ else:
1956
+ # Reconstruct from curves data
1957
+ log_text = self._reconstruct_log_text(
1958
+ run_name, params, curves
1959
+ )
1960
+
1961
+ # Build a subsampled table (~100 rows max)
1962
+ n_epochs = len(curves.get('train_losses', []))
1963
+ step = max(1, n_epochs // 100)
1964
+ indices = list(range(0, n_epochs, step))
1965
+ if n_epochs > 0 and (n_epochs - 1) not in indices:
1966
+ indices.append(n_epochs - 1)
1967
+
1968
+ table = []
1969
+ for i in indices:
1970
+ row = {'epoch': i}
1971
+ for key in ('train_losses', 'test_losses', 'train_accs',
1972
+ 'test_accs', 'grad_norms', 'param_norms'):
1973
+ vals = curves.get(key, [])
1974
+ row[key.replace('_', '_')] = (
1975
+ round(vals[i], 6) if i < len(vals) else None
1976
+ )
1977
+ table.append(row)
1978
+
1979
+ all_runs[run_name] = {
1980
+ 'config': {
1981
+ 'prime': self.p,
1982
+ 'd_mlp': self.d_mlp,
1983
+ 'act_type': params['act_type'],
1984
+ 'init_type': params['init_type'],
1985
+ 'init_scale': params['init_scale'],
1986
+ 'optimizer': params['optimizer'],
1987
+ 'lr': params['lr'],
1988
+ 'weight_decay': params['weight_decay'],
1989
+ 'frac_train': params['frac_train'],
1990
+ 'num_epochs': params['num_epochs'],
1991
+ 'seed': params['seed'],
1992
+ },
1993
+ 'log_text': log_text,
1994
+ 'table': table,
1995
+ 'total_epochs': n_epochs,
1996
+ }
1997
+
1998
+ if all_runs:
1999
+ with open(self._out('training_log.json'), 'w') as f:
2000
+ json.dump(all_runs, f)
2001
+ print(f" Saved training_log.json ({len(all_runs)} runs)")
2002
+ else:
2003
+ print(" SKIP: no training curves found")
2004
+
2005
+ def _reconstruct_log_text(self, run_name, params, curves):
2006
+ """Reconstruct a human-readable training log from curves data."""
2007
+ lines = []
2008
+ lines.append(f"{'=' * 70}")
2009
+ lines.append(f"Training Log: p={self.p}, run={run_name}")
2010
+ lines.append(f"{'=' * 70}")
2011
+ lines.append("")
2012
+ lines.append("Configuration:")
2013
+ lines.append(f" prime (p) = {self.p}")
2014
+ lines.append(f" d_mlp = {self.d_mlp}")
2015
+ lines.append(f" activation = {params['act_type']}")
2016
+ lines.append(f" init_type = {params['init_type']}")
2017
+ lines.append(f" init_scale = {params['init_scale']}")
2018
+ lines.append(f" optimizer = {params['optimizer']}")
2019
+ lines.append(f" learning_rate = {params['lr']}")
2020
+ lines.append(f" weight_decay = {params['weight_decay']}")
2021
+ lines.append(f" frac_train = {params['frac_train']}")
2022
+ lines.append(f" num_epochs = {params['num_epochs']}")
2023
+ lines.append(f" seed = {params['seed']}")
2024
+ lines.append("")
2025
+ lines.append(f"{'─' * 70}")
2026
+ lines.append(
2027
+ f"{'Epoch':>8s} {'Train Loss':>12s} {'Test Loss':>12s} "
2028
+ f"{'Train Acc':>10s} {'Test Acc':>10s} "
2029
+ f"{'Grad Norm':>10s} {'Param Norm':>11s}"
2030
+ )
2031
+ lines.append(f"{'─' * 70}")
2032
+
2033
+ train_losses = curves.get('train_losses', [])
2034
+ test_losses = curves.get('test_losses', [])
2035
+ train_accs = curves.get('train_accs', [])
2036
+ test_accs = curves.get('test_accs', [])
2037
+ grad_norms = curves.get('grad_norms', [])
2038
+ param_norms = curves.get('param_norms', [])
2039
+ n_epochs = len(train_losses)
2040
+
2041
+ step = max(1, n_epochs // 100)
2042
+ indices = list(range(0, n_epochs, step))
2043
+ if n_epochs > 0 and (n_epochs - 1) not in indices:
2044
+ indices.append(n_epochs - 1)
2045
+
2046
+ for i in indices:
2047
+ tl = f"{train_losses[i]:.6f}" if i < len(train_losses) else "N/A"
2048
+ tel = f"{test_losses[i]:.6f}" if i < len(test_losses) else "N/A"
2049
+ ta = f"{train_accs[i]:.4f}" if i < len(train_accs) else "N/A"
2050
+ tea = f"{test_accs[i]:.4f}" if i < len(test_accs) else "N/A"
2051
+ gn = f"{grad_norms[i]:.4f}" if i < len(grad_norms) else "N/A"
2052
+ pn = f"{param_norms[i]:.4f}" if i < len(param_norms) else "N/A"
2053
+ lines.append(
2054
+ f"{i:>8d} {tl:>12s} {tel:>12s} "
2055
+ f"{ta:>10s} {tea:>10s} "
2056
+ f"{gn:>10s} {pn:>11s}"
2057
+ )
2058
+
2059
+ lines.append(f"{'─' * 70}")
2060
+ lines.append("")
2061
+ lines.append("Final Results:")
2062
+ if train_losses:
2063
+ lines.append(f" Train Loss = {train_losses[-1]:.6f}")
2064
+ if test_losses:
2065
+ lines.append(f" Test Loss = {test_losses[-1]:.6f}")
2066
+ if train_accs:
2067
+ lines.append(f" Train Acc = {train_accs[-1]:.4f}")
2068
+ if test_accs:
2069
+ lines.append(f" Test Acc = {test_accs[-1]:.4f}")
2070
+ if param_norms:
2071
+ lines.append(f" Param Norm = {param_norms[-1]:.4f}")
2072
+ lines.append(f"\nTotal epochs trained: {n_epochs}")
2073
+ return "\n".join(lines)
2074
+
2075
+ # ------------------------------------------------------------------
2076
+ # Generate all
2077
+ # ------------------------------------------------------------------
2078
+
2079
+ def generate_all(self):
2080
+ """Generate all tab plots with error handling."""
2081
+ print(f"\n{'=' * 60}")
2082
+ print(f"Generating plots for p={self.p}")
2083
+ print(f" Input: {self.input_dir}")
2084
+ print(f" Output: {self.output_dir}")
2085
+ print(f"{'=' * 60}")
2086
+
2087
+ # Save metadata and training logs first
2088
+ try:
2089
+ self._save_metadata()
2090
+ except Exception as e:
2091
+ print(f" [ERROR] metadata failed: {e}")
2092
+ traceback.print_exc()
2093
+
2094
+ try:
2095
+ self._save_training_log()
2096
+ except Exception as e:
2097
+ print(f" [ERROR] training log failed: {e}")
2098
+ traceback.print_exc()
2099
+
2100
+ generators = [
2101
+ ('Tab 1', self.generate_tab1),
2102
+ ('Tab 2', self.generate_tab2),
2103
+ ('Tab 3', self.generate_tab3),
2104
+ ('Tab 4', self.generate_tab4),
2105
+ ('Tab 5', self.generate_tab5),
2106
+ ('Tab 6', self.generate_tab6),
2107
+ ('Tab 7', self.generate_tab7),
2108
+ ]
2109
+
2110
+ for name, gen_fn in generators:
2111
+ try:
2112
+ gen_fn()
2113
+ except Exception as e:
2114
+ print(f" [ERROR] {name} failed: {e}")
2115
+ traceback.print_exc()
2116
+
2117
+ # Precompute interactive JSON data
2118
+ interactive = [
2119
+ ('Neuron Spectra', self._precompute_neuron_spectra),
2120
+ ('Logit Explorer', self._precompute_logit_explorer),
2121
+ ('Grokking Slider', self._precompute_grokk_slider),
2122
+ ]
2123
+ for name, fn in interactive:
2124
+ try:
2125
+ fn()
2126
+ except Exception as e:
2127
+ print(f" [ERROR] {name} failed: {e}")
2128
+ traceback.print_exc()
2129
+
2130
+ print(f"\nDone generating plots for p={self.p}")
2131
+
2132
+
2133
+ # ======================================================================
2134
+ # CLI
2135
+ # ======================================================================
2136
+
2137
+ def main():
2138
+ parser = argparse.ArgumentParser(
2139
+ description='Generate all model-dependent plots for the HF app.'
2140
+ )
2141
+ parser.add_argument('--all', action='store_true',
2142
+ help='Generate plots for all p found in input dir')
2143
+ parser.add_argument('--p', type=int,
2144
+ help='Generate plots for a specific p')
2145
+ parser.add_argument('--input', type=str, default='./trained_models',
2146
+ help='Base input directory containing p_PPP subdirs')
2147
+ parser.add_argument('--output', type=str,
2148
+ default='./precomputed_results',
2149
+ help='Base output directory for precomputed results')
2150
+ args = parser.parse_args()
2151
+
2152
+ if not args.all and args.p is None:
2153
+ parser.error("Specify --all or --p P")
2154
+
2155
+ if args.p:
2156
+ moduli = [args.p]
2157
+ else:
2158
+ # Discover moduli from input directory
2159
+ moduli = []
2160
+ if os.path.isdir(args.input):
2161
+ for d in sorted(os.listdir(args.input)):
2162
+ if d.startswith('p_'):
2163
+ try:
2164
+ p = int(d.split('_')[1])
2165
+ moduli.append(p)
2166
+ except (ValueError, IndexError):
2167
+ pass
2168
+ if not moduli:
2169
+ print(f"No p_PPP directories found in {args.input}")
2170
+ sys.exit(1)
2171
+
2172
+ total = len(moduli)
2173
+ for i, p in enumerate(moduli):
2174
+ print(f"\n[{i + 1}/{total}] Processing p={p}")
2175
+ # Handle both p_23 and p_023 naming conventions
2176
+ input_dir = os.path.join(args.input, f'p_{p:03d}')
2177
+ if not os.path.isdir(input_dir):
2178
+ input_dir = os.path.join(args.input, f'p_{p}')
2179
+ if not os.path.isdir(input_dir):
2180
+ print(f" Input directory not found: {input_dir}")
2181
+ continue
2182
+
2183
+ output_dir = os.path.join(args.output, f'p_{p:03d}')
2184
+
2185
+ gen = PlotGenerator(p=p, input_dir=input_dir, output_dir=output_dir)
2186
+ gen.generate_all()
2187
+
2188
+ print(f"\nAll done. Processed {total} prime(s).")
2189
+
2190
+
2191
+ if __name__ == '__main__':
2192
+ main()
precompute/grokking_stage_detector.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Automatic detection of grokking stage boundaries from training curves.
3
+
4
+ Three stages:
5
+ Stage 1 (Memorization): Train accuracy rises, test accuracy stays low
6
+ Stage 2 (Transition): Test accuracy starts climbing
7
+ Stage 3 (Generalization): Test accuracy near 1.0
8
+
9
+ Returns (stage1_end, stage2_end) as epoch indices, or (None, None) if
10
+ grokking is not detected.
11
+ """
12
+
13
+
14
+ def detect_grokking_stages(train_losses, test_losses, train_accs=None, test_accs=None):
15
+ """
16
+ Detect memorization -> transition -> generalization boundaries.
17
+
18
+ Heuristic:
19
+ - stage1_end: first epoch where train accuracy >= 0.95 (memorization complete)
20
+ - stage2_end: first epoch where test accuracy >= 0.95 (generalization reached)
21
+
22
+ Fallback (if accuracy curves not available):
23
+ - stage1_end: first epoch where train loss < 0.1
24
+ - stage2_end: first epoch where test loss < 0.1
25
+ """
26
+ if train_accs is not None and test_accs is not None:
27
+ stage1_end = None
28
+ for i, a in enumerate(train_accs):
29
+ if a >= 0.95:
30
+ stage1_end = i
31
+ break
32
+
33
+ stage2_end = None
34
+ for i, a in enumerate(test_accs):
35
+ if a >= 0.95:
36
+ stage2_end = i
37
+ break
38
+
39
+ # Sanity: stage1 should come before stage2
40
+ if stage1_end is not None and stage2_end is not None and stage1_end >= stage2_end:
41
+ stage1_end = stage2_end // 3
42
+ else:
43
+ stage1_end = None
44
+ for i, loss in enumerate(train_losses):
45
+ if loss < 0.1:
46
+ stage1_end = i
47
+ break
48
+
49
+ stage2_end = None
50
+ for i, loss in enumerate(test_losses):
51
+ if loss < 0.1:
52
+ stage2_end = i
53
+ break
54
+
55
+ return stage1_end, stage2_end
precompute/neuron_selector.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Automated neuron selection strategies for all primes.
3
+ Replaces hard-coded neuron indices from the analysis notebooks.
4
+ """
5
+ import torch
6
+ import numpy as np
7
+ from collections import Counter
8
+
9
+
10
+ def select_top_neurons_by_frequency(max_freq_ls, W_in_decode, n=20):
11
+ """
12
+ Select top N neurons covering all frequencies (round-robin).
13
+ Used for heatmap plots (Tab 2).
14
+
15
+ Picks the highest-magnitude neuron from each frequency in turn,
16
+ cycling through frequencies until n neurons are selected. This ensures
17
+ the heatmap shows diversification across all frequencies, matching
18
+ the blog's Figure 2.
19
+
20
+ Returns list of neuron indices into the original d_mlp-sized arrays.
21
+ """
22
+ d_mlp = W_in_decode.shape[0]
23
+ magnitudes = W_in_decode.abs().max(dim=1).values
24
+
25
+ # Group neurons by their dominant frequency, sorted by magnitude (descending)
26
+ from collections import defaultdict
27
+ freq_groups = defaultdict(list)
28
+ for i in range(d_mlp):
29
+ f = max_freq_ls[i]
30
+ if f > 0: # skip DC neurons
31
+ freq_groups[f].append((magnitudes[i].item(), i))
32
+
33
+ # Sort each group by magnitude descending
34
+ for f in freq_groups:
35
+ freq_groups[f].sort(key=lambda x: -x[0])
36
+
37
+ # Round-robin across frequencies (ascending order)
38
+ freqs_sorted = sorted(freq_groups.keys())
39
+ selected = []
40
+ pointers = {f: 0 for f in freqs_sorted}
41
+
42
+ while len(selected) < min(n, d_mlp) and freqs_sorted:
43
+ exhausted = []
44
+ for f in freqs_sorted:
45
+ if len(selected) >= n:
46
+ break
47
+ if pointers[f] < len(freq_groups[f]):
48
+ _, idx = freq_groups[f][pointers[f]]
49
+ selected.append(idx)
50
+ pointers[f] += 1
51
+ else:
52
+ exhausted.append(f)
53
+ for f in exhausted:
54
+ freqs_sorted.remove(f)
55
+
56
+ return selected
57
+
58
+
59
+ def select_lineplot_neurons(sorted_indices, n=3):
60
+ """
61
+ Select first N neurons from the frequency-sorted set for line plots (Tab 2).
62
+ Picks neurons evenly spaced through the sorted list to show diverse frequencies.
63
+ """
64
+ if len(sorted_indices) <= n:
65
+ return list(range(len(sorted_indices)))
66
+ step = len(sorted_indices) // n
67
+ return [i * step for i in range(n)]
68
+
69
+
70
+ def select_phase_frequency(max_freq_ls, p):
71
+ """
72
+ Choose the frequency for phase distribution analysis (Tab 3).
73
+ Picks the frequency with the most neurons assigned to it (mode),
74
+ excluding frequency 0 (DC component).
75
+ """
76
+ freq_counts = Counter(f for f in max_freq_ls if f > 0)
77
+ if not freq_counts:
78
+ return 1
79
+ return freq_counts.most_common(1)[0][0]
80
+
81
+
82
+ def select_lottery_neuron(model_load, fourier_basis, decode_scales_phis_fn):
83
+ """
84
+ Find the neuron with the clearest frequency specialization (Tab 6).
85
+ Picks the neuron with the highest ratio of dominant frequency scale
86
+ to second-highest frequency scale.
87
+ """
88
+ scales, _, _ = decode_scales_phis_fn(model_load, fourier_basis)
89
+ # scales: [n_neurons, K+1], skip DC at index 0
90
+ scales_no_dc = scales[:, 1:]
91
+
92
+ if scales_no_dc.shape[1] < 2:
93
+ return 0
94
+
95
+ sorted_scales, _ = torch.sort(scales_no_dc, dim=1, descending=True)
96
+ ratio = sorted_scales[:, 0] / (sorted_scales[:, 1] + 1e-10)
97
+
98
+ return ratio.argmax().item()
precompute/prime_config.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for all moduli and training runs.
3
+ Defines the d_mlp sizing formula and the 5 training run configurations.
4
+ p can be any odd number >= 3 (not restricted to primes).
5
+ """
6
+ import math
7
+
8
+
9
+ def get_moduli(low=3, high=199):
10
+ """Return all odd numbers in [low, high]."""
11
+ moduli = []
12
+ for n in range(low, high + 1):
13
+ if n >= 3 and n % 2 == 1:
14
+ moduli.append(n)
15
+ return moduli
16
+
17
+
18
+ # Keep old name as alias for backward compatibility
19
+ get_primes = get_moduli
20
+
21
+
22
+ def compute_d_mlp(p: int) -> int:
23
+ """
24
+ Compute d_mlp maintaining the ratio from p=23, d_mlp=512.
25
+ Formula: d_mlp = max(512, ceil(512/529 * p^2))
26
+ Can have more neurons but not less than the ratio dictates.
27
+ """
28
+ ratio = 512 / (23 ** 2) # 512/529 ≈ 0.9679
29
+ return max(512, math.ceil(ratio * p * p))
30
+
31
+
32
+ # Minimum p overall (p=2 has 0 non-DC frequencies, making Fourier analysis degenerate)
33
+ MIN_P = 3
34
+
35
+ # Minimum p for grokking experiments (need enough test data for meaningful split)
36
+ MIN_P_GROKKING = 19
37
+
38
+ # Backward-compatible aliases
39
+ MIN_PRIME = MIN_P
40
+ MIN_PRIME_GROKKING = MIN_P_GROKKING
41
+
42
+ # 5 training run configurations per p
43
+ TRAINING_RUNS = {
44
+ "standard": {
45
+ "embed_type": "one_hot",
46
+ "init_type": "random",
47
+ "optimizer": "AdamW",
48
+ "act_type": "ReLU",
49
+ "lr": 5e-5,
50
+ "weight_decay": 0,
51
+ "frac_train": 1.0,
52
+ "num_epochs": 5000,
53
+ "save_every": 200,
54
+ "init_scale": 0.1,
55
+ "save_models": True,
56
+ "batch_style": "full",
57
+ "seed": 42,
58
+ },
59
+ "grokking": {
60
+ "embed_type": "one_hot",
61
+ "init_type": "random",
62
+ "optimizer": "AdamW",
63
+ "act_type": "ReLU",
64
+ "lr": 1e-4,
65
+ "weight_decay": 2.0,
66
+ "frac_train": 0.75,
67
+ "num_epochs": 50000,
68
+ "save_every": 200,
69
+ "init_scale": 0.1,
70
+ "save_models": True,
71
+ "batch_style": "full",
72
+ "seed": 42,
73
+ },
74
+ "quad_random": {
75
+ "embed_type": "one_hot",
76
+ "init_type": "random",
77
+ "optimizer": "AdamW",
78
+ "act_type": "Quad",
79
+ "lr": 5e-5,
80
+ "weight_decay": 0,
81
+ "frac_train": 1.0,
82
+ "num_epochs": 5000,
83
+ "save_every": 200,
84
+ "init_scale": 0.1,
85
+ "save_models": True,
86
+ "batch_style": "full",
87
+ "seed": 42,
88
+ },
89
+ "quad_single_freq": {
90
+ "embed_type": "one_hot",
91
+ "init_type": "single-freq",
92
+ "optimizer": "SGD",
93
+ "act_type": "Quad",
94
+ "lr": 0.1,
95
+ "weight_decay": 0,
96
+ "frac_train": 1.0,
97
+ "num_epochs": 5000,
98
+ "save_every": 200,
99
+ "init_scale": 0.02,
100
+ "save_models": True,
101
+ "batch_style": "full",
102
+ "seed": 42,
103
+ },
104
+ "relu_single_freq": {
105
+ "embed_type": "one_hot",
106
+ "init_type": "single-freq",
107
+ "optimizer": "SGD",
108
+ "act_type": "ReLU",
109
+ "lr": 0.01,
110
+ "weight_decay": 0,
111
+ "frac_train": 1.0,
112
+ "num_epochs": 5000,
113
+ "save_every": 200,
114
+ "init_scale": 0.002,
115
+ "save_models": True,
116
+ "batch_style": "full",
117
+ "seed": 42,
118
+ },
119
+ }
120
+
121
+ # Analytical computation configs (no training needed)
122
+ ANALYTICAL_CONFIGS = {
123
+ "decouple_dynamics": {
124
+ "init_k": 2,
125
+ "num_steps_case1": 1400,
126
+ "learning_rate_case1": 1,
127
+ "init_phi_case1": 1.5,
128
+ "init_psi_case1": 0.18,
129
+ "num_steps_case2": 700,
130
+ "learning_rate_case2": 1,
131
+ "init_phi_case2": -0.72,
132
+ "init_psi_case2": -2.91,
133
+ "amplitude": 0.02,
134
+ },
135
+ }
precompute/run_all.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Pre-compute results for all odd p in [3, MAX_P].
3
+ # Deletes checkpoints after each p to save disk space.
4
+ #
5
+ # Usage:
6
+ # bash precompute/run_all.sh # p = 3, 5, 7, ..., 99
7
+ # MAX_P=199 bash precompute/run_all.sh # p = 3, 5, 7, ..., 199
8
+ #
9
+ # Run from the project root directory.
10
+
11
+ MAX_P=${MAX_P:-99}
12
+
13
+ set -e
14
+ echo "=== Pre-computing all odd p in [3, $MAX_P] ==="
15
+
16
+ COMPLETED=0
17
+ FAILED=0
18
+
19
+ for P in $(seq 3 2 "$MAX_P"); do
20
+ echo ""
21
+ echo "========================================"
22
+ echo " Processing p=$P"
23
+ echo "========================================"
24
+ if CLEANUP=1 bash precompute/run_pipeline.sh "$P"; then
25
+ COMPLETED=$((COMPLETED + 1))
26
+ else
27
+ echo "[FAIL] p=$P failed"
28
+ FAILED=$((FAILED + 1))
29
+ fi
30
+ done
31
+
32
+ echo ""
33
+ echo "=== All done. Completed: $COMPLETED, Failed: $FAILED ==="
34
+ echo "=== Precomputed results size: ==="
35
+ du -sh precomputed_results/
precompute/run_pipeline.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Full pre-computation pipeline for a single modulus p (any odd number >= 3).
3
+ #
4
+ # Usage:
5
+ # bash precompute/run_pipeline.sh 23
6
+ # bash precompute/run_pipeline.sh 9 --d_mlp 128
7
+ # P=23 bash precompute/run_pipeline.sh
8
+ #
9
+ # # Delete checkpoints after generating plots (saves disk space):
10
+ # CLEANUP=1 bash precompute/run_pipeline.sh 97
11
+ #
12
+ # Run from the project root directory.
13
+
14
+ P=${1:-${P:-23}}
15
+ shift 2>/dev/null || true # consume the p arg
16
+
17
+ # CLEANUP=1 to delete model checkpoints after plot generation
18
+ CLEANUP=${CLEANUP:-0}
19
+
20
+ # Collect remaining args (e.g. --d_mlp 128) to pass to train_all.py
21
+ EXTRA_ARGS="$@"
22
+
23
+ set -e
24
+ echo "=== Running full pipeline for p=$P $EXTRA_ARGS ==="
25
+
26
+ # Step 1: Train all 5 configurations
27
+ echo ""
28
+ echo "--- Step 1/4: Training ---"
29
+ python precompute/train_all.py --p "$P" --output ./trained_models --resume $EXTRA_ARGS
30
+
31
+ # Step 2: Generate model-based plots (d_mlp inferred from checkpoint)
32
+ echo ""
33
+ echo "--- Step 2/4: Generating model-based plots ---"
34
+ python precompute/generate_plots.py --p "$P" --input ./trained_models --output ./precomputed_results
35
+
36
+ # Step 3: Generate analytical simulation plots
37
+ echo ""
38
+ echo "--- Step 3/4: Generating analytical plots ---"
39
+ python precompute/generate_analytical.py --p "$P" --output ./precomputed_results
40
+
41
+ # Step 4: Cleanup checkpoints if requested
42
+ PADDED=$(printf '%03d' "$P")
43
+ MODEL_DIR="trained_models/p_${PADDED}"
44
+ if [ "$CLEANUP" = "1" ] && [ -d "$MODEL_DIR" ]; then
45
+ echo ""
46
+ echo "--- Cleanup: Deleting checkpoints for p=$P ---"
47
+ SIZE=$(du -sh "$MODEL_DIR" | cut -f1)
48
+ rm -rf "$MODEL_DIR"
49
+ echo " Freed $SIZE from $MODEL_DIR"
50
+ fi
51
+
52
+ # Step 5: Verify
53
+ echo ""
54
+ echo "--- Verification ---"
55
+ RESULT_DIR="precomputed_results/p_${PADDED}"
56
+ echo "=== Results in ${RESULT_DIR}/ ==="
57
+ ls -la "${RESULT_DIR}/"
58
+ FILE_COUNT=$(ls -1 "${RESULT_DIR}/" | wc -l | tr -d ' ')
59
+ echo "=== Total files: ${FILE_COUNT} ==="
60
+ echo "=== Pipeline complete for p=$P ==="
precompute/train_all.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Batch training script for all odd moduli p in [3, 199].
4
+
5
+ Usage:
6
+ # Train all runs for all odd p
7
+ python train_all.py --all
8
+
9
+ # Train specific p
10
+ python train_all.py --p 23
11
+
12
+ # Train specific run type for a p
13
+ python train_all.py --p 23 --run standard
14
+
15
+ # Resume (skips completed runs)
16
+ python train_all.py --all --resume
17
+
18
+ # Custom output directory
19
+ python train_all.py --all --output ./my_models
20
+ """
21
+ import argparse
22
+ import json
23
+ import os
24
+ import sys
25
+
26
+ # Add src to path
27
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
28
+
29
+ import torch
30
+ from prime_config import get_moduli, compute_d_mlp, TRAINING_RUNS, MIN_P, MIN_P_GROKKING
31
+ from utils import Config
32
+ from nnTrainer import Trainer
33
+
34
+
35
+ def build_config_dict(p, run_params, d_mlp_override=None):
36
+ """Build a nested config dict compatible with the Config class."""
37
+ d_mlp = d_mlp_override if d_mlp_override is not None else compute_d_mlp(p)
38
+ return {
39
+ 'data': {
40
+ 'p': p,
41
+ 'd_vocab': None,
42
+ 'fn_name': 'add',
43
+ 'frac_train': run_params['frac_train'],
44
+ 'batch_style': run_params['batch_style'],
45
+ },
46
+ 'model': {
47
+ 'd_model': None,
48
+ 'd_mlp': d_mlp,
49
+ 'act_type': run_params['act_type'],
50
+ 'embed_type': run_params['embed_type'],
51
+ 'init_type': run_params['init_type'],
52
+ 'init_scale': run_params['init_scale'],
53
+ },
54
+ 'training': {
55
+ 'num_epochs': run_params['num_epochs'],
56
+ 'lr': run_params['lr'],
57
+ 'weight_decay': run_params['weight_decay'],
58
+ 'optimizer': run_params['optimizer'],
59
+ 'stopping_thresh': -1,
60
+ 'save_models': run_params['save_models'],
61
+ 'save_every': run_params['save_every'],
62
+ 'seed': run_params['seed'],
63
+ },
64
+ }
65
+
66
+
67
+ def _save_training_log(output_dir, p, run_name, run_params, d_mlp, curves):
68
+ """Save a human-readable training_log.txt summarizing the run."""
69
+ log_path = os.path.join(output_dir, "training_log.txt")
70
+ n_epochs = len(curves.get('train_losses', []))
71
+ with open(log_path, 'w') as f:
72
+ f.write(f"{'=' * 70}\n")
73
+ f.write(f"Training Log: p={p}, run={run_name}\n")
74
+ f.write(f"{'=' * 70}\n\n")
75
+ f.write(f"Configuration:\n")
76
+ f.write(f" prime (p) = {p}\n")
77
+ f.write(f" d_mlp = {d_mlp}\n")
78
+ f.write(f" activation = {run_params['act_type']}\n")
79
+ f.write(f" init_type = {run_params['init_type']}\n")
80
+ f.write(f" init_scale = {run_params['init_scale']}\n")
81
+ f.write(f" optimizer = {run_params['optimizer']}\n")
82
+ f.write(f" learning_rate = {run_params['lr']}\n")
83
+ f.write(f" weight_decay = {run_params['weight_decay']}\n")
84
+ f.write(f" frac_train = {run_params['frac_train']}\n")
85
+ f.write(f" num_epochs = {run_params['num_epochs']}\n")
86
+ f.write(f" batch_style = {run_params['batch_style']}\n")
87
+ f.write(f" seed = {run_params['seed']}\n")
88
+ f.write(f"\n{'─' * 70}\n")
89
+ f.write(f"{'Epoch':>8s} {'Train Loss':>12s} {'Test Loss':>12s} "
90
+ f"{'Train Acc':>10s} {'Test Acc':>10s} "
91
+ f"{'Grad Norm':>10s} {'Param Norm':>11s}\n")
92
+ f.write(f"{'─' * 70}\n")
93
+
94
+ # Print every 100 epochs + the last epoch
95
+ train_losses = curves.get('train_losses', [])
96
+ test_losses = curves.get('test_losses', [])
97
+ train_accs = curves.get('train_accs', [])
98
+ test_accs = curves.get('test_accs', [])
99
+ grad_norms = curves.get('grad_norms', [])
100
+ param_norms = curves.get('param_norms', [])
101
+
102
+ step = max(1, n_epochs // 100) # ~100 lines
103
+ indices = list(range(0, n_epochs, step))
104
+ if n_epochs > 0 and (n_epochs - 1) not in indices:
105
+ indices.append(n_epochs - 1)
106
+
107
+ for i in indices:
108
+ tl = f"{train_losses[i]:.6f}" if i < len(train_losses) else "N/A"
109
+ tel = f"{test_losses[i]:.6f}" if i < len(test_losses) else "N/A"
110
+ ta = f"{train_accs[i]:.4f}" if i < len(train_accs) else "N/A"
111
+ tea = f"{test_accs[i]:.4f}" if i < len(test_accs) else "N/A"
112
+ gn = f"{grad_norms[i]:.4f}" if i < len(grad_norms) else "N/A"
113
+ pn = f"{param_norms[i]:.4f}" if i < len(param_norms) else "N/A"
114
+ f.write(f"{i:>8d} {tl:>12s} {tel:>12s} "
115
+ f"{ta:>10s} {tea:>10s} "
116
+ f"{gn:>10s} {pn:>11s}\n")
117
+
118
+ f.write(f"{'─' * 70}\n\n")
119
+ f.write(f"Final Results:\n")
120
+ if train_losses:
121
+ f.write(f" Train Loss = {train_losses[-1]:.6f}\n")
122
+ if test_losses:
123
+ f.write(f" Test Loss = {test_losses[-1]:.6f}\n")
124
+ if train_accs:
125
+ f.write(f" Train Acc = {train_accs[-1]:.4f}\n")
126
+ if test_accs:
127
+ f.write(f" Test Acc = {test_accs[-1]:.4f}\n")
128
+ if param_norms:
129
+ f.write(f" Param Norm = {param_norms[-1]:.4f}\n")
130
+ f.write(f"\nTotal epochs trained: {n_epochs}\n")
131
+
132
+
133
+ def run_training(p, run_name, output_base, d_mlp_override=None):
134
+ """Train a single run for a single prime."""
135
+ if p < MIN_P:
136
+ print(f"[SKIP] p={p}, run={run_name}: p < {MIN_P} (too few Fourier frequencies)")
137
+ return
138
+
139
+ # Single-freq init needs at least 1 non-DC frequency: (p-1)//2 >= 1 → p >= 3
140
+ if run_name in ('quad_single_freq', 'relu_single_freq') and (p - 1) // 2 < 1:
141
+ print(f"[SKIP] p={p}, run={run_name}: no non-DC frequencies for single-freq init")
142
+ return
143
+
144
+ if run_name == 'grokking' and p < MIN_P_GROKKING:
145
+ print(f"[SKIP] p={p}, run={run_name}: p < {MIN_P_GROKKING} (too few test points)")
146
+ return
147
+
148
+ run_params = TRAINING_RUNS[run_name]
149
+ config_dict = build_config_dict(p, run_params, d_mlp_override)
150
+ d_mlp = d_mlp_override if d_mlp_override is not None else compute_d_mlp(p)
151
+
152
+ output_dir = os.path.join(output_base, f"p_{p:03d}", run_name)
153
+ os.makedirs(output_dir, exist_ok=True)
154
+
155
+ # Check if already completed
156
+ marker = os.path.join(output_dir, "DONE")
157
+ if os.path.exists(marker):
158
+ print(f"[SKIP] p={p}, run={run_name} already completed")
159
+ return
160
+
161
+ print(f"[TRAIN] p={p}, d_mlp={d_mlp}, run={run_name}, "
162
+ f"epochs={run_params['num_epochs']}")
163
+
164
+ config = Config(config_dict)
165
+ trainer = Trainer(config=config, use_wandb=False)
166
+
167
+ # Override save directory so checkpoints go into our output structure
168
+ trainer.save_dir = output_dir
169
+ run_subdir = os.path.join(output_dir, trainer.run_name)
170
+ os.makedirs(run_subdir, exist_ok=True)
171
+
172
+ # Re-save train/test data to the overridden location so generate_plots.py
173
+ # can find them (Trainer.__init__ saves to the original save_dir)
174
+ torch.save(trainer.train, os.path.join(run_subdir, 'train_data.pth'))
175
+ torch.save(trainer.test, os.path.join(run_subdir, 'test_data.pth'))
176
+
177
+ trainer.initial_save_if_appropriate()
178
+
179
+ # Plateau early-stopping for grokking: after 10K epochs, if curves
180
+ # haven't changed in the last 1000 epochs, stop training.
181
+ plateau_check = (run_name == 'grokking')
182
+ plateau_min_epoch = 10000
183
+ plateau_window = 1000
184
+ plateau_loss_tol = 1e-3 # absolute change in loss
185
+ plateau_acc_tol = 0.005 # absolute change in accuracy
186
+
187
+ for epoch in range(config.num_epochs):
188
+ train_loss, test_loss = trainer.do_a_training_step(epoch)
189
+
190
+ if test_loss.item() < config.stopping_thresh:
191
+ print(f" Early stopping at epoch {epoch}: "
192
+ f"test loss {test_loss.item():.6f}")
193
+ break
194
+
195
+ # Plateau detection for grokking
196
+ if (plateau_check and epoch >= plateau_min_epoch
197
+ and epoch % plateau_window == 0):
198
+ tl = trainer.train_losses
199
+ tel = trainer.test_losses
200
+ ta = trainer.train_accs
201
+ tea = trainer.test_accs
202
+ w = plateau_window
203
+ if len(tl) >= w and len(tel) >= w:
204
+ tl_flat = (max(tl[-w:]) - min(tl[-w:])) < plateau_loss_tol
205
+ tel_flat = (max(tel[-w:]) - min(tel[-w:])) < plateau_loss_tol
206
+ ta_flat = (not ta) or (max(ta[-w:]) - min(ta[-w:])) < plateau_acc_tol
207
+ tea_flat = (not tea) or (max(tea[-w:]) - min(tea[-w:])) < plateau_acc_tol
208
+ if tl_flat and tel_flat and ta_flat and tea_flat:
209
+ print(f" Plateau early stopping at epoch {epoch}: "
210
+ f"no change in last {w} epochs")
211
+ break
212
+
213
+ if config.is_it_time_to_save(epoch=epoch):
214
+ trainer.save_epoch(epoch=epoch, save_to_wandb=False, local_save=True)
215
+
216
+ trainer.post_training_save(
217
+ save_optimizer_and_scheduler=False, log_to_wandb=False
218
+ )
219
+
220
+ # Save training curves as JSON for plot generation
221
+ curves = {
222
+ 'train_losses': trainer.train_losses,
223
+ 'test_losses': trainer.test_losses,
224
+ 'train_accs': trainer.train_accs,
225
+ 'test_accs': trainer.test_accs,
226
+ 'grad_norms': trainer.grad_norms,
227
+ 'param_norms': trainer.param_norms,
228
+ }
229
+ curves_path = os.path.join(output_dir, "training_curves.json")
230
+ with open(curves_path, 'w') as f:
231
+ json.dump(curves, f)
232
+
233
+ # Save a human-readable training log
234
+ _save_training_log(output_dir, p, run_name, run_params, d_mlp, curves)
235
+
236
+ # Write completion marker
237
+ with open(marker, 'w') as f:
238
+ f.write(f"p={p} run={run_name} completed\n")
239
+
240
+ print(f"[DONE] p={p}, run={run_name}, "
241
+ f"train_acc={trainer.train_accs[-1]:.4f}, "
242
+ f"test_acc={trainer.test_accs[-1]:.4f}")
243
+
244
+
245
+ def main():
246
+ parser = argparse.ArgumentParser(
247
+ description='Batch training for modular addition experiments'
248
+ )
249
+ parser.add_argument('--all', action='store_true',
250
+ help='Train all odd p in [3, 199]')
251
+ parser.add_argument('--p', type=int,
252
+ help='Train a specific odd modulus p')
253
+ parser.add_argument('--run', type=str, choices=list(TRAINING_RUNS.keys()),
254
+ help='Train a specific run type')
255
+ parser.add_argument('--output', type=str, default='./trained_models',
256
+ help='Output directory for trained models')
257
+ parser.add_argument('--d_mlp', type=int, default=None,
258
+ help='Override d_mlp (number of hidden neurons). '
259
+ 'Default: auto-computed from p.')
260
+ parser.add_argument('--resume', action='store_true',
261
+ help='Skip already-completed runs (checks DONE marker)')
262
+ args = parser.parse_args()
263
+
264
+ if not args.all and args.p is None:
265
+ parser.error("Specify --all or --p P")
266
+
267
+ moduli = [args.p] if args.p else get_moduli()
268
+ runs = [args.run] if args.run else list(TRAINING_RUNS.keys())
269
+
270
+ total = len(moduli) * len(runs)
271
+ completed = 0
272
+
273
+ for p in moduli:
274
+ for run_name in runs:
275
+ completed += 1
276
+ print(f"\n{'='*60}")
277
+ print(f"[{completed}/{total}] p={p}, run={run_name}")
278
+ print(f"{'='*60}")
279
+ try:
280
+ run_training(p, run_name, args.output, d_mlp_override=args.d_mlp)
281
+ except Exception as e:
282
+ print(f"[FAIL] p={p}, run={run_name}: {e}")
283
+ import traceback
284
+ traceback.print_exc()
285
+
286
+ print(f"\nAll done. {completed} runs processed.")
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
precomputed_results/p_015/p015_full_training_para_origin.png ADDED
precomputed_results/p_015/p015_lineplot_in.png ADDED

Git LFS Details

  • SHA256: ebd5c296c2f284aab3f108a300f6d7b642fe276790230cb787fad5210de07845
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB
precomputed_results/p_015/p015_lineplot_out.png ADDED

Git LFS Details

  • SHA256: 8476e480d7022ea6feb2263ccc0391f58918a3cf3638fe8108b437cf3c761a31
  • Pointer size: 131 Bytes
  • Size of remote file: 161 kB
precomputed_results/p_015/p015_logits_interactive.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"pairs": [[0, 0], [1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0], [7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0], [13, 0], [14, 0]], "correct_answers": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "logits": [[5.845241069793701, -0.2678382396697998, -0.27407053112983704, 0.5236045718193054, 0.1727389544248581, -0.5161969065666199, 0.32797300815582275, 0.03217152878642082, 0.9118141531944275, 0.1262667030096054, -0.48561030626296997, -0.08425326645374298, 0.7506414651870728, 0.26417604088783264, 0.06685394793748856], [1.2039813995361328, 6.88271427154541, 1.5816727876663208, -1.3607698678970337, 0.2838805019855499, -0.1305035948753357, -0.1361088901758194, 0.1741233468055725, -0.20070961117744446, -0.19441378116607666, -0.07290667295455933, -0.036422792822122574, -0.045937348157167435, -0.4817635118961334, -1.3201030492782593], [1.021496057510376, 0.460868775844574, 6.665925979614258, -0.31677037477493286, 0.5722809433937073, -0.4090900421142578, -1.4602270126342773, -1.284806728363037, 0.1630801558494568, -0.14396820962429047, 0.21148079633712769, 0.009713075123727322, -0.13466662168502808, -1.4001739025115967, -0.48616650700569153], [2.0029213428497314, -1.0664080381393433, 0.1110871359705925, 6.9900641441345215, -0.6577256321907043, 0.14176632463932037, 0.5043318271636963, -0.9338439702987671, -0.5924315452575684, 0.0853288546204567, -0.8137574791908264, 0.2345130294561386, -0.747736930847168, -0.7098338603973389, -0.09835103899240494], [1.1424025297164917, -0.11194394528865814, -0.35566678643226624, 0.07304318249225616, 6.601017951965332, -0.7027722001075745, -0.2666034996509552, -0.44442567229270935, 0.8163065314292908, -0.5371125340461731, 0.3994847536087036, -2.494434118270874, -0.7756778001785278, -0.05254651606082916, -0.19730210304260254], [-0.2894173860549927, -0.4359486699104309, -0.5648083686828613, 0.2765766680240631, -0.8157410025596619, 6.83429479598999, -0.35287508368492126, -0.13654330372810364, 0.4015010595321655, -0.5329406261444092, -0.8750499486923218, 0.27872195839881897, 0.017333246767520905, 0.40312278270721436, -0.36508798599243164], [0.8411731719970703, -0.7100308537483215, -0.03279941901564598, -1.725080132484436, 0.041413623839616776, -0.020579706877470016, 7.065398216247559, -0.6480932235717773, 0.3441147208213806, -0.6311541795730591, -0.5848420858383179, 0.23020878434181213, 1.6761562824249268, -0.4346846342086792, 0.1869385540485382], [1.0707398653030396, 0.15244746208190918, -0.8544912338256836, -0.15931977331638336, -0.30050942301750183, -0.14030054211616516, -0.8752976059913635, 6.867440223693848, -0.9738104343414307, -0.016967639327049255, 0.06942860782146454, -0.36363551020622253, -0.5302596092224121, -0.04578635096549988, 1.362797498703003], [1.3280404806137085, 0.42291590571403503, 0.2617959976196289, -0.14120015501976013, 0.16784554719924927, -0.3041268587112427, -0.9258386492729187, -1.557655930519104, 6.762355327606201, -1.1574827432632446, 0.6234251260757446, 0.20528843998908997, -0.12094193696975708, -0.20512300729751587, 0.5443016290664673], [1.1472804546356201, -0.6102383732795715, -0.025639446452260017, 1.0203014612197876, -0.7307616472244263, -0.12430916726589203, -1.7754124402999878, -0.18994221091270447, -0.2811968922615051, 7.137627601623535, -0.20736677944660187, -0.2558075189590454, -0.8490561842918396, -0.11468800902366638, -0.2236211597919464], [-0.4398433268070221, -0.21957175433635712, 0.3112000524997711, -0.31722167134284973, -0.045922309160232544, -0.8279529213905334, 0.09667133539915085, -0.1345088630914688, -0.4463891386985779, -0.33588770031929016, 6.491363525390625, -0.12330206483602524, 0.10725130885839462, 0.42743009328842163, -0.8439421057701111], [0.8614866733551025, -0.09104573726654053, -0.26044338941574097, -1.4104074239730835, -1.4556188583374023, -0.1587509959936142, -0.8687411546707153, 0.6269921064376831, -0.11933901160955429, -0.3620198369026184, 0.2740001678466797, 6.834388256072998, 0.9378089308738708, -0.8011277914047241, 0.10253019630908966], [0.8134461045265198, -0.5108746886253357, -0.11389578878879547, -0.7348212003707886, -0.3866046965122223, -0.30147287249565125, -1.0970127582550049, -0.5308148264884949, -0.25204434990882874, 0.6024379134178162, -0.5250549912452698, 0.2944880723953247, 6.6044793128967285, -0.06822022795677185, 0.009033107198774815], [1.3036129474639893, 0.3358013927936554, -1.516667366027832, -0.7404187321662903, 0.37273597717285156, 0.4566952586174011, -0.5649716854095459, -0.16146983206272125, -0.5403905510902405, -0.6627720594406128, 0.10609856992959976, 1.1525702476501465, 0.3066888749599457, 6.983869552612305, 0.6811028122901917], [1.5695332288742065, -1.1690753698349, 0.29273056983947754, -0.6167480945587158, -0.481355756521225, 0.3711283802986145, -0.3477175235748291, 0.06404877454042435, -0.34069737792015076, -0.5002283453941345, -0.22661544382572174, 0.29657235741615295, -1.5196651220321655, 1.143660306930542, 6.862580299377441]], "output_classes": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]}
precomputed_results/p_015/p015_lottery_beta_contour.png ADDED
precomputed_results/p_015/p015_lottery_mech_magnitude.png ADDED
precomputed_results/p_015/p015_lottery_mech_phase.png ADDED
precomputed_results/p_015/p015_magnitude_distribution.png ADDED
precomputed_results/p_015/p015_metadata.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prime": 15,
3
+ "d_mlp": 512,
4
+ "training_runs": {
5
+ "standard": {
6
+ "act_type": "ReLU",
7
+ "lr": 5e-05,
8
+ "weight_decay": 0,
9
+ "num_epochs": 5000,
10
+ "frac_train": 1.0,
11
+ "init_type": "random",
12
+ "init_scale": 0.1,
13
+ "optimizer": "AdamW"
14
+ },
15
+ "grokking": {
16
+ "act_type": "ReLU",
17
+ "lr": 0.0001,
18
+ "weight_decay": 2.0,
19
+ "num_epochs": 50000,
20
+ "frac_train": 0.75,
21
+ "init_type": "random",
22
+ "init_scale": 0.1,
23
+ "optimizer": "AdamW"
24
+ },
25
+ "quad_random": {
26
+ "act_type": "Quad",
27
+ "lr": 5e-05,
28
+ "weight_decay": 0,
29
+ "num_epochs": 5000,
30
+ "frac_train": 1.0,
31
+ "init_type": "random",
32
+ "init_scale": 0.1,
33
+ "optimizer": "AdamW"
34
+ },
35
+ "quad_single_freq": {
36
+ "act_type": "Quad",
37
+ "lr": 0.1,
38
+ "weight_decay": 0,
39
+ "num_epochs": 5000,
40
+ "frac_train": 1.0,
41
+ "init_type": "single-freq",
42
+ "init_scale": 0.02,
43
+ "optimizer": "SGD"
44
+ },
45
+ "relu_single_freq": {
46
+ "act_type": "ReLU",
47
+ "lr": 0.01,
48
+ "weight_decay": 0,
49
+ "num_epochs": 5000,
50
+ "frac_train": 1.0,
51
+ "init_type": "single-freq",
52
+ "init_scale": 0.002,
53
+ "optimizer": "SGD"
54
+ }
55
+ },
56
+ "final_metrics": {
57
+ "standard": {
58
+ "train_acc": 1.0,
59
+ "test_acc": 1.0,
60
+ "train_loss": 0.020928841084241867,
61
+ "test_loss": 0.020928841084241867
62
+ },
63
+ "quad_random": {
64
+ "train_acc": 1.0,
65
+ "test_acc": 1.0,
66
+ "train_loss": 0.0036203155759721994,
67
+ "test_loss": 0.0036203155759721994
68
+ },
69
+ "quad_single_freq": {
70
+ "train_acc": 1.0,
71
+ "test_acc": 1.0,
72
+ "train_loss": 0.04876862093806267,
73
+ "test_loss": 0.04876862093806267
74
+ },
75
+ "relu_single_freq": {
76
+ "train_acc": 1.0,
77
+ "test_acc": 1.0,
78
+ "train_loss": 2.7064406871795654,
79
+ "test_loss": 2.7064406871795654
80
+ }
81
+ }
82
+ }
precomputed_results/p_015/p015_neuron_spectra.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"fourier_basis_names": ["Const", "cos 1", "sin 1", "cos 2", "sin 2", "cos 3", "sin 3", "cos 4", "sin 4", "cos 5", "sin 5", "cos 6", "sin 6", "cos 7", "sin 7"], "neurons": {"neuron_0": {"global_index": 298, "dominant_freq": 1, "fourier_magnitudes_in": [0.009785229340195656, 0.835197389125824, 0.06538262963294983, 0.023014001548290253, 0.11825679987668991, 0.13374952971935272, 0.06052215397357941, 0.05648910254240036, 0.19347421824932098, 0.022295288741588593, 0.025461282581090927, 0.10164802521467209, 0.13591913878917694, 0.05420650914311409, 0.015495105646550655], "fourier_magnitudes_out": [0.05270613357424736, 0.8947794437408447, 0.023624544963240623, 0.0090867318212986, 0.013868873007595539, 0.23038825392723083, 0.01721891760826111, 0.03962008282542229, 0.026063552126288414, 0.13086184859275818, 0.030048014596104622, 0.0664755254983902, 0.05337228253483772, 0.13469304144382477, 0.026413951069116592]}, "neuron_1": {"global_index": 322, "dominant_freq": 2, "fourier_magnitudes_in": [0.0048967949114739895, 0.05831458047032356, 0.047817476093769073, 0.8282294273376465, 0.08354193717241287, 0.07255098968744278, 0.11165464669466019, 0.025725897401571274, 0.1043122410774231, 0.06207004189491272, 0.03182118758559227, 0.16142725944519043, 0.05974670872092247, 0.07021868228912354, 0.14918409287929535], "fourier_magnitudes_out": [0.0682767927646637, 0.04459874704480171, 0.048241592943668365, 0.8725450038909912, 0.12229588627815247, 0.00944548286497593, 0.04323046654462814, 0.04066399484872818, 0.016419490799307823, 0.08050032705068588, 0.10128692537546158, 0.2789347767829895, 0.11171453446149826, 0.0022370540536940098, 0.05760588496923447]}, "neuron_2": {"global_index": 506, "dominant_freq": 3, "fourier_magnitudes_in": [0.16119761765003204, 0.007301056291908026, 0.015491282567381859, 0.008759702555835247, 0.012521358206868172, 0.8759171366691589, 0.027130309492349625, 0.011930187232792377, 0.008996764197945595, 0.017064398154616356, 0.0032833898440003395, 0.2937963902950287, 0.03241246938705444, 0.01809331774711609, 0.01297019049525261], "fourier_magnitudes_out": [0.1367007941007614, 0.08054599165916443, 0.009201026521623135, 0.04534965381026268, 0.08026888966560364, 0.8633893728256226, 0.050462786108255386, 0.0044480981305241585, 0.03184027969837189, 0.05210161209106445, 0.04297168180346489, 0.32101351022720337, 0.06568200141191483, 0.035220254212617874, 0.014398200437426567]}, "neuron_3": {"global_index": 163, "dominant_freq": 4, "fourier_magnitudes_in": [0.031200144439935684, 0.04270334914326668, 0.08881780505180359, 0.030023256316781044, 0.08304275572299957, 0.04510613903403282, 0.1826746016740799, 0.03483400121331215, 0.8419885039329529, 0.013886661268770695, 0.08348363637924194, 0.050276018679142, 0.10324658453464508, 0.007152870763093233, 0.037329111248254776], "fourier_magnitudes_out": [0.0033031131606549025, 0.04313803091645241, 0.04859255626797676, 0.03949907422065735, 0.0202469564974308, 0.20968425273895264, 0.08351283520460129, 0.8787142038345337, 0.06666962057352066, 0.12211224436759949, 0.0748063325881958, 0.03903647139668465, 0.004882670473307371, 0.03963864594697952, 0.00997573509812355]}, "neuron_4": {"global_index": 56, "dominant_freq": 5, "fourier_magnitudes_in": [0.3008343279361725, 0.00417996384203434, 0.0018625137163326144, 0.0030638582538813353, 0.0034040315076708794, 0.0014134111115708947, 0.004357232246547937, 0.0004787891812156886, 0.004558315966278315, 0.8440173864364624, 0.003970965277403593, 0.0037021928001195192, 0.00269315205514431, 0.004479836206883192, 0.0009506550850346684], "fourier_magnitudes_out": [0.34606269001960754, 0.027366885915398598, 0.036413464695215225, 0.017639687284827232, 0.027893830090761185, 0.009557241573929787, 0.03348158299922943, 0.02666180580854416, 0.006313585676252842, 0.9367755651473999, 0.021462714299559593, 0.004096114542335272, 0.023737004026770592, 0.012917638756334782, 0.005492078140377998]}, "neuron_5": {"global_index": 382, "dominant_freq": 6, "fourier_magnitudes_in": [0.17134937644004822, 0.030019262805581093, 0.003407673677429557, 0.009750651195645332, 0.017143065109848976, 0.29516395926475525, 0.04455732926726341, 0.0013349098153412342, 0.022812439128756523, 0.007311187218874693, 0.0026531207840889692, 0.8911311626434326, 0.017707478255033493, 0.012044227682054043, 0.007314047310501337], "fourier_magnitudes_out": [0.15681979060173035, 0.02663942240178585, 0.015795163810253143, 0.03183992952108383, 0.019026663154363632, 0.34929507970809937, 0.021825529634952545, 0.0012088950024917722, 0.013926293700933456, 0.004897939506918192, 0.010765696875751019, 0.9581749439239502, 0.012587963603436947, 0.02056090161204338, 0.01620820164680481]}, "neuron_6": {"global_index": 44, "dominant_freq": 7, "fourier_magnitudes_in": [0.016063887625932693, 0.005506421905010939, 0.06481847167015076, 0.050152119249105453, 0.08404979109764099, 0.04145295172929764, 0.0641774982213974, 0.06344524770975113, 0.07214409857988358, 0.09170015156269073, 0.048928599804639816, 0.22082282602787018, 0.04532390087842941, 0.8467198014259338, 0.08364292979240417], "fourier_magnitudes_out": [0.015495719388127327, 0.0706552267074585, 0.03188098222017288, 0.022803593426942825, 0.05851001664996147, 0.006458980031311512, 0.07394170761108398, 0.06014903634786606, 0.08831571042537689, 0.10563234984874725, 0.11405274271965027, 0.26272422075271606, 0.13614341616630554, 0.9203148484230042, 0.14940021932125092]}, "neuron_7": {"global_index": 157, "dominant_freq": 1, "fourier_magnitudes_in": [0.023283498361706734, 0.14338412880897522, 0.8202316761016846, 0.05399405211210251, 0.058775369077920914, 0.12734520435333252, 0.1968384087085724, 0.070158950984478, 0.07580292969942093, 0.14696937799453735, 0.0658537819981575, 0.08757402002811432, 0.050403349101543427, 0.07286650687456131, 0.030624864622950554], "fourier_magnitudes_out": [0.10438423603773117, 0.8233321905136108, 0.2895442247390747, 0.11436242610216141, 0.043502938002347946, 0.08488017320632935, 0.21032141149044037, 0.03133478760719299, 0.147483691573143, 0.09128359705209732, 0.11333037912845612, 0.042169392108917236, 0.14409606158733368, 0.06334586441516876, 0.010641958564519882]}, "neuron_8": {"global_index": 436, "dominant_freq": 2, "fourier_magnitudes_in": [0.02816132642328739, 0.0366082526743412, 0.09432870894670486, 0.82112056016922, 0.2999240458011627, 0.07823537290096283, 0.07104597240686417, 0.010237633250653744, 0.05316127836704254, 0.01780366338789463, 0.10813448578119278, 0.10922082513570786, 0.2154001146554947, 0.06505948305130005, 0.005772633943706751], "fourier_magnitudes_out": [0.07935837656259537, 0.05681774765253067, 0.11203812062740326, 0.7207208275794983, 0.5456924438476562, 0.08229123800992966, 0.0757966861128807, 0.025318337604403496, 0.12979470193386078, 0.16904646158218384, 0.01633336953818798, 0.10492201149463654, 0.25266337394714355, 0.07896621525287628, 0.056506332010030746]}, "neuron_9": {"global_index": 339, "dominant_freq": 3, "fourier_magnitudes_in": [0.00793448369950056, 0.23576614260673523, 0.05173733830451965, 0.0808015987277031, 0.08904809504747391, 0.8467172980308533, 0.07009700685739517, 0.2346189022064209, 0.06041780859231949, 0.22861061990261078, 0.005940192844718695, 0.08250848948955536, 0.05041716247797012, 0.09760993719100952, 0.08870971947908401], "fourier_magnitudes_out": [0.07995986938476562, 0.058010708540678024, 0.0238045621663332, 0.044154077768325806, 0.06384548544883728, 0.8888619542121887, 0.1083059012889862, 0.05018500238656998, 0.06562003493309021, 0.032845836132764816, 0.06583794951438904, 0.19849520921707153, 0.08598171174526215, 0.048292793333530426, 0.056917209178209305]}, "neuron_10": {"global_index": 405, "dominant_freq": 4, "fourier_magnitudes_in": [0.04605866223573685, 0.10036957263946533, 0.1019270122051239, 0.08908756822347641, 0.05084078758955002, 0.09276082366704941, 0.1580628901720047, 0.11540821939706802, 0.8267641663551331, 0.07836127281188965, 0.007629718631505966, 0.10522190481424332, 0.08720945566892624, 0.07787368446588516, 0.07161819189786911], "fourier_magnitudes_out": [0.0646987333893776, 0.06891193985939026, 0.1397920846939087, 0.05930565670132637, 0.08342889696359634, 0.09649658203125, 0.19548238813877106, 0.8096501231193542, 0.276018351316452, 0.04780340939760208, 0.11156941205263138, 0.02261027880012989, 0.12174452841281891, 0.13799422979354858, 0.09665590524673462]}, "neuron_11": {"global_index": 363, "dominant_freq": 5, "fourier_magnitudes_in": [0.26700034737586975, 0.027114098891615868, 0.030112946406006813, 0.004235479515045881, 0.04029411822557449, 0.03277970105409622, 0.02381693758070469, 0.03963545709848404, 0.008424329571425915, 0.8159627914428711, 0.03508993238210678, 0.012520099990069866, 0.03853161633014679, 0.037013012915849686, 0.016481192782521248], "fourier_magnitudes_out": [0.3175621032714844, 0.022698314860463142, 0.029232148081064224, 0.009378736838698387, 0.04299170523881912, 0.021195683628320694, 0.006493568420410156, 0.02105909213423729, 0.005779783241450787, 0.9192318916320801, 0.025057973340153694, 0.015149646438658237, 0.004153982736170292, 0.01740305870771408, 0.028322339057922363]}, "neuron_12": {"global_index": 195, "dominant_freq": 6, "fourier_magnitudes_in": [0.09348957985639572, 0.09145021438598633, 0.019759872928261757, 0.14044030010700226, 0.010322250425815582, 0.09715981036424637, 0.15762653946876526, 0.08852937072515488, 0.012697099708020687, 0.10808465629816055, 0.005568717140704393, 0.11048246920108795, 0.8844196796417236, 0.06171249970793724, 0.021244505420327187], "fourier_magnitudes_out": [0.16980455815792084, 0.02732899598777294, 0.036924730986356735, 0.05185381695628166, 0.045569196343421936, 0.26525792479515076, 0.05154341459274292, 0.0016775119584053755, 0.005609582178294659, 0.06192351505160332, 0.05204838141798973, 0.878484845161438, 0.06501992046833038, 0.029970666393637657, 0.007208859547972679]}, "neuron_13": {"global_index": 400, "dominant_freq": 7, "fourier_magnitudes_in": [0.03170298784971237, 0.01636633276939392, 0.02570384368300438, 0.059365708380937576, 0.038365013897418976, 0.04416408762335777, 0.03580497205257416, 0.08423422276973724, 0.05347143113613129, 0.14569905400276184, 0.11170309782028198, 0.19443102180957794, 0.1770995706319809, 0.16263896226882935, 0.8427256941795349], "fourier_magnitudes_out": [0.001597732538357377, 0.03749779239296913, 0.021893106400966644, 0.021266808733344078, 0.003741121618077159, 0.06272387504577637, 0.031196942552924156, 0.06795763224363327, 0.09264159947633743, 0.009625964798033237, 0.16665540635585785, 0.20691494643688202, 0.2753947973251343, 0.8446058630943298, 0.2736137807369232]}, "neuron_14": {"global_index": 204, "dominant_freq": 1, "fourier_magnitudes_in": [0.009332116693258286, 0.18445806205272675, 0.8107064962387085, 0.0019745242316275835, 0.021862220019102097, 0.164852112531662, 0.19152098894119263, 0.0416061170399189, 0.006671736016869545, 0.12151989340782166, 0.0737357810139656, 0.07512296736240387, 0.00437825545668602, 0.09026234596967697, 0.010851189494132996], "fourier_magnitudes_out": [0.00363162811845541, 0.7924743294715881, 0.3171338140964508, 0.004241126589477062, 0.01878344640135765, 0.12165644019842148, 0.20536403357982635, 0.018597787246108055, 0.0810178741812706, 0.016301296651363373, 0.17405299842357635, 0.0980185940861702, 0.04579637944698334, 0.14084625244140625, 0.08989735692739487]}, "neuron_15": {"global_index": 200, "dominant_freq": 2, "fourier_magnitudes_in": [0.18793319165706635, 0.07831189781427383, 0.0546899177134037, 0.26695817708969116, 0.8170905709266663, 0.05296998843550682, 0.1943284422159195, 0.23756185173988342, 0.15034066140651703, 0.017131371423602104, 0.0028046099469065666, 0.11494778096675873, 0.10373453050851822, 0.12329491227865219, 0.18243563175201416], "fourier_magnitudes_out": [0.09242420643568039, 0.0712248757481575, 0.12614493072032928, 0.7048966288566589, 0.5352984666824341, 0.1033194437623024, 0.05411561205983162, 0.020737502723932266, 0.1372404545545578, 0.13828690350055695, 0.03525715693831444, 0.11022429913282394, 0.2024645209312439, 0.14993123710155487, 0.08065018057823181]}, "neuron_16": {"global_index": 29, "dominant_freq": 3, "fourier_magnitudes_in": [0.06590171158313751, 0.1875513195991516, 0.13940751552581787, 0.05203652381896973, 0.12810246646404266, 0.8321968913078308, 0.14579933881759644, 0.1637534499168396, 0.11946561932563782, 0.1717541664838791, 0.013540252111852169, 0.15768876671791077, 0.04509717971086502, 0.05775374174118042, 0.17158927023410797], "fourier_magnitudes_out": [0.004155661445111036, 0.03632631525397301, 0.11120419949293137, 0.07094360142946243, 0.0025094610173255205, 0.8083617091178894, 0.2703225910663605, 0.04779119789600372, 0.01505393534898758, 0.019837403669953346, 0.020111212506890297, 0.13071304559707642, 0.1915966272354126, 0.012666973285377026, 0.03241927549242973]}, "neuron_17": {"global_index": 263, "dominant_freq": 4, "fourier_magnitudes_in": [0.04006792977452278, 0.04470885172486305, 0.03644431754946709, 0.09875550121068954, 0.0844721719622612, 0.2842579483985901, 0.07786386460065842, 0.8181069493293762, 0.07545538246631622, 0.1041831225156784, 0.08087007701396942, 0.013009922578930855, 0.05374658852815628, 0.010971452109515667, 0.05345306172966957], "fourier_magnitudes_out": [0.0032453967723995447, 0.04532613605260849, 0.050931766629219055, 0.06243853271007538, 0.1471986621618271, 0.07246130704879761, 0.0618220679461956, 0.793653130531311, 0.21675735712051392, 0.10524935275316238, 0.002365024061873555, 0.05830618739128113, 0.12600253522396088, 0.015407783910632133, 0.04248036816716194]}, "neuron_18": {"global_index": 468, "dominant_freq": 5, "fourier_magnitudes_in": [0.27209052443504333, 0.01535376999527216, 0.024236787110567093, 0.014399196021258831, 0.026386309415102005, 0.03005525842308998, 0.0007289683562703431, 0.013318601995706558, 0.025410467758774757, 0.8113285303115845, 0.024080565199255943, 0.028674796223640442, 0.0011739643523469567, 0.015659669414162636, 0.025659453123807907], "fourier_magnitudes_out": [0.2273026406764984, 0.010362203232944012, 0.08463546633720398, 0.048502951860427856, 0.006748202722519636, 0.008361948654055595, 0.017721673473715782, 0.055468566715717316, 0.07133140414953232, 0.8983375430107117, 0.06657372415065765, 0.07770014554262161, 0.05767229199409485, 0.02166377194225788, 0.007756791543215513]}, "neuron_19": {"global_index": 502, "dominant_freq": 6, "fourier_magnitudes_in": [0.161783829331398, 0.011249735951423645, 0.011750025674700737, 0.006551133934408426, 0.005556976888328791, 0.3025686740875244, 0.004988331813365221, 0.01731167919933796, 0.009423403069376945, 0.001427029725164175, 0.0038674750830978155, 0.8558500409126282, 0.014558211900293827, 0.0111812399700284, 0.0006309517193585634], "fourier_magnitudes_out": [0.17063283920288086, 0.016314541921019554, 0.05021467059850693, 0.016806060448288918, 0.04298456758260727, 0.2806430757045746, 0.014211905188858509, 0.004449731670320034, 0.02185676619410515, 0.006607384420931339, 0.005355009343475103, 0.875554621219635, 0.016872091218829155, 0.022274665534496307, 0.03286416828632355]}}}
precomputed_results/p_015/p015_output_logits.png ADDED
precomputed_results/p_015/p015_overview.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"std_epochs": [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800], "std_ipr": [0.2509222626686096, 0.26378345489501953, 0.3110557198524475, 0.36821821331977844, 0.4191873371601105, 0.46056264638900757, 0.4932430386543274, 0.518873393535614, 0.5395100712776184, 0.5563467741012573, 0.5705686807632446, 0.5829041600227356, 0.5938962697982788, 0.6040970087051392, 0.6136700510978699, 0.6225748658180237, 0.6310203671455383, 0.6389821171760559, 0.6463097333908081, 0.6528617739677429, 0.6585475206375122, 0.6634621620178223, 0.6678526401519775, 0.6718972325325012, 0.6756421327590942], "std_train_loss": [2.7085084915161133, 2.677509307861328, 2.6341302394866943, 2.5730724334716797, 2.492614269256592, 2.3927817344665527, 2.2741761207580566, 2.137855052947998, 1.9853696823120117, 1.8187763690948486, 1.6410351991653442, 1.4555929899215698, 1.2665268182754517, 1.0787031650543213, 0.8973091840744019, 0.7275450825691223, 0.5740776062011719, 0.44004037976264954, 0.32742777466773987, 0.23663438856601715, 0.1664789468050003, 0.1142977625131607, 0.0767836645245552, 0.05059399828314781, 0.032766345888376236]}
precomputed_results/p_015/p015_overview_loss_ipr.png ADDED
precomputed_results/p_015/p015_overview_phase_scatter.png ADDED
precomputed_results/p_015/p015_phase_align_approx1.png ADDED

Git LFS Details

  • SHA256: 726fb7a76bbbe19d3cea0f55c72b7db45468644cf2403a8993bed849bd8a7715
  • Pointer size: 131 Bytes
  • Size of remote file: 241 kB
precomputed_results/p_015/p015_phase_align_approx2.png ADDED

Git LFS Details

  • SHA256: 4e968cd2206257c9fd0898e0f0ab810eb94728d93e3dc13ca732114107e3f2d1
  • Pointer size: 131 Bytes
  • Size of remote file: 232 kB
precomputed_results/p_015/p015_phase_align_quad.png ADDED
precomputed_results/p_015/p015_phase_align_relu.png ADDED
precomputed_results/p_015/p015_phase_distribution.png ADDED
precomputed_results/p_015/p015_phase_relationship.png ADDED
precomputed_results/p_015/p015_single_freq_quad.png ADDED

Git LFS Details

  • SHA256: ccbca685fc93991c9965ca3e8fc3fa87ae3fd03e74fb5c40b9f03468e63dda79
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB
precomputed_results/p_015/p015_single_freq_relu.png ADDED

Git LFS Details

  • SHA256: fa5f0678750033da499f044eb3d50a94cfd0f2ce00b69dbd76692a8ae6be4aa6
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
precomputed_results/p_015/p015_training_log.json ADDED
The diff for this file is too large to render. See raw diff
 
precomputed_results/p_023/p023_full_training_para_origin.png ADDED

Git LFS Details

  • SHA256: 6bda97a1602fffb1ff14fec8ba05d3a7c928ff8765d1cc7c957eee3e89c2875b
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
precomputed_results/p_023/p023_grokk_abs_phase_diff.png ADDED
precomputed_results/p_023/p023_grokk_acc.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"epochs": [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000, 10200, 10400, 10600, 10800, 11000, 11200, 11400, 11600, 11800, 12000, 12200, 12400, 12600, 12800, 13000, 13200, 13400, 13600, 13800, 14000, 14200, 14400, 14600, 14800, 15000, 15200, 15400, 15600, 15800, 16000, 16200, 16400, 16600, 16800, 17000, 17200, 17400, 17600, 17800, 18000, 18200, 18400, 18600, 18800, 19000, 19200, 19400, 19600, 19800, 20000, 20200, 20400, 20600, 20800, 21000, 21200, 21400, 21600, 21800, 22000, 22200, 22400, 22600, 22800, 23000, 23200, 23400, 23600, 23800], "train_accs": [0.045454545454545456, 0.23737373737373738, 0.5404040404040404, 0.7272727272727273, 0.7348484848484849, 0.7424242424242424, 0.76010101010101, 0.8055555555555556, 0.8838383838383839, 0.9671717171717171, 0.9924242424242424, 0.9974747474747475, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "test_accs": [0.03759398496240601, 0.007518796992481203, 0.007518796992481203, 0.007518796992481203, 0.015037593984962405, 0.03007518796992481, 0.06766917293233082, 0.18045112781954886, 0.39849624060150374, 0.6466165413533834, 0.7142857142857143, 0.7293233082706767, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7518796992481203, 0.7669172932330827, 0.7819548872180451, 0.7819548872180451, 0.7819548872180451, 0.7819548872180451, 0.7819548872180451, 0.7819548872180451, 0.7819548872180451, 0.7969924812030075, 0.8270676691729323, 0.8270676691729323, 0.8421052631578947, 0.9022556390977443, 0.9022556390977443, 0.9022556390977443, 0.9022556390977443, 0.9172932330827067, 0.9323308270676691, 0.9323308270676691, 0.9323308270676691, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939], "stage1_end": 1744, "stage2_end": 9431}
precomputed_results/p_023/p023_grokk_acc.png ADDED
precomputed_results/p_023/p023_grokk_avg_ipr.png ADDED
precomputed_results/p_023/p023_grokk_decoded_weights_dynamic.png ADDED

Git LFS Details

  • SHA256: aaadb8ffd70fa189e963c58714fcfd065cc9cdbdec5d6ca41a064201f1ebaf0c
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
precomputed_results/p_023/p023_grokk_epoch_data.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"prime": 23, "epochs": [0, 2600, 5200, 7800, 10400, 13200, 15800, 18400, 21000, 23800], "grids": [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], [[1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]]}
precomputed_results/p_023/p023_grokk_loss.json ADDED
The diff for this file is too large to render. See raw diff
 
precomputed_results/p_023/p023_grokk_loss.png ADDED
precomputed_results/p_023/p023_grokk_memorization_accuracy.png ADDED
precomputed_results/p_023/p023_grokk_memorization_common_to_rare.png ADDED

Git LFS Details

  • SHA256: 5694167dbc67e7f563f1b3744e4eb0ceb35343caceb46d45d409ca30e9a5fcbb
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
precomputed_results/p_023/p023_lineplot_in.png ADDED

Git LFS Details

  • SHA256: fe3f29f4670ab511d3f76fb5721747b8e03d8253fe97f542f34908024177d2b5
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB