zhuoranyang's picture
Split DFT heatmap into side-by-side W_E/W_L, simplify Tab 1
9d19288 verified
# Pre-computation Pipeline
Batch training and plot generation for all odd moduli $p$ in [3, 199]. Trains 5 model configurations per $p$ and generates publication-quality figures plus interactive JSON data files covering the paper's core results.
All commands are run from the **project root directory**.
## Quick Start (Shell Script)
The easiest way to run the full pipeline for a single modulus:
```bash
# Run the complete pipeline for p=23
bash precompute/run_pipeline.sh 23
# Or using an environment variable
P=23 bash precompute/run_pipeline.sh
```
This runs training, plot generation, analytical simulation, and verification in sequence.
## Complete Pipeline (Single Modulus, Manual Steps)
```bash
# Step 1: Train all 5 model configurations
python precompute/train_all.py --p 23 --output ./trained_models
# Step 2: Generate model-based plots (21 PNGs + 6 JSONs + metadata)
python precompute/generate_plots.py --p 23 --input ./trained_models --output ./precomputed_results
# Step 3: Generate analytical simulation plots (2 PNGs, no model needed)
python precompute/generate_analytical.py --p 23 --output ./precomputed_results
# Step 4: Verify
ls precomputed_results/p_023/
```
## Complete Pipeline (All Odd p)
```bash
# Train everything (225 runs total). Use --resume to skip completed runs.
python precompute/train_all.py --all --output ./trained_models --resume
# Generate all plots
python precompute/generate_plots.py --all --input ./trained_models --output ./precomputed_results
# Generate all analytical plots
python precompute/generate_analytical.py --all --output ./precomputed_results
```
---
## The 5 Model Configurations
Each modulus is trained with 5 configurations that correspond to different sections of the paper:
### 1. Standard Training (`standard`)
The baseline experiment for Parts I--II (Mechanism & Dynamics). Demonstrates Fourier feature learning: neurons decompose modular addition into sparse frequency components with phase alignment (ψ ≈ 2φ).
| Parameter | Value |
|-----------|-------|
| Activation | ReLU |
| Initialization | random |
| Optimizer | AdamW |
| Learning rate | 5e-5 |
| Weight decay | 0 |
| Train fraction | 1.0 (all p² pairs) |
| Epochs | 5,000 |
| Init scale | 0.1 |
**Used by:** Tab 1 (Overview), Tab 2 (Fourier Weights), Tab 3 (Phase Analysis), Tab 4 (Output Logits)
### 2. Grokking (`grokking`)
Reproduces the grokking phenomenon (Part III). The model memorizes training data first, then suddenly generalizes. Requires partial training data + weight decay.
| Parameter | Value |
|-----------|-------|
| Activation | ReLU |
| Initialization | random |
| Optimizer | AdamW |
| Learning rate | 1e-4 |
| Weight decay | **2.0** |
| Train fraction | **0.75** |
| Epochs | **50,000** |
| Init scale | 0.1 |
**Used by:** Tab 1 (Overview, grokking curves), Tab 6 (Grokking)
**Note:** Only runs for p ≥ 19 (smaller $p$ have too few test points for meaningful grokking).
### 3. Quadratic Activation (`quad_random`)
Uses σ(x) = x² activation. The quadratic nonlinearity directly implements the frequency factorization mechanism from the theory, enabling clean analysis of the lottery ticket mechanism.
| Parameter | Value |
|-----------|-------|
| Activation | **Quad** |
| Initialization | random |
| Optimizer | AdamW |
| Learning rate | 5e-5 |
| Weight decay | 0 |
| Train fraction | 1.0 |
| Epochs | 5,000 |
| Init scale | 0.1 |
**Used by:** Tab 5 (Lottery Mechanism)
### 4. Single-Frequency Quad (`quad_single_freq`)
Initializes neurons at specific frequencies to study gradient dynamics under controlled conditions. Validates the phase alignment theorem and single-frequency preservation theorem.
| Parameter | Value |
|-----------|-------|
| Activation | **Quad** |
| Initialization | **single-freq** |
| Optimizer | **SGD** |
| Learning rate | **0.1** |
| Weight decay | 0 |
| Train fraction | 1.0 |
| Epochs | 5,000 |
| Init scale | **0.02** |
**Used by:** Tab 7 (Gradient Dynamics, quadratic panels)
### 5. Single-Frequency ReLU (`relu_single_freq`)
Same as above but with ReLU activation. Shows that the theoretical results (proved for quadratic) hold approximately for ReLU with small harmonic leakage.
| Parameter | Value |
|-----------|-------|
| Activation | **ReLU** |
| Initialization | **single-freq** |
| Optimizer | **SGD** |
| Learning rate | **0.01** |
| Weight decay | 0 |
| Train fraction | 1.0 |
| Epochs | 5,000 |
| Init scale | **0.002** |
**Used by:** Tab 7 (Gradient Dynamics, ReLU panels)
---
## Neuron Sizing
The number of hidden neurons scales with $p$ to maintain the ratio from the baseline experiment ($p=23$, $d_\text{mlp}=512$):
```
d_mlp = max(512, ceil(512/529 * p²))
```
Examples: $p=3 \to 512$, $p=23 \to 512$, $p=53 \to 2720$, $p=97 \to 9108$, $p=199 \to 38329$.
---
## Blog Figure → Pipeline Output Mapping
The table below maps every figure in the blog post to the corresponding file generated by the pipeline. Each figure is reproduced for every $p$, allowing users to verify the paper's claims across different moduli.
### Part I: Mechanism (Tabs 2--4, standard run)
| Blog Figure | Description | Pipeline Output | Tab | Verified? |
|-------------|-------------|----------------|-----|-----------|
| **Fig. 2** — Fourier sparsity of learned weights | DFT heatmap: each row is a neuron, each column is a Fourier mode (cos k, sin k). Sparse = one bright cell per row, confirming single-frequency specialization. | `pXXX_dft_heatmap_in.png`, `pXXX_dft_heatmap_out.png` | 2 | The heatmap applies `W_in @ fourier_basis.T` and `W_out.T @ fourier_basis.T` to show DFT coefficients. X-axis labels are Fourier mode names (Const, cos 1, sin 1, ...). Sparsity is visible as one dominant pair per neuron row. |
| **Fig. 3** — Cosine fits to individual neurons | Raw learned weight values (dots) vs. best-fit cosine (dashed) for 3 representative neurons. Left: input weights θ_m. Right: output weights ξ_m. | `pXXX_lineplot_in.png`, `pXXX_lineplot_out.png` | 2 | Projects raw weights into Fourier space, keeps top-2 components, projects back. The fit quality demonstrates that each neuron is well-described by a single cosine. |
| **Fig. 4** — Phase alignment ψ = 2φ | Scatter plot of (2φ_m mod 2π) vs (ψ_m mod 2π). All points lie on the diagonal y = x. | `pXXX_phase_relationship.png` | 3 | Computed via `compute_neuron()` for every neuron. The diagonal pattern is Observation 2 from the paper. |
| **Fig. 5** — Higher-order phase symmetry | Polar plot: phase angles ι·φ_m on concentric rings for ι = 1, 2, 3, 4. Uniform spread confirms the cancellation condition Σ exp(i·ι·φ_m) ≈ 0. | `pXXX_phase_distribution.png` | 3 | Shows phases for the most common frequency group. For large p with many neurons, the uniform spread is clearly visible. |
| **Fig. 6** — Magnitude homogeneity | Violin plots of α_m (input) and β_m (output) across all neurons. Tight concentration confirms magnitude homogeneity (Observation 3c). | `pXXX_magnitude_distribution.png` | 3 | Uses `compute_neuron()` to extract scale for every neuron. |
| **Fig. 7** — Output logits (flawed indicator) | Heatmap of f(x,y)[j] for x=0. Bright red diagonal at j=(x+y) mod p (correct answer, coefficient p/2). Faint pink at j=2x mod p and j=2y mod p (spurious peaks, coefficient p/4). | `pXXX_output_logits.png` | 4 | Forward pass through the trained model with **matching activation** (ReLU for standard run). Rectangles highlight the correct answer and spurious peak positions. |
### Part II: Dynamics (Tabs 5, 7, 8)
| Blog Figure | Description | Pipeline Output | Tab | Verified? |
|-------------|-------------|----------------|-----|-----------|
| **Fig. 8** — Phase alignment dynamics | Phase trajectories (φ, ψ, 2φ) and magnitude growth (α, β) over training. Left: Quad activation. Right: ReLU. Shows ψ → 2φ convergence. | `pXXX_phase_align_quad.png`, `pXXX_phase_align_relu.png` | 7 | Tracks the neuron with largest final scale across all checkpoints. Shows phases converging and magnitudes growing. |
| **Fig. 9** — Lottery ticket race | Left: phase misalignment D_m^k(t) for all frequencies within one neuron. The winner (smallest initial D) converges first. Right: magnitude β_m^k(t). Winner grows explosively. | `pXXX_lottery_mech_phase.png`, `pXXX_lottery_mech_magnitude.png` | 5 | Tracks all frequency components of a single neuron via `decode_scales_phis()` across checkpoints from the `quad_random` run. The winning frequency (highlighted in red) has the smallest initial misalignment. |
| **Fig. 10** — Lottery outcome contour | Final magnitude β as a function of (initial magnitude, initial phase difference 2φ₀). Largest values at small D, symmetric about D = π. | `pXXX_lottery_beta_contour.png` | 5 | Simulates gradient flow on a 30×30 grid of initial conditions. Each point runs 100 steps of the analytical ODE. |
| **Fig. 11** — Single-frequency preservation (Quad) | DFT heatmap at multiple training timepoints. The initialized frequency retains all energy; no cross-frequency leakage. | `pXXX_single_freq_quad.png` | 7 | Shows DFT of weights at 3 timepoints (step 0, mid, final) for the `quad_single_freq` run. Each column is a Fourier mode; sparsity confirms preservation. |
| **Fig. 12a** — Single-frequency preservation (ReLU) | Same as Fig. 11 but with ReLU. Small harmonic leakage visible at 3k*, 5k* (input) and 2k*, 3k* (output), decaying as O(r⁻²). | `pXXX_single_freq_relu.png` | 7 | Shows DFT at 2 timepoints (step 0, final) for the `relu_single_freq` run. Dominant frequency overwhelms harmonics. |
| **Fig. 12b** — Phase alignment under ReLU | Phase and magnitude trajectories for ReLU single-frequency init. Same zero-attractor behavior as Quad. | `pXXX_phase_align_relu.png` | 7 | Same as Fig. 8 right panel. |
| — Decoupled ODE simulation | Pure ODE integration (no neural network) showing phase convergence and magnitude competition for all frequencies within one neuron. Two cases with different initial conditions. | `pXXX_phase_align_approx1.png`, `pXXX_phase_align_approx2.png` | 8 | Generated by `generate_analytical.py`, not `generate_plots.py`. Validates the theory without any training. |
### Part III: Grokking (Tab 6)
| Blog Figure | Description | Pipeline Output | Tab | Verified? |
|-------------|-------------|----------------|-----|-----------|
| **Fig. 13a** — Grokking loss curves | Training and test loss over 50k epochs. Three stages: (I) train loss drops, (II) test loss drops, (III) both near zero. Stage boundaries marked. | `pXXX_grokk_loss.json` → interactive Plotly chart | 6 | From `training_curves.json`. Stage boundaries detected by `grokking_stage_detector.py`. Shaded regions distinguish the three stages. |
| **Fig. 13b** — Grokking accuracy curves | Training and test accuracy. Train → 100% in Stage I, test jumps in Stage II. | `pXXX_grokk_acc.json` → interactive Plotly chart | 6 | Computed by running forward pass on train/test data at each checkpoint. |
| **Fig. 13c** — Phase alignment progress | Average |sin(D_m*)| over training. Decreases throughout, steepest in Stage II. | `pXXX_grokk_abs_phase_diff.png` | 6 | Computed via `decode_weights()` + `compute_neuron()` at each grokking checkpoint. |
| **Fig. 13d** — IPR and parameter norm | Dual-axis: IPR (Fourier sparsity) increases sharply in Stage II; parameter norm shrinks in Stage III. | `pXXX_grokk_avg_ipr.png` | 6 | IPR uses the corrected per-frequency magnitude formula: A_k = sqrt(c_k² + s_k²), IPR = Σ A_k⁴ / (Σ A_k²)². Parameter norms from `training_curves.json`. |
| **Fig. 14** — Memorization accuracy heatmap | Three panels at end of Stage I: (1) training data distribution under symmetry, (2) accuracy grid, (3) softmax probability at ground truth. Red rectangles = true test pairs. | `pXXX_grokk_memorization_accuracy.png` | 6 | Forward pass at the checkpoint closest to stage1_end. The symmetric architecture guarantees ~70% test accuracy during memorization. |
| **Fig. 15** — Common-to-rare memorization | Four panels: training data distribution + accuracy at 3 timepoints during Stage I. Shows common pairs (both (i,j) and (j,i) in train) memorized before rare pairs (only one ordering). | `pXXX_grokk_memorization_common_to_rare.png` | 6 | Epochs selected at 0, stage1_end/2, stage1_end. Red rectangles mark asymmetric training pairs. |
| **Fig. 16** — Weight evolution during grokking | 2×3 DFT heatmap grid showing θ_m and ξ_m at Step 0 (random init), end of Stage I (noisy multi-frequency), and end of Stage II (clean single-frequency). | `pXXX_grokk_decoded_weights_dynamic.png` | 6 | DFT coefficients `W @ fourier_basis.T` at 3 key epochs. The transition from diffuse to sparse confirms the sparsification narrative. |
### Overview (Tab 1)
| Blog Figure | Description | Pipeline Output | Tab | Verified? |
|-------------|-------------|----------------|-----|-----------|
| — Overview dashboard | 2×2 grid: standard loss + grokking loss (top), standard IPR + grokking IPR (bottom). Plus phase scatter from standard final checkpoint. | `pXXX_overview_loss_ipr.png`, `pXXX_overview_phase_scatter.png`, `pXXX_overview.json` | 1 | Combines data from standard and grokking runs. The phase scatter uses the same computation as Fig. 4. For $p < 19$, only the standard column is shown. |
### Not Currently Generated (Blog-Only Figures)
| Blog Figure | Description | Status |
|-------------|-------------|--------|
| **Fig. 1** — Architecture illustration | Schematic of the two-layer network, DFT decomposition, and mechanism overview. | Static illustration, not dependent on p. Could be included as a fixed image in the app. |
---
## Interactive JSON Data Files
In addition to static PNG plots, the pipeline generates JSON files for interactive Plotly charts in the Gradio app:
| File | Content | Used In |
|------|---------|---------|
| `pXXX_overview.json` | Standard loss/IPR + grokking loss/IPR time series | Tab 1: Interactive loss and IPR charts |
| `pXXX_neuron_spectra.json` | Per-neuron Fourier magnitudes (W_in and W_out) for top-20 neurons, sorted by frequency | Tab 2: Neuron Inspector dropdown → bar chart of frequency decomposition |
| `pXXX_logits_interactive.json` | Output logits for p representative (a,b) pairs, plus correct answers | Tab 4: Logit Explorer dropdown → bar chart with correct answer highlighted |
| `pXXX_grokk_loss.json` | Full training/test loss curves + stage boundaries | Tab 6: Interactive loss chart with stage shading |
| `pXXX_grokk_acc.json` | Accuracy at each checkpoint epoch + stage boundaries | Tab 6: Interactive accuracy chart with stage shading |
| `pXXX_grokk_epoch_data.json` | p×p accuracy grids at ~10 evenly-spaced grokking epochs | Tab 6: Epoch Slider → heatmap animation across training |
| `pXXX_metadata.json` | Config for all 5 runs + final metrics (loss, accuracy) | Displayed in the app's info panel for the selected $p$ |
---
## Output Structure
All plots for a modulus are saved in a single flat directory. Each file is prefixed with `pXXX_` so the folder is self-contained and browsable:
```
precomputed_results/p_023/
# Metadata
p023_metadata.json
# Tab 1: Overview (Blog: summary of standard + grokking)
p023_overview_loss_ipr.png # 2×2 grid: loss + IPR for both setups
p023_overview_phase_scatter.png # Phase alignment scatter (same as Fig. 4)
p023_overview.json # Interactive data
# Tab 2: Fourier Weights (Blog: Figures 2, 3)
p023_dft_heatmap_in.png # DFT heatmap, W_E input layer (Fig. 2)
p023_dft_heatmap_out.png # DFT heatmap, W_L output layer (Fig. 2)
p023_lineplot_in.png # Cosine fits, input layer (Fig. 3 left)
p023_lineplot_out.png # Cosine fits, output layer (Fig. 3 right)
p023_neuron_spectra.json # Interactive: neuron inspector
# Tab 3: Phase Analysis (Blog: Figures 4, 5, 6)
p023_phase_distribution.png # Polar phase plot (Fig. 5)
p023_phase_relationship.png # 2φ vs ψ scatter (Fig. 4)
p023_magnitude_distribution.png # Violin plots (Fig. 6)
# Tab 4: Output Logits (Blog: Figure 7)
p023_output_logits.png # Logit heatmap (Fig. 7)
p023_logits_interactive.json # Interactive: logit explorer
# Tab 5: Lottery Mechanism (Blog: Figures 9, 10)
p023_lottery_mech_magnitude.png # Magnitude race (Fig. 9 right)
p023_lottery_mech_phase.png # Phase misalignment race (Fig. 9 left)
p023_lottery_beta_contour.png # Contour plot (Fig. 10)
# Tab 6: Grokking (Blog: Figures 13, 14, 15, 16)
p023_grokk_loss.json # Interactive loss curves (Fig. 13a)
p023_grokk_acc.json # Interactive accuracy curves (Fig. 13b)
p023_grokk_abs_phase_diff.png # Phase alignment progress (Fig. 13c)
p023_grokk_avg_ipr.png # IPR + param norms (Fig. 13d)
p023_grokk_memorization_accuracy.png # 3-panel heatmap (Fig. 14)
p023_grokk_memorization_common_to_rare.png # 4-panel sequence (Fig. 15)
p023_grokk_decoded_weights_dynamic.png # DFT evolution (Fig. 16)
p023_grokk_epoch_data.json # Interactive: epoch slider
# Tab 7: Gradient Dynamics (Blog: Figures 8, 11, 12)
p023_phase_align_quad.png # Phase + magnitude, Quad (Fig. 8 left)
p023_single_freq_quad.png # DFT heatmap over time, Quad (Fig. 11)
p023_phase_align_relu.png # Phase + magnitude, ReLU (Fig. 8 right / 12b)
p023_single_freq_relu.png # DFT heatmap over time, ReLU (Fig. 12a)
# Tab 8: Decoupled Simulation (no blog figure number)
p023_phase_align_approx1.png # ODE simulation case 1
p023_phase_align_approx2.png # ODE simulation case 2
```
**29 files per $p$:** 21 PNGs + 6 JSONs from trained models, 2 PNGs from analytical simulation.
---
## Correctness Verification
### How each computation matches the paper
1. **DFT Decomposition (Figs. 2, 11, 12a, 16):** We compute `W @ fourier_basis.T` where `fourier_basis` is the orthonormal DFT basis from `get_fourier_basis(p)`. The basis has rows: [Const, cos 1, sin 1, cos 2, sin 2, ..., cos K, sin K] with K = (p-1)/2 for odd $p$. Each row is L2-normalized. This matches the standard real DFT on Z_p.
2. **Phase extraction (Figs. 4, 8, 9, 13c):** For frequency k, the DFT gives coefficients (c_k, s_k) at indices (2k-1, 2k). The magnitude is α = sqrt(2/p) · sqrt(c_k² + s_k²), and the phase is φ = arctan2(-s_k, c_k). This convention matches the paper's θ_m[j] = α cos(ω_k j + φ) representation.
3. **IPR (Figs. 13d, Overview):** Uses the corrected per-frequency magnitude formula: A_k = sqrt(c_k² + s_k²) (combining cos/sin pairs), then IPR = Σ A_k⁴ / (Σ A_k²)². This gives IPR → 1 for perfect single-frequency neurons, matching the paper's definition.
4. **Phase alignment (Fig. 4):** The doubled-phase relationship ψ_m = 2φ_m is verified by extracting φ from W_in and ψ from W_out using the same `compute_neuron()` function, then plotting (2φ mod 2π) vs (ψ mod 2π).
5. **Output logits (Fig. 7):** Forward pass uses the **same activation function** as training (ReLU for standard run). The flawed indicator structure (main diagonal + two ghost diagonals) is visible because the standard run trains to 100% accuracy with clean Fourier features.
6. **Lottery mechanism (Figs. 9, 10):** Uses the `quad_random` run (quadratic activation, random init) which matches the theoretical setting. `decode_scales_phis()` extracts per-frequency magnitudes and phases at each checkpoint. The winning frequency is the one with smallest initial |D| = |2φ - ψ|.
7. **Grokking stages (Figs. 13--16):** `grokking_stage_detector.py` identifies stage boundaries from training curves. Stage I ends when train accuracy ≈ 1.0, Stage II ends when test accuracy ≈ 1.0. Memorization heatmaps use forward pass at the closest checkpoint to stage1_end.
8. **Analytical simulation (Tab 8):** Numerically integrates the four-variable ODE system from Section 5.3 of the paper. No neural network is involved — this validates the theoretical dynamics directly.
### Why results generalize across $p$
The paper's theory is stated for general odd $p$. Key properties that scale:
- **Fourier basis:** Always has (p-1)/2 non-DC frequencies for any odd $p$.
- **Phase alignment:** The ψ = 2φ relationship is a consequence of the gradient dynamics, independent of p.
- **Lottery mechanism:** Random initial misalignments are uniform on [0, 2π) for any p.
- **Grokking three stages:** The stage structure depends on the balance of loss gradient vs. weight decay, not on p specifically (though the stage durations and test accuracy during memorization may vary).
- **Network width:** d_mlp scales as O(p²) to maintain the neuron-to-frequency ratio, ensuring enough neurons per frequency for diversification.
---
## Scripts
| Script | Purpose |
|--------|---------|
| `run_pipeline.sh` | Runs the complete pipeline (train + plots + analytical + verify) for a single modulus. |
| `train_all.py` | Trains all 5 model configurations. Saves checkpoints + `training_curves.json`. |
| `generate_plots.py` | Loads trained models and generates all model-dependent plots (Tabs 1--7) plus interactive JSONs and metadata. |
| `generate_analytical.py` | Runs gradient flow simulations to generate theory plots (Tab 8). No model needed. |
| `prime_config.py` | Configuration: moduli list, d_mlp formula, training run parameters. |
| `neuron_selector.py` | Automated neuron selection for plots (replaces hardcoded indices from notebooks). |
| `grokking_stage_detector.py` | Detects memorization/transition/generalization stage boundaries from training curves. |
---
## Analytical Simulations (No Model Needed)
`generate_analytical.py` produces 2 plots per $p$ by simulating gradient flow on decoupled frequency dynamics. These validate the theoretical analysis without training any model.
- **Case 1**: Shows phase difference D* converging from initial conditions (φ₀=1.5, ψ₀=0.18)
- **Case 2**: Different initial conditions (φ₀=-0.72, ψ₀=-2.91) showing convergence from the other side
Both cases confirm the phase alignment theorem: D → 0 is the stable attractor, D → π is unstable.