zhuoranyang's picture
Split DFT heatmap into side-by-side W_E/W_L, simplify Tab 1
9d19288 verified

A newer version of the Gradio SDK is available: 6.6.0

Upgrade

Pre-computation Pipeline

Batch training and plot generation for all odd moduli $p$ in [3, 199]. Trains 5 model configurations per $p$ and generates publication-quality figures plus interactive JSON data files covering the paper's core results.

All commands are run from the project root directory.

Quick Start (Shell Script)

The easiest way to run the full pipeline for a single modulus:

# Run the complete pipeline for p=23
bash precompute/run_pipeline.sh 23

# Or using an environment variable
P=23 bash precompute/run_pipeline.sh

This runs training, plot generation, analytical simulation, and verification in sequence.

Complete Pipeline (Single Modulus, Manual Steps)

# Step 1: Train all 5 model configurations
python precompute/train_all.py --p 23 --output ./trained_models

# Step 2: Generate model-based plots (21 PNGs + 6 JSONs + metadata)
python precompute/generate_plots.py --p 23 --input ./trained_models --output ./precomputed_results

# Step 3: Generate analytical simulation plots (2 PNGs, no model needed)
python precompute/generate_analytical.py --p 23 --output ./precomputed_results

# Step 4: Verify
ls precomputed_results/p_023/

Complete Pipeline (All Odd p)

# Train everything (225 runs total). Use --resume to skip completed runs.
python precompute/train_all.py --all --output ./trained_models --resume

# Generate all plots
python precompute/generate_plots.py --all --input ./trained_models --output ./precomputed_results

# Generate all analytical plots
python precompute/generate_analytical.py --all --output ./precomputed_results

The 5 Model Configurations

Each modulus is trained with 5 configurations that correspond to different sections of the paper:

1. Standard Training (standard)

The baseline experiment for Parts I--II (Mechanism & Dynamics). Demonstrates Fourier feature learning: neurons decompose modular addition into sparse frequency components with phase alignment (ψ ≈ 2φ).

Parameter Value
Activation ReLU
Initialization random
Optimizer AdamW
Learning rate 5e-5
Weight decay 0
Train fraction 1.0 (all p² pairs)
Epochs 5,000
Init scale 0.1

Used by: Tab 1 (Overview), Tab 2 (Fourier Weights), Tab 3 (Phase Analysis), Tab 4 (Output Logits)

2. Grokking (grokking)

Reproduces the grokking phenomenon (Part III). The model memorizes training data first, then suddenly generalizes. Requires partial training data + weight decay.

Parameter Value
Activation ReLU
Initialization random
Optimizer AdamW
Learning rate 1e-4
Weight decay 2.0
Train fraction 0.75
Epochs 50,000
Init scale 0.1

Used by: Tab 1 (Overview, grokking curves), Tab 6 (Grokking) Note: Only runs for p ≥ 19 (smaller $p$ have too few test points for meaningful grokking).

3. Quadratic Activation (quad_random)

Uses σ(x) = x² activation. The quadratic nonlinearity directly implements the frequency factorization mechanism from the theory, enabling clean analysis of the lottery ticket mechanism.

Parameter Value
Activation Quad
Initialization random
Optimizer AdamW
Learning rate 5e-5
Weight decay 0
Train fraction 1.0
Epochs 5,000
Init scale 0.1

Used by: Tab 5 (Lottery Mechanism)

4. Single-Frequency Quad (quad_single_freq)

Initializes neurons at specific frequencies to study gradient dynamics under controlled conditions. Validates the phase alignment theorem and single-frequency preservation theorem.

Parameter Value
Activation Quad
Initialization single-freq
Optimizer SGD
Learning rate 0.1
Weight decay 0
Train fraction 1.0
Epochs 5,000
Init scale 0.02

Used by: Tab 7 (Gradient Dynamics, quadratic panels)

5. Single-Frequency ReLU (relu_single_freq)

Same as above but with ReLU activation. Shows that the theoretical results (proved for quadratic) hold approximately for ReLU with small harmonic leakage.

Parameter Value
Activation ReLU
Initialization single-freq
Optimizer SGD
Learning rate 0.01
Weight decay 0
Train fraction 1.0
Epochs 5,000
Init scale 0.002

Used by: Tab 7 (Gradient Dynamics, ReLU panels)


Neuron Sizing

The number of hidden neurons scales with $p$ to maintain the ratio from the baseline experiment ($p=23$, $d_\text{mlp}=512$):

d_mlp = max(512, ceil(512/529 * p²))

Examples: $p=3 \to 512$, $p=23 \to 512$, $p=53 \to 2720$, $p=97 \to 9108$, $p=199 \to 38329$.


Blog Figure → Pipeline Output Mapping

The table below maps every figure in the blog post to the corresponding file generated by the pipeline. Each figure is reproduced for every $p$, allowing users to verify the paper's claims across different moduli.

Part I: Mechanism (Tabs 2--4, standard run)

Blog Figure Description Pipeline Output Tab Verified?
Fig. 2 — Fourier sparsity of learned weights DFT heatmap: each row is a neuron, each column is a Fourier mode (cos k, sin k). Sparse = one bright cell per row, confirming single-frequency specialization. pXXX_dft_heatmap_in.png, pXXX_dft_heatmap_out.png 2 The heatmap applies W_in @ fourier_basis.T and W_out.T @ fourier_basis.T to show DFT coefficients. X-axis labels are Fourier mode names (Const, cos 1, sin 1, ...). Sparsity is visible as one dominant pair per neuron row.
Fig. 3 — Cosine fits to individual neurons Raw learned weight values (dots) vs. best-fit cosine (dashed) for 3 representative neurons. Left: input weights θ_m. Right: output weights ξ_m. pXXX_lineplot_in.png, pXXX_lineplot_out.png 2 Projects raw weights into Fourier space, keeps top-2 components, projects back. The fit quality demonstrates that each neuron is well-described by a single cosine.
Fig. 4 — Phase alignment ψ = 2φ Scatter plot of (2φ_m mod 2π) vs (ψ_m mod 2π). All points lie on the diagonal y = x. pXXX_phase_relationship.png 3 Computed via compute_neuron() for every neuron. The diagonal pattern is Observation 2 from the paper.
Fig. 5 — Higher-order phase symmetry Polar plot: phase angles ι·φ_m on concentric rings for ι = 1, 2, 3, 4. Uniform spread confirms the cancellation condition Σ exp(i·ι·φ_m) ≈ 0. pXXX_phase_distribution.png 3 Shows phases for the most common frequency group. For large p with many neurons, the uniform spread is clearly visible.
Fig. 6 — Magnitude homogeneity Violin plots of α_m (input) and β_m (output) across all neurons. Tight concentration confirms magnitude homogeneity (Observation 3c). pXXX_magnitude_distribution.png 3 Uses compute_neuron() to extract scale for every neuron.
Fig. 7 — Output logits (flawed indicator) Heatmap of f(x,y)[j] for x=0. Bright red diagonal at j=(x+y) mod p (correct answer, coefficient p/2). Faint pink at j=2x mod p and j=2y mod p (spurious peaks, coefficient p/4). pXXX_output_logits.png 4 Forward pass through the trained model with matching activation (ReLU for standard run). Rectangles highlight the correct answer and spurious peak positions.

Part II: Dynamics (Tabs 5, 7, 8)

Blog Figure Description Pipeline Output Tab Verified?
Fig. 8 — Phase alignment dynamics Phase trajectories (φ, ψ, 2φ) and magnitude growth (α, β) over training. Left: Quad activation. Right: ReLU. Shows ψ → 2φ convergence. pXXX_phase_align_quad.png, pXXX_phase_align_relu.png 7 Tracks the neuron with largest final scale across all checkpoints. Shows phases converging and magnitudes growing.
Fig. 9 — Lottery ticket race Left: phase misalignment D_m^k(t) for all frequencies within one neuron. The winner (smallest initial D) converges first. Right: magnitude β_m^k(t). Winner grows explosively. pXXX_lottery_mech_phase.png, pXXX_lottery_mech_magnitude.png 5 Tracks all frequency components of a single neuron via decode_scales_phis() across checkpoints from the quad_random run. The winning frequency (highlighted in red) has the smallest initial misalignment.
Fig. 10 — Lottery outcome contour Final magnitude β as a function of (initial magnitude, initial phase difference 2φ₀). Largest values at small D, symmetric about D = π. pXXX_lottery_beta_contour.png 5 Simulates gradient flow on a 30×30 grid of initial conditions. Each point runs 100 steps of the analytical ODE.
Fig. 11 — Single-frequency preservation (Quad) DFT heatmap at multiple training timepoints. The initialized frequency retains all energy; no cross-frequency leakage. pXXX_single_freq_quad.png 7 Shows DFT of weights at 3 timepoints (step 0, mid, final) for the quad_single_freq run. Each column is a Fourier mode; sparsity confirms preservation.
Fig. 12a — Single-frequency preservation (ReLU) Same as Fig. 11 but with ReLU. Small harmonic leakage visible at 3k*, 5k* (input) and 2k*, 3k* (output), decaying as O(r⁻²). pXXX_single_freq_relu.png 7 Shows DFT at 2 timepoints (step 0, final) for the relu_single_freq run. Dominant frequency overwhelms harmonics.
Fig. 12b — Phase alignment under ReLU Phase and magnitude trajectories for ReLU single-frequency init. Same zero-attractor behavior as Quad. pXXX_phase_align_relu.png 7 Same as Fig. 8 right panel.
— Decoupled ODE simulation Pure ODE integration (no neural network) showing phase convergence and magnitude competition for all frequencies within one neuron. Two cases with different initial conditions. pXXX_phase_align_approx1.png, pXXX_phase_align_approx2.png 8 Generated by generate_analytical.py, not generate_plots.py. Validates the theory without any training.

Part III: Grokking (Tab 6)

Blog Figure Description Pipeline Output Tab Verified?
Fig. 13a — Grokking loss curves Training and test loss over 50k epochs. Three stages: (I) train loss drops, (II) test loss drops, (III) both near zero. Stage boundaries marked. pXXX_grokk_loss.json → interactive Plotly chart 6 From training_curves.json. Stage boundaries detected by grokking_stage_detector.py. Shaded regions distinguish the three stages.
Fig. 13b — Grokking accuracy curves Training and test accuracy. Train → 100% in Stage I, test jumps in Stage II. pXXX_grokk_acc.json → interactive Plotly chart 6 Computed by running forward pass on train/test data at each checkpoint.
Fig. 13c — Phase alignment progress Average sin(D_m*) over training. Decreases throughout, steepest in Stage II. pXXX_grokk_abs_phase_diff.png
Fig. 13d — IPR and parameter norm Dual-axis: IPR (Fourier sparsity) increases sharply in Stage II; parameter norm shrinks in Stage III. pXXX_grokk_avg_ipr.png 6 IPR uses the corrected per-frequency magnitude formula: A_k = sqrt(c_k² + s_k²), IPR = Σ A_k⁴ / (Σ A_k²)². Parameter norms from training_curves.json.
Fig. 14 — Memorization accuracy heatmap Three panels at end of Stage I: (1) training data distribution under symmetry, (2) accuracy grid, (3) softmax probability at ground truth. Red rectangles = true test pairs. pXXX_grokk_memorization_accuracy.png 6 Forward pass at the checkpoint closest to stage1_end. The symmetric architecture guarantees ~70% test accuracy during memorization.
Fig. 15 — Common-to-rare memorization Four panels: training data distribution + accuracy at 3 timepoints during Stage I. Shows common pairs (both (i,j) and (j,i) in train) memorized before rare pairs (only one ordering). pXXX_grokk_memorization_common_to_rare.png 6 Epochs selected at 0, stage1_end/2, stage1_end. Red rectangles mark asymmetric training pairs.
Fig. 16 — Weight evolution during grokking 2×3 DFT heatmap grid showing θ_m and ξ_m at Step 0 (random init), end of Stage I (noisy multi-frequency), and end of Stage II (clean single-frequency). pXXX_grokk_decoded_weights_dynamic.png 6 DFT coefficients W @ fourier_basis.T at 3 key epochs. The transition from diffuse to sparse confirms the sparsification narrative.

Overview (Tab 1)

Blog Figure Description Pipeline Output Tab Verified?
— Overview dashboard 2×2 grid: standard loss + grokking loss (top), standard IPR + grokking IPR (bottom). Plus phase scatter from standard final checkpoint. pXXX_overview_loss_ipr.png, pXXX_overview_phase_scatter.png, pXXX_overview.json 1 Combines data from standard and grokking runs. The phase scatter uses the same computation as Fig. 4. For $p < 19$, only the standard column is shown.

Not Currently Generated (Blog-Only Figures)

Blog Figure Description Status
Fig. 1 — Architecture illustration Schematic of the two-layer network, DFT decomposition, and mechanism overview. Static illustration, not dependent on p. Could be included as a fixed image in the app.

Interactive JSON Data Files

In addition to static PNG plots, the pipeline generates JSON files for interactive Plotly charts in the Gradio app:

File Content Used In
pXXX_overview.json Standard loss/IPR + grokking loss/IPR time series Tab 1: Interactive loss and IPR charts
pXXX_neuron_spectra.json Per-neuron Fourier magnitudes (W_in and W_out) for top-20 neurons, sorted by frequency Tab 2: Neuron Inspector dropdown → bar chart of frequency decomposition
pXXX_logits_interactive.json Output logits for p representative (a,b) pairs, plus correct answers Tab 4: Logit Explorer dropdown → bar chart with correct answer highlighted
pXXX_grokk_loss.json Full training/test loss curves + stage boundaries Tab 6: Interactive loss chart with stage shading
pXXX_grokk_acc.json Accuracy at each checkpoint epoch + stage boundaries Tab 6: Interactive accuracy chart with stage shading
pXXX_grokk_epoch_data.json p×p accuracy grids at ~10 evenly-spaced grokking epochs Tab 6: Epoch Slider → heatmap animation across training
pXXX_metadata.json Config for all 5 runs + final metrics (loss, accuracy) Displayed in the app's info panel for the selected $p$

Output Structure

All plots for a modulus are saved in a single flat directory. Each file is prefixed with pXXX_ so the folder is self-contained and browsable:

precomputed_results/p_023/
  # Metadata
  p023_metadata.json

  # Tab 1: Overview (Blog: summary of standard + grokking)
  p023_overview_loss_ipr.png          # 2×2 grid: loss + IPR for both setups
  p023_overview_phase_scatter.png     # Phase alignment scatter (same as Fig. 4)
  p023_overview.json                  # Interactive data

  # Tab 2: Fourier Weights (Blog: Figures 2, 3)
  p023_dft_heatmap_in.png             # DFT heatmap, W_E input layer (Fig. 2)
  p023_dft_heatmap_out.png            # DFT heatmap, W_L output layer (Fig. 2)
  p023_lineplot_in.png                # Cosine fits, input layer (Fig. 3 left)
  p023_lineplot_out.png               # Cosine fits, output layer (Fig. 3 right)
  p023_neuron_spectra.json            # Interactive: neuron inspector

  # Tab 3: Phase Analysis (Blog: Figures 4, 5, 6)
  p023_phase_distribution.png         # Polar phase plot (Fig. 5)
  p023_phase_relationship.png         # 2φ vs ψ scatter (Fig. 4)
  p023_magnitude_distribution.png     # Violin plots (Fig. 6)

  # Tab 4: Output Logits (Blog: Figure 7)
  p023_output_logits.png              # Logit heatmap (Fig. 7)
  p023_logits_interactive.json        # Interactive: logit explorer

  # Tab 5: Lottery Mechanism (Blog: Figures 9, 10)
  p023_lottery_mech_magnitude.png     # Magnitude race (Fig. 9 right)
  p023_lottery_mech_phase.png         # Phase misalignment race (Fig. 9 left)
  p023_lottery_beta_contour.png       # Contour plot (Fig. 10)

  # Tab 6: Grokking (Blog: Figures 13, 14, 15, 16)
  p023_grokk_loss.json                # Interactive loss curves (Fig. 13a)
  p023_grokk_acc.json                 # Interactive accuracy curves (Fig. 13b)
  p023_grokk_abs_phase_diff.png       # Phase alignment progress (Fig. 13c)
  p023_grokk_avg_ipr.png              # IPR + param norms (Fig. 13d)
  p023_grokk_memorization_accuracy.png          # 3-panel heatmap (Fig. 14)
  p023_grokk_memorization_common_to_rare.png    # 4-panel sequence (Fig. 15)
  p023_grokk_decoded_weights_dynamic.png        # DFT evolution (Fig. 16)
  p023_grokk_epoch_data.json          # Interactive: epoch slider

  # Tab 7: Gradient Dynamics (Blog: Figures 8, 11, 12)
  p023_phase_align_quad.png           # Phase + magnitude, Quad (Fig. 8 left)
  p023_single_freq_quad.png           # DFT heatmap over time, Quad (Fig. 11)
  p023_phase_align_relu.png           # Phase + magnitude, ReLU (Fig. 8 right / 12b)
  p023_single_freq_relu.png           # DFT heatmap over time, ReLU (Fig. 12a)

  # Tab 8: Decoupled Simulation (no blog figure number)
  p023_phase_align_approx1.png        # ODE simulation case 1
  p023_phase_align_approx2.png        # ODE simulation case 2

29 files per $p$: 21 PNGs + 6 JSONs from trained models, 2 PNGs from analytical simulation.


Correctness Verification

How each computation matches the paper

  1. DFT Decomposition (Figs. 2, 11, 12a, 16): We compute W @ fourier_basis.T where fourier_basis is the orthonormal DFT basis from get_fourier_basis(p). The basis has rows: [Const, cos 1, sin 1, cos 2, sin 2, ..., cos K, sin K] with K = (p-1)/2 for odd $p$. Each row is L2-normalized. This matches the standard real DFT on Z_p.

  2. Phase extraction (Figs. 4, 8, 9, 13c): For frequency k, the DFT gives coefficients (c_k, s_k) at indices (2k-1, 2k). The magnitude is α = sqrt(2/p) · sqrt(c_k² + s_k²), and the phase is φ = arctan2(-s_k, c_k). This convention matches the paper's θ_m[j] = α cos(ω_k j + φ) representation.

  3. IPR (Figs. 13d, Overview): Uses the corrected per-frequency magnitude formula: A_k = sqrt(c_k² + s_k²) (combining cos/sin pairs), then IPR = Σ A_k⁴ / (Σ A_k²)². This gives IPR → 1 for perfect single-frequency neurons, matching the paper's definition.

  4. Phase alignment (Fig. 4): The doubled-phase relationship ψ_m = 2φ_m is verified by extracting φ from W_in and ψ from W_out using the same compute_neuron() function, then plotting (2φ mod 2π) vs (ψ mod 2π).

  5. Output logits (Fig. 7): Forward pass uses the same activation function as training (ReLU for standard run). The flawed indicator structure (main diagonal + two ghost diagonals) is visible because the standard run trains to 100% accuracy with clean Fourier features.

  6. Lottery mechanism (Figs. 9, 10): Uses the quad_random run (quadratic activation, random init) which matches the theoretical setting. decode_scales_phis() extracts per-frequency magnitudes and phases at each checkpoint. The winning frequency is the one with smallest initial |D| = |2φ - ψ|.

  7. Grokking stages (Figs. 13--16): grokking_stage_detector.py identifies stage boundaries from training curves. Stage I ends when train accuracy ≈ 1.0, Stage II ends when test accuracy ≈ 1.0. Memorization heatmaps use forward pass at the closest checkpoint to stage1_end.

  8. Analytical simulation (Tab 8): Numerically integrates the four-variable ODE system from Section 5.3 of the paper. No neural network is involved — this validates the theoretical dynamics directly.

Why results generalize across $p$

The paper's theory is stated for general odd $p$. Key properties that scale:

  • Fourier basis: Always has (p-1)/2 non-DC frequencies for any odd $p$.
  • Phase alignment: The ψ = 2φ relationship is a consequence of the gradient dynamics, independent of p.
  • Lottery mechanism: Random initial misalignments are uniform on [0, 2π) for any p.
  • Grokking three stages: The stage structure depends on the balance of loss gradient vs. weight decay, not on p specifically (though the stage durations and test accuracy during memorization may vary).
  • Network width: d_mlp scales as O(p²) to maintain the neuron-to-frequency ratio, ensuring enough neurons per frequency for diversification.

Scripts

Script Purpose
run_pipeline.sh Runs the complete pipeline (train + plots + analytical + verify) for a single modulus.
train_all.py Trains all 5 model configurations. Saves checkpoints + training_curves.json.
generate_plots.py Loads trained models and generates all model-dependent plots (Tabs 1--7) plus interactive JSONs and metadata.
generate_analytical.py Runs gradient flow simulations to generate theory plots (Tab 8). No model needed.
prime_config.py Configuration: moduli list, d_mlp formula, training run parameters.
neuron_selector.py Automated neuron selection for plots (replaces hardcoded indices from notebooks).
grokking_stage_detector.py Detects memorization/transition/generalization stage boundaries from training curves.

Analytical Simulations (No Model Needed)

generate_analytical.py produces 2 plots per $p$ by simulating gradient flow on decoupled frequency dynamics. These validate the theoretical analysis without training any model.

  • Case 1: Shows phase difference D* converging from initial conditions (φ₀=1.5, ψ₀=0.18)
  • Case 2: Different initial conditions (φ₀=-0.72, ψ₀=-2.91) showing convergence from the other side

Both cases confirm the phase alignment theorem: D → 0 is the stable attractor, D → π is unstable.