Deploy app with precomputed results for p=15,23,29,31
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +41 -0
- .gitignore +18 -0
- README.md +200 -12
- hf_app/app.py +1375 -0
- hf_app/requirements.txt +9 -0
- precompute/README.md +337 -0
- precompute/__init__.py +0 -0
- precompute/generate_analytical.py +358 -0
- precompute/generate_plots.py +2192 -0
- precompute/grokking_stage_detector.py +55 -0
- precompute/neuron_selector.py +98 -0
- precompute/prime_config.py +135 -0
- precompute/run_all.sh +35 -0
- precompute/run_pipeline.sh +60 -0
- precompute/train_all.py +290 -0
- precomputed_results/p_015/p015_full_training_para_origin.png +0 -0
- precomputed_results/p_015/p015_lineplot_in.png +3 -0
- precomputed_results/p_015/p015_lineplot_out.png +3 -0
- precomputed_results/p_015/p015_logits_interactive.json +1 -0
- precomputed_results/p_015/p015_lottery_beta_contour.png +0 -0
- precomputed_results/p_015/p015_lottery_mech_magnitude.png +0 -0
- precomputed_results/p_015/p015_lottery_mech_phase.png +0 -0
- precomputed_results/p_015/p015_magnitude_distribution.png +0 -0
- precomputed_results/p_015/p015_metadata.json +82 -0
- precomputed_results/p_015/p015_neuron_spectra.json +1 -0
- precomputed_results/p_015/p015_output_logits.png +0 -0
- precomputed_results/p_015/p015_overview.json +1 -0
- precomputed_results/p_015/p015_overview_loss_ipr.png +0 -0
- precomputed_results/p_015/p015_overview_phase_scatter.png +0 -0
- precomputed_results/p_015/p015_phase_align_approx1.png +3 -0
- precomputed_results/p_015/p015_phase_align_approx2.png +3 -0
- precomputed_results/p_015/p015_phase_align_quad.png +0 -0
- precomputed_results/p_015/p015_phase_align_relu.png +0 -0
- precomputed_results/p_015/p015_phase_distribution.png +0 -0
- precomputed_results/p_015/p015_phase_relationship.png +0 -0
- precomputed_results/p_015/p015_single_freq_quad.png +3 -0
- precomputed_results/p_015/p015_single_freq_relu.png +3 -0
- precomputed_results/p_015/p015_training_log.json +0 -0
- precomputed_results/p_023/p023_full_training_para_origin.png +3 -0
- precomputed_results/p_023/p023_grokk_abs_phase_diff.png +0 -0
- precomputed_results/p_023/p023_grokk_acc.json +1 -0
- precomputed_results/p_023/p023_grokk_acc.png +0 -0
- precomputed_results/p_023/p023_grokk_avg_ipr.png +0 -0
- precomputed_results/p_023/p023_grokk_decoded_weights_dynamic.png +3 -0
- precomputed_results/p_023/p023_grokk_epoch_data.json +1 -0
- precomputed_results/p_023/p023_grokk_loss.json +0 -0
- precomputed_results/p_023/p023_grokk_loss.png +0 -0
- precomputed_results/p_023/p023_grokk_memorization_accuracy.png +0 -0
- precomputed_results/p_023/p023_grokk_memorization_common_to_rare.png +3 -0
- precomputed_results/p_023/p023_lineplot_in.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,44 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
precomputed_results/p_015/p015_lineplot_in.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
precomputed_results/p_015/p015_lineplot_out.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
precomputed_results/p_015/p015_phase_align_approx1.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
precomputed_results/p_015/p015_phase_align_approx2.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
precomputed_results/p_015/p015_single_freq_quad.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
precomputed_results/p_015/p015_single_freq_relu.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
precomputed_results/p_023/p023_full_training_para_origin.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
precomputed_results/p_023/p023_grokk_decoded_weights_dynamic.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
precomputed_results/p_023/p023_grokk_memorization_common_to_rare.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
precomputed_results/p_023/p023_lineplot_in.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
precomputed_results/p_023/p023_lineplot_out.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
precomputed_results/p_023/p023_output_logits.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
precomputed_results/p_023/p023_overview_loss_ipr.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
precomputed_results/p_023/p023_phase_align_approx1.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
precomputed_results/p_023/p023_phase_align_approx2.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
precomputed_results/p_023/p023_single_freq_quad.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
precomputed_results/p_023/p023_single_freq_relu.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
precomputed_results/p_029/p029_full_training_para_origin.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
precomputed_results/p_029/p029_grokk_decoded_weights_dynamic.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
precomputed_results/p_029/p029_grokk_memorization_accuracy.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
precomputed_results/p_029/p029_grokk_memorization_common_to_rare.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
precomputed_results/p_029/p029_lineplot_in.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
precomputed_results/p_029/p029_lineplot_out.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
precomputed_results/p_029/p029_output_logits.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
precomputed_results/p_029/p029_overview_loss_ipr.png filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
precomputed_results/p_029/p029_phase_align_approx1.png filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
precomputed_results/p_029/p029_phase_align_approx2.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
precomputed_results/p_029/p029_single_freq_quad.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
precomputed_results/p_029/p029_single_freq_relu.png filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
precomputed_results/p_031/p031_full_training_para_origin.png filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
precomputed_results/p_031/p031_grokk_decoded_weights_dynamic.png filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
precomputed_results/p_031/p031_grokk_memorization_accuracy.png filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
precomputed_results/p_031/p031_grokk_memorization_common_to_rare.png filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
precomputed_results/p_031/p031_lineplot_in.png filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
precomputed_results/p_031/p031_lineplot_out.png filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
precomputed_results/p_031/p031_output_logits.png filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
precomputed_results/p_031/p031_overview_loss_ipr.png filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
precomputed_results/p_031/p031_phase_align_approx1.png filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
precomputed_results/p_031/p031_phase_align_approx2.png filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
precomputed_results/p_031/p031_single_freq_quad.png filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
precomputed_results/p_031/p031_single_freq_relu.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src/wandb/
|
| 2 |
+
notebooks/simulate_dynamics.ipynb
|
| 3 |
+
|
| 4 |
+
# Claude AI files
|
| 5 |
+
.claude/
|
| 6 |
+
|
| 7 |
+
# Python cache files
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.py[cod]
|
| 10 |
+
*$py.class
|
| 11 |
+
|
| 12 |
+
# Model checkpoints (too large for git; regenerate with precompute/run_pipeline.sh)
|
| 13 |
+
trained_models/
|
| 14 |
+
saved_models/
|
| 15 |
+
|
| 16 |
+
# OS files
|
| 17 |
+
.DS_Store
|
| 18 |
+
tmp/
|
README.md
CHANGED
|
@@ -1,14 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
title: Modular Addition Feature Learning
|
| 3 |
-
emoji: 😻
|
| 4 |
-
colorFrom: green
|
| 5 |
-
colorTo: green
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 6.6.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
-
short_description: Interactive Demo of Paper on Modular Addition
|
| 12 |
-
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# On the Mechanism and Dynamics of Modular Addition
|
| 2 |
+
|
| 3 |
+
### Fourier Features, Lottery Ticket, and Grokking
|
| 4 |
+
|
| 5 |
+
**Jianliang He, Leda Wang, Siyu Chen, Zhuoran Yang**
|
| 6 |
+
*Department of Statistics and Data Science, Yale University*
|
| 7 |
+
|
| 8 |
+
[[arXiv (coming soon)](#)] [[Blog (coming soon)](#)] [[Interactive Demo](https://huggingface.co/spaces/y-agent/modular-addition-feature-learning)]
|
| 9 |
+
|
| 10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
## Overview
|
| 13 |
+
|
| 14 |
+
This repository provides the code for studying how a two-layer neural network learns modular arithmetic $f(x,y) = (x+y) \bmod p$. We analyze three phenomena:
|
| 15 |
+
|
| 16 |
+
1. **Fourier Feature Learning** — Each neuron independently discovers a cosine wave at a single frequency, collectively implementing a discrete Fourier transform that the network was never taught.
|
| 17 |
+
2. **Lottery Ticket Dynamics** — Random initialization determines which frequency each neuron will specialize in: the frequency with the best initial phase alignment wins a winner-take-all competition.
|
| 18 |
+
3. **Grokking** — Under partial data with weight decay, the network first memorizes, then suddenly generalizes through a three-stage process: memorization → sparsification → cleanup.
|
| 19 |
+
|
| 20 |
+
## Interactive Demo
|
| 21 |
+
|
| 22 |
+
An interactive Gradio app visualizes all results with math explanations and interactive Plotly charts:
|
| 23 |
+
|
| 24 |
+
- **9 analysis tabs** covering mechanism, dynamics, grokking, and analytical simulations
|
| 25 |
+
- **Interactive features**: neuron frequency inspector, logit explorer, grokking epoch slider
|
| 26 |
+
- **On-demand training**: generate results for any odd $p \geq 3$ directly from the app
|
| 27 |
+
- **Pre-computed examples** included for $p = 15, 23, 29, 31$
|
| 28 |
+
|
| 29 |
+
### Launch Locally
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
pip install -r requirements.txt
|
| 33 |
+
python hf_app/app.py
|
| 34 |
+
# Opens at http://localhost:7860
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### Deploy to Hugging Face Spaces
|
| 38 |
+
|
| 39 |
+
1. Create a new Space at [huggingface.co/new-space](https://huggingface.co/new-space) (SDK: Gradio)
|
| 40 |
+
2. Push the repo:
|
| 41 |
+
```bash
|
| 42 |
+
git remote add hf https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 43 |
+
git push hf main
|
| 44 |
+
```
|
| 45 |
+
3. The app reads from `precomputed_results/` — the included examples (p=15, 23, 29, 31) work out of the box
|
| 46 |
+
4. Users can generate results for additional $p$ values on-demand via the "Generate" button. New results are auto-committed back to the Space repo so they persist.
|
| 47 |
+
|
| 48 |
+
> **Tip:** For GPU-accelerated on-demand training, select a GPU runtime in your Space settings.
|
| 49 |
+
|
| 50 |
+
## Pre-computation Pipeline
|
| 51 |
+
|
| 52 |
+
The `precompute/` directory trains 5 model configurations per modulus and generates all plots + interactive JSON data. See [`precompute/README.md`](precompute/README.md) for full documentation.
|
| 53 |
+
|
| 54 |
+
### Quick Start
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
# Full pipeline for a single modulus (train → plots → analytical → verify)
|
| 58 |
+
bash precompute/run_pipeline.sh 23
|
| 59 |
+
|
| 60 |
+
# With custom d_mlp
|
| 61 |
+
bash precompute/run_pipeline.sh 23 --d_mlp 128
|
| 62 |
+
|
| 63 |
+
# Delete checkpoints after generating plots (saves disk space)
|
| 64 |
+
CLEANUP=1 bash precompute/run_pipeline.sh 23
|
| 65 |
+
|
| 66 |
+
# Batch: all odd p in [3, 99]
|
| 67 |
+
bash precompute/run_all.sh
|
| 68 |
+
|
| 69 |
+
# Or up to p=199
|
| 70 |
+
MAX_P=199 bash precompute/run_all.sh
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### Manual Steps
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
# Step 1: Train all 5 configurations
|
| 77 |
+
python precompute/train_all.py --p 23 --output ./trained_models --resume
|
| 78 |
+
|
| 79 |
+
# Step 2: Generate model-based plots (21 PNGs + 7 JSONs)
|
| 80 |
+
python precompute/generate_plots.py --p 23 --input ./trained_models --output ./precomputed_results
|
| 81 |
+
|
| 82 |
+
# Step 3: Generate analytical simulation plots (2 PNGs, no model needed)
|
| 83 |
+
python precompute/generate_analytical.py --p 23 --output ./precomputed_results
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### Output
|
| 87 |
+
|
| 88 |
+
Each modulus produces ~33 files in `precomputed_results/p_XXX/`:
|
| 89 |
+
|
| 90 |
+
| Category | Files | Description |
|
| 91 |
+
|----------|-------|-------------|
|
| 92 |
+
| Overview (Tab 1) | 2 PNGs + 1 JSON | Loss, IPR, phase scatter |
|
| 93 |
+
| Fourier Weights (Tab 2) | 3 PNGs + 1 JSON | DFT heatmaps, cosine fits, neuron spectra |
|
| 94 |
+
| Phase Analysis (Tab 3) | 3 PNGs | Phase distribution, alignment, magnitudes |
|
| 95 |
+
| Output Logits (Tab 4) | 1 PNG + 1 JSON | Logit heatmap, interactive explorer |
|
| 96 |
+
| Lottery Mechanism (Tab 5) | 3 PNGs | Magnitude race, phase convergence, contour |
|
| 97 |
+
| Grokking (Tab 6) | 5 PNGs + 3 JSONs | Loss/acc curves, memorization, weight evolution |
|
| 98 |
+
| Gradient Dynamics (Tab 7) | 4 PNGs | Phase alignment + DFT for Quad and ReLU |
|
| 99 |
+
| Decoupled Simulation (Tab 8) | 2 PNGs | Analytical ODE integration |
|
| 100 |
+
| Metadata | 2 JSONs | Config + training log |
|
| 101 |
+
|
| 102 |
+
> **Note:** Grokking results (Tab 6) require $p \geq 19$. Smaller values of $p$ have too few data points for a meaningful train/test split.
|
| 103 |
+
|
| 104 |
+
## The 5 Training Configurations
|
| 105 |
+
|
| 106 |
+
| Config | Activation | Optimizer | LR | Weight Decay | Data | Epochs | Used In |
|
| 107 |
+
|--------|-----------|-----------|-----|-------------|------|--------|---------|
|
| 108 |
+
| `standard` | ReLU | AdamW | 5e-5 | 0 | 100% | 5,000 | Tabs 1–4 |
|
| 109 |
+
| `grokking` | ReLU | AdamW | 1e-4 | 2.0 | 75% | 50,000 | Tabs 1, 6 |
|
| 110 |
+
| `quad_random` | Quad | AdamW | 5e-5 | 0 | 100% | 5,000 | Tab 5 |
|
| 111 |
+
| `quad_single_freq` | Quad | SGD | 0.1 | 0 | 100% | 5,000 | Tab 7 |
|
| 112 |
+
| `relu_single_freq` | ReLU | SGD | 0.01 | 0 | 100% | 5,000 | Tab 7 |
|
| 113 |
+
|
| 114 |
+
## Running a Single Experiment
|
| 115 |
+
|
| 116 |
+
For custom experiments outside the pre-computation pipeline:
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
cd src
|
| 120 |
+
|
| 121 |
+
# Train with default config (p=97, d_mlp=1024, ReLU, 5000 epochs)
|
| 122 |
+
python module_nn.py
|
| 123 |
+
|
| 124 |
+
# Train with specific parameters
|
| 125 |
+
python module_nn.py --p 23 --d_mlp 512 --num_epochs 5000 --lr 5e-5
|
| 126 |
+
|
| 127 |
+
# Dry run: see config without training
|
| 128 |
+
python module_nn.py --dry_run --p 23 --d_mlp 512
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
## Notebooks
|
| 132 |
+
|
| 133 |
+
Interactive analysis notebooks in `notebooks/`:
|
| 134 |
+
|
| 135 |
+
| Notebook | Description |
|
| 136 |
+
|----------|-------------|
|
| 137 |
+
| `empirical_insight_standard.ipynb` | Fourier weight analysis, phase distributions, output logits |
|
| 138 |
+
| `empirical_insight_grokk.ipynb` | Grokking stages, weight dynamics, IPR evolution |
|
| 139 |
+
| `lottery_mechanism.ipynb` | Neuron specialization, frequency magnitude/phase tracking |
|
| 140 |
+
| `interprete_gd_dynamics.ipynb` | Phase alignment under single-frequency initialization |
|
| 141 |
+
| `decouple_dynamics_simulation.ipynb` | Analytical gradient flow simulation |
|
| 142 |
+
|
| 143 |
+
## Setup
|
| 144 |
+
|
| 145 |
+
### Requirements
|
| 146 |
+
|
| 147 |
+
- Python 3.8+
|
| 148 |
+
- PyTorch 2.0+
|
| 149 |
+
- CUDA-capable GPU (recommended for $p > 50$; CPU works for small $p$)
|
| 150 |
+
|
| 151 |
+
### Installation
|
| 152 |
+
|
| 153 |
+
```bash
|
| 154 |
+
git clone https://github.com/Y-Agent/modular-addition-feature-learning.git
|
| 155 |
+
cd modular-addition-feature-learning
|
| 156 |
+
pip install -r requirements.txt
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
## Project Structure
|
| 160 |
+
|
| 161 |
+
```
|
| 162 |
+
modular-addition-feature-learning/
|
| 163 |
+
├── src/ # Core source code
|
| 164 |
+
│ ├── module_nn.py # Training script with CLI
|
| 165 |
+
│ ├── nnTrainer.py # Training loop and optimization
|
| 166 |
+
│ ├── model_base.py # Neural network architecture (EmbedMLP)
|
| 167 |
+
│ ├── mechanism_base.py # Fourier analysis and decomposition
|
| 168 |
+
│ ├── utils.py # Configuration and helpers
|
| 169 |
+
│ └── configs.yaml # Default hyperparameters
|
| 170 |
+
├── precompute/ # Batch training and plot generation
|
| 171 |
+
│ ├── run_pipeline.sh # Full pipeline for one modulus
|
| 172 |
+
│ ├── run_all.sh # Batch pipeline for all odd p
|
| 173 |
+
│ ├── train_all.py # Train 5 configurations
|
| 174 |
+
│ ├── generate_plots.py # Generate model-based plots + JSONs
|
| 175 |
+
│ ├── generate_analytical.py # Analytical ODE simulation plots
|
| 176 |
+
│ └── prime_config.py # Configurations and sizing formula
|
| 177 |
+
├── hf_app/ # Gradio web application
|
| 178 |
+
│ └── app.py # Interactive visualization app
|
| 179 |
+
├── precomputed_results/ # Pre-computed plots and data
|
| 180 |
+
│ ├── p_015/ # Results for p=15
|
| 181 |
+
│ ├── p_023/ # Results for p=23
|
| 182 |
+
│ ├── p_029/ # Results for p=29
|
| 183 |
+
│ └── p_031/ # Results for p=31
|
| 184 |
+
├── notebooks/ # Analysis and visualization notebooks
|
| 185 |
+
├── requirements.txt # Python dependencies
|
| 186 |
+
└── README.md
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
## Citation
|
| 190 |
+
|
| 191 |
+
```bibtex
|
| 192 |
+
@article{he2025modular,
|
| 193 |
+
title={On the Mechanism and Dynamics of Modular Addition: Fourier Features, Lottery Ticket, and Grokking},
|
| 194 |
+
author={He, Jianliang and Wang, Leda and Chen, Siyu and Yang, Zhuoran},
|
| 195 |
+
journal={arXiv preprint arXiv:XXXX.XXXXX},
|
| 196 |
+
year={2025}
|
| 197 |
+
}
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
## License
|
| 201 |
+
|
| 202 |
+
[MIT License](LICENSE)
|
hf_app/app.py
ADDED
|
@@ -0,0 +1,1375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Gradio app for Modular Addition Feature Learning visualization.
|
| 4 |
+
Serves pre-computed results for odd moduli p in [3, 199].
|
| 5 |
+
|
| 6 |
+
All results are pre-computed as PNG images and JSON data files.
|
| 7 |
+
No GPU needed at serving time.
|
| 8 |
+
|
| 9 |
+
Tab structure:
|
| 10 |
+
Core Interpretability:
|
| 11 |
+
1. Training Overview -- loss + IPR sparsity
|
| 12 |
+
2. Fourier Weights -- decoded W_in/W_out heatmaps + line plots + neuron inspector
|
| 13 |
+
3. Phase Analysis -- phase distribution, 2phi vs psi, magnitudes
|
| 14 |
+
4. Output Logits -- predicted logit heatmap + interactive logit explorer
|
| 15 |
+
5. Lottery Mechanism -- neuron specialization, magnitude/phase, contour
|
| 16 |
+
Grokking:
|
| 17 |
+
6. Grokking -- loss/acc, phase alignment, IPR, memorization, epoch slider
|
| 18 |
+
Theory:
|
| 19 |
+
7. Gradient Dynamics -- phase alignment for Quad & ReLU single-freq init
|
| 20 |
+
8. Decoupled Simulation -- analytical gradient flow (no model needed)
|
| 21 |
+
Diagnostics:
|
| 22 |
+
9. Training Log -- per-run hyperparameters and epoch-by-epoch metrics
|
| 23 |
+
"""
|
| 24 |
+
import gradio as gr
|
| 25 |
+
import json
|
| 26 |
+
import logging
|
| 27 |
+
import os
|
| 28 |
+
import shutil
|
| 29 |
+
import subprocess
|
| 30 |
+
import sys
|
| 31 |
+
|
| 32 |
+
import numpy as np
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
# Force pandas to be fully imported before plotly lazily imports it
|
| 36 |
+
# (avoids "partially initialized module 'pandas'" in threaded callbacks)
|
| 37 |
+
import pandas # noqa: F401
|
| 38 |
+
import plotly.graph_objects as go
|
| 39 |
+
|
| 40 |
+
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 41 |
+
RESULTS_DIR = os.path.join(PROJECT_ROOT, "precomputed_results")
|
| 42 |
+
TRAINED_MODELS_DIR = os.path.join(PROJECT_ROOT, "trained_models")
|
| 43 |
+
|
| 44 |
+
# Max p for on-demand training (d_mlp grows as O(p^2), memory limit)
|
| 45 |
+
MAX_P_ON_DEMAND = 97
|
| 46 |
+
|
| 47 |
+
COLORS = ['#0D2758', '#60656F', '#DEA54B', '#A32015', '#347186']
|
| 48 |
+
STAGE_COLORS = ['rgba(212,175,55,0.15)', 'rgba(139,115,85,0.15)', 'rgba(192,192,192,0.15)']
|
| 49 |
+
|
| 50 |
+
# KaTeX delimiters for Gradio Markdown
|
| 51 |
+
LATEX_DELIMITERS = [
|
| 52 |
+
{"left": "$$", "right": "$$", "display": True},
|
| 53 |
+
{"left": "$", "right": "$", "display": False},
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
# Custom CSS for Palatino font and styling
|
| 57 |
+
CUSTOM_CSS = r"""
|
| 58 |
+
@import url('https://fonts.googleapis.com/css2?family=Libre+Baskerville:ital,wght@0,400;0,700;1,400&display=swap');
|
| 59 |
+
|
| 60 |
+
* {
|
| 61 |
+
font-family: "Palatino Linotype", "Book Antiqua", Palatino, "Libre Baskerville", Georgia, serif !important;
|
| 62 |
+
}
|
| 63 |
+
code, pre, .code, .monospace {
|
| 64 |
+
font-family: "Menlo", "Consolas", "Monaco", monospace !important;
|
| 65 |
+
}
|
| 66 |
+
.katex, .katex * {
|
| 67 |
+
font-family: KaTeX_Main, "Times New Roman", serif !important;
|
| 68 |
+
}
|
| 69 |
+
h1 {
|
| 70 |
+
font-family: "Palatino Linotype", "Book Antiqua", Palatino, "Libre Baskerville", Georgia, serif !important;
|
| 71 |
+
text-align: center !important;
|
| 72 |
+
margin-bottom: 0.1em !important;
|
| 73 |
+
}
|
| 74 |
+
h3 {
|
| 75 |
+
font-family: "Palatino Linotype", "Book Antiqua", Palatino, "Libre Baskerville", Georgia, serif !important;
|
| 76 |
+
text-align: center !important;
|
| 77 |
+
color: var(--neutral-500) !important;
|
| 78 |
+
font-weight: normal !important;
|
| 79 |
+
margin-top: 0 !important;
|
| 80 |
+
}
|
| 81 |
+
h2, h4 {
|
| 82 |
+
font-family: "Palatino Linotype", "Book Antiqua", Palatino, "Libre Baskerville", Georgia, serif !important;
|
| 83 |
+
}
|
| 84 |
+
blockquote {
|
| 85 |
+
border-left: 3px solid var(--color-accent) !important;
|
| 86 |
+
background-color: var(--block-background-fill) !important;
|
| 87 |
+
padding: 0.5em 1em !important;
|
| 88 |
+
margin: 0.5em 0 !important;
|
| 89 |
+
}
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
# ---------------------------------------------------------------------------
|
| 93 |
+
# Math explanation text for each tab (following the paper precisely)
|
| 94 |
+
# ---------------------------------------------------------------------------
|
| 95 |
+
|
| 96 |
+
MATH_TAB1 = r"""
|
| 97 |
+
### Overview
|
| 98 |
+
|
| 99 |
+
We study how a two-layer neural network learns to compute modular addition $f(x,y) = (x+y) \bmod p$. The network has $M$ hidden neurons. Each input integer $x$ is represented as a one-hot vector, and the network produces a score for each of the $p$ possible answers. During training, the network learns two weight vectors per neuron: an **input weight** $\theta_m$ and an **output weight** $\xi_m$, both vectors of length $p$.
|
| 100 |
+
|
| 101 |
+
#### Two Training Setups
|
| 102 |
+
|
| 103 |
+
1. **Full-data (Tabs 1--5, 7).** Train on all $p^2$ input pairs with no held-out data and no regularization. This produces clean features ideal for studying what the network learns and how.
|
| 104 |
+
|
| 105 |
+
2. **Grokking (Tab 6).** Train on only 75% of input pairs with weight decay $\lambda = 2.0$ (a penalty that shrinks weights over time). These two ingredients -- incomplete data + weight decay -- cause the network to first memorize, then suddenly generalize, a phenomenon called **grokking**.
|
| 106 |
+
|
| 107 |
+
#### What the Network Learns
|
| 108 |
+
|
| 109 |
+
Each neuron's weight vectors turn into **cosine waves** at a single frequency -- the network independently rediscovers the Discrete Fourier Transform. The neurons collectively cover all frequencies with balanced strengths, enabling them to "vote" together and identify the correct answer $(x+y) \bmod p$.
|
| 110 |
+
|
| 111 |
+
#### How It Learns (Dynamics)
|
| 112 |
+
|
| 113 |
+
Frequencies **compete** within each neuron during training. The frequency whose input and output phases happen to start best-aligned grows fastest -- a **lottery ticket mechanism** where the random initialization determines the outcome before training begins.
|
| 114 |
+
|
| 115 |
+
#### Grokking (Three Stages)
|
| 116 |
+
|
| 117 |
+
When trained on partial data with weight decay: **(I) Memorization** -- the network fits the training data using noisy, multi-frequency features. **(II) Generalization** -- weight decay prunes away the noise, leaving clean single-frequency features; test accuracy jumps. **(III) Cleanup** -- weight decay slowly polishes the solution.
|
| 118 |
+
|
| 119 |
+
#### Progress Measures on These Plots
|
| 120 |
+
|
| 121 |
+
- **Loss**: Cross-entropy loss (lower = better predictions). We show both training loss and test loss.
|
| 122 |
+
|
| 123 |
+
- **IPR (Inverse Participation Ratio)**: Measures how concentrated a neuron's energy is across frequencies. We decompose each neuron's weights into Fourier components, measure the strength $A_k$ at each frequency $k$, and compute:
|
| 124 |
+
|
| 125 |
+
$$\text{IPR} = \frac{\sum_k A_k^4}{\left(\sum_k A_k^2\right)^2}.$$
|
| 126 |
+
|
| 127 |
+
When a neuron uses only **one frequency**, IPR $= 1$ (fully specialized). When energy is spread across **many frequencies**, IPR is close to $0$. Watching IPR rise toward 1 during training shows the network specializing.
|
| 128 |
+
|
| 129 |
+
- **Phase scatter**: Each neuron has an input phase $\phi_m$ and output phase $\psi_m$. The theory predicts the output phase equals twice the input phase ($\psi_m = 2\phi_m$). The scatter plot checks this: all points should fall on the diagonal.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
MATH_TAB2 = r"""
|
| 133 |
+
### Every Neuron is a Cosine Wave
|
| 134 |
+
> **Setup:** ReLU activation, full data, no weight decay.
|
| 135 |
+
|
| 136 |
+
After training, each neuron's weight vectors become clean **cosine waves** at a single frequency. Concretely, the input weight of neuron $m$ looks like:
|
| 137 |
+
|
| 138 |
+
$$\underbrace{\theta_m[j]}_{\text{input weight at position } j} = \underbrace{\alpha_m}_{\text{input magnitude}} \cdot \cos\!\left(\underbrace{\frac{2\pi k}{p}}_{\text{frequency}} \cdot j + \underbrace{\phi_m}_{\text{input phase}}\right),$$
|
| 139 |
+
|
| 140 |
+
and the output weight has the same form with its own magnitude $\beta_m$ (output magnitude) and phase $\psi_m$ (output phase). Each neuron picks **one frequency** $k$ out of the $(p{-}1)/2$ possible frequencies. No one told the network about Fourier analysis -- it rediscovered this representation on its own through training.
|
| 141 |
+
|
| 142 |
+
**Heatmap**: Each row is a neuron, each column is a Fourier component (cosine and sine at each frequency). If a row has only one bright cell, that neuron is using a single frequency -- and that's exactly what we see.
|
| 143 |
+
|
| 144 |
+
**Line Plots**: The dots are the actual learned weights; the dashed curves are best-fit cosines. The near-perfect fits confirm each neuron is well-described by a single cosine at a single frequency.
|
| 145 |
+
|
| 146 |
+
**Neuron Inspector**: Select a neuron from the dropdown to see how its energy is distributed across all frequencies (for both input and output weights).
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
MATH_TAB3 = r"""
|
| 150 |
+
### Phase Alignment and Collective Diversification
|
| 151 |
+
> **Setup:** ReLU activation, full data, no weight decay.
|
| 152 |
+
|
| 153 |
+
#### The Input and Output Phases Lock Together
|
| 154 |
+
|
| 155 |
+
Each neuron has an input phase $\phi_m$ and an output phase $\psi_m$ (the "shift" of each cosine wave). These are not independent -- training drives them into a precise relationship:
|
| 156 |
+
|
| 157 |
+
$$\underbrace{\psi_m}_{\text{output phase}} = 2 \times \underbrace{\phi_m}_{\text{input phase}}.$$
|
| 158 |
+
|
| 159 |
+
**Why "doubled"?** The activation function squares (or, for ReLU, roughly squares) the sum of two cosines. Squaring a cosine at phase $\phi$ naturally produces terms at phase $2\phi$. The output layer learns to match this by setting its own phase to $2\phi$, so the two layers work together coherently.
|
| 160 |
+
|
| 161 |
+
The **scatter plot** checks this: we plot $2\phi_m$ (horizontal) vs. $\psi_m$ (vertical) for every neuron. If the relationship holds, all points land on the diagonal. This relationship is not built into the architecture -- it **emerges from training** (see Tab 7 for why).
|
| 162 |
+
|
| 163 |
+
#### Neurons Organize Themselves into a Balanced Ensemble
|
| 164 |
+
|
| 165 |
+
The neurons don't just specialize to single frequencies -- they also organize *collectively*:
|
| 166 |
+
|
| 167 |
+
1. **Frequency balance:** Every frequency gets roughly the same number of neurons.
|
| 168 |
+
2. **Phase spread:** Within each frequency group, the phases are spread uniformly around the circle. This is what enables **noise cancellation** -- the random noise from individual neurons averages out when their phases are evenly spaced.
|
| 169 |
+
3. **Magnitude balance:** All neurons contribute roughly equally to the output (no single neuron dominates).
|
| 170 |
+
|
| 171 |
+
The **polar plot** shows phases at multiples ($1\times, 2\times, 3\times, 4\times$) on concentric rings -- uniform spread confirms the cancellation condition. The **violin plots** show the distribution of input magnitudes ($\alpha$) and output magnitudes ($\beta$) -- tight concentration confirms magnitude balance.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
MATH_TAB4 = r"""
|
| 175 |
+
### The Mechanism: Majority Voting in Fourier Space
|
| 176 |
+
> **Setup:** ReLU activation, full data, no weight decay.
|
| 177 |
+
|
| 178 |
+
#### How Neurons Vote for the Correct Answer
|
| 179 |
+
|
| 180 |
+
Each neuron produces a score for every possible output $j \in \{0, 1, \ldots, p{-}1\}$. Thanks to the phase alignment ($\psi = 2\phi$, see Tab 3), each neuron's score has a **signal** component that peaks at the correct answer $j = (x+y) \bmod p$, plus **noise** that depends on that neuron's particular phase.
|
| 181 |
+
|
| 182 |
+
When we sum over many neurons within a frequency group, the signal adds up (every neuron agrees on the right answer) while the noise cancels out (different neurons have different phases, and the phase spread from Tab 3 ensures the noise averages to zero). This is **majority voting** -- each neuron casts a noisy vote, but the consensus is correct.
|
| 183 |
+
|
| 184 |
+
#### The "Flawed Indicator"
|
| 185 |
+
|
| 186 |
+
After summing over all neurons and all frequency groups, the network's output simplifies to:
|
| 187 |
+
|
| 188 |
+
$$\text{score for answer } j \;\propto\; \underbrace{\frac{p}{2} \cdot \mathbf{1}[j = (x{+}y) \bmod p]}_{\text{correct answer (strongest)}} \;+\; \underbrace{\frac{p}{4} \cdot \bigl(\mathbf{1}[j = 2x \bmod p] + \mathbf{1}[j = 2y \bmod p]\bigr)}_{\text{two "ghost" peaks (half strength)}}.$$
|
| 189 |
+
|
| 190 |
+
The correct answer gets score $p/2$, but two **spurious ghost peaks** appear at $2x \bmod p$ and $2y \bmod p$ with score $p/4$. The correct answer always wins because $p/2 > p/4$, so the network always predicts correctly despite the ghosts.
|
| 191 |
+
|
| 192 |
+
**Heatmap**: The network's output scores for all inputs with $x = 0$. The bright diagonal is the correct answer. The faint lines are the ghost peaks.
|
| 193 |
+
|
| 194 |
+
**Logit Explorer**: Pick an input pair $(x, y)$ to see the full score distribution. The correct answer (highlighted) should be the tallest bar.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
MATH_TAB5 = r"""
|
| 198 |
+
### The Lottery Ticket: How Each Neuron Picks Its Frequency
|
| 199 |
+
> **Setup:** Quadratic activation ($\sigma(x) = x^2$), full data, random initialization.
|
| 200 |
+
|
| 201 |
+
#### The Competition
|
| 202 |
+
|
| 203 |
+
At the start of training, every neuron has a tiny bit of energy at **every** frequency -- nothing is specialized yet. But the input and output phases at each frequency start at random values, so some frequencies happen to be better aligned (input phase and output phase closer to the $\psi = 2\phi$ relationship) than others.
|
| 204 |
+
|
| 205 |
+
The key insight: **a frequency grows faster when its phases are better aligned.** The growth rate of a frequency's magnitude depends on how close it is to alignment:
|
| 206 |
+
|
| 207 |
+
$$\text{growth rate} \;\propto\; \cos(\underbrace{2\phi - \psi}_{\text{phase misalignment }\mathcal{D}}).$$
|
| 208 |
+
|
| 209 |
+
When the misalignment $\mathcal{D}$ is small (phases nearly aligned), $\cos(\mathcal{D}) \approx 1$ and the frequency grows quickly. When $\mathcal{D}$ is large, growth stalls.
|
| 210 |
+
|
| 211 |
+
#### Winner Takes All
|
| 212 |
+
|
| 213 |
+
This creates a **positive feedback loop**: the best-aligned frequency grows a little, which helps it align even better, which makes it grow even faster. The gap compounds exponentially until one frequency completely dominates -- **the winner takes all.**
|
| 214 |
+
|
| 215 |
+
The winning frequency is simply the one that started closest to alignment:
|
| 216 |
+
|
| 217 |
+
$$\text{winning frequency} = \text{the } k \text{ with smallest initial misalignment } |\mathcal{D}_m^k|.$$
|
| 218 |
+
|
| 219 |
+
This is a **lottery ticket**: the outcome is determined by the random initialization before training even begins. Since each neuron draws independent random phases, different neurons pick different winning frequencies, naturally producing the balanced frequency coverage seen in Tab 3.
|
| 220 |
+
|
| 221 |
+
**Phase plot:** Shows how the misalignment $\mathcal{D}$ evolves over training for each frequency within one neuron. The winner (red) converges to zero first; the others barely move.
|
| 222 |
+
|
| 223 |
+
**Magnitude plot:** Shows how the output magnitude $\beta$ (strength of each frequency) evolves. All start equal. Once the winner aligns, it grows explosively while the others stay frozen.
|
| 224 |
+
|
| 225 |
+
**Contour plot:** Final magnitude as a function of (initial magnitude, initial misalignment). Largest values appear at small misalignment -- confirming that alignment determines the winner.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
MATH_TAB6 = r"""
|
| 229 |
+
### Grokking: From Memorization to Generalization
|
| 230 |
+
> **Setup:** ReLU activation, 75% training fraction, weight decay $\lambda = 2.0$.
|
| 231 |
+
|
| 232 |
+
Under the train-test split setup, the network quickly memorizes the training set but takes much longer to generalize. Our analysis reveals grokking is a **three-stage process**, each driven by a different balance of forces.
|
| 233 |
+
|
| 234 |
+
**Stage I -- Memorization (loss gradient dominates).** The loss gradient dominates and the network rapidly memorizes training data. Training accuracy reaches 100% while test accuracy reaches only ~70%. The ~70% figure (not ~50%) arises because the architecture is symmetric in $x$ and $y$: since $\theta_m[x] + \theta_m[y]$ is invariant under swapping $(x,y) \leftrightarrow (y,x)$, memorizing $(x,y)$ automatically gives the correct answer for $(y,x)$. The lottery mechanism runs on incomplete data, producing a "noisy" multi-frequency representation. We also observe a **common-to-rare ordering**: the network first memorizes symmetric pairs (both $(i,j)$ and $(j,i)$ in training) while actively *suppressing* rare pairs, before eventually memorizing them too.
|
| 235 |
+
|
| 236 |
+
**Stage II -- Fast Generalization (loss + weight decay).** Weight decay penalizes all magnitudes equally, but the dominant frequency has much larger magnitude and can "afford" the penalty. Non-feature frequencies are driven to zero -- a **sparsification** effect visible as the sharp IPR increase. This transforms the noisy memorization solution into clean single-frequency-per-neuron features. Test accuracy jumps steeply.
|
| 237 |
+
|
| 238 |
+
**Stage III -- Slow Cleanup (weight decay dominates).** The loss gradient becomes negligible (both losses $\approx 0$). Weight decay alone slowly shrinks norms at rate $\partial_t \|w\| = -\lambda \|w\|$. The feature frequencies are already identified; this stage fine-tunes magnitudes. The network transitions from a lookup table to a generalizing algorithm implementing the indicator function from the mechanism (Tab 4).
|
| 239 |
+
|
| 240 |
+
**Four progress measures**: (a) Loss -- train drops in Stage I, test drops in Stage II. (b) Accuracy -- train reaches 100% early, test jumps in Stage II. (c) Phase alignment -- $|\sin(\mathcal{D}_m^\star)|$ decreases throughout. (d) IPR + parameter norms -- IPR increases sharply in Stage II, norms shrink in Stage III.
|
| 241 |
+
|
| 242 |
+
**Epoch Slider**: Use the slider below to see how the accuracy grid evolves across the three stages.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
MATH_TAB7 = r"""
|
| 246 |
+
### Training Dynamics: Phase Alignment and Single-Frequency Preservation
|
| 247 |
+
> **Setup:** Quadratic and ReLU activations, full data, single-frequency initialization, SGD.
|
| 248 |
+
|
| 249 |
+
#### The Four-Variable ODE
|
| 250 |
+
|
| 251 |
+
Under small initialization ($\kappa_{\mathrm{init}} \ll 1$), the dynamics decouple: each neuron evolves independently, and within each neuron, different Fourier modes evolve independently (because $\sum_{x \in \mathbb{Z}_p} \cos(\omega_k x) \cos(\omega_\tau x) = \frac{p}{2}\delta_{k,\tau}$). The full dynamics reduce to independent four-variable ODEs per (neuron, frequency):
|
| 252 |
+
|
| 253 |
+
$$\partial_t \alpha \approx 2p \cdot \alpha \cdot \beta \cdot \cos(\mathcal{D}), \qquad \partial_t \beta \approx p \cdot \alpha^2 \cdot \cos(\mathcal{D}),$$
|
| 254 |
+
$$\partial_t \phi \approx 2p \cdot \beta \cdot \sin(\mathcal{D}), \qquad \partial_t \psi \approx -p \cdot \frac{\alpha^2}{\beta} \cdot \sin(\mathcal{D}),$$
|
| 255 |
+
|
| 256 |
+
where $\mathcal{D} = (2\phi - \psi) \bmod 2\pi$ is the **phase misalignment**. This system has a clear physical interpretation: **magnitudes grow when phases are aligned** ($\cos(\mathcal{D}) \approx 1$), and **phases rotate toward alignment** ($\sin(\mathcal{D}) \to 0$). The dynamics self-coordinate: phases align first (while magnitudes are small), then magnitudes explode.
|
| 257 |
+
|
| 258 |
+
#### Phase Alignment Theorem
|
| 259 |
+
|
| 260 |
+
$\mathcal{D}(t) \to 0$ from any initial condition except the measure-zero unstable point $\mathcal{D} = \pi$. The dynamics on the circle behave like an **overdamped pendulum**: $\mathcal{D} = 0$ is a stable attractor, $\mathcal{D} = \pi$ is an unstable repeller. This is not a coincidence or a property of initialization -- it is an **inevitable consequence of the training dynamics**. It explains Observation 2 ($\psi = 2\phi$).
|
| 261 |
+
|
| 262 |
+
#### Single-Frequency Preservation Theorem
|
| 263 |
+
|
| 264 |
+
Under the decoupled flow, if a neuron starts at a single frequency, it remains there for all time. The Fourier orthogonality on $\mathbb{Z}_p$ prevents energy from leaking between modes.
|
| 265 |
+
|
| 266 |
+
**Quadratic** (left panels): Theory matches experiment almost exactly. The DFT heatmap shows the dominant frequency growing while all others stay at zero.
|
| 267 |
+
|
| 268 |
+
**ReLU** (right panels): Same qualitative behavior with minor quantitative differences. Small energy "leaks" to harmonic multiples ($3k^\star, 5k^\star, \ldots$ for input; $2k^\star, 3k^\star, \ldots$ for output). The leakage decays as $O(r^{-2})$ where $r$ is the harmonic order (third harmonic has $1/9$ the strength, fifth has $1/25$), keeping the dominant frequency overwhelmingly dominant.
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
MATH_TAB9 = r"""
|
| 272 |
+
### Training Log
|
| 273 |
+
|
| 274 |
+
This tab shows the training logs for each of the 5 configurations run for the selected modulo $p$. Select a run from the dropdown to view its hyperparameters and per-epoch training metrics.
|
| 275 |
+
|
| 276 |
+
The 5 training runs are:
|
| 277 |
+
- **standard**: ReLU, full data, no weight decay -- produces the clean Fourier features analyzed in Tabs 1--5
|
| 278 |
+
- **grokking**: ReLU, 75% data, weight decay $\lambda = 2.0$ -- demonstrates the memorization $\to$ generalization transition (Tab 6)
|
| 279 |
+
- **quad_random**: Quadratic activation, full data, random init -- the lottery ticket mechanism (Tab 5)
|
| 280 |
+
- **quad_single_freq**: Quadratic activation, single-frequency init, SGD -- verifies single-frequency preservation (Tab 7)
|
| 281 |
+
- **relu_single_freq**: ReLU, single-frequency init, SGD -- ReLU variant of the dynamics (Tab 7)
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
MATH_TAB8 = r"""
|
| 285 |
+
### Decoupled Gradient Flow Simulation
|
| 286 |
+
> **Setup:** Analytical ODE integration (no neural network training).
|
| 287 |
+
|
| 288 |
+
This tab shows a pure mathematical simulation of the multi-frequency gradient flow, **without training any neural network**. We numerically integrate the four-variable ODEs for all frequency modes simultaneously within a single neuron:
|
| 289 |
+
|
| 290 |
+
$$\partial_t \alpha_k \approx 2p \cdot \alpha_k \cdot \beta_k \cdot \cos(\mathcal{D}_k), \qquad \partial_t \beta_k \approx p \cdot \alpha_k^2 \cdot \cos(\mathcal{D}_k),$$
|
| 291 |
+
$$\partial_t \phi_k \approx 2p \cdot \beta_k \cdot \sin(\mathcal{D}_k), \qquad \partial_t \psi_k \approx -p \cdot \frac{\alpha_k^2}{\beta_k} \cdot \sin(\mathcal{D}_k),$$
|
| 292 |
+
|
| 293 |
+
for each frequency $k = 1, \ldots, (p{-}1)/2$, with random initial conditions.
|
| 294 |
+
|
| 295 |
+
The simulation confirms the theoretical predictions from Tab 7:
|
| 296 |
+
|
| 297 |
+
- **Phase alignment:** Phase misalignments $\mathcal{D}_k = (2\phi_k - \psi_k) \bmod 2\pi$ converge to $0$ for most frequencies, or linger near $\pi$ (the unstable repeller) before eventually escaping.
|
| 298 |
+
- **Magnitude competition:** Magnitudes grow explosively for the frequency where $\mathcal{D}_k \approx 0$ first, while others remain near their initial level.
|
| 299 |
+
- **Lottery outcome:** The winning frequency (smallest initial $\mathcal{D}_k$) dominates all others, reproducing the full lottery ticket mechanism without any neural network -- just ODEs.
|
| 300 |
+
|
| 301 |
+
Two cases are shown with different initial conditions to illustrate that the mechanism is robust: regardless of the random starting point, the frequency with the best initial phase alignment always wins.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# ---------------------------------------------------------------------------
|
| 306 |
+
# Data loading helpers
|
| 307 |
+
# ---------------------------------------------------------------------------
|
| 308 |
+
|
| 309 |
+
MIN_P = 3 # p=2 has 0 non-DC Fourier frequencies; analysis is degenerate
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def get_available_moduli():
|
| 313 |
+
"""Discover which p values have pre-computed results (odd p >= 3)."""
|
| 314 |
+
moduli = []
|
| 315 |
+
if os.path.exists(RESULTS_DIR):
|
| 316 |
+
for d in sorted(os.listdir(RESULTS_DIR)):
|
| 317 |
+
if d.startswith("p_"):
|
| 318 |
+
try:
|
| 319 |
+
p = int(d.split("_")[1])
|
| 320 |
+
if p >= MIN_P:
|
| 321 |
+
moduli.append(p)
|
| 322 |
+
except ValueError:
|
| 323 |
+
pass
|
| 324 |
+
return moduli
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _prime_dir(p):
|
| 328 |
+
return os.path.join(RESULTS_DIR, f"p_{p:03d}")
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def load_json_file(p, filename):
|
| 332 |
+
"""Load a JSON file from the prime's directory."""
|
| 333 |
+
path = os.path.join(_prime_dir(p), f"p{p:03d}_{filename}")
|
| 334 |
+
if os.path.exists(path):
|
| 335 |
+
with open(path) as f:
|
| 336 |
+
return json.load(f)
|
| 337 |
+
return None
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def safe_img(p, filename):
|
| 341 |
+
"""Return image path or None (Gradio handles None gracefully)."""
|
| 342 |
+
path = os.path.join(_prime_dir(p), f"p{p:03d}_{filename}")
|
| 343 |
+
return path if os.path.exists(path) else None
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ---------------------------------------------------------------------------
|
| 347 |
+
# Interactive Plotly chart builders
|
| 348 |
+
# ---------------------------------------------------------------------------
|
| 349 |
+
|
| 350 |
+
def _to_np(v):
|
| 351 |
+
"""Convert a list/value to a numpy array (bypasses plotly's pandas check)."""
|
| 352 |
+
if v is None:
|
| 353 |
+
return None
|
| 354 |
+
return np.asarray(v)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def make_loss_chart(data, title="Training Loss"):
|
| 358 |
+
"""Build an interactive Plotly loss chart from JSON data."""
|
| 359 |
+
if data is None:
|
| 360 |
+
return None
|
| 361 |
+
fig = go.Figure()
|
| 362 |
+
n = len(data.get('train_losses', []))
|
| 363 |
+
epochs = np.arange(n)
|
| 364 |
+
|
| 365 |
+
fig.add_trace(go.Scatter(
|
| 366 |
+
x=epochs, y=_to_np(data['train_losses']),
|
| 367 |
+
name='Train Loss', line=dict(color=COLORS[0]),
|
| 368 |
+
))
|
| 369 |
+
if 'test_losses' in data:
|
| 370 |
+
fig.add_trace(go.Scatter(
|
| 371 |
+
x=epochs, y=_to_np(data['test_losses']),
|
| 372 |
+
name='Test Loss', line=dict(color=COLORS[3]),
|
| 373 |
+
))
|
| 374 |
+
|
| 375 |
+
s1 = data.get('stage1_end')
|
| 376 |
+
s2 = data.get('stage2_end')
|
| 377 |
+
if s1 is not None:
|
| 378 |
+
fig.add_vrect(x0=0, x1=s1, fillcolor=STAGE_COLORS[0],
|
| 379 |
+
line_width=0, annotation_text="Memorization",
|
| 380 |
+
annotation_position="top left")
|
| 381 |
+
if s1 is not None and s2 is not None:
|
| 382 |
+
fig.add_vrect(x0=s1, x1=s2, fillcolor=STAGE_COLORS[1],
|
| 383 |
+
line_width=0, annotation_text="Transition",
|
| 384 |
+
annotation_position="top left")
|
| 385 |
+
if s2 is not None:
|
| 386 |
+
fig.add_vrect(x0=s2, x1=n, fillcolor=STAGE_COLORS[2],
|
| 387 |
+
line_width=0, annotation_text="Generalization",
|
| 388 |
+
annotation_position="top left")
|
| 389 |
+
|
| 390 |
+
fig.update_layout(
|
| 391 |
+
title=title, xaxis_title='Epoch', yaxis_title='Loss',
|
| 392 |
+
template='plotly_white', height=400,
|
| 393 |
+
legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
|
| 394 |
+
)
|
| 395 |
+
return fig
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def make_acc_chart(data, title="Training Accuracy"):
|
| 399 |
+
"""Build an interactive Plotly accuracy chart."""
|
| 400 |
+
if data is None:
|
| 401 |
+
return None
|
| 402 |
+
fig = go.Figure()
|
| 403 |
+
epochs = _to_np(data.get('epochs', list(range(len(data.get('train_accs', []))))))
|
| 404 |
+
|
| 405 |
+
fig.add_trace(go.Scatter(
|
| 406 |
+
x=epochs, y=_to_np(data['train_accs']),
|
| 407 |
+
name='Train Acc', line=dict(color=COLORS[0]),
|
| 408 |
+
))
|
| 409 |
+
if 'test_accs' in data:
|
| 410 |
+
fig.add_trace(go.Scatter(
|
| 411 |
+
x=epochs, y=_to_np(data['test_accs']),
|
| 412 |
+
name='Test Acc', line=dict(color=COLORS[3]),
|
| 413 |
+
))
|
| 414 |
+
|
| 415 |
+
s1 = data.get('stage1_end')
|
| 416 |
+
s2 = data.get('stage2_end')
|
| 417 |
+
if s1 is not None:
|
| 418 |
+
fig.add_vrect(x0=0, x1=s1, fillcolor=STAGE_COLORS[0], line_width=0)
|
| 419 |
+
if s1 is not None and s2 is not None:
|
| 420 |
+
fig.add_vrect(x0=s1, x1=s2, fillcolor=STAGE_COLORS[1], line_width=0)
|
| 421 |
+
if s2 is not None:
|
| 422 |
+
n = int(epochs.max()) if len(epochs) > 0 else len(data.get('train_accs', []))
|
| 423 |
+
fig.add_vrect(x0=s2, x1=n, fillcolor=STAGE_COLORS[2], line_width=0)
|
| 424 |
+
|
| 425 |
+
fig.update_layout(
|
| 426 |
+
title=title, xaxis_title='Epoch', yaxis_title='Accuracy',
|
| 427 |
+
template='plotly_white', height=400,
|
| 428 |
+
legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
|
| 429 |
+
)
|
| 430 |
+
return fig
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def make_neuron_spectrum_chart(data, neuron_key):
|
| 435 |
+
"""Build a Plotly bar chart for a single neuron's Fourier spectrum."""
|
| 436 |
+
if data is None or neuron_key not in data.get('neurons', {}):
|
| 437 |
+
return None
|
| 438 |
+
neuron = data['neurons'][neuron_key]
|
| 439 |
+
names = data.get('fourier_basis_names', [])
|
| 440 |
+
mags_in = _to_np(neuron['fourier_magnitudes_in'])
|
| 441 |
+
mags_out = _to_np(neuron['fourier_magnitudes_out'])
|
| 442 |
+
dom_freq = neuron.get('dominant_freq', '?')
|
| 443 |
+
|
| 444 |
+
fig = go.Figure()
|
| 445 |
+
fig.add_trace(go.Bar(
|
| 446 |
+
x=names, y=mags_in, name='W_in magnitude',
|
| 447 |
+
marker_color=COLORS[0], opacity=0.8,
|
| 448 |
+
))
|
| 449 |
+
fig.add_trace(go.Bar(
|
| 450 |
+
x=names, y=mags_out, name='W_out magnitude',
|
| 451 |
+
marker_color=COLORS[3], opacity=0.8,
|
| 452 |
+
))
|
| 453 |
+
fig.update_layout(
|
| 454 |
+
title=f"Neuron {neuron_key} (dominant freq={dom_freq})",
|
| 455 |
+
xaxis_title='Fourier Component',
|
| 456 |
+
yaxis_title='Magnitude',
|
| 457 |
+
barmode='group',
|
| 458 |
+
template='plotly_white', height=350,
|
| 459 |
+
)
|
| 460 |
+
return fig
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def make_logit_bar_chart(data, pair_index):
|
| 464 |
+
"""Build a Plotly bar chart of logits for a specific (a,b) pair."""
|
| 465 |
+
if data is None:
|
| 466 |
+
return None
|
| 467 |
+
pairs = data.get('pairs', [])
|
| 468 |
+
logits_all = data.get('logits', [])
|
| 469 |
+
correct = data.get('correct_answers', [])
|
| 470 |
+
classes = data.get('output_classes', [])
|
| 471 |
+
|
| 472 |
+
if pair_index >= len(pairs):
|
| 473 |
+
return None
|
| 474 |
+
|
| 475 |
+
a, b = pairs[pair_index]
|
| 476 |
+
logits = _to_np(logits_all[pair_index])
|
| 477 |
+
correct_ans = correct[pair_index]
|
| 478 |
+
|
| 479 |
+
bar_colors = [COLORS[3] if c == correct_ans else COLORS[0] for c in classes]
|
| 480 |
+
|
| 481 |
+
fig = go.Figure()
|
| 482 |
+
fig.add_trace(go.Bar(
|
| 483 |
+
x=[str(c) for c in classes], y=logits,
|
| 484 |
+
marker_color=bar_colors,
|
| 485 |
+
hovertemplate='Class %{x}: %{y:.3f}<extra></extra>',
|
| 486 |
+
))
|
| 487 |
+
fig.update_layout(
|
| 488 |
+
title=f"Logits for ({a}, {b}) -- correct = {correct_ans}",
|
| 489 |
+
xaxis_title='Output Class',
|
| 490 |
+
yaxis_title='Logit Value',
|
| 491 |
+
template='plotly_white', height=350,
|
| 492 |
+
)
|
| 493 |
+
return fig
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def make_grokk_heatmap(data, epoch_index):
|
| 497 |
+
"""Build a Plotly heatmap of accuracy grid at a grokking checkpoint."""
|
| 498 |
+
if data is None:
|
| 499 |
+
return None
|
| 500 |
+
epochs = data.get('epochs', [])
|
| 501 |
+
grids = data.get('grids', [])
|
| 502 |
+
if epoch_index >= len(grids):
|
| 503 |
+
return None
|
| 504 |
+
|
| 505 |
+
grid = _to_np(grids[epoch_index])
|
| 506 |
+
ep = epochs[epoch_index]
|
| 507 |
+
|
| 508 |
+
fig = go.Figure(data=go.Heatmap(
|
| 509 |
+
z=grid,
|
| 510 |
+
colorscale=[[0, 'white'], [1, COLORS[0]]],
|
| 511 |
+
zmin=0, zmax=1,
|
| 512 |
+
hovertemplate='a=%{y}, b=%{x}: %{z:.0f}<extra></extra>',
|
| 513 |
+
))
|
| 514 |
+
fig.update_layout(
|
| 515 |
+
title=f"Accuracy Grid at Epoch {ep}",
|
| 516 |
+
xaxis_title='Second Input (b)',
|
| 517 |
+
yaxis_title='First Input (a)',
|
| 518 |
+
template='plotly_white', height=450,
|
| 519 |
+
yaxis=dict(autorange='reversed'),
|
| 520 |
+
)
|
| 521 |
+
return fig
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
# ---------------------------------------------------------------------------
|
| 525 |
+
# Tab update functions
|
| 526 |
+
# ---------------------------------------------------------------------------
|
| 527 |
+
|
| 528 |
+
def update_tab1(p):
|
| 529 |
+
"""Overview: standard + grokking loss/IPR, phase scatter."""
|
| 530 |
+
img_overview = safe_img(p, "overview_loss_ipr.png")
|
| 531 |
+
img_phase = safe_img(p, "overview_phase_scatter.png")
|
| 532 |
+
# Also build interactive charts from overview.json
|
| 533 |
+
data = load_json_file(p, "overview.json")
|
| 534 |
+
std_loss_chart = None
|
| 535 |
+
grokk_loss_chart = None
|
| 536 |
+
std_ipr_chart = None
|
| 537 |
+
grokk_ipr_chart = None
|
| 538 |
+
|
| 539 |
+
if data:
|
| 540 |
+
# Standard loss chart
|
| 541 |
+
std_ep = data.get('std_epochs', [])
|
| 542 |
+
std_tl = data.get('std_train_loss', [])
|
| 543 |
+
if std_tl:
|
| 544 |
+
fig = go.Figure()
|
| 545 |
+
fig.add_trace(go.Scatter(
|
| 546 |
+
x=_to_np(std_ep[:len(std_tl)]), y=_to_np(std_tl),
|
| 547 |
+
name='Train Loss', line=dict(color=COLORS[0]),
|
| 548 |
+
))
|
| 549 |
+
fig.update_layout(
|
| 550 |
+
title='Standard: Training Loss (ReLU, full data)',
|
| 551 |
+
xaxis_title='Step', yaxis_title='Loss',
|
| 552 |
+
template='plotly_white', height=350,
|
| 553 |
+
)
|
| 554 |
+
std_loss_chart = fig
|
| 555 |
+
|
| 556 |
+
# Standard IPR chart
|
| 557 |
+
std_ipr = data.get('std_ipr', [])
|
| 558 |
+
if std_ipr:
|
| 559 |
+
fig = go.Figure()
|
| 560 |
+
fig.add_trace(go.Scatter(
|
| 561 |
+
x=_to_np(std_ep[:len(std_ipr)]), y=_to_np(std_ipr),
|
| 562 |
+
name='Avg IPR', line=dict(color=COLORS[3]),
|
| 563 |
+
))
|
| 564 |
+
fig.update_layout(
|
| 565 |
+
title='Standard: IPR (Fourier Sparsity)',
|
| 566 |
+
xaxis_title='Step', yaxis_title='IPR',
|
| 567 |
+
yaxis=dict(range=[0, 1.05]),
|
| 568 |
+
template='plotly_white', height=350,
|
| 569 |
+
)
|
| 570 |
+
std_ipr_chart = fig
|
| 571 |
+
|
| 572 |
+
# Grokking loss chart
|
| 573 |
+
grokk_ep = data.get('grokk_epochs', [])
|
| 574 |
+
grokk_tl = data.get('grokk_train_loss', [])
|
| 575 |
+
grokk_tel = data.get('grokk_test_loss', [])
|
| 576 |
+
if grokk_tl or grokk_tel:
|
| 577 |
+
fig = go.Figure()
|
| 578 |
+
if grokk_tl:
|
| 579 |
+
fig.add_trace(go.Scatter(
|
| 580 |
+
x=_to_np(grokk_ep[:len(grokk_tl)]), y=_to_np(grokk_tl),
|
| 581 |
+
name='Train Loss', line=dict(color=COLORS[0]),
|
| 582 |
+
))
|
| 583 |
+
if grokk_tel:
|
| 584 |
+
fig.add_trace(go.Scatter(
|
| 585 |
+
x=_to_np(grokk_ep[:len(grokk_tel)]), y=_to_np(grokk_tel),
|
| 586 |
+
name='Test Loss', line=dict(color=COLORS[3]),
|
| 587 |
+
))
|
| 588 |
+
fig.update_layout(
|
| 589 |
+
title='Grokking: Loss (ReLU, 75% data, WD)',
|
| 590 |
+
xaxis_title='Step', yaxis_title='Loss',
|
| 591 |
+
template='plotly_white', height=350,
|
| 592 |
+
)
|
| 593 |
+
grokk_loss_chart = fig
|
| 594 |
+
|
| 595 |
+
# Grokking IPR chart
|
| 596 |
+
grokk_ipr = data.get('grokk_ipr', [])
|
| 597 |
+
if grokk_ipr:
|
| 598 |
+
fig = go.Figure()
|
| 599 |
+
fig.add_trace(go.Scatter(
|
| 600 |
+
x=_to_np(grokk_ep[:len(grokk_ipr)]), y=_to_np(grokk_ipr),
|
| 601 |
+
name='Avg IPR', line=dict(color=COLORS[3]),
|
| 602 |
+
))
|
| 603 |
+
fig.update_layout(
|
| 604 |
+
title='Grokking: IPR (weight decay drives sparsification)',
|
| 605 |
+
xaxis_title='Step', yaxis_title='IPR',
|
| 606 |
+
yaxis=dict(range=[0, 1.05]),
|
| 607 |
+
template='plotly_white', height=350,
|
| 608 |
+
)
|
| 609 |
+
grokk_ipr_chart = fig
|
| 610 |
+
|
| 611 |
+
return (img_overview, std_loss_chart, grokk_loss_chart,
|
| 612 |
+
std_ipr_chart, grokk_ipr_chart, img_phase)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def update_tab2(p):
|
| 616 |
+
"""Fourier Weights: heatmap + line plots."""
|
| 617 |
+
return (
|
| 618 |
+
safe_img(p, "full_training_para_origin.png"),
|
| 619 |
+
safe_img(p, "lineplot_in.png"),
|
| 620 |
+
safe_img(p, "lineplot_out.png"),
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def update_tab3(p):
|
| 625 |
+
"""Phase Analysis: distribution, relationship, magnitude."""
|
| 626 |
+
return (
|
| 627 |
+
safe_img(p, "phase_distribution.png"),
|
| 628 |
+
safe_img(p, "phase_relationship.png"),
|
| 629 |
+
safe_img(p, "magnitude_distribution.png"),
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def update_tab4(p):
|
| 634 |
+
"""Output Logits."""
|
| 635 |
+
return safe_img(p, "output_logits.png")
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def update_tab5(p):
|
| 639 |
+
"""Lottery Mechanism: magnitude, phase, contour."""
|
| 640 |
+
return (
|
| 641 |
+
safe_img(p, "lottery_mech_magnitude.png"),
|
| 642 |
+
safe_img(p, "lottery_mech_phase.png"),
|
| 643 |
+
safe_img(p, "lottery_beta_contour.png"),
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
def update_tab6(p):
|
| 648 |
+
"""Grokking: loss/acc charts + analysis images."""
|
| 649 |
+
loss_data = load_json_file(p, "grokk_loss.json")
|
| 650 |
+
acc_data = load_json_file(p, "grokk_acc.json")
|
| 651 |
+
loss_chart = make_loss_chart(loss_data, title="Grokking: Loss")
|
| 652 |
+
acc_chart = make_acc_chart(acc_data, title="Grokking: Accuracy")
|
| 653 |
+
return (
|
| 654 |
+
loss_chart,
|
| 655 |
+
acc_chart,
|
| 656 |
+
safe_img(p, "grokk_abs_phase_diff.png"),
|
| 657 |
+
safe_img(p, "grokk_avg_ipr.png"),
|
| 658 |
+
safe_img(p, "grokk_memorization_accuracy.png"),
|
| 659 |
+
safe_img(p, "grokk_memorization_common_to_rare.png"),
|
| 660 |
+
safe_img(p, "grokk_decoded_weights_dynamic.png"),
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def update_tab7(p):
|
| 665 |
+
"""Gradient Dynamics: Quad and ReLU single-freq."""
|
| 666 |
+
return (
|
| 667 |
+
safe_img(p, "phase_align_quad.png"),
|
| 668 |
+
safe_img(p, "single_freq_quad.png"),
|
| 669 |
+
safe_img(p, "phase_align_relu.png"),
|
| 670 |
+
safe_img(p, "single_freq_relu.png"),
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def update_tab8(p):
|
| 675 |
+
"""Decoupled Simulation: 2 analytical gradient flow plots."""
|
| 676 |
+
return (
|
| 677 |
+
safe_img(p, "phase_align_approx1.png"),
|
| 678 |
+
safe_img(p, "phase_align_approx2.png"),
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def update_tab9(p):
|
| 683 |
+
"""Training Log: return available run names and initial log."""
|
| 684 |
+
data = load_json_file(p, "training_log.json")
|
| 685 |
+
if data is None:
|
| 686 |
+
return [], None, "", ""
|
| 687 |
+
run_names = list(data.keys())
|
| 688 |
+
# Show first run by default
|
| 689 |
+
first_run = run_names[0] if run_names else None
|
| 690 |
+
if first_run:
|
| 691 |
+
run_data = data[first_run]
|
| 692 |
+
config = run_data.get('config', {})
|
| 693 |
+
config_text = _format_config_md(first_run, config)
|
| 694 |
+
log_text = run_data.get('log_text', 'No log available.')
|
| 695 |
+
else:
|
| 696 |
+
config_text = ""
|
| 697 |
+
log_text = ""
|
| 698 |
+
return run_names, first_run, config_text, log_text
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def _format_config_md(run_name, config):
|
| 702 |
+
"""Format a run config as a Markdown summary."""
|
| 703 |
+
lines = [f"**Run: {run_name}**\n"]
|
| 704 |
+
key_labels = {
|
| 705 |
+
'prime': 'Modulo (p)', 'd_mlp': 'd_mlp',
|
| 706 |
+
'act_type': 'Activation', 'init_type': 'Init Type',
|
| 707 |
+
'init_scale': 'Init Scale', 'optimizer': 'Optimizer',
|
| 708 |
+
'lr': 'Learning Rate', 'weight_decay': 'Weight Decay',
|
| 709 |
+
'frac_train': 'Frac Train', 'num_epochs': 'Num Epochs',
|
| 710 |
+
'seed': 'Seed',
|
| 711 |
+
}
|
| 712 |
+
for key, label in key_labels.items():
|
| 713 |
+
val = config.get(key, 'N/A')
|
| 714 |
+
lines.append(f"- **{label}**: `{val}`")
|
| 715 |
+
return "\n".join(lines)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def update_info(p):
|
| 719 |
+
meta = load_json_file(p, "metadata.json")
|
| 720 |
+
if not meta:
|
| 721 |
+
return f"**p = {p}** | No metadata available"
|
| 722 |
+
d_mlp = meta.get('d_mlp', '?')
|
| 723 |
+
parts = [f"**p = {p}**", f"d_mlp = {d_mlp}"]
|
| 724 |
+
std_metrics = meta.get('final_metrics', {}).get('standard', {})
|
| 725 |
+
if 'train_acc' in std_metrics:
|
| 726 |
+
parts.append(f"Train Acc = {std_metrics['train_acc']:.4f}")
|
| 727 |
+
if 'test_acc' in std_metrics:
|
| 728 |
+
parts.append(f"Test Acc = {std_metrics['test_acc']:.4f}")
|
| 729 |
+
if 'train_loss' in std_metrics:
|
| 730 |
+
parts.append(f"Train Loss = {std_metrics['train_loss']:.6f}")
|
| 731 |
+
return " | ".join(parts)
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
# ---------------------------------------------------------------------------
|
| 735 |
+
# Interactive callback helpers
|
| 736 |
+
# ---------------------------------------------------------------------------
|
| 737 |
+
|
| 738 |
+
def _get_neuron_choices(p):
|
| 739 |
+
"""Return list of neuron keys from neuron_spectra.json."""
|
| 740 |
+
data = load_json_file(p, "neuron_spectra.json")
|
| 741 |
+
if data is None:
|
| 742 |
+
return []
|
| 743 |
+
return list(data.get('neurons', {}).keys())
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def _get_pair_choices(p):
|
| 747 |
+
"""Return list of (a,b) pair labels from logits_interactive.json."""
|
| 748 |
+
data = load_json_file(p, "logits_interactive.json")
|
| 749 |
+
if data is None:
|
| 750 |
+
return []
|
| 751 |
+
pairs = data.get('pairs', [])
|
| 752 |
+
return [f"({a}, {b})" for a, b in pairs]
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def _get_grokk_epochs(p):
|
| 756 |
+
"""Return list of epoch values from grokk_epoch_data.json."""
|
| 757 |
+
data = load_json_file(p, "grokk_epoch_data.json")
|
| 758 |
+
if data is None:
|
| 759 |
+
return []
|
| 760 |
+
return data.get('epochs', [])
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
# ---------------------------------------------------------------------------
|
| 764 |
+
# Markdown helper -- ensures latex_delimiters are set
|
| 765 |
+
# ---------------------------------------------------------------------------
|
| 766 |
+
|
| 767 |
+
def _md(text, **kwargs):
|
| 768 |
+
"""Create a gr.Markdown with KaTeX delimiters enabled."""
|
| 769 |
+
return gr.Markdown(text, latex_delimiters=LATEX_DELIMITERS, **kwargs)
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
# ---------------------------------------------------------------------------
|
| 773 |
+
# Main app
|
| 774 |
+
# ---------------------------------------------------------------------------
|
| 775 |
+
|
| 776 |
+
def on_p_change(p_str):
|
| 777 |
+
"""Called when the prime dropdown changes. Returns all outputs."""
|
| 778 |
+
p = int(p_str)
|
| 779 |
+
|
| 780 |
+
info = update_info(p)
|
| 781 |
+
|
| 782 |
+
# Overview
|
| 783 |
+
(t1_img_overview, t1_std_loss, t1_grokk_loss,
|
| 784 |
+
t1_std_ipr, t1_grokk_ipr, t1_phase_scatter) = update_tab1(p)
|
| 785 |
+
# Core Interpretability
|
| 786 |
+
t2_heatmap, t2_line_in, t2_line_out = update_tab2(p)
|
| 787 |
+
t3_phase_dist, t3_phase_rel, t3_magnitude = update_tab3(p)
|
| 788 |
+
t4_logits = update_tab4(p)
|
| 789 |
+
t5_mag, t5_phase, t5_contour = update_tab5(p)
|
| 790 |
+
# Grokking
|
| 791 |
+
(t6_loss, t6_acc, t6_phase_diff, t6_ipr,
|
| 792 |
+
t6_memo, t6_memo_rare, t6_decoded) = update_tab6(p)
|
| 793 |
+
# Theory
|
| 794 |
+
t7_pa_quad, t7_sf_quad, t7_pa_relu, t7_sf_relu = update_tab7(p)
|
| 795 |
+
t8_approx1, t8_approx2 = update_tab8(p)
|
| 796 |
+
|
| 797 |
+
# Training Log
|
| 798 |
+
t9_run_names, t9_default_run, t9_config_text, t9_log = update_tab9(p)
|
| 799 |
+
t9_run_dd_update = gr.update(
|
| 800 |
+
choices=t9_run_names,
|
| 801 |
+
value=t9_default_run,
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
# Interactive widget updates
|
| 805 |
+
neuron_choices = _get_neuron_choices(p)
|
| 806 |
+
neuron_dd_update = gr.update(
|
| 807 |
+
choices=neuron_choices,
|
| 808 |
+
value=neuron_choices[0] if neuron_choices else None,
|
| 809 |
+
)
|
| 810 |
+
neuron_spectra_data = load_json_file(p, "neuron_spectra.json")
|
| 811 |
+
neuron_chart = make_neuron_spectrum_chart(
|
| 812 |
+
neuron_spectra_data, neuron_choices[0]
|
| 813 |
+
) if neuron_choices else None
|
| 814 |
+
|
| 815 |
+
pair_choices = _get_pair_choices(p)
|
| 816 |
+
pair_dd_update = gr.update(
|
| 817 |
+
choices=pair_choices,
|
| 818 |
+
value=pair_choices[0] if pair_choices else None,
|
| 819 |
+
)
|
| 820 |
+
logit_data = load_json_file(p, "logits_interactive.json")
|
| 821 |
+
logit_chart = make_logit_bar_chart(logit_data, 0) if pair_choices else None
|
| 822 |
+
|
| 823 |
+
grokk_epochs = _get_grokk_epochs(p)
|
| 824 |
+
if grokk_epochs:
|
| 825 |
+
slider_update = gr.update(
|
| 826 |
+
minimum=0, maximum=len(grokk_epochs) - 1, value=0, step=1,
|
| 827 |
+
visible=True,
|
| 828 |
+
)
|
| 829 |
+
else:
|
| 830 |
+
slider_update = gr.update(minimum=0, maximum=0, value=0, visible=False)
|
| 831 |
+
grokk_slider_data = load_json_file(p, "grokk_epoch_data.json")
|
| 832 |
+
grokk_heatmap = make_grokk_heatmap(grokk_slider_data, 0) if grokk_epochs else None
|
| 833 |
+
epoch_label = f"Epoch: {grokk_epochs[0]}" if grokk_epochs else ""
|
| 834 |
+
|
| 835 |
+
return [
|
| 836 |
+
info,
|
| 837 |
+
# Tab 1: Overview
|
| 838 |
+
t1_img_overview, t1_std_loss, t1_grokk_loss,
|
| 839 |
+
t1_std_ipr, t1_grokk_ipr, t1_phase_scatter,
|
| 840 |
+
# Tab 2: Fourier Weights
|
| 841 |
+
t2_heatmap, t2_line_in, t2_line_out,
|
| 842 |
+
neuron_dd_update, neuron_chart,
|
| 843 |
+
# Tab 3: Phase Analysis
|
| 844 |
+
t3_phase_dist, t3_phase_rel, t3_magnitude,
|
| 845 |
+
# Tab 4: Output Logits
|
| 846 |
+
t4_logits,
|
| 847 |
+
pair_dd_update, logit_chart,
|
| 848 |
+
# Tab 5: Lottery Mechanism
|
| 849 |
+
t5_mag, t5_phase, t5_contour,
|
| 850 |
+
# Tab 6: Grokking
|
| 851 |
+
t6_loss, t6_acc, t6_phase_diff, t6_ipr,
|
| 852 |
+
t6_memo, t6_memo_rare, t6_decoded,
|
| 853 |
+
slider_update, grokk_heatmap, epoch_label,
|
| 854 |
+
# Tab 7: Gradient Dynamics
|
| 855 |
+
t7_pa_quad, t7_sf_quad, t7_pa_relu, t7_sf_relu,
|
| 856 |
+
# Tab 8: Decoupled Simulation
|
| 857 |
+
t8_approx1, t8_approx2,
|
| 858 |
+
# Tab 9: Training Log
|
| 859 |
+
t9_run_dd_update, t9_config_text, t9_log,
|
| 860 |
+
]
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def _commit_results_to_repo(p):
|
| 864 |
+
"""Try to commit new precomputed results back to the HF Space repo.
|
| 865 |
+
|
| 866 |
+
On HF Spaces, the repo is writable via the huggingface_hub API.
|
| 867 |
+
This allows results to accumulate as users generate them.
|
| 868 |
+
Returns (success, message).
|
| 869 |
+
"""
|
| 870 |
+
try:
|
| 871 |
+
from huggingface_hub import HfApi
|
| 872 |
+
except ImportError:
|
| 873 |
+
return False, "huggingface_hub not installed"
|
| 874 |
+
|
| 875 |
+
space_id = os.environ.get("SPACE_ID") # e.g. "username/space-name"
|
| 876 |
+
if not space_id:
|
| 877 |
+
return False, "Not running on HF Spaces (SPACE_ID not set)"
|
| 878 |
+
|
| 879 |
+
result_dir = os.path.join(RESULTS_DIR, f"p_{p:03d}")
|
| 880 |
+
if not os.path.isdir(result_dir):
|
| 881 |
+
return False, "No results directory found"
|
| 882 |
+
|
| 883 |
+
try:
|
| 884 |
+
api = HfApi()
|
| 885 |
+
api.upload_folder(
|
| 886 |
+
folder_path=result_dir,
|
| 887 |
+
path_in_repo=f"precomputed_results/p_{p:03d}",
|
| 888 |
+
repo_id=space_id,
|
| 889 |
+
repo_type="space",
|
| 890 |
+
commit_message=f"Add precomputed results for p={p}",
|
| 891 |
+
)
|
| 892 |
+
return True, f"Committed results for p={p} to {space_id}"
|
| 893 |
+
except Exception as e:
|
| 894 |
+
logger.warning(f"Failed to commit results for p={p}: {e}")
|
| 895 |
+
return False, str(e)
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
def _run_step_streaming(cmd, env, label):
|
| 899 |
+
"""Run a subprocess, yielding (line, error_flag) for each output line."""
|
| 900 |
+
proc = subprocess.Popen(
|
| 901 |
+
cmd, cwd=PROJECT_ROOT, env=env,
|
| 902 |
+
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
| 903 |
+
text=True, bufsize=1,
|
| 904 |
+
)
|
| 905 |
+
for line in proc.stdout:
|
| 906 |
+
yield line.rstrip("\n"), False
|
| 907 |
+
proc.wait()
|
| 908 |
+
if proc.returncode != 0:
|
| 909 |
+
yield f"[ERROR] {label} failed (exit code {proc.returncode})", True
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def run_pipeline_for_p_streaming(p):
|
| 913 |
+
"""Generator: run full pipeline for p, yielding log lines.
|
| 914 |
+
|
| 915 |
+
Yields (log_line: str, is_error: bool, is_done: bool).
|
| 916 |
+
Deletes model checkpoints after plot generation to save space.
|
| 917 |
+
"""
|
| 918 |
+
if p < 3 or p % 2 == 0:
|
| 919 |
+
yield f"Error: p must be an odd number >= 3, got {p}", True, True
|
| 920 |
+
return
|
| 921 |
+
if p > MAX_P_ON_DEMAND:
|
| 922 |
+
yield f"Error: p={p} exceeds on-demand limit of {MAX_P_ON_DEMAND}", True, True
|
| 923 |
+
return
|
| 924 |
+
|
| 925 |
+
result_dir = os.path.join(RESULTS_DIR, f"p_{p:03d}")
|
| 926 |
+
if os.path.isdir(result_dir) and len(os.listdir(result_dir)) > 5:
|
| 927 |
+
yield f"Results for p={p} already exist ({len(os.listdir(result_dir))} files)", False, True
|
| 928 |
+
return
|
| 929 |
+
|
| 930 |
+
env = os.environ.copy()
|
| 931 |
+
env["PYTHONPATH"] = PROJECT_ROOT + ":" + env.get("PYTHONPATH", "")
|
| 932 |
+
|
| 933 |
+
steps = [
|
| 934 |
+
("Step 1/3: Training 5 configurations", [
|
| 935 |
+
sys.executable, "precompute/train_all.py",
|
| 936 |
+
"--p", str(p), "--output", TRAINED_MODELS_DIR, "--resume",
|
| 937 |
+
]),
|
| 938 |
+
("Step 2/3: Generating model-based plots", [
|
| 939 |
+
sys.executable, "precompute/generate_plots.py",
|
| 940 |
+
"--p", str(p), "--input", TRAINED_MODELS_DIR,
|
| 941 |
+
"--output", RESULTS_DIR,
|
| 942 |
+
]),
|
| 943 |
+
("Step 3/3: Generating analytical plots", [
|
| 944 |
+
sys.executable, "precompute/generate_analytical.py",
|
| 945 |
+
"--p", str(p), "--output", RESULTS_DIR,
|
| 946 |
+
]),
|
| 947 |
+
]
|
| 948 |
+
|
| 949 |
+
for label, cmd in steps:
|
| 950 |
+
yield f"\n{'='*60}", False, False
|
| 951 |
+
yield f" {label} (p={p})", False, False
|
| 952 |
+
yield f"{'='*60}", False, False
|
| 953 |
+
for line, is_err in _run_step_streaming(cmd, env, label):
|
| 954 |
+
if is_err:
|
| 955 |
+
yield line, True, True
|
| 956 |
+
return
|
| 957 |
+
yield line, False, False
|
| 958 |
+
|
| 959 |
+
# Cleanup checkpoints
|
| 960 |
+
model_dir = os.path.join(TRAINED_MODELS_DIR, f"p_{p:03d}")
|
| 961 |
+
if os.path.isdir(model_dir):
|
| 962 |
+
shutil.rmtree(model_dir)
|
| 963 |
+
yield f"Cleaned up checkpoints: {model_dir}", False, False
|
| 964 |
+
|
| 965 |
+
n_files = len(os.listdir(result_dir)) if os.path.isdir(result_dir) else 0
|
| 966 |
+
|
| 967 |
+
# Try to commit results back to the HF repo
|
| 968 |
+
ok_commit, commit_msg = _commit_results_to_repo(p)
|
| 969 |
+
if ok_commit:
|
| 970 |
+
yield f"Results saved to HF repo.", False, False
|
| 971 |
+
|
| 972 |
+
yield f"\nDone! Generated {n_files} files for p={p}.", False, True
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
def create_app():
|
| 976 |
+
moduli = get_available_moduli()
|
| 977 |
+
p_choices = [str(p) for p in moduli]
|
| 978 |
+
default_p = p_choices[0] if p_choices else None
|
| 979 |
+
|
| 980 |
+
with gr.Blocks(
|
| 981 |
+
title="Modular Addition Feature Learning",
|
| 982 |
+
) as app:
|
| 983 |
+
_md(
|
| 984 |
+
r"# On the Mechanism and Dynamics of Modular Addition" "\n"
|
| 985 |
+
r"### Fourier Features, Lottery Ticket, and Grokking" "\n\n"
|
| 986 |
+
r"**Jianliang He, Leda Wang, Siyu Chen, Zhuoran Yang**" "\n"
|
| 987 |
+
r"*Department of Statistics and Data Science, Yale University*" "\n\n"
|
| 988 |
+
r'<a href="#">[arXiv]</a> '
|
| 989 |
+
r'<a href="#">[Blog]</a> '
|
| 990 |
+
r'<a href="https://github.com/Y-Agent/modular-addition-feature-learning">[Code]</a>' "\n\n"
|
| 991 |
+
r"---" "\n\n"
|
| 992 |
+
r"This interactive explorer visualizes how a two-layer neural network "
|
| 993 |
+
r"learns modular arithmetic $f(x,y) = (x + y) \bmod p$ through "
|
| 994 |
+
r"Fourier feature learning, lottery ticket dynamics, and the grokking "
|
| 995 |
+
r"phenomenon. Select a modulo $p$ (any odd number $\geq 3$) below to view pre-computed results." "\n\n"
|
| 996 |
+
r"> **Note:** Grokking experiments (Tab 6) require $p \geq 19$ to have enough data for a meaningful train/test split. "
|
| 997 |
+
r"For $p < 19$, grokking plots will not be generated."
|
| 998 |
+
)
|
| 999 |
+
|
| 1000 |
+
# Hidden state for current modulo
|
| 1001 |
+
current_p = gr.State(value=int(default_p) if default_p else 3)
|
| 1002 |
+
|
| 1003 |
+
with gr.Row():
|
| 1004 |
+
p_dropdown = gr.Dropdown(
|
| 1005 |
+
choices=p_choices,
|
| 1006 |
+
value=default_p,
|
| 1007 |
+
label="Select Modulo (p)",
|
| 1008 |
+
interactive=True,
|
| 1009 |
+
scale=2,
|
| 1010 |
+
)
|
| 1011 |
+
info_md = _md(
|
| 1012 |
+
update_info(int(default_p)) if default_p else ""
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
with gr.Accordion("Generate results for a new p", open=False):
|
| 1016 |
+
_md(
|
| 1017 |
+
f"Enter any odd number $p \\geq 3$ (max {MAX_P_ON_DEMAND} "
|
| 1018 |
+
f"for on-demand training). This will train 5 model "
|
| 1019 |
+
f"configurations and generate all plots. Training logs "
|
| 1020 |
+
f"are streamed below in real time."
|
| 1021 |
+
)
|
| 1022 |
+
with gr.Row():
|
| 1023 |
+
new_p_input = gr.Number(
|
| 1024 |
+
value=None, label="New p (odd, ≥ 3)",
|
| 1025 |
+
precision=0, scale=1,
|
| 1026 |
+
)
|
| 1027 |
+
generate_btn = gr.Button(
|
| 1028 |
+
"Generate", variant="primary", scale=1,
|
| 1029 |
+
)
|
| 1030 |
+
generate_status = _md("")
|
| 1031 |
+
generate_log = gr.Code(
|
| 1032 |
+
value="", language=None, label="Pipeline Log",
|
| 1033 |
+
lines=15, interactive=False, visible=False,
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
# ----- Tabs -----
|
| 1037 |
+
with gr.Tabs():
|
| 1038 |
+
|
| 1039 |
+
# === Core Interpretability ===
|
| 1040 |
+
|
| 1041 |
+
# Tab 1: Overview
|
| 1042 |
+
with gr.Tab("1. Overview"):
|
| 1043 |
+
_md(MATH_TAB1)
|
| 1044 |
+
t1_img_overview = gr.Image(
|
| 1045 |
+
label="Loss & IPR Overview (Static)", type="filepath"
|
| 1046 |
+
)
|
| 1047 |
+
with gr.Row():
|
| 1048 |
+
t1_std_loss = gr.Plot(label="Standard: Loss")
|
| 1049 |
+
t1_grokk_loss = gr.Plot(label="Grokking: Loss")
|
| 1050 |
+
with gr.Row():
|
| 1051 |
+
t1_std_ipr = gr.Plot(label="Standard: IPR")
|
| 1052 |
+
t1_grokk_ipr = gr.Plot(label="Grokking: IPR")
|
| 1053 |
+
t1_phase_scatter = gr.Image(
|
| 1054 |
+
label="Phase Alignment: \u03c8 = 2\u03c6", type="filepath"
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
# Tab 2: Fourier Weights
|
| 1058 |
+
with gr.Tab("2. Fourier Weights"):
|
| 1059 |
+
_md(MATH_TAB2)
|
| 1060 |
+
t2_heatmap = gr.Image(label="Decoded W_in / W_out Heatmap", type="filepath")
|
| 1061 |
+
with gr.Row():
|
| 1062 |
+
t2_line_in = gr.Image(label="First-Layer Line Plots (with cosine fit)", type="filepath")
|
| 1063 |
+
t2_line_out = gr.Image(label="Second-Layer Line Plots (with cosine fit)", type="filepath")
|
| 1064 |
+
_md("#### Neuron Frequency Inspector")
|
| 1065 |
+
t2_neuron_dd = gr.Dropdown(
|
| 1066 |
+
choices=[], value=None,
|
| 1067 |
+
label="Select Neuron", interactive=True,
|
| 1068 |
+
)
|
| 1069 |
+
t2_neuron_chart = gr.Plot(label="Neuron Fourier Spectrum")
|
| 1070 |
+
|
| 1071 |
+
# Tab 3: Phase Analysis
|
| 1072 |
+
with gr.Tab("3. Phase Analysis"):
|
| 1073 |
+
_md(MATH_TAB3)
|
| 1074 |
+
with gr.Row():
|
| 1075 |
+
t3_phase_dist = gr.Image(label="Phase Distribution", type="filepath")
|
| 1076 |
+
t3_phase_rel = gr.Image(
|
| 1077 |
+
label="Phase Relationship (2\u03c6 vs \u03c8)", type="filepath"
|
| 1078 |
+
)
|
| 1079 |
+
t3_magnitude = gr.Image(label="Magnitude Distribution", type="filepath")
|
| 1080 |
+
|
| 1081 |
+
# Tab 4: Output Logits
|
| 1082 |
+
with gr.Tab("4. Output Logits"):
|
| 1083 |
+
_md(MATH_TAB4)
|
| 1084 |
+
t4_logits = gr.Image(label="Output Logits Heatmap", type="filepath")
|
| 1085 |
+
_md("#### Logit Explorer")
|
| 1086 |
+
t4_pair_dd = gr.Dropdown(
|
| 1087 |
+
choices=[], value=None,
|
| 1088 |
+
label="Select Input Pair (a, b)", interactive=True,
|
| 1089 |
+
)
|
| 1090 |
+
t4_logit_chart = gr.Plot(label="Logit Distribution")
|
| 1091 |
+
|
| 1092 |
+
# Tab 5: Lottery Mechanism
|
| 1093 |
+
with gr.Tab("5. Lottery Mechanism"):
|
| 1094 |
+
_md(MATH_TAB5)
|
| 1095 |
+
_md(r"""**Magnitude plot** below: Each curve tracks one frequency's output magnitude $\beta_k$ within a single neuron over training. All frequencies start with equal magnitude (from random initialization). The winning frequency (best initial phase alignment) grows explosively while others remain frozen.""")
|
| 1096 |
+
t5_mag = gr.Image(label="Frequency Magnitude Evolution", type="filepath")
|
| 1097 |
+
_md(r"""**Phase plot** below: Each curve shows the phase misalignment $\mathcal{D}_k = 2\phi_k - \psi_k$ for one frequency within the same neuron. The winning frequency (colored) converges to $\mathcal{D} = 0$ (perfect alignment) first; other frequencies barely change because their magnitudes remain small.""")
|
| 1098 |
+
t5_phase = gr.Image(label="Phase Misalignment Convergence", type="filepath")
|
| 1099 |
+
_md(r"""**Contour plot** below: Final output magnitude as a function of initial magnitude and initial phase misalignment, across all neurons. The largest final magnitudes (brightest regions) appear at small initial misalignment $|\mathcal{D}|$, confirming that initial phase alignment -- not initial magnitude -- determines which frequency wins.""")
|
| 1100 |
+
t5_contour = gr.Image(label="Final Magnitude Contour", type="filepath")
|
| 1101 |
+
|
| 1102 |
+
# === Grokking ===
|
| 1103 |
+
|
| 1104 |
+
# Tab 6: Grokking
|
| 1105 |
+
with gr.Tab("6. Grokking"):
|
| 1106 |
+
_md(MATH_TAB6)
|
| 1107 |
+
|
| 1108 |
+
_md(r"""#### (a) Loss and (b) Accuracy
|
| 1109 |
+
|
| 1110 |
+
**(a) Loss:** Training loss (blue) drops rapidly in Stage I as the network memorizes training data. Test loss (red) stays high until Stage II, when weight decay forces the network to find a generalizing solution, causing test loss to plummet. The three colored bands mark the three stages.
|
| 1111 |
+
|
| 1112 |
+
**(b) Accuracy:** Training accuracy reaches 100% early (Stage I). Test accuracy stays at ~70% during memorization (not 50% -- the built-in symmetry $f(a,b) = f(b,a)$ gives "free" correct answers for the swapped pair). Test accuracy jumps sharply in Stage II when the network transitions from memorization to Fourier features.""")
|
| 1113 |
+
with gr.Row():
|
| 1114 |
+
t6_loss = gr.Plot(label="Grokking Loss (Interactive)")
|
| 1115 |
+
t6_acc = gr.Plot(label="Grokking Accuracy (Interactive)")
|
| 1116 |
+
|
| 1117 |
+
_md(r"""#### (c) Phase Alignment and (d) IPR & Norms
|
| 1118 |
+
|
| 1119 |
+
**(c) Phase alignment:** Average $|\sin(\mathcal{D}_m^\star)|$ over all neurons, where $\mathcal{D}_m^\star = 2\phi_m^\star - \psi_m^\star$ is the phase misalignment at each neuron's dominant frequency. This measures how far the network is from the ideal relationship $\psi = 2\phi$. It decreases throughout training as phases align, with the steepest drop during Stage II.
|
| 1120 |
+
|
| 1121 |
+
**(d) IPR and parameter norms:** IPR (Fourier sparsity) increases sharply in Stage II -- this is the "aha" moment where multi-frequency noise collapses into clean single-frequency features. Parameter norms shrink steadily in Stage III as weight decay slowly polishes the solution.""")
|
| 1122 |
+
with gr.Row():
|
| 1123 |
+
t6_phase_diff = gr.Image(
|
| 1124 |
+
label="Phase Difference |sin(D*)|", type="filepath"
|
| 1125 |
+
)
|
| 1126 |
+
t6_ipr = gr.Image(label="IPR & Parameter Norms", type="filepath")
|
| 1127 |
+
|
| 1128 |
+
_md(r"""#### (e) Memorization Accuracy Grid
|
| 1129 |
+
|
| 1130 |
+
Each cell $(i,j)$ in the grid shows whether the network correctly predicts $(i+j) \bmod p$ at a given training epoch. **White = correct, dark = incorrect.** Training pairs are marked with dots.
|
| 1131 |
+
|
| 1132 |
+
During Stage I, the network first memorizes **symmetric pairs** -- pairs where both $(i,j)$ and $(j,i)$ are in the training set (they appear on both sides of the diagonal). These are learned first because the architecture treats inputs symmetrically: $\theta_m[i] + \theta_m[j] = \theta_m[j] + \theta_m[i]$, so learning one automatically gives the other.
|
| 1133 |
+
|
| 1134 |
+
**Asymmetric pairs** (where only one of $(i,j)$ or $(j,i)$ is in training) are harder to memorize and are learned later. Some test pairs may even be *actively suppressed* (the network gets them wrong on purpose) before eventually being memorized.""")
|
| 1135 |
+
t6_memo = gr.Image(label="Memorization Accuracy", type="filepath")
|
| 1136 |
+
|
| 1137 |
+
_md(r"""#### (f) Common-to-Rare Ordering
|
| 1138 |
+
|
| 1139 |
+
This plot reorders the accuracy grid to reveal the **memorization sequence**. Instead of plotting by input value, it sorts pairs by how "common" they are in the training set:
|
| 1140 |
+
|
| 1141 |
+
- **Common pairs** (top-left): Both $(i,j)$ and $(j,i)$ in training set. These are memorized first.
|
| 1142 |
+
- **Rare pairs** (bottom-right): Only one ordering in training set. These are memorized last, and may be temporarily suppressed before being learned.
|
| 1143 |
+
|
| 1144 |
+
The plot shows a clear **top-left to bottom-right** progression, confirming that the network memorizes common pairs before rare ones.""")
|
| 1145 |
+
t6_memo_rare = gr.Image(label="Memorization: Common to Rare", type="filepath")
|
| 1146 |
+
|
| 1147 |
+
_md(r"""#### (g) Decoded Weights Across Stages
|
| 1148 |
+
|
| 1149 |
+
DFT heatmaps of the network's weights at key epochs through the three stages. Each row is a neuron; each column is a Fourier frequency component.
|
| 1150 |
+
|
| 1151 |
+
- **Stage I (memorization):** Weights are noisy with energy spread across many frequencies -- the network is using all available capacity to memorize.
|
| 1152 |
+
- **Stage II (generalization):** Weight decay kills the weak frequencies. Each neuron's energy concentrates into a single frequency -- clean Fourier features emerge.
|
| 1153 |
+
- **Stage III (cleanup):** Features are already clean; weight decay slowly shrinks overall magnitude without changing the structure.""")
|
| 1154 |
+
t6_decoded = gr.Image(label="Decoded Weights Across Stages", type="filepath")
|
| 1155 |
+
|
| 1156 |
+
_md(r"""#### Accuracy Grid Across Training (Interactive)
|
| 1157 |
+
|
| 1158 |
+
Use the slider to scrub through training epochs and watch the accuracy grid evolve. In Stage I, you'll see the symmetric pairs (along both diagonals) light up first, then asymmetric pairs fill in, and finally the entire grid becomes white in Stage II as the network generalizes.""")
|
| 1159 |
+
t6_slider = gr.Slider(
|
| 1160 |
+
minimum=0, maximum=0, value=0, step=1,
|
| 1161 |
+
label="Epoch Snapshot Index", interactive=True,
|
| 1162 |
+
visible=False,
|
| 1163 |
+
)
|
| 1164 |
+
t6_heatmap = gr.Plot(label="Accuracy Heatmap")
|
| 1165 |
+
t6_epoch_label = _md("")
|
| 1166 |
+
|
| 1167 |
+
# === Theory ===
|
| 1168 |
+
|
| 1169 |
+
# Tab 7: Gradient Dynamics
|
| 1170 |
+
with gr.Tab("7. Gradient Dynamics"):
|
| 1171 |
+
_md(MATH_TAB7)
|
| 1172 |
+
_md(r"""#### Quadratic Activation ($\sigma(x) = x^2$)
|
| 1173 |
+
|
| 1174 |
+
**Left -- Phase alignment:** Tracks the input phase $\phi_m^\star$, output phase $\psi_m^\star$, and doubled input phase $2\phi_m^\star$ of the dominant frequency in a single neuron over training. The theory predicts $\psi \to 2\phi$; here we see $\psi$ (red) and $2\phi$ (blue) converge and overlap, confirming phase alignment. The phases lock in early while magnitudes are still small.
|
| 1175 |
+
|
| 1176 |
+
**Right -- DFT heatmaps:** Decoded weights in Fourier space at key training steps. At step 0, the neuron starts with energy at a single frequency (by construction -- single-frequency initialization). At later steps, the dominant frequency grows while all other frequencies stay at zero. This confirms the **single-frequency preservation theorem**: Fourier orthogonality prevents energy leakage between modes.""")
|
| 1177 |
+
with gr.Row():
|
| 1178 |
+
t7_pa_quad = gr.Image(label="Phase Alignment (Quad)", type="filepath")
|
| 1179 |
+
t7_sf_quad = gr.Image(label="Decoded Weights (Quad)", type="filepath")
|
| 1180 |
+
_md(r"""#### ReLU Activation ($\sigma(x) = \max(0, x)$)
|
| 1181 |
+
|
| 1182 |
+
**Left -- Phase alignment:** Same as quadratic above, but with ReLU. The qualitative behavior is identical: $\psi$ converges to $2\phi$. Minor quantitative differences arise because ReLU is not exactly $x^2$.
|
| 1183 |
+
|
| 1184 |
+
**Right -- DFT heatmaps:** Unlike quadratic, ReLU leaks small amounts of energy to **harmonic multiples** of the dominant frequency ($3k^\star, 5k^\star, \ldots$ for input weights; $2k^\star, 3k^\star, \ldots$ for output weights). This leakage decays as $O(r^{-2})$ where $r$ is the harmonic order, so the dominant frequency remains overwhelmingly dominant. The faint "stripes" at harmonic positions are this leakage.""")
|
| 1185 |
+
with gr.Row():
|
| 1186 |
+
t7_pa_relu = gr.Image(label="Phase Alignment (ReLU)", type="filepath")
|
| 1187 |
+
t7_sf_relu = gr.Image(label="Decoded Weights (ReLU)", type="filepath")
|
| 1188 |
+
|
| 1189 |
+
# Tab 8: Decoupled Simulation
|
| 1190 |
+
with gr.Tab("8. Decoupled Simulation"):
|
| 1191 |
+
_md(MATH_TAB8)
|
| 1192 |
+
_md(r"""Each 3-panel figure below shows one simulation run. The gray curves are non-winning frequencies; the colored curves are the winning frequency $k^\star$.
|
| 1193 |
+
|
| 1194 |
+
**Top panel -- Phase alignment:** $\psi_{k^\star}$ (red) and $2\phi_{k^\star}$ (blue) converge toward each other, confirming that training drives phases into the $\psi = 2\phi$ relationship even in this pure ODE setting (no neural network).
|
| 1195 |
+
|
| 1196 |
+
**Middle panel -- Phase difference $D_{k^\star}$:** The misalignment $\mathcal{D}_{k^\star} = 2\phi_{k^\star} - \psi_{k^\star}$ converges toward $0$ (or $\pi/2$ transiently in Case 1). The dashed horizontal line marks $\pi/2$. Non-winning frequencies (gray) remain scattered because their magnitudes are too small to drive phase alignment.
|
| 1197 |
+
|
| 1198 |
+
**Bottom panel -- Magnitude evolution:** The winning frequency's magnitudes ($\alpha_{k^\star}$ and $\beta_{k^\star}$) grow explosively once phase alignment is achieved, while all other frequencies remain near their initial values. This is the lottery ticket effect in pure form.""")
|
| 1199 |
+
with gr.Row():
|
| 1200 |
+
t8_approx1 = gr.Image(
|
| 1201 |
+
label="Gradient Flow (Case 1: with annotations)", type="filepath"
|
| 1202 |
+
)
|
| 1203 |
+
t8_approx2 = gr.Image(label="Gradient Flow (Case 2)", type="filepath")
|
| 1204 |
+
|
| 1205 |
+
# Tab 9: Training Log
|
| 1206 |
+
with gr.Tab("9. Training Log"):
|
| 1207 |
+
_md(MATH_TAB9)
|
| 1208 |
+
t9_run_dd = gr.Dropdown(
|
| 1209 |
+
choices=[], value=None,
|
| 1210 |
+
label="Select Training Run", interactive=True,
|
| 1211 |
+
)
|
| 1212 |
+
t9_config_md = _md("")
|
| 1213 |
+
t9_log_text = gr.Code(
|
| 1214 |
+
value="", language=None, label="Training Log",
|
| 1215 |
+
lines=30, interactive=False,
|
| 1216 |
+
)
|
| 1217 |
+
|
| 1218 |
+
# All outputs for prime change
|
| 1219 |
+
all_outputs = [
|
| 1220 |
+
info_md,
|
| 1221 |
+
# Tab 1: Overview
|
| 1222 |
+
t1_img_overview, t1_std_loss, t1_grokk_loss,
|
| 1223 |
+
t1_std_ipr, t1_grokk_ipr, t1_phase_scatter,
|
| 1224 |
+
# Tab 2
|
| 1225 |
+
t2_heatmap, t2_line_in, t2_line_out,
|
| 1226 |
+
t2_neuron_dd, t2_neuron_chart,
|
| 1227 |
+
# Tab 3
|
| 1228 |
+
t3_phase_dist, t3_phase_rel, t3_magnitude,
|
| 1229 |
+
# Tab 4
|
| 1230 |
+
t4_logits,
|
| 1231 |
+
t4_pair_dd, t4_logit_chart,
|
| 1232 |
+
# Tab 5
|
| 1233 |
+
t5_mag, t5_phase, t5_contour,
|
| 1234 |
+
# Tab 6
|
| 1235 |
+
t6_loss, t6_acc, t6_phase_diff, t6_ipr,
|
| 1236 |
+
t6_memo, t6_memo_rare, t6_decoded,
|
| 1237 |
+
t6_slider, t6_heatmap, t6_epoch_label,
|
| 1238 |
+
# Tab 7
|
| 1239 |
+
t7_pa_quad, t7_sf_quad, t7_pa_relu, t7_sf_relu,
|
| 1240 |
+
# Tab 8
|
| 1241 |
+
t8_approx1, t8_approx2,
|
| 1242 |
+
# Tab 9
|
| 1243 |
+
t9_run_dd, t9_config_md, t9_log_text,
|
| 1244 |
+
]
|
| 1245 |
+
|
| 1246 |
+
# --- Prime change handler ---
|
| 1247 |
+
def p_change_and_store(p_str):
|
| 1248 |
+
p = int(p_str)
|
| 1249 |
+
results = on_p_change(p_str)
|
| 1250 |
+
return [p] + results
|
| 1251 |
+
|
| 1252 |
+
p_dropdown.change(
|
| 1253 |
+
fn=p_change_and_store,
|
| 1254 |
+
inputs=[p_dropdown],
|
| 1255 |
+
outputs=[current_p] + all_outputs,
|
| 1256 |
+
)
|
| 1257 |
+
|
| 1258 |
+
# --- Neuron dropdown handler ---
|
| 1259 |
+
def on_neuron_change(neuron_key, p):
|
| 1260 |
+
data = load_json_file(p, "neuron_spectra.json")
|
| 1261 |
+
return make_neuron_spectrum_chart(data, neuron_key)
|
| 1262 |
+
|
| 1263 |
+
t2_neuron_dd.change(
|
| 1264 |
+
fn=on_neuron_change,
|
| 1265 |
+
inputs=[t2_neuron_dd, current_p],
|
| 1266 |
+
outputs=[t2_neuron_chart],
|
| 1267 |
+
)
|
| 1268 |
+
|
| 1269 |
+
# --- Logit pair dropdown handler ---
|
| 1270 |
+
def on_pair_change(pair_str, p):
|
| 1271 |
+
data = load_json_file(p, "logits_interactive.json")
|
| 1272 |
+
if data is None or not pair_str:
|
| 1273 |
+
return None
|
| 1274 |
+
pairs = data.get('pairs', [])
|
| 1275 |
+
pair_labels = [f"({a}, {b})" for a, b in pairs]
|
| 1276 |
+
if pair_str in pair_labels:
|
| 1277 |
+
idx = pair_labels.index(pair_str)
|
| 1278 |
+
else:
|
| 1279 |
+
idx = 0
|
| 1280 |
+
return make_logit_bar_chart(data, idx)
|
| 1281 |
+
|
| 1282 |
+
t4_pair_dd.change(
|
| 1283 |
+
fn=on_pair_change,
|
| 1284 |
+
inputs=[t4_pair_dd, current_p],
|
| 1285 |
+
outputs=[t4_logit_chart],
|
| 1286 |
+
)
|
| 1287 |
+
|
| 1288 |
+
# --- Grokking slider handler ---
|
| 1289 |
+
def on_grokk_slider(slider_val, p):
|
| 1290 |
+
data = load_json_file(p, "grokk_epoch_data.json")
|
| 1291 |
+
if data is None:
|
| 1292 |
+
return None, ""
|
| 1293 |
+
idx = int(slider_val)
|
| 1294 |
+
epochs = data.get('epochs', [])
|
| 1295 |
+
label = f"**Epoch: {epochs[idx]}**" if idx < len(epochs) else ""
|
| 1296 |
+
return make_grokk_heatmap(data, idx), label
|
| 1297 |
+
|
| 1298 |
+
t6_slider.change(
|
| 1299 |
+
fn=on_grokk_slider,
|
| 1300 |
+
inputs=[t6_slider, current_p],
|
| 1301 |
+
outputs=[t6_heatmap, t6_epoch_label],
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
# --- Training log run dropdown handler ---
|
| 1305 |
+
def on_log_run_change(run_name, p):
|
| 1306 |
+
data = load_json_file(p, "training_log.json")
|
| 1307 |
+
if data is None or run_name not in data:
|
| 1308 |
+
return "", ""
|
| 1309 |
+
run_data = data[run_name]
|
| 1310 |
+
config = run_data.get('config', {})
|
| 1311 |
+
config_text = _format_config_md(run_name, config)
|
| 1312 |
+
log_text = run_data.get('log_text', 'No log available.')
|
| 1313 |
+
return config_text, log_text
|
| 1314 |
+
|
| 1315 |
+
t9_run_dd.change(
|
| 1316 |
+
fn=on_log_run_change,
|
| 1317 |
+
inputs=[t9_run_dd, current_p],
|
| 1318 |
+
outputs=[t9_config_md, t9_log_text],
|
| 1319 |
+
)
|
| 1320 |
+
|
| 1321 |
+
# --- On-demand training handler (streaming) ---
|
| 1322 |
+
def on_generate_click(new_p):
|
| 1323 |
+
if new_p is None:
|
| 1324 |
+
yield (
|
| 1325 |
+
gr.update(), gr.update(),
|
| 1326 |
+
"Enter a value for p.",
|
| 1327 |
+
gr.update(visible=False, value=""),
|
| 1328 |
+
)
|
| 1329 |
+
return
|
| 1330 |
+
p = int(new_p)
|
| 1331 |
+
log_lines = []
|
| 1332 |
+
yield (
|
| 1333 |
+
gr.update(), gr.update(),
|
| 1334 |
+
f"**Running pipeline for p={p}...**",
|
| 1335 |
+
gr.update(visible=True, value="Starting...\n"),
|
| 1336 |
+
)
|
| 1337 |
+
for line, is_err, is_done in run_pipeline_for_p_streaming(p):
|
| 1338 |
+
log_lines.append(line)
|
| 1339 |
+
# Keep last 200 lines to avoid memory bloat
|
| 1340 |
+
display = "\n".join(log_lines[-200:])
|
| 1341 |
+
if is_err:
|
| 1342 |
+
yield (
|
| 1343 |
+
gr.update(), gr.update(),
|
| 1344 |
+
f"**Error:** {line}",
|
| 1345 |
+
gr.update(value=display),
|
| 1346 |
+
)
|
| 1347 |
+
return
|
| 1348 |
+
if is_done:
|
| 1349 |
+
new_moduli = get_available_moduli()
|
| 1350 |
+
new_choices = [str(v) for v in new_moduli]
|
| 1351 |
+
yield (
|
| 1352 |
+
gr.update(choices=new_choices, value=str(p)),
|
| 1353 |
+
gr.update(),
|
| 1354 |
+
f"**Success:** {line}",
|
| 1355 |
+
gr.update(value=display),
|
| 1356 |
+
)
|
| 1357 |
+
return
|
| 1358 |
+
yield (
|
| 1359 |
+
gr.update(), gr.update(),
|
| 1360 |
+
f"**Running pipeline for p={p}...**",
|
| 1361 |
+
gr.update(value=display),
|
| 1362 |
+
)
|
| 1363 |
+
|
| 1364 |
+
generate_btn.click(
|
| 1365 |
+
fn=on_generate_click,
|
| 1366 |
+
inputs=[new_p_input],
|
| 1367 |
+
outputs=[p_dropdown, current_p, generate_status, generate_log],
|
| 1368 |
+
)
|
| 1369 |
+
|
| 1370 |
+
return app
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
if __name__ == "__main__":
|
| 1374 |
+
app = create_app()
|
| 1375 |
+
app.launch(theme=gr.themes.Soft(), css=CUSTOM_CSS)
|
hf_app/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 4 |
+
numpy>=1.24
|
| 5 |
+
matplotlib>=3.7
|
| 6 |
+
Pillow>=9.0
|
| 7 |
+
plotly>=5.0
|
| 8 |
+
einops>=0.6
|
| 9 |
+
scipy>=1.10
|
precompute/README.md
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pre-computation Pipeline
|
| 2 |
+
|
| 3 |
+
Batch training and plot generation for all odd moduli $p$ in [3, 199]. Trains 5 model configurations per $p$ and generates publication-quality figures plus interactive JSON data files covering the paper's core results.
|
| 4 |
+
|
| 5 |
+
All commands are run from the **project root directory**.
|
| 6 |
+
|
| 7 |
+
## Quick Start (Shell Script)
|
| 8 |
+
|
| 9 |
+
The easiest way to run the full pipeline for a single modulus:
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# Run the complete pipeline for p=23
|
| 13 |
+
bash precompute/run_pipeline.sh 23
|
| 14 |
+
|
| 15 |
+
# Or using an environment variable
|
| 16 |
+
P=23 bash precompute/run_pipeline.sh
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
This runs training, plot generation, analytical simulation, and verification in sequence.
|
| 20 |
+
|
| 21 |
+
## Complete Pipeline (Single Modulus, Manual Steps)
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
# Step 1: Train all 5 model configurations
|
| 25 |
+
python precompute/train_all.py --p 23 --output ./trained_models
|
| 26 |
+
|
| 27 |
+
# Step 2: Generate model-based plots (21 PNGs + 6 JSONs + metadata)
|
| 28 |
+
python precompute/generate_plots.py --p 23 --input ./trained_models --output ./precomputed_results
|
| 29 |
+
|
| 30 |
+
# Step 3: Generate analytical simulation plots (2 PNGs, no model needed)
|
| 31 |
+
python precompute/generate_analytical.py --p 23 --output ./precomputed_results
|
| 32 |
+
|
| 33 |
+
# Step 4: Verify
|
| 34 |
+
ls precomputed_results/p_023/
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## Complete Pipeline (All Odd p)
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
# Train everything (225 runs total). Use --resume to skip completed runs.
|
| 41 |
+
python precompute/train_all.py --all --output ./trained_models --resume
|
| 42 |
+
|
| 43 |
+
# Generate all plots
|
| 44 |
+
python precompute/generate_plots.py --all --input ./trained_models --output ./precomputed_results
|
| 45 |
+
|
| 46 |
+
# Generate all analytical plots
|
| 47 |
+
python precompute/generate_analytical.py --all --output ./precomputed_results
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## The 5 Model Configurations
|
| 53 |
+
|
| 54 |
+
Each modulus is trained with 5 configurations that correspond to different sections of the paper:
|
| 55 |
+
|
| 56 |
+
### 1. Standard Training (`standard`)
|
| 57 |
+
|
| 58 |
+
The baseline experiment for Parts I--II (Mechanism & Dynamics). Demonstrates Fourier feature learning: neurons decompose modular addition into sparse frequency components with phase alignment (ψ ≈ 2φ).
|
| 59 |
+
|
| 60 |
+
| Parameter | Value |
|
| 61 |
+
|-----------|-------|
|
| 62 |
+
| Activation | ReLU |
|
| 63 |
+
| Initialization | random |
|
| 64 |
+
| Optimizer | AdamW |
|
| 65 |
+
| Learning rate | 5e-5 |
|
| 66 |
+
| Weight decay | 0 |
|
| 67 |
+
| Train fraction | 1.0 (all p² pairs) |
|
| 68 |
+
| Epochs | 5,000 |
|
| 69 |
+
| Init scale | 0.1 |
|
| 70 |
+
|
| 71 |
+
**Used by:** Tab 1 (Overview), Tab 2 (Fourier Weights), Tab 3 (Phase Analysis), Tab 4 (Output Logits)
|
| 72 |
+
|
| 73 |
+
### 2. Grokking (`grokking`)
|
| 74 |
+
|
| 75 |
+
Reproduces the grokking phenomenon (Part III). The model memorizes training data first, then suddenly generalizes. Requires partial training data + weight decay.
|
| 76 |
+
|
| 77 |
+
| Parameter | Value |
|
| 78 |
+
|-----------|-------|
|
| 79 |
+
| Activation | ReLU |
|
| 80 |
+
| Initialization | random |
|
| 81 |
+
| Optimizer | AdamW |
|
| 82 |
+
| Learning rate | 1e-4 |
|
| 83 |
+
| Weight decay | **2.0** |
|
| 84 |
+
| Train fraction | **0.75** |
|
| 85 |
+
| Epochs | **50,000** |
|
| 86 |
+
| Init scale | 0.1 |
|
| 87 |
+
|
| 88 |
+
**Used by:** Tab 1 (Overview, grokking curves), Tab 6 (Grokking)
|
| 89 |
+
**Note:** Only runs for p ≥ 19 (smaller $p$ have too few test points for meaningful grokking).
|
| 90 |
+
|
| 91 |
+
### 3. Quadratic Activation (`quad_random`)
|
| 92 |
+
|
| 93 |
+
Uses σ(x) = x² activation. The quadratic nonlinearity directly implements the frequency factorization mechanism from the theory, enabling clean analysis of the lottery ticket mechanism.
|
| 94 |
+
|
| 95 |
+
| Parameter | Value |
|
| 96 |
+
|-----------|-------|
|
| 97 |
+
| Activation | **Quad** |
|
| 98 |
+
| Initialization | random |
|
| 99 |
+
| Optimizer | AdamW |
|
| 100 |
+
| Learning rate | 5e-5 |
|
| 101 |
+
| Weight decay | 0 |
|
| 102 |
+
| Train fraction | 1.0 |
|
| 103 |
+
| Epochs | 5,000 |
|
| 104 |
+
| Init scale | 0.1 |
|
| 105 |
+
|
| 106 |
+
**Used by:** Tab 5 (Lottery Mechanism)
|
| 107 |
+
|
| 108 |
+
### 4. Single-Frequency Quad (`quad_single_freq`)
|
| 109 |
+
|
| 110 |
+
Initializes neurons at specific frequencies to study gradient dynamics under controlled conditions. Validates the phase alignment theorem and single-frequency preservation theorem.
|
| 111 |
+
|
| 112 |
+
| Parameter | Value |
|
| 113 |
+
|-----------|-------|
|
| 114 |
+
| Activation | **Quad** |
|
| 115 |
+
| Initialization | **single-freq** |
|
| 116 |
+
| Optimizer | **SGD** |
|
| 117 |
+
| Learning rate | **0.1** |
|
| 118 |
+
| Weight decay | 0 |
|
| 119 |
+
| Train fraction | 1.0 |
|
| 120 |
+
| Epochs | 5,000 |
|
| 121 |
+
| Init scale | **0.02** |
|
| 122 |
+
|
| 123 |
+
**Used by:** Tab 7 (Gradient Dynamics, quadratic panels)
|
| 124 |
+
|
| 125 |
+
### 5. Single-Frequency ReLU (`relu_single_freq`)
|
| 126 |
+
|
| 127 |
+
Same as above but with ReLU activation. Shows that the theoretical results (proved for quadratic) hold approximately for ReLU with small harmonic leakage.
|
| 128 |
+
|
| 129 |
+
| Parameter | Value |
|
| 130 |
+
|-----------|-------|
|
| 131 |
+
| Activation | **ReLU** |
|
| 132 |
+
| Initialization | **single-freq** |
|
| 133 |
+
| Optimizer | **SGD** |
|
| 134 |
+
| Learning rate | **0.01** |
|
| 135 |
+
| Weight decay | 0 |
|
| 136 |
+
| Train fraction | 1.0 |
|
| 137 |
+
| Epochs | 5,000 |
|
| 138 |
+
| Init scale | **0.002** |
|
| 139 |
+
|
| 140 |
+
**Used by:** Tab 7 (Gradient Dynamics, ReLU panels)
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
## Neuron Sizing
|
| 145 |
+
|
| 146 |
+
The number of hidden neurons scales with $p$ to maintain the ratio from the baseline experiment ($p=23$, $d_\text{mlp}=512$):
|
| 147 |
+
|
| 148 |
+
```
|
| 149 |
+
d_mlp = max(512, ceil(512/529 * p²))
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
Examples: $p=3 \to 512$, $p=23 \to 512$, $p=53 \to 2720$, $p=97 \to 9108$, $p=199 \to 38329$.
|
| 153 |
+
|
| 154 |
+
---
|
| 155 |
+
|
| 156 |
+
## Blog Figure → Pipeline Output Mapping
|
| 157 |
+
|
| 158 |
+
The table below maps every figure in the blog post to the corresponding file generated by the pipeline. Each figure is reproduced for every $p$, allowing users to verify the paper's claims across different moduli.
|
| 159 |
+
|
| 160 |
+
### Part I: Mechanism (Tabs 2--4, standard run)
|
| 161 |
+
|
| 162 |
+
| Blog Figure | Description | Pipeline Output | Tab | Verified? |
|
| 163 |
+
|-------------|-------------|----------------|-----|-----------|
|
| 164 |
+
| **Fig. 2** — Fourier sparsity of learned weights | DFT heatmap: each row is a neuron, each column is a Fourier mode (cos k, sin k). Sparse = one bright cell per row, confirming single-frequency specialization. | `pXXX_full_training_para_origin.png` | 2 | The heatmap applies `W_in @ fourier_basis.T` and `W_out.T @ fourier_basis.T` to show DFT coefficients. X-axis labels are Fourier mode names (Const, cos 1, sin 1, ...). Sparsity is visible as one dominant pair per neuron row. |
|
| 165 |
+
| **Fig. 3** — Cosine fits to individual neurons | Raw learned weight values (dots) vs. best-fit cosine (dashed) for 3 representative neurons. Left: input weights θ_m. Right: output weights ξ_m. | `pXXX_lineplot_in.png`, `pXXX_lineplot_out.png` | 2 | Projects raw weights into Fourier space, keeps top-2 components, projects back. The fit quality demonstrates that each neuron is well-described by a single cosine. |
|
| 166 |
+
| **Fig. 4** — Phase alignment ψ = 2φ | Scatter plot of (2φ_m mod 2π) vs (ψ_m mod 2π). All points lie on the diagonal y = x. | `pXXX_phase_relationship.png` | 3 | Computed via `compute_neuron()` for every neuron. The diagonal pattern is Observation 2 from the paper. |
|
| 167 |
+
| **Fig. 5** — Higher-order phase symmetry | Polar plot: phase angles ι·φ_m on concentric rings for ι = 1, 2, 3, 4. Uniform spread confirms the cancellation condition Σ exp(i·ι·φ_m) ≈ 0. | `pXXX_phase_distribution.png` | 3 | Shows phases for the most common frequency group. For large p with many neurons, the uniform spread is clearly visible. |
|
| 168 |
+
| **Fig. 6** — Magnitude homogeneity | Violin plots of α_m (input) and β_m (output) across all neurons. Tight concentration confirms magnitude homogeneity (Observation 3c). | `pXXX_magnitude_distribution.png` | 3 | Uses `compute_neuron()` to extract scale for every neuron. |
|
| 169 |
+
| **Fig. 7** — Output logits (flawed indicator) | Heatmap of f(x,y)[j] for x=0. Bright red diagonal at j=(x+y) mod p (correct answer, coefficient p/2). Faint pink at j=2x mod p and j=2y mod p (spurious peaks, coefficient p/4). | `pXXX_output_logits.png` | 4 | Forward pass through the trained model with **matching activation** (ReLU for standard run). Rectangles highlight the correct answer and spurious peak positions. |
|
| 170 |
+
|
| 171 |
+
### Part II: Dynamics (Tabs 5, 7, 8)
|
| 172 |
+
|
| 173 |
+
| Blog Figure | Description | Pipeline Output | Tab | Verified? |
|
| 174 |
+
|-------------|-------------|----------------|-----|-----------|
|
| 175 |
+
| **Fig. 8** — Phase alignment dynamics | Phase trajectories (φ, ψ, 2φ) and magnitude growth (α, β) over training. Left: Quad activation. Right: ReLU. Shows ψ → 2φ convergence. | `pXXX_phase_align_quad.png`, `pXXX_phase_align_relu.png` | 7 | Tracks the neuron with largest final scale across all checkpoints. Shows phases converging and magnitudes growing. |
|
| 176 |
+
| **Fig. 9** — Lottery ticket race | Left: phase misalignment D_m^k(t) for all frequencies within one neuron. The winner (smallest initial D) converges first. Right: magnitude β_m^k(t). Winner grows explosively. | `pXXX_lottery_mech_phase.png`, `pXXX_lottery_mech_magnitude.png` | 5 | Tracks all frequency components of a single neuron via `decode_scales_phis()` across checkpoints from the `quad_random` run. The winning frequency (highlighted in red) has the smallest initial misalignment. |
|
| 177 |
+
| **Fig. 10** — Lottery outcome contour | Final magnitude β as a function of (initial magnitude, initial phase difference 2φ₀). Largest values at small D, symmetric about D = π. | `pXXX_lottery_beta_contour.png` | 5 | Simulates gradient flow on a 30×30 grid of initial conditions. Each point runs 100 steps of the analytical ODE. |
|
| 178 |
+
| **Fig. 11** — Single-frequency preservation (Quad) | DFT heatmap at multiple training timepoints. The initialized frequency retains all energy; no cross-frequency leakage. | `pXXX_single_freq_quad.png` | 7 | Shows DFT of weights at 3 timepoints (step 0, mid, final) for the `quad_single_freq` run. Each column is a Fourier mode; sparsity confirms preservation. |
|
| 179 |
+
| **Fig. 12a** — Single-frequency preservation (ReLU) | Same as Fig. 11 but with ReLU. Small harmonic leakage visible at 3k*, 5k* (input) and 2k*, 3k* (output), decaying as O(r⁻²). | `pXXX_single_freq_relu.png` | 7 | Shows DFT at 2 timepoints (step 0, final) for the `relu_single_freq` run. Dominant frequency overwhelms harmonics. |
|
| 180 |
+
| **Fig. 12b** — Phase alignment under ReLU | Phase and magnitude trajectories for ReLU single-frequency init. Same zero-attractor behavior as Quad. | `pXXX_phase_align_relu.png` | 7 | Same as Fig. 8 right panel. |
|
| 181 |
+
| — Decoupled ODE simulation | Pure ODE integration (no neural network) showing phase convergence and magnitude competition for all frequencies within one neuron. Two cases with different initial conditions. | `pXXX_phase_align_approx1.png`, `pXXX_phase_align_approx2.png` | 8 | Generated by `generate_analytical.py`, not `generate_plots.py`. Validates the theory without any training. |
|
| 182 |
+
|
| 183 |
+
### Part III: Grokking (Tab 6)
|
| 184 |
+
|
| 185 |
+
| Blog Figure | Description | Pipeline Output | Tab | Verified? |
|
| 186 |
+
|-------------|-------------|----------------|-----|-----------|
|
| 187 |
+
| **Fig. 13a** — Grokking loss curves | Training and test loss over 50k epochs. Three stages: (I) train loss drops, (II) test loss drops, (III) both near zero. Stage boundaries marked. | `pXXX_grokk_loss.json` → interactive Plotly chart | 6 | From `training_curves.json`. Stage boundaries detected by `grokking_stage_detector.py`. Shaded regions distinguish the three stages. |
|
| 188 |
+
| **Fig. 13b** — Grokking accuracy curves | Training and test accuracy. Train → 100% in Stage I, test jumps in Stage II. | `pXXX_grokk_acc.json` → interactive Plotly chart | 6 | Computed by running forward pass on train/test data at each checkpoint. |
|
| 189 |
+
| **Fig. 13c** — Phase alignment progress | Average |sin(D_m*)| over training. Decreases throughout, steepest in Stage II. | `pXXX_grokk_abs_phase_diff.png` | 6 | Computed via `decode_weights()` + `compute_neuron()` at each grokking checkpoint. |
|
| 190 |
+
| **Fig. 13d** — IPR and parameter norm | Dual-axis: IPR (Fourier sparsity) increases sharply in Stage II; parameter norm shrinks in Stage III. | `pXXX_grokk_avg_ipr.png` | 6 | IPR uses the corrected per-frequency magnitude formula: A_k = sqrt(c_k² + s_k²), IPR = Σ A_k⁴ / (Σ A_k²)². Parameter norms from `training_curves.json`. |
|
| 191 |
+
| **Fig. 14** — Memorization accuracy heatmap | Three panels at end of Stage I: (1) training data distribution under symmetry, (2) accuracy grid, (3) softmax probability at ground truth. Red rectangles = true test pairs. | `pXXX_grokk_memorization_accuracy.png` | 6 | Forward pass at the checkpoint closest to stage1_end. The symmetric architecture guarantees ~70% test accuracy during memorization. |
|
| 192 |
+
| **Fig. 15** — Common-to-rare memorization | Four panels: training data distribution + accuracy at 3 timepoints during Stage I. Shows common pairs (both (i,j) and (j,i) in train) memorized before rare pairs (only one ordering). | `pXXX_grokk_memorization_common_to_rare.png` | 6 | Epochs selected at 0, stage1_end/2, stage1_end. Red rectangles mark asymmetric training pairs. |
|
| 193 |
+
| **Fig. 16** — Weight evolution during grokking | 2×3 DFT heatmap grid showing θ_m and ξ_m at Step 0 (random init), end of Stage I (noisy multi-frequency), and end of Stage II (clean single-frequency). | `pXXX_grokk_decoded_weights_dynamic.png` | 6 | DFT coefficients `W @ fourier_basis.T` at 3 key epochs. The transition from diffuse to sparse confirms the sparsification narrative. |
|
| 194 |
+
|
| 195 |
+
### Overview (Tab 1)
|
| 196 |
+
|
| 197 |
+
| Blog Figure | Description | Pipeline Output | Tab | Verified? |
|
| 198 |
+
|-------------|-------------|----------------|-----|-----------|
|
| 199 |
+
| — Overview dashboard | 2×2 grid: standard loss + grokking loss (top), standard IPR + grokking IPR (bottom). Plus phase scatter from standard final checkpoint. | `pXXX_overview_loss_ipr.png`, `pXXX_overview_phase_scatter.png`, `pXXX_overview.json` | 1 | Combines data from standard and grokking runs. The phase scatter uses the same computation as Fig. 4. For $p < 19$, only the standard column is shown. |
|
| 200 |
+
|
| 201 |
+
### Not Currently Generated (Blog-Only Figures)
|
| 202 |
+
|
| 203 |
+
| Blog Figure | Description | Status |
|
| 204 |
+
|-------------|-------------|--------|
|
| 205 |
+
| **Fig. 1** — Architecture illustration | Schematic of the two-layer network, DFT decomposition, and mechanism overview. | Static illustration, not dependent on p. Could be included as a fixed image in the app. |
|
| 206 |
+
|
| 207 |
+
---
|
| 208 |
+
|
| 209 |
+
## Interactive JSON Data Files
|
| 210 |
+
|
| 211 |
+
In addition to static PNG plots, the pipeline generates JSON files for interactive Plotly charts in the Gradio app:
|
| 212 |
+
|
| 213 |
+
| File | Content | Used In |
|
| 214 |
+
|------|---------|---------|
|
| 215 |
+
| `pXXX_overview.json` | Standard loss/IPR + grokking loss/IPR time series | Tab 1: Interactive loss and IPR charts |
|
| 216 |
+
| `pXXX_neuron_spectra.json` | Per-neuron Fourier magnitudes (W_in and W_out) for top-20 neurons, sorted by frequency | Tab 2: Neuron Inspector dropdown → bar chart of frequency decomposition |
|
| 217 |
+
| `pXXX_logits_interactive.json` | Output logits for p representative (a,b) pairs, plus correct answers | Tab 4: Logit Explorer dropdown → bar chart with correct answer highlighted |
|
| 218 |
+
| `pXXX_grokk_loss.json` | Full training/test loss curves + stage boundaries | Tab 6: Interactive loss chart with stage shading |
|
| 219 |
+
| `pXXX_grokk_acc.json` | Accuracy at each checkpoint epoch + stage boundaries | Tab 6: Interactive accuracy chart with stage shading |
|
| 220 |
+
| `pXXX_grokk_epoch_data.json` | p×p accuracy grids at ~10 evenly-spaced grokking epochs | Tab 6: Epoch Slider → heatmap animation across training |
|
| 221 |
+
| `pXXX_metadata.json` | Config for all 5 runs + final metrics (loss, accuracy) | Displayed in the app's info panel for the selected $p$ |
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
## Output Structure
|
| 226 |
+
|
| 227 |
+
All plots for a modulus are saved in a single flat directory. Each file is prefixed with `pXXX_` so the folder is self-contained and browsable:
|
| 228 |
+
|
| 229 |
+
```
|
| 230 |
+
precomputed_results/p_023/
|
| 231 |
+
# Metadata
|
| 232 |
+
p023_metadata.json
|
| 233 |
+
|
| 234 |
+
# Tab 1: Overview (Blog: summary of standard + grokking)
|
| 235 |
+
p023_overview_loss_ipr.png # 2×2 grid: loss + IPR for both setups
|
| 236 |
+
p023_overview_phase_scatter.png # Phase alignment scatter (same as Fig. 4)
|
| 237 |
+
p023_overview.json # Interactive data
|
| 238 |
+
|
| 239 |
+
# Tab 2: Fourier Weights (Blog: Figures 2, 3)
|
| 240 |
+
p023_full_training_para_origin.png # DFT heatmap (Fig. 2)
|
| 241 |
+
p023_lineplot_in.png # Cosine fits, input layer (Fig. 3 left)
|
| 242 |
+
p023_lineplot_out.png # Cosine fits, output layer (Fig. 3 right)
|
| 243 |
+
p023_neuron_spectra.json # Interactive: neuron inspector
|
| 244 |
+
|
| 245 |
+
# Tab 3: Phase Analysis (Blog: Figures 4, 5, 6)
|
| 246 |
+
p023_phase_distribution.png # Polar phase plot (Fig. 5)
|
| 247 |
+
p023_phase_relationship.png # 2φ vs ψ scatter (Fig. 4)
|
| 248 |
+
p023_magnitude_distribution.png # Violin plots (Fig. 6)
|
| 249 |
+
|
| 250 |
+
# Tab 4: Output Logits (Blog: Figure 7)
|
| 251 |
+
p023_output_logits.png # Logit heatmap (Fig. 7)
|
| 252 |
+
p023_logits_interactive.json # Interactive: logit explorer
|
| 253 |
+
|
| 254 |
+
# Tab 5: Lottery Mechanism (Blog: Figures 9, 10)
|
| 255 |
+
p023_lottery_mech_magnitude.png # Magnitude race (Fig. 9 right)
|
| 256 |
+
p023_lottery_mech_phase.png # Phase misalignment race (Fig. 9 left)
|
| 257 |
+
p023_lottery_beta_contour.png # Contour plot (Fig. 10)
|
| 258 |
+
|
| 259 |
+
# Tab 6: Grokking (Blog: Figures 13, 14, 15, 16)
|
| 260 |
+
p023_grokk_loss.json # Interactive loss curves (Fig. 13a)
|
| 261 |
+
p023_grokk_acc.json # Interactive accuracy curves (Fig. 13b)
|
| 262 |
+
p023_grokk_abs_phase_diff.png # Phase alignment progress (Fig. 13c)
|
| 263 |
+
p023_grokk_avg_ipr.png # IPR + param norms (Fig. 13d)
|
| 264 |
+
p023_grokk_memorization_accuracy.png # 3-panel heatmap (Fig. 14)
|
| 265 |
+
p023_grokk_memorization_common_to_rare.png # 4-panel sequence (Fig. 15)
|
| 266 |
+
p023_grokk_decoded_weights_dynamic.png # DFT evolution (Fig. 16)
|
| 267 |
+
p023_grokk_epoch_data.json # Interactive: epoch slider
|
| 268 |
+
|
| 269 |
+
# Tab 7: Gradient Dynamics (Blog: Figures 8, 11, 12)
|
| 270 |
+
p023_phase_align_quad.png # Phase + magnitude, Quad (Fig. 8 left)
|
| 271 |
+
p023_single_freq_quad.png # DFT heatmap over time, Quad (Fig. 11)
|
| 272 |
+
p023_phase_align_relu.png # Phase + magnitude, ReLU (Fig. 8 right / 12b)
|
| 273 |
+
p023_single_freq_relu.png # DFT heatmap over time, ReLU (Fig. 12a)
|
| 274 |
+
|
| 275 |
+
# Tab 8: Decoupled Simulation (no blog figure number)
|
| 276 |
+
p023_phase_align_approx1.png # ODE simulation case 1
|
| 277 |
+
p023_phase_align_approx2.png # ODE simulation case 2
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
**29 files per $p$:** 21 PNGs + 6 JSONs from trained models, 2 PNGs from analytical simulation.
|
| 281 |
+
|
| 282 |
+
---
|
| 283 |
+
|
| 284 |
+
## Correctness Verification
|
| 285 |
+
|
| 286 |
+
### How each computation matches the paper
|
| 287 |
+
|
| 288 |
+
1. **DFT Decomposition (Figs. 2, 11, 12a, 16):** We compute `W @ fourier_basis.T` where `fourier_basis` is the orthonormal DFT basis from `get_fourier_basis(p)`. The basis has rows: [Const, cos 1, sin 1, cos 2, sin 2, ..., cos K, sin K] with K = (p-1)/2 for odd $p$. Each row is L2-normalized. This matches the standard real DFT on Z_p.
|
| 289 |
+
|
| 290 |
+
2. **Phase extraction (Figs. 4, 8, 9, 13c):** For frequency k, the DFT gives coefficients (c_k, s_k) at indices (2k-1, 2k). The magnitude is α = sqrt(2/p) · sqrt(c_k² + s_k²), and the phase is φ = arctan2(-s_k, c_k). This convention matches the paper's θ_m[j] = α cos(ω_k j + φ) representation.
|
| 291 |
+
|
| 292 |
+
3. **IPR (Figs. 13d, Overview):** Uses the corrected per-frequency magnitude formula: A_k = sqrt(c_k² + s_k²) (combining cos/sin pairs), then IPR = Σ A_k⁴ / (Σ A_k²)². This gives IPR → 1 for perfect single-frequency neurons, matching the paper's definition.
|
| 293 |
+
|
| 294 |
+
4. **Phase alignment (Fig. 4):** The doubled-phase relationship ψ_m = 2φ_m is verified by extracting φ from W_in and ψ from W_out using the same `compute_neuron()` function, then plotting (2φ mod 2π) vs (ψ mod 2π).
|
| 295 |
+
|
| 296 |
+
5. **Output logits (Fig. 7):** Forward pass uses the **same activation function** as training (ReLU for standard run). The flawed indicator structure (main diagonal + two ghost diagonals) is visible because the standard run trains to 100% accuracy with clean Fourier features.
|
| 297 |
+
|
| 298 |
+
6. **Lottery mechanism (Figs. 9, 10):** Uses the `quad_random` run (quadratic activation, random init) which matches the theoretical setting. `decode_scales_phis()` extracts per-frequency magnitudes and phases at each checkpoint. The winning frequency is the one with smallest initial |D| = |2φ - ψ|.
|
| 299 |
+
|
| 300 |
+
7. **Grokking stages (Figs. 13--16):** `grokking_stage_detector.py` identifies stage boundaries from training curves. Stage I ends when train accuracy ≈ 1.0, Stage II ends when test accuracy ≈ 1.0. Memorization heatmaps use forward pass at the closest checkpoint to stage1_end.
|
| 301 |
+
|
| 302 |
+
8. **Analytical simulation (Tab 8):** Numerically integrates the four-variable ODE system from Section 5.3 of the paper. No neural network is involved — this validates the theoretical dynamics directly.
|
| 303 |
+
|
| 304 |
+
### Why results generalize across $p$
|
| 305 |
+
|
| 306 |
+
The paper's theory is stated for general odd $p$. Key properties that scale:
|
| 307 |
+
|
| 308 |
+
- **Fourier basis:** Always has (p-1)/2 non-DC frequencies for any odd $p$.
|
| 309 |
+
- **Phase alignment:** The ψ = 2φ relationship is a consequence of the gradient dynamics, independent of p.
|
| 310 |
+
- **Lottery mechanism:** Random initial misalignments are uniform on [0, 2π) for any p.
|
| 311 |
+
- **Grokking three stages:** The stage structure depends on the balance of loss gradient vs. weight decay, not on p specifically (though the stage durations and test accuracy during memorization may vary).
|
| 312 |
+
- **Network width:** d_mlp scales as O(p²) to maintain the neuron-to-frequency ratio, ensuring enough neurons per frequency for diversification.
|
| 313 |
+
|
| 314 |
+
---
|
| 315 |
+
|
| 316 |
+
## Scripts
|
| 317 |
+
|
| 318 |
+
| Script | Purpose |
|
| 319 |
+
|--------|---------|
|
| 320 |
+
| `run_pipeline.sh` | Runs the complete pipeline (train + plots + analytical + verify) for a single modulus. |
|
| 321 |
+
| `train_all.py` | Trains all 5 model configurations. Saves checkpoints + `training_curves.json`. |
|
| 322 |
+
| `generate_plots.py` | Loads trained models and generates all model-dependent plots (Tabs 1--7) plus interactive JSONs and metadata. |
|
| 323 |
+
| `generate_analytical.py` | Runs gradient flow simulations to generate theory plots (Tab 8). No model needed. |
|
| 324 |
+
| `prime_config.py` | Configuration: moduli list, d_mlp formula, training run parameters. |
|
| 325 |
+
| `neuron_selector.py` | Automated neuron selection for plots (replaces hardcoded indices from notebooks). |
|
| 326 |
+
| `grokking_stage_detector.py` | Detects memorization/transition/generalization stage boundaries from training curves. |
|
| 327 |
+
|
| 328 |
+
---
|
| 329 |
+
|
| 330 |
+
## Analytical Simulations (No Model Needed)
|
| 331 |
+
|
| 332 |
+
`generate_analytical.py` produces 2 plots per $p$ by simulating gradient flow on decoupled frequency dynamics. These validate the theoretical analysis without training any model.
|
| 333 |
+
|
| 334 |
+
- **Case 1**: Shows phase difference D* converging from initial conditions (φ₀=1.5, ψ₀=0.18)
|
| 335 |
+
- **Case 2**: Different initial conditions (φ₀=-0.72, ψ₀=-2.91) showing convergence from the other side
|
| 336 |
+
|
| 337 |
+
Both cases confirm the phase alignment theorem: D → 0 is the stable attractor, D → π is unstable.
|
precompute/__init__.py
ADDED
|
File without changes
|
precompute/generate_analytical.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate "Decoupled Simulation" plots -- analytical gradient flow simulations
|
| 4 |
+
that don't require trained models.
|
| 5 |
+
|
| 6 |
+
Produces 2 plots per p, saved to {output_dir}/p_{p:03d}/:
|
| 7 |
+
1. p{p:03d}_phase_align_approx1.png -- case 1: longer simulation with annotations
|
| 8 |
+
2. p{p:03d}_phase_align_approx2.png -- case 2: shorter simulation
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python generate_analytical.py --all
|
| 12 |
+
python generate_analytical.py --p 23
|
| 13 |
+
python generate_analytical.py --p 23 --output ./my_output
|
| 14 |
+
"""
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import matplotlib
|
| 22 |
+
matplotlib.use('Agg')
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
|
| 25 |
+
# Add project root to path so we can import src modules
|
| 26 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
| 27 |
+
from mechanism_base import get_fourier_basis, normalize_to_pi
|
| 28 |
+
from prime_config import get_moduli, ANALYTICAL_CONFIGS
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Style constants
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
COLORS = ['#0D2758', '#60656F', '#DEA54B', '#A32015', '#347186']
|
| 34 |
+
DPI = 150
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ===========================================================================
|
| 38 |
+
# Decouple dynamics simulation
|
| 39 |
+
# ===========================================================================
|
| 40 |
+
|
| 41 |
+
def gradient_update(theta, xi, p, device):
|
| 42 |
+
"""
|
| 43 |
+
Compute the sum of gradients over all frequency modes k.
|
| 44 |
+
|
| 45 |
+
For each frequency k from 1 to (p-1)//2, project theta and xi onto the
|
| 46 |
+
Fourier basis to obtain 2-coefficient vectors, then compute alpha, phi,
|
| 47 |
+
beta, psi and the corresponding gradient contributions.
|
| 48 |
+
"""
|
| 49 |
+
fourier_basis, _ = get_fourier_basis(p, device)
|
| 50 |
+
fourier_basis = fourier_basis.to(theta.dtype)
|
| 51 |
+
theta_coeff = fourier_basis @ theta
|
| 52 |
+
xi_coeff = fourier_basis @ xi
|
| 53 |
+
|
| 54 |
+
total_grad_theta = torch.zeros_like(theta)
|
| 55 |
+
total_grad_xi = torch.zeros_like(xi)
|
| 56 |
+
|
| 57 |
+
j_values = torch.arange(p, device=device, dtype=theta.dtype)
|
| 58 |
+
factor = np.sqrt(2.0 / p)
|
| 59 |
+
|
| 60 |
+
for k in range(1, p // 2 + 1):
|
| 61 |
+
coeff_indices = [k * 2 - 1, k * 2]
|
| 62 |
+
neuron_coeff_theta = theta_coeff[coeff_indices]
|
| 63 |
+
neuron_coeff_xi = xi_coeff[coeff_indices]
|
| 64 |
+
|
| 65 |
+
alpha = factor * torch.norm(neuron_coeff_theta, dim=0)
|
| 66 |
+
phi = torch.arctan2(-neuron_coeff_theta[1], neuron_coeff_theta[0])
|
| 67 |
+
|
| 68 |
+
beta = factor * torch.norm(neuron_coeff_xi, dim=0)
|
| 69 |
+
psi = torch.arctan2(-neuron_coeff_xi[1], neuron_coeff_xi[0])
|
| 70 |
+
|
| 71 |
+
w_k = 2 * np.pi * k / p
|
| 72 |
+
grad_theta_k = 2 * p * alpha * beta * torch.cos(w_k * j_values + psi - phi)
|
| 73 |
+
grad_xi_k = p * alpha.pow(2) * torch.cos(w_k * j_values + 2 * phi)
|
| 74 |
+
|
| 75 |
+
total_grad_theta += grad_theta_k / p ** 2
|
| 76 |
+
total_grad_xi += grad_xi_k / p ** 2
|
| 77 |
+
|
| 78 |
+
return total_grad_theta, total_grad_xi
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def simulate_gradient_flow(theta, xi, p, num_steps, learning_rate, device):
|
| 82 |
+
"""Euler integration of the coupled gradient-flow ODEs."""
|
| 83 |
+
theta_history = [theta.clone()]
|
| 84 |
+
xi_history = [xi.clone()]
|
| 85 |
+
|
| 86 |
+
for _ in range(num_steps):
|
| 87 |
+
grad_theta, grad_xi = gradient_update(theta, xi, p, device)
|
| 88 |
+
theta = theta + learning_rate * grad_theta
|
| 89 |
+
xi = xi + learning_rate * grad_xi
|
| 90 |
+
theta_history.append(theta.clone())
|
| 91 |
+
xi_history.append(xi.clone())
|
| 92 |
+
|
| 93 |
+
return theta_history, xi_history
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def analyze_history(theta_history, xi_history, p, fourier_basis):
|
| 97 |
+
"""
|
| 98 |
+
Extract time series of alpha, phi, beta, psi, delta for every frequency k.
|
| 99 |
+
"""
|
| 100 |
+
theta_hist_tensor = torch.stack(theta_history)
|
| 101 |
+
xi_hist_tensor = torch.stack(xi_history)
|
| 102 |
+
|
| 103 |
+
theta_coeffs_hist = fourier_basis @ theta_hist_tensor.T
|
| 104 |
+
xi_coeffs_hist = fourier_basis @ xi_hist_tensor.T
|
| 105 |
+
|
| 106 |
+
results = {
|
| 107 |
+
'alphas': {}, 'phis': {}, 'betas': {}, 'psis': {}, 'deltas': {}
|
| 108 |
+
}
|
| 109 |
+
factor = np.sqrt(2.0 / p)
|
| 110 |
+
|
| 111 |
+
for k in range(1, p // 2 + 1):
|
| 112 |
+
idx = [k * 2 - 1, k * 2]
|
| 113 |
+
neuron_theta_hist = theta_coeffs_hist[idx, :]
|
| 114 |
+
neuron_xi_hist = xi_coeffs_hist[idx, :]
|
| 115 |
+
|
| 116 |
+
alphas_k = factor * torch.norm(neuron_theta_hist, dim=0)
|
| 117 |
+
phis_k = torch.atan2(-neuron_theta_hist[1, :], neuron_theta_hist[0, :])
|
| 118 |
+
|
| 119 |
+
betas_k = factor * torch.norm(neuron_xi_hist, dim=0)
|
| 120 |
+
psis_k = torch.atan2(-neuron_xi_hist[1, :], neuron_xi_hist[0, :])
|
| 121 |
+
|
| 122 |
+
deltas_k = normalize_to_pi(2 * phis_k - psis_k)
|
| 123 |
+
|
| 124 |
+
results['alphas'][k] = alphas_k.numpy()
|
| 125 |
+
results['phis'][k] = phis_k.numpy()
|
| 126 |
+
results['betas'][k] = betas_k.numpy()
|
| 127 |
+
results['psis'][k] = psis_k.numpy()
|
| 128 |
+
results['deltas'][k] = deltas_k.numpy()
|
| 129 |
+
|
| 130 |
+
return results
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _run_decouple_simulation(p, init_k, num_steps, lr, init_phi, init_psi,
|
| 134 |
+
amplitude, device):
|
| 135 |
+
"""Initialize and run a single decouple-dynamics simulation."""
|
| 136 |
+
fourier_basis, _ = get_fourier_basis(p, device)
|
| 137 |
+
fourier_basis = fourier_basis.to(torch.float64)
|
| 138 |
+
w_k = 2 * np.pi * init_k / p
|
| 139 |
+
|
| 140 |
+
theta_init = amplitude * torch.tensor(
|
| 141 |
+
[np.cos(w_k * j + init_phi) for j in range(p)],
|
| 142 |
+
device=device, dtype=torch.float64,
|
| 143 |
+
)
|
| 144 |
+
xi_init = amplitude * torch.tensor(
|
| 145 |
+
[np.cos(w_k * j + init_psi) for j in range(p)],
|
| 146 |
+
device=device, dtype=torch.float64,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
theta_history, xi_history = simulate_gradient_flow(
|
| 150 |
+
theta_init, xi_init, p, num_steps, lr, device,
|
| 151 |
+
)
|
| 152 |
+
results = analyze_history(theta_history, xi_history, p, fourier_basis)
|
| 153 |
+
return results
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _plot_decouple(results, p, num_steps, lr, init_k, save_path,
|
| 157 |
+
show_vline=True, vline_x=500):
|
| 158 |
+
"""
|
| 159 |
+
Publication-quality 3-panel figure:
|
| 160 |
+
Top: psi_k* and 2*phi_k* vs time
|
| 161 |
+
Middle: D_k* (phase difference) vs time, horizontal line at pi/2
|
| 162 |
+
Bottom: alpha_k* and beta_k* vs time
|
| 163 |
+
"""
|
| 164 |
+
plt.rcParams['mathtext.fontset'] = 'cm'
|
| 165 |
+
|
| 166 |
+
alphas = np.array(results['alphas'][init_k])
|
| 167 |
+
betas = np.array(results['betas'][init_k])
|
| 168 |
+
deltas = np.array(results['deltas'][init_k])
|
| 169 |
+
phis = np.array(results['phis'][init_k])
|
| 170 |
+
psis = np.array(results['psis'][init_k])
|
| 171 |
+
|
| 172 |
+
# Phase wrapping fix: normalize 2*phi to [-pi,pi], adjust psi to
|
| 173 |
+
# stay within pi of 2*phi, then unwrap the time series so there
|
| 174 |
+
# are no discontinuous jumps at +-pi boundaries.
|
| 175 |
+
def _fix_phase_pair(two_phi_raw, psi_raw):
|
| 176 |
+
two_phi = normalize_to_pi(two_phi_raw)
|
| 177 |
+
psi_fixed = normalize_to_pi(psi_raw).copy()
|
| 178 |
+
diff = psi_fixed - two_phi
|
| 179 |
+
psi_fixed[diff > np.pi] -= 2 * np.pi
|
| 180 |
+
psi_fixed[diff < -np.pi] += 2 * np.pi
|
| 181 |
+
return np.unwrap(two_phi), np.unwrap(psi_fixed)
|
| 182 |
+
|
| 183 |
+
phis2_plot, psis_plot = _fix_phase_pair(2 * phis, psis)
|
| 184 |
+
|
| 185 |
+
x = np.arange(num_steps + 1) * lr
|
| 186 |
+
vline_kwargs = dict(color='gray', linestyle='--', linewidth=1.5)
|
| 187 |
+
|
| 188 |
+
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(7, 9), sharex=True)
|
| 189 |
+
fig.suptitle(f'Decoupled Gradient Flow (p={p})', fontsize=20, y=1.01)
|
| 190 |
+
|
| 191 |
+
# --- Top: phase alignment ---
|
| 192 |
+
for k in range(1, (p - 1) // 2 + 1):
|
| 193 |
+
if k != init_k:
|
| 194 |
+
bg_2phi, bg_psi = _fix_phase_pair(
|
| 195 |
+
2 * np.array(results['phis'][k]),
|
| 196 |
+
np.array(results['psis'][k]),
|
| 197 |
+
)
|
| 198 |
+
ax1.plot(x, bg_psi, lw=1.5, alpha=0.4, color='gray')
|
| 199 |
+
ax1.plot(x, bg_2phi, lw=1.5, alpha=0.4, color='gray',
|
| 200 |
+
linestyle='--')
|
| 201 |
+
ax1.plot(x, psis_plot, color=COLORS[3], linewidth=2.5,
|
| 202 |
+
label=r"$\psi_{k^\star}$")
|
| 203 |
+
ax1.plot(x, phis2_plot, linewidth=2.5, color=COLORS[0],
|
| 204 |
+
label=r"$2\phi_{k^\star}$")
|
| 205 |
+
if show_vline:
|
| 206 |
+
ax1.axvline(x=vline_x, **vline_kwargs)
|
| 207 |
+
ax1.set_title('Dynamics of Phase Alignment', fontsize=18)
|
| 208 |
+
ax1.set_ylabel('Phase (radians)', fontsize=14)
|
| 209 |
+
ax1.legend(fontsize=18)
|
| 210 |
+
ax1.grid(True)
|
| 211 |
+
|
| 212 |
+
# --- Middle: phase difference ---
|
| 213 |
+
for k in range(1, (p - 1) // 2 + 1):
|
| 214 |
+
if k != init_k:
|
| 215 |
+
ax2.plot(x, np.array(results['deltas'][k]),
|
| 216 |
+
lw=1.5, alpha=0.4, color='gray')
|
| 217 |
+
ax2.plot(x, deltas, color=COLORS[0], linewidth=2.5,
|
| 218 |
+
label=r"$D_{k^\star}$")
|
| 219 |
+
if show_vline:
|
| 220 |
+
ax2.axvline(x=vline_x, **vline_kwargs)
|
| 221 |
+
ax2.axhline(y=np.pi / 2, **vline_kwargs)
|
| 222 |
+
ax2.text(x=max(x) * 0.05, y=np.pi / 2 - 0.45,
|
| 223 |
+
s=r"$D^\star_{k^\star}=\pi/2$", fontsize=16, color='black')
|
| 224 |
+
ax2.set_title('Dynamics of Phase Difference', fontsize=18)
|
| 225 |
+
ax2.set_ylabel('Phase (radians)', fontsize=14)
|
| 226 |
+
ax2.legend(fontsize=18)
|
| 227 |
+
ax2.grid(True)
|
| 228 |
+
|
| 229 |
+
# --- Bottom: magnitude evolution ---
|
| 230 |
+
for k in range(1, (p - 1) // 2 + 1):
|
| 231 |
+
if k != init_k:
|
| 232 |
+
ax3.plot(x, np.array(results['alphas'][k]),
|
| 233 |
+
lw=1.5, alpha=0.4, color='gray')
|
| 234 |
+
ax3.plot(x, np.array(results['betas'][k]),
|
| 235 |
+
lw=1.5, alpha=0.4, color='gray', linestyle='--')
|
| 236 |
+
ax3.plot(x, alphas, linewidth=2.5, color=COLORS[0],
|
| 237 |
+
label=r"$\alpha_{k^\star}$")
|
| 238 |
+
ax3.plot(x, betas, linewidth=2.5, color=COLORS[3],
|
| 239 |
+
label=r"$\beta_{k^\star}$")
|
| 240 |
+
if show_vline:
|
| 241 |
+
ax3.axvline(x=vline_x, **vline_kwargs)
|
| 242 |
+
ax3.set_title('Magnitude Evolution', fontsize=18)
|
| 243 |
+
ax3.set_xlabel('Time', fontsize=18)
|
| 244 |
+
ax3.set_ylabel('Magnitude', fontsize=14)
|
| 245 |
+
ax3.legend(fontsize=18)
|
| 246 |
+
ax3.grid(True)
|
| 247 |
+
|
| 248 |
+
plt.tight_layout()
|
| 249 |
+
plt.savefig(save_path, dpi=DPI, bbox_inches='tight')
|
| 250 |
+
plt.close(fig)
|
| 251 |
+
print(f" Saved {save_path}")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def generate_decouple_dynamics(p, output_dir):
|
| 255 |
+
"""Generate the two decouple-dynamics phase-alignment plots."""
|
| 256 |
+
max_freq = (p - 1) // 2
|
| 257 |
+
if max_freq < 1:
|
| 258 |
+
print(f" SKIP: p={p} has no non-DC frequencies for analytical simulation")
|
| 259 |
+
return
|
| 260 |
+
|
| 261 |
+
cfg = ANALYTICAL_CONFIGS["decouple_dynamics"]
|
| 262 |
+
device = torch.device("cpu")
|
| 263 |
+
init_k = min(cfg["init_k"], max_freq)
|
| 264 |
+
amplitude = cfg["amplitude"]
|
| 265 |
+
|
| 266 |
+
# Case 1: longer simulation with vline annotations
|
| 267 |
+
print(f" Running decouple dynamics case 1 (p={p}) ...")
|
| 268 |
+
results1 = _run_decouple_simulation(
|
| 269 |
+
p, init_k,
|
| 270 |
+
num_steps=cfg["num_steps_case1"],
|
| 271 |
+
lr=cfg["learning_rate_case1"],
|
| 272 |
+
init_phi=cfg["init_phi_case1"],
|
| 273 |
+
init_psi=cfg["init_psi_case1"],
|
| 274 |
+
amplitude=amplitude,
|
| 275 |
+
device=device,
|
| 276 |
+
)
|
| 277 |
+
_plot_decouple(
|
| 278 |
+
results1, p,
|
| 279 |
+
num_steps=cfg["num_steps_case1"],
|
| 280 |
+
lr=cfg["learning_rate_case1"],
|
| 281 |
+
init_k=init_k,
|
| 282 |
+
save_path=os.path.join(output_dir, f"p{p:03d}_phase_align_approx1.png"),
|
| 283 |
+
show_vline=True,
|
| 284 |
+
vline_x=cfg["num_steps_case1"] * cfg["learning_rate_case1"] * 0.36,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Case 2: shorter simulation without vline annotations
|
| 288 |
+
print(f" Running decouple dynamics case 2 (p={p}) ...")
|
| 289 |
+
results2 = _run_decouple_simulation(
|
| 290 |
+
p, init_k,
|
| 291 |
+
num_steps=cfg["num_steps_case2"],
|
| 292 |
+
lr=cfg["learning_rate_case2"],
|
| 293 |
+
init_phi=cfg["init_phi_case2"],
|
| 294 |
+
init_psi=cfg["init_psi_case2"],
|
| 295 |
+
amplitude=amplitude,
|
| 296 |
+
device=device,
|
| 297 |
+
)
|
| 298 |
+
_plot_decouple(
|
| 299 |
+
results2, p,
|
| 300 |
+
num_steps=cfg["num_steps_case2"],
|
| 301 |
+
lr=cfg["learning_rate_case2"],
|
| 302 |
+
init_k=init_k,
|
| 303 |
+
save_path=os.path.join(output_dir, f"p{p:03d}_phase_align_approx2.png"),
|
| 304 |
+
show_vline=False,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# ===========================================================================
|
| 309 |
+
# Entry point
|
| 310 |
+
# ===========================================================================
|
| 311 |
+
|
| 312 |
+
def generate_all_for_prime(p, output_base):
|
| 313 |
+
"""Generate the 2 decoupled simulation plots for a single prime."""
|
| 314 |
+
output_dir = os.path.join(output_base, f"p_{p:03d}")
|
| 315 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 316 |
+
|
| 317 |
+
print(f"\n{'='*60}")
|
| 318 |
+
print(f"Generating decoupled simulation plots for p={p}")
|
| 319 |
+
print(f"Output: {output_dir}")
|
| 320 |
+
print(f"{'='*60}")
|
| 321 |
+
|
| 322 |
+
# Use float64 globally for numerical precision in simulations
|
| 323 |
+
prev_dtype = torch.get_default_dtype()
|
| 324 |
+
torch.set_default_dtype(torch.float64)
|
| 325 |
+
|
| 326 |
+
try:
|
| 327 |
+
generate_decouple_dynamics(p, output_dir)
|
| 328 |
+
finally:
|
| 329 |
+
torch.set_default_dtype(prev_dtype)
|
| 330 |
+
|
| 331 |
+
print(f"[DONE] p={p}: 2 plots written to {output_dir}")
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def main():
|
| 335 |
+
parser = argparse.ArgumentParser(
|
| 336 |
+
description='Generate decoupled simulation plots (analytical, no model needed)'
|
| 337 |
+
)
|
| 338 |
+
parser.add_argument('--all', action='store_true',
|
| 339 |
+
help='Generate plots for all odd p in [3, 199]')
|
| 340 |
+
parser.add_argument('--p', type=int,
|
| 341 |
+
help='Generate plots for a specific p')
|
| 342 |
+
parser.add_argument('--output', type=str, default='./precomputed_results',
|
| 343 |
+
help='Base output directory (default: ./precomputed_results)')
|
| 344 |
+
args = parser.parse_args()
|
| 345 |
+
|
| 346 |
+
if not args.all and args.p is None:
|
| 347 |
+
parser.error("Specify --all or --p P")
|
| 348 |
+
|
| 349 |
+
moduli = [args.p] if args.p else get_moduli()
|
| 350 |
+
|
| 351 |
+
for p in moduli:
|
| 352 |
+
generate_all_for_prime(p, args.output)
|
| 353 |
+
|
| 354 |
+
print(f"\nAll done. Processed {len(moduli)} value(s) of p.")
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if __name__ == "__main__":
|
| 358 |
+
main()
|
precompute/generate_plots.py
ADDED
|
@@ -0,0 +1,2192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Main plot generation script for the HF app.
|
| 4 |
+
Creates all model-dependent plots (Tabs 1-7) from trained checkpoints.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python generate_plots.py --all # Generate for all primes
|
| 8 |
+
python generate_plots.py --p 23 # Generate for a specific p
|
| 9 |
+
python generate_plots.py --p 23 --input ./trained_models --output ./hf_app/precomputed_results
|
| 10 |
+
"""
|
| 11 |
+
import matplotlib
|
| 12 |
+
matplotlib.use('Agg')
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import math
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import traceback
|
| 20 |
+
|
| 21 |
+
import matplotlib.colors as mcolors
|
| 22 |
+
import matplotlib.patches as patches
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 28 |
+
from matplotlib.ticker import FuncFormatter
|
| 29 |
+
|
| 30 |
+
# Add project root to path so we can import src modules
|
| 31 |
+
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 32 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 33 |
+
sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src'))
|
| 34 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 35 |
+
|
| 36 |
+
from src.mechanism_base import (
|
| 37 |
+
get_fourier_basis,
|
| 38 |
+
decode_weights,
|
| 39 |
+
compute_neuron,
|
| 40 |
+
decode_scales_phis,
|
| 41 |
+
normalize_to_pi,
|
| 42 |
+
)
|
| 43 |
+
from src.model_base import EmbedMLP
|
| 44 |
+
from src.utils import cross_entropy_high_precision, acc_rate
|
| 45 |
+
from precompute.neuron_selector import (
|
| 46 |
+
select_top_neurons_by_frequency,
|
| 47 |
+
select_lineplot_neurons,
|
| 48 |
+
select_phase_frequency,
|
| 49 |
+
select_lottery_neuron,
|
| 50 |
+
)
|
| 51 |
+
from precompute.grokking_stage_detector import detect_grokking_stages
|
| 52 |
+
from precompute.prime_config import compute_d_mlp, TRAINING_RUNS, MIN_P_GROKKING
|
| 53 |
+
|
| 54 |
+
# ---------- Lightweight train/test data regeneration ----------
|
| 55 |
+
|
| 56 |
+
def _gen_train_test(p, frac_train=0.75, seed=42):
|
| 57 |
+
"""
|
| 58 |
+
Regenerate train/test split deterministically without needing a Config object.
|
| 59 |
+
Mirrors the logic in utils.gen_train_test for the 'add' function.
|
| 60 |
+
Returns (train_data, test_data) where each is a tensor of shape (N, 2).
|
| 61 |
+
"""
|
| 62 |
+
import random as _random
|
| 63 |
+
all_pairs = []
|
| 64 |
+
for i in range(p):
|
| 65 |
+
for j in range(p):
|
| 66 |
+
all_pairs.append((i, j))
|
| 67 |
+
data_tensor = torch.tensor(all_pairs, dtype=torch.long)
|
| 68 |
+
_random.seed(seed)
|
| 69 |
+
indices = torch.randperm(len(all_pairs))
|
| 70 |
+
data_tensor = data_tensor[indices]
|
| 71 |
+
if frac_train >= 1.0:
|
| 72 |
+
return data_tensor, data_tensor
|
| 73 |
+
div = int(frac_train * len(all_pairs))
|
| 74 |
+
return data_tensor[:div], data_tensor[div:]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ---------- Style constants ----------
|
| 78 |
+
COLORS = ['#0D2758', '#60656F', '#DEA54B', '#A32015', '#347186']
|
| 79 |
+
CMAP_DIVERGING = LinearSegmentedColormap.from_list(
|
| 80 |
+
'cividis_white_center', ['#0D2758', 'white', '#A32015'], N=256
|
| 81 |
+
)
|
| 82 |
+
CMAP_SEQUENTIAL = LinearSegmentedColormap.from_list(
|
| 83 |
+
'cividis_white_seq', ['white', '#0D2758'], N=256
|
| 84 |
+
)
|
| 85 |
+
DPI = 150
|
| 86 |
+
plt.rcParams['mathtext.fontset'] = 'cm'
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _save_fig(fig, path):
|
| 90 |
+
"""Save a figure and close it."""
|
| 91 |
+
fig.savefig(path, dpi=DPI, bbox_inches='tight', format='png')
|
| 92 |
+
plt.close(fig)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ======================================================================
|
| 96 |
+
# Helpers for loading checkpoints
|
| 97 |
+
# ======================================================================
|
| 98 |
+
|
| 99 |
+
def _find_run_dir(base_dir):
|
| 100 |
+
"""
|
| 101 |
+
Given a run type directory (e.g. trained_models/p_023/standard/),
|
| 102 |
+
find the actual checkpoint directory. It may be a timestamped
|
| 103 |
+
subdirectory, or the checkpoints may live directly in base_dir.
|
| 104 |
+
Returns the path that contains the .pth checkpoint files.
|
| 105 |
+
"""
|
| 106 |
+
if not os.path.isdir(base_dir):
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
# Check if .pth files live directly here
|
| 110 |
+
pth_files = [f for f in os.listdir(base_dir)
|
| 111 |
+
if f.endswith('.pth') and f not in ('train_data.pth', 'test_data.pth')]
|
| 112 |
+
if pth_files:
|
| 113 |
+
return base_dir
|
| 114 |
+
|
| 115 |
+
# Otherwise look for a single timestamped subdirectory
|
| 116 |
+
subdirs = [d for d in os.listdir(base_dir)
|
| 117 |
+
if os.path.isdir(os.path.join(base_dir, d))]
|
| 118 |
+
for sd in sorted(subdirs):
|
| 119 |
+
candidate = os.path.join(base_dir, sd)
|
| 120 |
+
files = os.listdir(candidate)
|
| 121 |
+
if any(f.endswith('.pth') for f in files):
|
| 122 |
+
return candidate
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _load_checkpoints(run_dir, device='cpu'):
|
| 127 |
+
"""
|
| 128 |
+
Load all numbered checkpoints from run_dir.
|
| 129 |
+
Returns dict {epoch_int: state_dict} sorted by epoch.
|
| 130 |
+
"""
|
| 131 |
+
loaded = {}
|
| 132 |
+
exclude = {'final.pth', 'test_data.pth', 'train_data.pth'}
|
| 133 |
+
for fname in os.listdir(run_dir):
|
| 134 |
+
fpath = os.path.join(run_dir, fname)
|
| 135 |
+
if (os.path.isfile(fpath) and fname.endswith('.pth')
|
| 136 |
+
and fname not in exclude):
|
| 137 |
+
try:
|
| 138 |
+
epoch = int(os.path.splitext(fname)[0])
|
| 139 |
+
except ValueError:
|
| 140 |
+
continue
|
| 141 |
+
data = torch.load(fpath, weights_only=True, map_location=device)
|
| 142 |
+
if isinstance(data, dict) and 'model' in data:
|
| 143 |
+
loaded[epoch] = data['model']
|
| 144 |
+
else:
|
| 145 |
+
loaded[epoch] = data
|
| 146 |
+
return {k: loaded[k] for k in sorted(loaded)}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _load_final(run_dir, device='cpu'):
|
| 150 |
+
"""Load the final.pth model data dict."""
|
| 151 |
+
fpath = os.path.join(run_dir, 'final.pth')
|
| 152 |
+
if not os.path.exists(fpath):
|
| 153 |
+
# Fall back to largest epoch checkpoint
|
| 154 |
+
ckpts = _load_checkpoints(run_dir, device)
|
| 155 |
+
if ckpts:
|
| 156 |
+
max_epoch = max(ckpts.keys())
|
| 157 |
+
return {'model': ckpts[max_epoch]}
|
| 158 |
+
return None
|
| 159 |
+
return torch.load(fpath, weights_only=True, map_location=device)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _load_training_curves(run_type_dir):
|
| 163 |
+
"""Load training_curves.json from the run type directory."""
|
| 164 |
+
path = os.path.join(run_type_dir, 'training_curves.json')
|
| 165 |
+
if os.path.exists(path):
|
| 166 |
+
with open(path) as f:
|
| 167 |
+
return json.load(f)
|
| 168 |
+
# Fall back: check inside the checkpoint subdirectory
|
| 169 |
+
run_dir = _find_run_dir(run_type_dir)
|
| 170 |
+
if run_dir and run_dir != run_type_dir:
|
| 171 |
+
path = os.path.join(run_dir, 'training_curves.json')
|
| 172 |
+
if os.path.exists(path):
|
| 173 |
+
with open(path) as f:
|
| 174 |
+
return json.load(f)
|
| 175 |
+
# Fall back: try loading from final.pth
|
| 176 |
+
if run_dir:
|
| 177 |
+
final_path = os.path.join(run_dir, 'final.pth')
|
| 178 |
+
if os.path.exists(final_path):
|
| 179 |
+
data = torch.load(final_path, weights_only=True, map_location='cpu')
|
| 180 |
+
if isinstance(data, dict):
|
| 181 |
+
curves = {}
|
| 182 |
+
for key in ('train_losses', 'test_losses', 'train_accs', 'test_accs',
|
| 183 |
+
'grad_norms', 'param_norms'):
|
| 184 |
+
if key in data:
|
| 185 |
+
val = data[key]
|
| 186 |
+
if isinstance(val, torch.Tensor):
|
| 187 |
+
val = val.cpu().tolist()
|
| 188 |
+
curves[key] = val
|
| 189 |
+
if curves:
|
| 190 |
+
return curves
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ======================================================================
|
| 195 |
+
# PlotGenerator
|
| 196 |
+
# ======================================================================
|
| 197 |
+
|
| 198 |
+
class PlotGenerator:
|
| 199 |
+
"""
|
| 200 |
+
Generates all model-dependent plots for a single prime p.
|
| 201 |
+
|
| 202 |
+
Parameters
|
| 203 |
+
----------
|
| 204 |
+
p : int
|
| 205 |
+
The prime modulus.
|
| 206 |
+
input_dir : str
|
| 207 |
+
Path to trained_models/p_PPP/ containing run-type subdirectories.
|
| 208 |
+
output_dir : str
|
| 209 |
+
Path to hf_app/precomputed_results/p_PPP/ where plots are saved.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, p, input_dir, output_dir):
|
| 213 |
+
self.p = p
|
| 214 |
+
self.input_dir = input_dir
|
| 215 |
+
self.output_dir = output_dir
|
| 216 |
+
self.device = 'cpu'
|
| 217 |
+
self.d_vocab = p
|
| 218 |
+
self.d_model = p
|
| 219 |
+
|
| 220 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 221 |
+
|
| 222 |
+
# Infer d_mlp from checkpoint weights; fall back to formula
|
| 223 |
+
self.d_mlp = self._infer_d_mlp() or compute_d_mlp(p)
|
| 224 |
+
|
| 225 |
+
# Fourier basis (mechanism_base version with device arg)
|
| 226 |
+
self.fourier_basis, self.fourier_basis_names = get_fourier_basis(p, self.device)
|
| 227 |
+
|
| 228 |
+
# All (a,b) pairs and labels
|
| 229 |
+
self.all_data = torch.tensor(
|
| 230 |
+
[(i, j) for i in range(p) for j in range(p)], dtype=torch.long
|
| 231 |
+
)
|
| 232 |
+
self.all_labels = torch.tensor(
|
| 233 |
+
[(i + j) % p for i in range(p) for j in range(p)], dtype=torch.long
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def _infer_d_mlp(self):
|
| 237 |
+
"""Infer d_mlp from the first available checkpoint's weight shape."""
|
| 238 |
+
for run_name in TRAINING_RUNS:
|
| 239 |
+
run_type_dir = os.path.join(self.input_dir, run_name)
|
| 240 |
+
run_dir = _find_run_dir(run_type_dir)
|
| 241 |
+
if run_dir is None:
|
| 242 |
+
continue
|
| 243 |
+
final = _load_final(run_dir, 'cpu')
|
| 244 |
+
if final and 'model' in final and 'mlp.W_in' in final['model']:
|
| 245 |
+
d_mlp = final['model']['mlp.W_in'].shape[0]
|
| 246 |
+
print(f" Inferred d_mlp={d_mlp} from {run_name} checkpoint")
|
| 247 |
+
return d_mlp
|
| 248 |
+
return None
|
| 249 |
+
|
| 250 |
+
# ------------------------------------------------------------------
|
| 251 |
+
# Path helpers
|
| 252 |
+
# ------------------------------------------------------------------
|
| 253 |
+
|
| 254 |
+
def _run_type_dir(self, run_name):
|
| 255 |
+
return os.path.join(self.input_dir, run_name)
|
| 256 |
+
|
| 257 |
+
def _run_dir(self, run_name):
|
| 258 |
+
return _find_run_dir(self._run_type_dir(run_name))
|
| 259 |
+
|
| 260 |
+
def _out(self, filename):
|
| 261 |
+
# Prefix every file with pXXX_ so folders are self-contained and browsable
|
| 262 |
+
return os.path.join(self.output_dir, f"p{self.p:03d}_{filename}")
|
| 263 |
+
|
| 264 |
+
# ------------------------------------------------------------------
|
| 265 |
+
# ------------------------------------------------------------------
|
| 266 |
+
# Shared IPR helper
|
| 267 |
+
# ------------------------------------------------------------------
|
| 268 |
+
|
| 269 |
+
def _compute_freq_ipr(self, W_dec):
|
| 270 |
+
"""IPR over per-frequency magnitudes (combines cos+sin pairs).
|
| 271 |
+
|
| 272 |
+
IPR = sum_k A_k^4 / (sum_k A_k^2)^2, where A_k = sqrt(c_k^2 + s_k^2).
|
| 273 |
+
IPR → 1 means all energy at a single frequency.
|
| 274 |
+
Returns mean IPR across neurons.
|
| 275 |
+
"""
|
| 276 |
+
K = (self.p - 1) // 2
|
| 277 |
+
A2 = torch.zeros(W_dec.shape[0], K)
|
| 278 |
+
for k in range(1, K + 1):
|
| 279 |
+
A2[:, k - 1] = W_dec[:, 2 * k - 1].pow(2) + W_dec[:, 2 * k].pow(2)
|
| 280 |
+
A4 = A2.pow(2)
|
| 281 |
+
denom = A2.sum(dim=1).pow(2)
|
| 282 |
+
valid = denom > 0
|
| 283 |
+
ipr = torch.zeros(W_dec.shape[0])
|
| 284 |
+
ipr[valid] = A4[valid].sum(dim=1) / denom[valid]
|
| 285 |
+
return ipr.mean()
|
| 286 |
+
|
| 287 |
+
def _ipr_at_checkpoint(self, model_sd):
|
| 288 |
+
"""Compute average IPR (across both layers) for a single checkpoint."""
|
| 289 |
+
W_in_d, W_out_d, _ = decode_weights(model_sd, self.fourier_basis)
|
| 290 |
+
return ((self._compute_freq_ipr(W_in_d)
|
| 291 |
+
+ self._compute_freq_ipr(W_out_d)) / 2).item()
|
| 292 |
+
|
| 293 |
+
# ------------------------------------------------------------------
|
| 294 |
+
# Tab 1: Overview (standard loss+IPR, grokking loss+IPR, phase plot)
|
| 295 |
+
# ------------------------------------------------------------------
|
| 296 |
+
|
| 297 |
+
def generate_tab1(self):
|
| 298 |
+
"""Generate overview plots: standard + grokking loss/IPR, plus phase scatter."""
|
| 299 |
+
print(f" [Tab 1] Overview for p={self.p}")
|
| 300 |
+
|
| 301 |
+
# ---- Standard run: loss + IPR ----
|
| 302 |
+
std_dir = self._run_dir('standard')
|
| 303 |
+
std_epochs, std_loss, std_ipr = [], [], []
|
| 304 |
+
if std_dir is not None:
|
| 305 |
+
std_curves = _load_training_curves(self._run_type_dir('standard'))
|
| 306 |
+
std_ckpts = _load_checkpoints(std_dir, self.device)
|
| 307 |
+
if std_ckpts:
|
| 308 |
+
std_epochs = sorted(std_ckpts.keys())
|
| 309 |
+
std_ipr = [self._ipr_at_checkpoint(std_ckpts[ep]) for ep in std_epochs]
|
| 310 |
+
if std_curves and 'train_losses' in std_curves:
|
| 311 |
+
se = std_epochs[1] - std_epochs[0] if len(std_epochs) > 1 else 200
|
| 312 |
+
std_loss = std_curves['train_losses'][::se][:len(std_epochs)]
|
| 313 |
+
|
| 314 |
+
# ---- Grokking run: train/test loss + IPR ----
|
| 315 |
+
grokk_epochs, grokk_train_loss, grokk_test_loss, grokk_ipr = [], [], [], []
|
| 316 |
+
has_grokk = self.p >= MIN_P_GROKKING
|
| 317 |
+
if has_grokk:
|
| 318 |
+
grokk_dir = self._run_dir('grokking')
|
| 319 |
+
if grokk_dir is not None:
|
| 320 |
+
grokk_curves = _load_training_curves(self._run_type_dir('grokking'))
|
| 321 |
+
grokk_ckpts = _load_checkpoints(grokk_dir, self.device)
|
| 322 |
+
if grokk_ckpts:
|
| 323 |
+
grokk_epochs = sorted(grokk_ckpts.keys())
|
| 324 |
+
grokk_ipr = [self._ipr_at_checkpoint(grokk_ckpts[ep])
|
| 325 |
+
for ep in grokk_epochs]
|
| 326 |
+
if grokk_curves:
|
| 327 |
+
se = grokk_epochs[1] - grokk_epochs[0] if len(grokk_epochs) > 1 else 200
|
| 328 |
+
if 'train_losses' in grokk_curves:
|
| 329 |
+
grokk_train_loss = grokk_curves['train_losses'][::se][:len(grokk_epochs)]
|
| 330 |
+
if 'test_losses' in grokk_curves:
|
| 331 |
+
grokk_test_loss = grokk_curves['test_losses'][::se][:len(grokk_epochs)]
|
| 332 |
+
|
| 333 |
+
if not std_epochs and not grokk_epochs:
|
| 334 |
+
print(" SKIP: no checkpoints found for standard or grokking run")
|
| 335 |
+
return
|
| 336 |
+
|
| 337 |
+
# ---- Static plot: 2×2 grid (std loss, grokk loss, std IPR, grokk IPR) ----
|
| 338 |
+
n_cols = 2 if has_grokk and grokk_epochs else 1
|
| 339 |
+
fig, axes = plt.subplots(2, n_cols, figsize=(5 * n_cols, 7),
|
| 340 |
+
constrained_layout=True)
|
| 341 |
+
if n_cols == 1:
|
| 342 |
+
axes = axes.reshape(2, 1)
|
| 343 |
+
|
| 344 |
+
# Standard loss (top-left)
|
| 345 |
+
ax = axes[0, 0]
|
| 346 |
+
if std_loss:
|
| 347 |
+
ax.plot(std_epochs[:len(std_loss)], std_loss,
|
| 348 |
+
color=COLORS[0], linewidth=1.5, label="Train Loss")
|
| 349 |
+
ax.set_title('Standard (ReLU, full data)', fontsize=14)
|
| 350 |
+
ax.set_ylabel('Loss', fontsize=13)
|
| 351 |
+
ax.legend(fontsize=11)
|
| 352 |
+
ax.grid(True, alpha=0.4)
|
| 353 |
+
|
| 354 |
+
# Standard IPR (bottom-left)
|
| 355 |
+
ax = axes[1, 0]
|
| 356 |
+
if std_ipr:
|
| 357 |
+
ax.plot(std_epochs[:len(std_ipr)], std_ipr,
|
| 358 |
+
color=COLORS[3], linewidth=1.5, label="Avg. IPR")
|
| 359 |
+
ax.set_title('Standard IPR', fontsize=14)
|
| 360 |
+
ax.set_xlabel('Step', fontsize=13)
|
| 361 |
+
ax.set_ylabel('IPR', fontsize=13)
|
| 362 |
+
ax.set_ylim([0, 1.05])
|
| 363 |
+
ax.legend(fontsize=11)
|
| 364 |
+
ax.grid(True, alpha=0.4)
|
| 365 |
+
|
| 366 |
+
if n_cols == 2:
|
| 367 |
+
# Grokking loss (top-right)
|
| 368 |
+
ax = axes[0, 1]
|
| 369 |
+
gx = grokk_epochs
|
| 370 |
+
if grokk_train_loss:
|
| 371 |
+
ax.plot(gx[:len(grokk_train_loss)], grokk_train_loss,
|
| 372 |
+
color=COLORS[0], linewidth=1.5, label="Train Loss")
|
| 373 |
+
if grokk_test_loss:
|
| 374 |
+
ax.plot(gx[:len(grokk_test_loss)], grokk_test_loss,
|
| 375 |
+
color=COLORS[3], linewidth=1.5, label="Test Loss")
|
| 376 |
+
ax.set_title('Grokking (ReLU, 75% data, WD)', fontsize=14)
|
| 377 |
+
ax.legend(fontsize=11)
|
| 378 |
+
ax.grid(True, alpha=0.4)
|
| 379 |
+
|
| 380 |
+
# Grokking IPR (bottom-right)
|
| 381 |
+
ax = axes[1, 1]
|
| 382 |
+
if grokk_ipr:
|
| 383 |
+
ax.plot(gx[:len(grokk_ipr)], grokk_ipr,
|
| 384 |
+
color=COLORS[3], linewidth=1.5, label="Avg. IPR")
|
| 385 |
+
ax.set_title('Grokking IPR', fontsize=14)
|
| 386 |
+
ax.set_xlabel('Step', fontsize=13)
|
| 387 |
+
ax.set_ylim([0, 1.05])
|
| 388 |
+
ax.legend(fontsize=11)
|
| 389 |
+
ax.grid(True, alpha=0.4)
|
| 390 |
+
|
| 391 |
+
_save_fig(fig, self._out('overview_loss_ipr.png'))
|
| 392 |
+
|
| 393 |
+
# ---- Phase relationship scatter from standard final checkpoint ----
|
| 394 |
+
if std_ckpts:
|
| 395 |
+
final_ep = max(std_ckpts.keys())
|
| 396 |
+
model_sd = std_ckpts[final_ep]
|
| 397 |
+
W_in_d, W_out_d, mfl = decode_weights(model_sd, self.fourier_basis)
|
| 398 |
+
n_neurons = W_in_d.shape[0]
|
| 399 |
+
phis_2, psis = [], []
|
| 400 |
+
for neuron in range(n_neurons):
|
| 401 |
+
_, phi = compute_neuron(neuron, mfl, W_in_d)
|
| 402 |
+
_, psi = compute_neuron(neuron, mfl, W_out_d)
|
| 403 |
+
two_phi = normalize_to_pi(2 * phi)
|
| 404 |
+
psi_n = normalize_to_pi(psi)
|
| 405 |
+
# Fix ±π wrap: keep ψ within π of 2φ
|
| 406 |
+
if psi_n - two_phi > np.pi:
|
| 407 |
+
psi_n -= 2 * np.pi
|
| 408 |
+
elif psi_n - two_phi < -np.pi:
|
| 409 |
+
psi_n += 2 * np.pi
|
| 410 |
+
phis_2.append(two_phi)
|
| 411 |
+
psis.append(psi_n)
|
| 412 |
+
|
| 413 |
+
fig, ax = plt.subplots(figsize=(5, 5))
|
| 414 |
+
ax.plot([-np.pi, np.pi], [-np.pi, np.pi], 'r-',
|
| 415 |
+
linewidth=3, alpha=0.8,
|
| 416 |
+
label=r'$\psi_m = 2\phi_m$', zorder=1)
|
| 417 |
+
ax.scatter(phis_2, psis, s=12, alpha=0.6, color=COLORS[0], zorder=2)
|
| 418 |
+
ax.legend(fontsize=12, loc='upper left')
|
| 419 |
+
ax.set_xlabel(r'$2\phi_m$', fontsize=14)
|
| 420 |
+
ax.set_ylabel(r'$\psi_m$', fontsize=14)
|
| 421 |
+
ax.set_title(r'Phase Alignment: $\psi_m = 2\phi_m$', fontsize=14)
|
| 422 |
+
ax.set_xlim([-np.pi, np.pi])
|
| 423 |
+
ax.set_ylim([-np.pi, np.pi])
|
| 424 |
+
ax.set_aspect('equal')
|
| 425 |
+
ax.grid(True, alpha=0.3)
|
| 426 |
+
_save_fig(fig, self._out('overview_phase_scatter.png'))
|
| 427 |
+
|
| 428 |
+
# ---- JSON for interactive Plotly charts ----
|
| 429 |
+
payload = {
|
| 430 |
+
'std_epochs': [int(e) for e in std_epochs],
|
| 431 |
+
'std_ipr': std_ipr,
|
| 432 |
+
}
|
| 433 |
+
if std_loss:
|
| 434 |
+
payload['std_train_loss'] = [float(v) for v in std_loss]
|
| 435 |
+
|
| 436 |
+
if has_grokk and grokk_epochs:
|
| 437 |
+
payload['grokk_epochs'] = [int(e) for e in grokk_epochs]
|
| 438 |
+
payload['grokk_ipr'] = grokk_ipr
|
| 439 |
+
if grokk_train_loss:
|
| 440 |
+
payload['grokk_train_loss'] = [float(v) for v in grokk_train_loss]
|
| 441 |
+
if grokk_test_loss:
|
| 442 |
+
payload['grokk_test_loss'] = [float(v) for v in grokk_test_loss]
|
| 443 |
+
|
| 444 |
+
with open(self._out('overview.json'), 'w') as f:
|
| 445 |
+
json.dump(payload, f)
|
| 446 |
+
|
| 447 |
+
files = ['overview_loss_ipr.png', 'overview.json']
|
| 448 |
+
if std_ckpts:
|
| 449 |
+
files.append('overview_phase_scatter.png')
|
| 450 |
+
print(f" Saved {', '.join(files)}")
|
| 451 |
+
|
| 452 |
+
# ------------------------------------------------------------------
|
| 453 |
+
# Tab 2: Fourier Weights (heatmap + lineplots)
|
| 454 |
+
# ------------------------------------------------------------------
|
| 455 |
+
|
| 456 |
+
def generate_tab2(self):
|
| 457 |
+
"""Generate full_training_para_origin.png, lineplot_in.png, lineplot_out.png."""
|
| 458 |
+
print(f" [Tab 2] Fourier Weights for p={self.p}")
|
| 459 |
+
run_dir = self._run_dir('standard')
|
| 460 |
+
if run_dir is None:
|
| 461 |
+
print(" SKIP: standard run directory not found")
|
| 462 |
+
return
|
| 463 |
+
|
| 464 |
+
final_data = _load_final(run_dir, self.device)
|
| 465 |
+
if final_data is None:
|
| 466 |
+
print(" SKIP: no final checkpoint")
|
| 467 |
+
return
|
| 468 |
+
model_load = final_data['model']
|
| 469 |
+
|
| 470 |
+
W_in_decode, W_out_decode, max_freq_ls = decode_weights(
|
| 471 |
+
model_load, self.fourier_basis
|
| 472 |
+
)
|
| 473 |
+
d_mlp = W_in_decode.shape[0]
|
| 474 |
+
num_neurons = min(20, d_mlp)
|
| 475 |
+
|
| 476 |
+
# Sort neurons by frequency
|
| 477 |
+
sorted_indices = select_top_neurons_by_frequency(
|
| 478 |
+
max_freq_ls, W_in_decode, n=num_neurons
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
freq_ls = np.array([max_freq_ls[i] for i in sorted_indices])
|
| 482 |
+
|
| 483 |
+
# DFT coefficients for heatmap (matches blog Figure 2)
|
| 484 |
+
W_in_dft = W_in_decode[sorted_indices, :]
|
| 485 |
+
W_out_dft = W_out_decode[sorted_indices, :]
|
| 486 |
+
# Raw weights for line plots (matches blog Figure 3)
|
| 487 |
+
W_in_raw = model_load['mlp.W_in'][sorted_indices, :]
|
| 488 |
+
W_out_raw = model_load['mlp.W_out'].T[sorted_indices, :]
|
| 489 |
+
|
| 490 |
+
# Sort within selected set by frequency
|
| 491 |
+
sort_order = np.argsort(freq_ls)
|
| 492 |
+
ranked_W_in_dft = W_in_dft[sort_order, :]
|
| 493 |
+
ranked_W_out_dft = W_out_dft[sort_order, :]
|
| 494 |
+
ranked_W_in_raw = W_in_raw[sort_order, :]
|
| 495 |
+
ranked_W_out_raw = W_out_raw[sort_order, :]
|
| 496 |
+
|
| 497 |
+
# ---- Heatmap plot (DFT coefficients, matching blog Figure 2) ----
|
| 498 |
+
fb_names = self.fourier_basis_names
|
| 499 |
+
n_modes = len(fb_names)
|
| 500 |
+
fig_w = max(8, n_modes * 0.4)
|
| 501 |
+
fig_h = max(8, num_neurons * 0.35 + 3)
|
| 502 |
+
fig, axes = plt.subplots(
|
| 503 |
+
2, 1, figsize=(fig_w, fig_h), constrained_layout=True,
|
| 504 |
+
gridspec_kw={"hspace": 0.15}
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# W_in DFT
|
| 508 |
+
ax_in = axes[0]
|
| 509 |
+
W_in_np = ranked_W_in_dft.detach().cpu().numpy()
|
| 510 |
+
abs_max_in = np.abs(W_in_np).max()
|
| 511 |
+
im_in = ax_in.imshow(
|
| 512 |
+
W_in_np,
|
| 513 |
+
cmap=CMAP_DIVERGING, vmin=-abs_max_in, vmax=abs_max_in,
|
| 514 |
+
aspect='auto'
|
| 515 |
+
)
|
| 516 |
+
ax_in.set_title(r'First-Layer $\theta_m$ after DFT', fontsize=18)
|
| 517 |
+
fig.colorbar(im_in, ax=ax_in, shrink=0.8)
|
| 518 |
+
y_locs = np.arange(num_neurons)
|
| 519 |
+
ax_in.set_yticks(y_locs)
|
| 520 |
+
ax_in.set_yticklabels(y_locs, fontsize=10)
|
| 521 |
+
ax_in.set_ylabel('Neuron #', fontsize=14)
|
| 522 |
+
x_locs = np.arange(n_modes)
|
| 523 |
+
ax_in.set_xticks(x_locs)
|
| 524 |
+
ax_in.set_xticklabels(fb_names, rotation=90, fontsize=10)
|
| 525 |
+
|
| 526 |
+
# W_out DFT
|
| 527 |
+
ax_out = axes[1]
|
| 528 |
+
W_out_np = ranked_W_out_dft.detach().cpu().numpy()
|
| 529 |
+
abs_max_out = np.abs(W_out_np).max()
|
| 530 |
+
im_out = ax_out.imshow(
|
| 531 |
+
W_out_np,
|
| 532 |
+
cmap=CMAP_DIVERGING, vmin=-abs_max_out, vmax=abs_max_out,
|
| 533 |
+
aspect='auto'
|
| 534 |
+
)
|
| 535 |
+
ax_out.set_title(r'Second-Layer $\xi_m$ after DFT', fontsize=18)
|
| 536 |
+
fig.colorbar(im_out, ax=ax_out, shrink=0.8)
|
| 537 |
+
ax_out.set_yticks(y_locs)
|
| 538 |
+
ax_out.set_yticklabels(y_locs, fontsize=10)
|
| 539 |
+
ax_out.set_ylabel('Neuron #', fontsize=14)
|
| 540 |
+
ax_out.set_xticks(x_locs)
|
| 541 |
+
ax_out.set_xticklabels(fb_names, rotation=90, fontsize=10)
|
| 542 |
+
ax_out.set_xlabel('Fourier Component', fontsize=14)
|
| 543 |
+
|
| 544 |
+
_save_fig(fig, self._out('full_training_para_origin.png'))
|
| 545 |
+
|
| 546 |
+
# ---- Line plots (raw weights + cosine fits, matching blog Figure 3) ----
|
| 547 |
+
lineplot_idx = select_lineplot_neurons(list(range(num_neurons)), n=3)
|
| 548 |
+
fb = self.fourier_basis
|
| 549 |
+
positions = np.arange(ranked_W_in_raw.shape[1])
|
| 550 |
+
|
| 551 |
+
for tag, weight_data, title_tex in [
|
| 552 |
+
('lineplot_in', ranked_W_in_raw, r'First-Layer Parameters $\theta_m$'),
|
| 553 |
+
('lineplot_out', ranked_W_out_raw, r'Second-Layer Parameters $\xi_m$'),
|
| 554 |
+
]:
|
| 555 |
+
if hasattr(weight_data, 'detach'):
|
| 556 |
+
weight_np = weight_data.detach().cpu()
|
| 557 |
+
else:
|
| 558 |
+
weight_np = weight_data
|
| 559 |
+
|
| 560 |
+
top3 = weight_np[lineplot_idx]
|
| 561 |
+
|
| 562 |
+
lp_w = max(8, self.p * 0.35)
|
| 563 |
+
fig, axes_lp = plt.subplots(
|
| 564 |
+
nrows=3, ncols=1, figsize=(lp_w, 8),
|
| 565 |
+
constrained_layout=True,
|
| 566 |
+
gridspec_kw={'hspace': 0.08}
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
for i, ax in enumerate(axes_lp):
|
| 570 |
+
data = top3[i]
|
| 571 |
+
if isinstance(data, torch.Tensor):
|
| 572 |
+
data_t = data.float()
|
| 573 |
+
else:
|
| 574 |
+
data_t = torch.tensor(data, dtype=torch.float32)
|
| 575 |
+
# Project into Fourier space, keep top 2 components, project back
|
| 576 |
+
proj = data_t @ fb.T
|
| 577 |
+
abs_proj = torch.abs(proj)
|
| 578 |
+
_, top2_idx = torch.topk(abs_proj, 2)
|
| 579 |
+
mask = torch.zeros_like(proj)
|
| 580 |
+
mask[top2_idx] = proj[top2_idx]
|
| 581 |
+
data_est = mask @ fb
|
| 582 |
+
data_np = data_t.numpy()
|
| 583 |
+
data_est_np = data_est.numpy()
|
| 584 |
+
|
| 585 |
+
ax.plot(data_np, marker='o', markersize=5,
|
| 586 |
+
color=COLORS[0], linewidth=1.5, linestyle=':',
|
| 587 |
+
label="Actual")
|
| 588 |
+
ax.plot(data_est_np, marker='o', markersize=5,
|
| 589 |
+
color=COLORS[3], linewidth=1.5, linestyle=':',
|
| 590 |
+
alpha=0.7, label="Fitted")
|
| 591 |
+
ax.set_ylim(-0.9, 0.9)
|
| 592 |
+
ax.set_ylabel(f'Neuron #{i+1}', fontsize=14)
|
| 593 |
+
ax.set_xticks(positions)
|
| 594 |
+
ax.grid(True, which='major', axis='both',
|
| 595 |
+
linestyle='--', linewidth=0.5, alpha=0.6)
|
| 596 |
+
if i < len(axes_lp) - 1:
|
| 597 |
+
ax.set_xticklabels([])
|
| 598 |
+
ax.legend(fontsize=12, loc="upper right")
|
| 599 |
+
|
| 600 |
+
axes_lp[-1].set_xlabel('Input Dimension', fontsize=14)
|
| 601 |
+
axes_lp[-1].set_xticks(positions)
|
| 602 |
+
axes_lp[-1].set_xticklabels(
|
| 603 |
+
np.arange(ranked_W_in_raw.shape[1]), rotation=0, fontsize=10
|
| 604 |
+
)
|
| 605 |
+
axes_lp[0].set_title(title_tex, fontsize=18)
|
| 606 |
+
|
| 607 |
+
_save_fig(fig, self._out(f'{tag}.png'))
|
| 608 |
+
|
| 609 |
+
print(" Saved full_training_para_origin.png, lineplot_in.png, lineplot_out.png")
|
| 610 |
+
|
| 611 |
+
# ------------------------------------------------------------------
|
| 612 |
+
# Tab 3: Phase Analysis
|
| 613 |
+
# ------------------------------------------------------------------
|
| 614 |
+
|
| 615 |
+
def generate_tab3(self):
|
| 616 |
+
"""Generate phase_distribution.png, phase_relationship.png, magnitude_distribution.png."""
|
| 617 |
+
print(f" [Tab 3] Phase Analysis for p={self.p}")
|
| 618 |
+
run_dir = self._run_dir('standard')
|
| 619 |
+
if run_dir is None:
|
| 620 |
+
print(" SKIP: standard run directory not found")
|
| 621 |
+
return
|
| 622 |
+
|
| 623 |
+
final_data = _load_final(run_dir, self.device)
|
| 624 |
+
if final_data is None:
|
| 625 |
+
print(" SKIP: no final checkpoint")
|
| 626 |
+
return
|
| 627 |
+
model_load = final_data['model']
|
| 628 |
+
|
| 629 |
+
W_in_decode, W_out_decode, max_freq_ls = decode_weights(
|
| 630 |
+
model_load, self.fourier_basis
|
| 631 |
+
)
|
| 632 |
+
d_mlp = W_in_decode.shape[0]
|
| 633 |
+
|
| 634 |
+
# Compute all neuron phases and magnitudes
|
| 635 |
+
coeff_in_scale_ls = []
|
| 636 |
+
coeff_out_scale_ls = []
|
| 637 |
+
coeff_phi_ls = []
|
| 638 |
+
coeff_psi_ls = []
|
| 639 |
+
|
| 640 |
+
for neuron in range(d_mlp):
|
| 641 |
+
s_in, phi_in = compute_neuron(neuron, max_freq_ls, W_in_decode)
|
| 642 |
+
s_out, phi_out = compute_neuron(neuron, max_freq_ls, W_out_decode)
|
| 643 |
+
coeff_in_scale_ls.append(s_in)
|
| 644 |
+
coeff_out_scale_ls.append(s_out)
|
| 645 |
+
coeff_phi_ls.append(phi_in)
|
| 646 |
+
coeff_psi_ls.append(phi_out)
|
| 647 |
+
|
| 648 |
+
coeff_phi_arr = np.array(coeff_phi_ls)
|
| 649 |
+
coeff_psi_arr = np.array(coeff_psi_ls)
|
| 650 |
+
|
| 651 |
+
# ---- Phase distribution on concentric circles ----
|
| 652 |
+
# Select the most common non-zero frequency for phase analysis
|
| 653 |
+
target_freq = select_phase_frequency(max_freq_ls, self.p)
|
| 654 |
+
freq_neurons = [i for i, f in enumerate(max_freq_ls) if f == target_freq]
|
| 655 |
+
phi_subset = np.array([coeff_phi_ls[i] for i in freq_neurons])
|
| 656 |
+
|
| 657 |
+
theta = np.linspace(0, 2 * np.pi, 300)
|
| 658 |
+
multipliers = [1, 2, 3, 4]
|
| 659 |
+
radii = [1.0, 0.88, 0.76, 0.64]
|
| 660 |
+
|
| 661 |
+
fig, ax = plt.subplots(figsize=(4, 4))
|
| 662 |
+
for m, r in zip(multipliers, radii):
|
| 663 |
+
x_c, y_c = r * np.cos(theta), r * np.sin(theta)
|
| 664 |
+
ax.plot(x_c, y_c, linewidth=0.8, color='gray', alpha=0.6)
|
| 665 |
+
|
| 666 |
+
x_pts = r * np.cos(m * phi_subset)
|
| 667 |
+
y_pts = r * np.sin(m * phi_subset)
|
| 668 |
+
label = fr'$\phi_m$' if m == 1 else fr'${m}\phi_m$'
|
| 669 |
+
ax.scatter(x_pts, y_pts, s=20, marker='o',
|
| 670 |
+
color=COLORS[m - 1], label=label)
|
| 671 |
+
|
| 672 |
+
ax.legend(
|
| 673 |
+
fontsize=15, loc='upper center', columnspacing=0.2,
|
| 674 |
+
handletextpad=0.1, bbox_to_anchor=(0.5, 1.15), ncol=4, frameon=False
|
| 675 |
+
)
|
| 676 |
+
ax.set_xlabel(r'$\cos(\phi_m)$', fontsize=19)
|
| 677 |
+
ax.set_ylabel(r'$\sin(\phi_m)$', fontsize=19)
|
| 678 |
+
ax.set_xticks([])
|
| 679 |
+
ax.set_yticks([])
|
| 680 |
+
for spine in ax.spines.values():
|
| 681 |
+
spine.set_visible(False)
|
| 682 |
+
|
| 683 |
+
_save_fig(fig, self._out('phase_distribution.png'))
|
| 684 |
+
|
| 685 |
+
# ---- Phase relationship: 2*phi vs psi ----
|
| 686 |
+
coeff_2phi_arr = np.array([normalize_to_pi(2 * phi) for phi in coeff_phi_arr])
|
| 687 |
+
coeff_psi_plot = coeff_psi_arr.copy()
|
| 688 |
+
# Fix ±π wrap: keep ψ within π of 2φ so boundary points stay on diagonal
|
| 689 |
+
diff = coeff_psi_plot - coeff_2phi_arr
|
| 690 |
+
coeff_psi_plot[diff > np.pi] -= 2 * np.pi
|
| 691 |
+
coeff_psi_plot[diff < -np.pi] += 2 * np.pi
|
| 692 |
+
|
| 693 |
+
fig, ax = plt.subplots(figsize=(5, 5))
|
| 694 |
+
ax.plot([-np.pi, np.pi], [-np.pi, np.pi], 'r-', linewidth=3, alpha=0.8,
|
| 695 |
+
label=r'$\psi_m = 2\phi_m$', zorder=1)
|
| 696 |
+
ax.scatter(
|
| 697 |
+
coeff_2phi_arr, coeff_psi_plot,
|
| 698 |
+
marker='.', color=COLORS[0], s=20, zorder=2
|
| 699 |
+
)
|
| 700 |
+
ax.legend(fontsize=12, loc='upper left')
|
| 701 |
+
ax.set_xlabel(r'$2\phi_m$', fontsize=14)
|
| 702 |
+
ax.set_ylabel(r'$\psi_m$', fontsize=14)
|
| 703 |
+
ax.set_title(r'Phase Alignment: $\psi_m = 2\phi_m$', fontsize=14)
|
| 704 |
+
ax.set_xlim(-np.pi * 1.1, np.pi * 1.1)
|
| 705 |
+
ax.set_ylim(-np.pi * 1.1, np.pi * 1.1)
|
| 706 |
+
ax.set_aspect('equal')
|
| 707 |
+
ax.grid(True, alpha=0.3)
|
| 708 |
+
|
| 709 |
+
_save_fig(fig, self._out('phase_relationship.png'))
|
| 710 |
+
|
| 711 |
+
# ---- Magnitude distribution (violin) ----
|
| 712 |
+
fig, ax = plt.subplots(figsize=(4, 4))
|
| 713 |
+
data_for_plot = [coeff_in_scale_ls, coeff_out_scale_ls]
|
| 714 |
+
positions = [1, 2]
|
| 715 |
+
|
| 716 |
+
parts = ax.violinplot(
|
| 717 |
+
data_for_plot, positions=positions, widths=0.6,
|
| 718 |
+
showmeans=True, showmedians=True, showextrema=True
|
| 719 |
+
)
|
| 720 |
+
for pc in parts['bodies']:
|
| 721 |
+
pc.set_facecolor(COLORS[0])
|
| 722 |
+
pc.set_alpha(0.7)
|
| 723 |
+
parts['cmedians'].set_color(COLORS[2])
|
| 724 |
+
parts['cmedians'].set_linewidth(2)
|
| 725 |
+
parts['cmeans'].set_color(COLORS[2])
|
| 726 |
+
parts['cmeans'].set_linewidth(2)
|
| 727 |
+
parts['cbars'].set_color(COLORS[0])
|
| 728 |
+
parts['cbars'].set_linewidth(1.5)
|
| 729 |
+
parts['cmaxes'].set_color(COLORS[0])
|
| 730 |
+
parts['cmins'].set_color(COLORS[0])
|
| 731 |
+
|
| 732 |
+
ax.set_xticks(positions)
|
| 733 |
+
ax.set_xticklabels(['First-Layer', 'Second-Layer'], fontsize=14)
|
| 734 |
+
ax.set_ylabel('Magnitude', fontsize=19)
|
| 735 |
+
ax.grid(True, alpha=0.3)
|
| 736 |
+
plt.tight_layout()
|
| 737 |
+
|
| 738 |
+
_save_fig(fig, self._out('magnitude_distribution.png'))
|
| 739 |
+
|
| 740 |
+
print(" Saved phase_distribution.png, phase_relationship.png, magnitude_distribution.png")
|
| 741 |
+
|
| 742 |
+
# ------------------------------------------------------------------
|
| 743 |
+
# Tab 4: Output Logits
|
| 744 |
+
# ------------------------------------------------------------------
|
| 745 |
+
|
| 746 |
+
def generate_tab4(self):
|
| 747 |
+
"""Generate output_logits.png."""
|
| 748 |
+
print(f" [Tab 4] Output Logits for p={self.p}")
|
| 749 |
+
run_dir = self._run_dir('standard')
|
| 750 |
+
if run_dir is None:
|
| 751 |
+
print(" SKIP: standard run directory not found")
|
| 752 |
+
return
|
| 753 |
+
|
| 754 |
+
final_data = _load_final(run_dir, self.device)
|
| 755 |
+
if final_data is None:
|
| 756 |
+
print(" SKIP: no final checkpoint")
|
| 757 |
+
return
|
| 758 |
+
model_load = final_data['model']
|
| 759 |
+
|
| 760 |
+
p = self.p
|
| 761 |
+
act_type = TRAINING_RUNS['standard']['act_type']
|
| 762 |
+
model = EmbedMLP(
|
| 763 |
+
d_vocab=self.d_vocab,
|
| 764 |
+
d_model=self.d_model,
|
| 765 |
+
d_mlp=self.d_mlp,
|
| 766 |
+
act_type=act_type,
|
| 767 |
+
use_cache=False
|
| 768 |
+
)
|
| 769 |
+
model.to(self.device)
|
| 770 |
+
model.load_state_dict(model_load)
|
| 771 |
+
model.eval()
|
| 772 |
+
|
| 773 |
+
with torch.no_grad():
|
| 774 |
+
logits = model(self.all_data).squeeze(1)
|
| 775 |
+
|
| 776 |
+
logits_np = logits.cpu().numpy()
|
| 777 |
+
|
| 778 |
+
# Show first p pairs (first row of the input grid)
|
| 779 |
+
interval_start = 0
|
| 780 |
+
interval_end = p
|
| 781 |
+
logits_interval = logits_np[interval_start:interval_end]
|
| 782 |
+
selected_pairs = self.all_data[interval_start:interval_end]
|
| 783 |
+
|
| 784 |
+
fig, ax = plt.subplots(figsize=(7, 6))
|
| 785 |
+
abs_max = np.abs(logits_np).max() * 0.8
|
| 786 |
+
im = ax.imshow(
|
| 787 |
+
logits_interval.T, cmap=CMAP_DIVERGING, aspect='auto',
|
| 788 |
+
vmin=-abs_max, vmax=abs_max
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
# Highlight target positions with rectangles
|
| 792 |
+
for i, (x_val_t, y_val_t) in enumerate(selected_pairs):
|
| 793 |
+
x_val = x_val_t.item()
|
| 794 |
+
y_val = y_val_t.item()
|
| 795 |
+
target_2x = (2 * x_val) % p
|
| 796 |
+
target_2y = (2 * y_val) % p
|
| 797 |
+
target_sum = (x_val + y_val) % p
|
| 798 |
+
|
| 799 |
+
rect_2x = patches.Rectangle(
|
| 800 |
+
(i - 0.5, target_2x - 0.5), 1, 1,
|
| 801 |
+
linewidth=1.6, edgecolor='#0D2758', facecolor='none', alpha=0.9
|
| 802 |
+
)
|
| 803 |
+
ax.add_patch(rect_2x)
|
| 804 |
+
if target_2y != target_2x:
|
| 805 |
+
rect_2y = patches.Rectangle(
|
| 806 |
+
(i - 0.5, target_2y - 0.5), 1, 1,
|
| 807 |
+
linewidth=1.6, edgecolor='#0D2758', facecolor='none', alpha=0.9
|
| 808 |
+
)
|
| 809 |
+
ax.add_patch(rect_2y)
|
| 810 |
+
rect_sum = patches.Rectangle(
|
| 811 |
+
(i - 0.5, target_sum - 0.5), 1, 1,
|
| 812 |
+
linewidth=1.6, edgecolor='#0D2758', facecolor='none', alpha=0.9
|
| 813 |
+
)
|
| 814 |
+
ax.add_patch(rect_sum)
|
| 815 |
+
|
| 816 |
+
n_pairs = interval_end - interval_start
|
| 817 |
+
if n_pairs <= 50:
|
| 818 |
+
x_positions = np.arange(n_pairs)
|
| 819 |
+
x_labels = [f"({selected_pairs[i][0].item()},{selected_pairs[i][1].item()})"
|
| 820 |
+
for i in range(n_pairs)]
|
| 821 |
+
ax.set_xticks(x_positions)
|
| 822 |
+
ax.set_xticklabels(x_labels, rotation=90, ha='right', fontsize=14)
|
| 823 |
+
else:
|
| 824 |
+
n_labels = min(25, n_pairs)
|
| 825 |
+
step = n_pairs // n_labels
|
| 826 |
+
x_positions = np.arange(0, n_pairs, step)
|
| 827 |
+
x_labels = [f"({selected_pairs[i][0].item()},{selected_pairs[i][1].item()})"
|
| 828 |
+
for i in x_positions]
|
| 829 |
+
ax.set_xticks(x_positions)
|
| 830 |
+
ax.set_xticklabels(x_labels, rotation=90, ha='right', fontsize=14)
|
| 831 |
+
|
| 832 |
+
ax.set_yticks(np.arange(p))
|
| 833 |
+
ax.set_yticklabels(np.arange(p), fontsize=14)
|
| 834 |
+
ax.set_xlabel("Input Pair", fontsize=18)
|
| 835 |
+
ax.set_ylabel("Output", fontsize=18)
|
| 836 |
+
plt.colorbar(im, ax=ax)
|
| 837 |
+
ax.grid(True, alpha=0.2, linestyle=':', linewidth=0.5, axis='x')
|
| 838 |
+
plt.tight_layout()
|
| 839 |
+
|
| 840 |
+
_save_fig(fig, self._out('output_logits.png'))
|
| 841 |
+
print(" Saved output_logits.png")
|
| 842 |
+
|
| 843 |
+
# ------------------------------------------------------------------
|
| 844 |
+
# Tab 5: Grokking
|
| 845 |
+
# ------------------------------------------------------------------
|
| 846 |
+
|
| 847 |
+
def generate_tab5(self):
|
| 848 |
+
"""Generate grokking-related plots."""
|
| 849 |
+
print(f" [Tab 5] Grokking for p={self.p}")
|
| 850 |
+
if self.p < MIN_P_GROKKING:
|
| 851 |
+
print(f" SKIP: p={self.p} < {MIN_P_GROKKING} (too few test points for grokking)")
|
| 852 |
+
return
|
| 853 |
+
run_dir = self._run_dir('grokking')
|
| 854 |
+
if run_dir is None:
|
| 855 |
+
print(" SKIP: grokking run directory not found")
|
| 856 |
+
return
|
| 857 |
+
|
| 858 |
+
curves = _load_training_curves(self._run_type_dir('grokking'))
|
| 859 |
+
checkpoints = _load_checkpoints(run_dir, self.device)
|
| 860 |
+
|
| 861 |
+
if not checkpoints:
|
| 862 |
+
print(" SKIP: no grokking checkpoints")
|
| 863 |
+
return
|
| 864 |
+
|
| 865 |
+
epochs = sorted(checkpoints.keys())
|
| 866 |
+
p = self.p
|
| 867 |
+
d_mlp = self.d_mlp
|
| 868 |
+
act_type = TRAINING_RUNS['grokking']['act_type']
|
| 869 |
+
|
| 870 |
+
# Load train/test data
|
| 871 |
+
train_data_path = os.path.join(run_dir, 'train_data.pth')
|
| 872 |
+
test_data_path = os.path.join(run_dir, 'test_data.pth')
|
| 873 |
+
train_data = None
|
| 874 |
+
test_data = None
|
| 875 |
+
train_labels = None
|
| 876 |
+
test_labels = None
|
| 877 |
+
if os.path.exists(train_data_path):
|
| 878 |
+
raw = torch.load(train_data_path, weights_only=False,
|
| 879 |
+
map_location=self.device)
|
| 880 |
+
# Handle both formats: plain tensor or (pairs, labels) tuple
|
| 881 |
+
if isinstance(raw, (tuple, list)):
|
| 882 |
+
train_data, train_labels = raw[0], raw[1]
|
| 883 |
+
else:
|
| 884 |
+
train_data = raw
|
| 885 |
+
if os.path.exists(test_data_path):
|
| 886 |
+
raw = torch.load(test_data_path, weights_only=False,
|
| 887 |
+
map_location=self.device)
|
| 888 |
+
if isinstance(raw, (tuple, list)):
|
| 889 |
+
test_data, test_labels = raw[0], raw[1]
|
| 890 |
+
else:
|
| 891 |
+
test_data = raw
|
| 892 |
+
|
| 893 |
+
# Fallback: regenerate data deterministically if files are missing
|
| 894 |
+
if train_data is None or test_data is None:
|
| 895 |
+
grokk_cfg = TRAINING_RUNS['grokking']
|
| 896 |
+
frac = grokk_cfg['frac_train']
|
| 897 |
+
seed = grokk_cfg['seed']
|
| 898 |
+
print(f" Regenerating train/test data (frac={frac}, seed={seed})")
|
| 899 |
+
train_data, test_data = _gen_train_test(p, frac_train=frac, seed=seed)
|
| 900 |
+
|
| 901 |
+
# Compute labels from pairs if not loaded directly
|
| 902 |
+
if train_labels is None and train_data is not None:
|
| 903 |
+
train_labels = torch.tensor(
|
| 904 |
+
[(train_data[i, 0].item() + train_data[i, 1].item()) % p
|
| 905 |
+
for i in range(train_data.shape[0])],
|
| 906 |
+
dtype=torch.long
|
| 907 |
+
)
|
| 908 |
+
if test_labels is None and test_data is not None:
|
| 909 |
+
test_labels = torch.tensor(
|
| 910 |
+
[(test_data[i, 0].item() + test_data[i, 1].item()) % p
|
| 911 |
+
for i in range(test_data.shape[0])],
|
| 912 |
+
dtype=torch.long
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
# Detect stage boundaries
|
| 916 |
+
train_losses = curves.get('train_losses', []) if curves else []
|
| 917 |
+
test_losses = curves.get('test_losses', []) if curves else []
|
| 918 |
+
train_accs_curve = curves.get('train_accs', None) if curves else None
|
| 919 |
+
test_accs_curve = curves.get('test_accs', None) if curves else None
|
| 920 |
+
|
| 921 |
+
stage1_end, stage2_end = detect_grokking_stages(
|
| 922 |
+
train_losses, test_losses, train_accs_curve, test_accs_curve
|
| 923 |
+
)
|
| 924 |
+
if stage1_end is None:
|
| 925 |
+
stage1_end = len(epochs) // 5
|
| 926 |
+
if stage2_end is None:
|
| 927 |
+
stage2_end = len(epochs) * 3 // 5
|
| 928 |
+
|
| 929 |
+
# ---- Loss JSON + static PNG ----
|
| 930 |
+
if train_losses:
|
| 931 |
+
loss_data = {
|
| 932 |
+
'train_losses': train_losses,
|
| 933 |
+
'test_losses': test_losses,
|
| 934 |
+
'stage1_end': stage1_end,
|
| 935 |
+
'stage2_end': stage2_end,
|
| 936 |
+
}
|
| 937 |
+
with open(self._out('grokk_loss.json'), 'w') as f:
|
| 938 |
+
json.dump(loss_data, f)
|
| 939 |
+
|
| 940 |
+
# Static loss PNG (matches blog Figure 13a)
|
| 941 |
+
max_step = min(len(train_losses), len(test_losses)) if test_losses else len(train_losses)
|
| 942 |
+
fig, ax = plt.subplots(figsize=(4, 4))
|
| 943 |
+
ax.plot(train_losses[:max_step], color='#0D2758', linewidth=2, label='Train')
|
| 944 |
+
if test_losses:
|
| 945 |
+
ax.plot(test_losses[:max_step], color='#A32015', linewidth=2, label='Test')
|
| 946 |
+
ax.axvspan(0, stage1_end, alpha=0.15, color='#D4AF37')
|
| 947 |
+
ax.axvspan(stage1_end, stage2_end, alpha=0.15, color='#8B7355')
|
| 948 |
+
ax.axvspan(stage2_end, max_step, alpha=0.15, color='#60656F')
|
| 949 |
+
ax.axvline(x=stage1_end, color='black', linestyle='--', linewidth=1)
|
| 950 |
+
ax.axvline(x=stage2_end, color='black', linestyle='--', linewidth=1)
|
| 951 |
+
ax.set_xlabel('Step', fontsize=16)
|
| 952 |
+
ax.set_ylabel('Loss', fontsize=16)
|
| 953 |
+
ax.legend(fontsize=16, loc='upper right')
|
| 954 |
+
ax.grid(True, linestyle='--', alpha=0.5)
|
| 955 |
+
plt.tight_layout()
|
| 956 |
+
_save_fig(fig, self._out('grokk_loss.png'))
|
| 957 |
+
|
| 958 |
+
# ---- Accuracy: compute from checkpoints if not in curves ----
|
| 959 |
+
train_accs = []
|
| 960 |
+
test_accs = []
|
| 961 |
+
if train_data is not None and test_data is not None:
|
| 962 |
+
for ep in epochs:
|
| 963 |
+
model = EmbedMLP(
|
| 964 |
+
d_vocab=self.d_vocab, d_model=self.d_model,
|
| 965 |
+
d_mlp=d_mlp, act_type=act_type, use_cache=False
|
| 966 |
+
).to(self.device)
|
| 967 |
+
model.load_state_dict(checkpoints[ep])
|
| 968 |
+
model.eval()
|
| 969 |
+
with torch.no_grad():
|
| 970 |
+
tr_logits = model(train_data)
|
| 971 |
+
te_logits = model(test_data)
|
| 972 |
+
train_accs.append(acc_rate(tr_logits, train_labels))
|
| 973 |
+
test_accs.append(acc_rate(te_logits, test_labels))
|
| 974 |
+
elif train_accs_curve is not None:
|
| 975 |
+
# Use curves data, subsample to match checkpoint epochs
|
| 976 |
+
save_every = epochs[1] - epochs[0] if len(epochs) > 1 else 200
|
| 977 |
+
train_accs = train_accs_curve[::save_every][:len(epochs)]
|
| 978 |
+
test_accs = test_accs_curve[::save_every][:len(epochs)]
|
| 979 |
+
|
| 980 |
+
acc_data = {
|
| 981 |
+
'epochs': epochs,
|
| 982 |
+
'train_accs': train_accs,
|
| 983 |
+
'test_accs': test_accs,
|
| 984 |
+
'stage1_end': stage1_end,
|
| 985 |
+
'stage2_end': stage2_end,
|
| 986 |
+
}
|
| 987 |
+
with open(self._out('grokk_acc.json'), 'w') as f:
|
| 988 |
+
json.dump(acc_data, f)
|
| 989 |
+
|
| 990 |
+
# Static accuracy PNG (matches blog Figure 13b)
|
| 991 |
+
if train_accs and test_accs:
|
| 992 |
+
fig, ax = plt.subplots(figsize=(4, 4))
|
| 993 |
+
ax.axvspan(0, stage1_end, alpha=0.15, color='#D4AF37')
|
| 994 |
+
ax.axvspan(stage1_end, stage2_end, alpha=0.15, color='#8B7355')
|
| 995 |
+
ax.axvspan(stage2_end, epochs[-1] if epochs else stage2_end,
|
| 996 |
+
alpha=0.15, color='#60656F')
|
| 997 |
+
ax.axvline(x=stage1_end, color='black', linestyle='--', linewidth=1)
|
| 998 |
+
ax.axvline(x=stage2_end, color='black', linestyle='--', linewidth=1)
|
| 999 |
+
ax.plot(epochs[:len(train_accs)], train_accs,
|
| 1000 |
+
label='Train', color='#0D2758', linewidth=2.5)
|
| 1001 |
+
ax.plot(epochs[:len(test_accs)], test_accs,
|
| 1002 |
+
label='Test', color='#A32015', linewidth=2.5)
|
| 1003 |
+
ax.set_xlabel('Step', fontsize=16)
|
| 1004 |
+
ax.set_ylabel('Accuracy', fontsize=16)
|
| 1005 |
+
ax.legend(fontsize=16, loc='lower right')
|
| 1006 |
+
ax.grid(True, linestyle='--', alpha=0.5)
|
| 1007 |
+
plt.tight_layout()
|
| 1008 |
+
_save_fig(fig, self._out('grokk_acc.png'))
|
| 1009 |
+
|
| 1010 |
+
# ---- Phase difference |sin(D*)| ----
|
| 1011 |
+
abs_phase_diff = []
|
| 1012 |
+
sparse_level = []
|
| 1013 |
+
|
| 1014 |
+
for ep in epochs:
|
| 1015 |
+
model_sd = checkpoints[ep]
|
| 1016 |
+
W_in_d, W_out_d, mfl = decode_weights(model_sd, self.fourier_basis)
|
| 1017 |
+
|
| 1018 |
+
sparse_level.append(self._ipr_at_checkpoint(model_sd))
|
| 1019 |
+
|
| 1020 |
+
phase_diffs = []
|
| 1021 |
+
for neuron in range(W_in_d.shape[0]):
|
| 1022 |
+
_, phi_in = compute_neuron(neuron, mfl, W_in_d)
|
| 1023 |
+
_, phi_out = compute_neuron(neuron, mfl, W_out_d)
|
| 1024 |
+
phase_diffs.append(normalize_to_pi(phi_out - 2 * phi_in))
|
| 1025 |
+
phase_diffs = np.array(phase_diffs)
|
| 1026 |
+
abs_phase_diff.append(np.mean(np.abs(np.sin(phase_diffs))))
|
| 1027 |
+
|
| 1028 |
+
# Limit to reasonable number of points for plotting
|
| 1029 |
+
n_plot = min(len(epochs), 100)
|
| 1030 |
+
x_phase = np.array(epochs[:n_plot])
|
| 1031 |
+
|
| 1032 |
+
fig, ax = plt.subplots(figsize=(4, 4))
|
| 1033 |
+
ax.axvspan(0, stage1_end, alpha=0.15, color='#D4AF37')
|
| 1034 |
+
ax.axvspan(stage1_end, min(stage2_end, x_phase[-1] if len(x_phase) else stage2_end),
|
| 1035 |
+
alpha=0.15, color='#8B7355')
|
| 1036 |
+
if len(x_phase):
|
| 1037 |
+
ax.axvspan(stage2_end, x_phase[-1], alpha=0.15, color='#60656F')
|
| 1038 |
+
ax.axvline(x=stage1_end, color='black', linestyle='--', linewidth=1)
|
| 1039 |
+
ax.axvline(x=stage2_end, color='black', linestyle='--', linewidth=1)
|
| 1040 |
+
ax.plot(x_phase, abs_phase_diff[:n_plot], marker='x', markersize=5,
|
| 1041 |
+
color='#986d56', label=r"Avg. $|\sin(D_m^\star)|$", linewidth=1.5)
|
| 1042 |
+
ax.set_xlabel('Step', fontsize=16)
|
| 1043 |
+
ax.set_ylabel('Average Value', fontsize=16)
|
| 1044 |
+
ax.set_ylim([0, 0.65])
|
| 1045 |
+
ax.legend(fontsize=16, loc="upper right")
|
| 1046 |
+
ax.grid(True, alpha=0.5, linestyle='--')
|
| 1047 |
+
plt.tight_layout()
|
| 1048 |
+
_save_fig(fig, self._out('grokk_abs_phase_diff.png'))
|
| 1049 |
+
|
| 1050 |
+
# ---- IPR + param norms (dual axis) ----
|
| 1051 |
+
x_all = np.array(epochs)
|
| 1052 |
+
param_norms = []
|
| 1053 |
+
if curves and 'param_norms' in curves:
|
| 1054 |
+
save_every = epochs[1] - epochs[0] if len(epochs) > 1 else 200
|
| 1055 |
+
param_norms = curves['param_norms'][::save_every][:len(epochs)]
|
| 1056 |
+
|
| 1057 |
+
fig, ax1 = plt.subplots(figsize=(4, 4))
|
| 1058 |
+
ax1.axvspan(0, stage1_end, alpha=0.15, color='#D4AF37')
|
| 1059 |
+
ax1.axvspan(stage1_end, min(stage2_end, x_all[-1] if len(x_all) else stage2_end),
|
| 1060 |
+
alpha=0.15, color='#8B7355')
|
| 1061 |
+
if len(x_all):
|
| 1062 |
+
ax1.axvspan(stage2_end, x_all[-1], alpha=0.15, color='#60656F')
|
| 1063 |
+
ax1.axvline(x=stage1_end, color='black', linestyle='--', linewidth=1)
|
| 1064 |
+
ax1.axvline(x=stage2_end, color='black', linestyle='--', linewidth=1)
|
| 1065 |
+
|
| 1066 |
+
line1 = ax1.plot(x_all, sparse_level, marker='x', markersize=3,
|
| 1067 |
+
color='#986d56', label=r"Avg. IPR", linewidth=1.5)
|
| 1068 |
+
ax1.set_xlabel('Step', fontsize=16)
|
| 1069 |
+
ax1.tick_params(axis='y')
|
| 1070 |
+
ax1.set_ylim([0, 1.05])
|
| 1071 |
+
|
| 1072 |
+
if param_norms:
|
| 1073 |
+
ax2 = ax1.twinx()
|
| 1074 |
+
line2 = ax2.plot(x_all[:len(param_norms)], param_norms,
|
| 1075 |
+
marker='o', markersize=3, color='#2E5266',
|
| 1076 |
+
label=r"Param. Norm", linewidth=1.5)
|
| 1077 |
+
ax2.tick_params(axis='y')
|
| 1078 |
+
lines = line1 + line2
|
| 1079 |
+
labels = [l.get_label() for l in lines]
|
| 1080 |
+
ax1.legend(lines, labels, fontsize=16, loc="lower right")
|
| 1081 |
+
else:
|
| 1082 |
+
ax1.legend(fontsize=16, loc="lower right")
|
| 1083 |
+
|
| 1084 |
+
ax1.grid(True, alpha=0.5, linestyle='--')
|
| 1085 |
+
plt.tight_layout()
|
| 1086 |
+
_save_fig(fig, self._out('grokk_avg_ipr.png'))
|
| 1087 |
+
|
| 1088 |
+
# ---- Memorization accuracy (3-panel) ----
|
| 1089 |
+
if train_data is not None:
|
| 1090 |
+
# Find a checkpoint near stage1_end
|
| 1091 |
+
closest_epoch = min(epochs, key=lambda e: abs(e - stage1_end))
|
| 1092 |
+
model_sd = checkpoints[closest_epoch]
|
| 1093 |
+
|
| 1094 |
+
model = EmbedMLP(
|
| 1095 |
+
d_vocab=self.d_vocab, d_model=self.d_model,
|
| 1096 |
+
d_mlp=d_mlp, act_type=act_type, use_cache=False
|
| 1097 |
+
).to(self.device)
|
| 1098 |
+
model.load_state_dict(model_sd)
|
| 1099 |
+
model.eval()
|
| 1100 |
+
|
| 1101 |
+
with torch.no_grad():
|
| 1102 |
+
logits = model(self.all_data).squeeze(1)
|
| 1103 |
+
|
| 1104 |
+
train_set = set([(int(i), int(j)) for i, j in train_data])
|
| 1105 |
+
true_test_points = []
|
| 1106 |
+
|
| 1107 |
+
train_mask = torch.zeros(p, p)
|
| 1108 |
+
for i in range(p):
|
| 1109 |
+
for j in range(p):
|
| 1110 |
+
if (i, j) in train_set:
|
| 1111 |
+
train_mask[i, j] = 1.0
|
| 1112 |
+
elif (j, i) in train_set:
|
| 1113 |
+
train_mask[i, j] = 0.65
|
| 1114 |
+
else:
|
| 1115 |
+
train_mask[i, j] = 0.0
|
| 1116 |
+
true_test_points.append((i, j))
|
| 1117 |
+
|
| 1118 |
+
predicted = torch.argmax(logits, dim=1).view(p, p)
|
| 1119 |
+
gt_grid = self.all_labels.view(p, p)
|
| 1120 |
+
accuracy_mask = (predicted == gt_grid).float()
|
| 1121 |
+
|
| 1122 |
+
probs = torch.softmax(logits, dim=1)
|
| 1123 |
+
gt_probs = torch.zeros(p * p)
|
| 1124 |
+
for idx in range(p * p):
|
| 1125 |
+
i_val = self.all_data[idx, 0].item()
|
| 1126 |
+
j_val = self.all_data[idx, 1].item()
|
| 1127 |
+
correct = (i_val + j_val) % p
|
| 1128 |
+
gt_probs[idx] = probs[idx, correct]
|
| 1129 |
+
gt_probs_grid = gt_probs.view(p, p)
|
| 1130 |
+
|
| 1131 |
+
fig = plt.figure(figsize=(20, 6))
|
| 1132 |
+
gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 1.1], wspace=0.15)
|
| 1133 |
+
|
| 1134 |
+
ax1 = fig.add_subplot(gs[0])
|
| 1135 |
+
ax2 = fig.add_subplot(gs[1])
|
| 1136 |
+
ax3 = fig.add_subplot(gs[2])
|
| 1137 |
+
|
| 1138 |
+
# Train mask
|
| 1139 |
+
im1 = ax1.imshow(train_mask.numpy(), cmap=CMAP_SEQUENTIAL,
|
| 1140 |
+
vmin=0, vmax=1, aspect='equal')
|
| 1141 |
+
ax1.set_title('Training Data under Symmetry', fontsize=21)
|
| 1142 |
+
ax1.set_ylabel('First Input', fontsize=18)
|
| 1143 |
+
ax1.set_xlabel('Second Input', fontsize=18)
|
| 1144 |
+
locs = np.arange(p)
|
| 1145 |
+
ax1.set_xticks(locs)
|
| 1146 |
+
ax1.set_yticks(locs)
|
| 1147 |
+
ax1.set_xticklabels(locs, fontsize=11)
|
| 1148 |
+
ax1.set_yticklabels(locs, fontsize=11)
|
| 1149 |
+
for ti, tj in true_test_points:
|
| 1150 |
+
rect = plt.Rectangle((tj - 0.5, ti - 0.5), 1, 1,
|
| 1151 |
+
linewidth=2.5, edgecolor='red', facecolor='none')
|
| 1152 |
+
ax1.add_patch(rect)
|
| 1153 |
+
|
| 1154 |
+
# Accuracy mask
|
| 1155 |
+
im2 = ax2.imshow(accuracy_mask.numpy(), cmap=CMAP_SEQUENTIAL,
|
| 1156 |
+
vmin=0, vmax=1, aspect='equal')
|
| 1157 |
+
ax2.set_title('Accuracy before Grokking', fontsize=21)
|
| 1158 |
+
ax2.set_xlabel('Second Input', fontsize=18)
|
| 1159 |
+
ax2.set_xticks(locs)
|
| 1160 |
+
ax2.set_yticks(locs)
|
| 1161 |
+
ax2.set_xticklabels(locs, fontsize=11)
|
| 1162 |
+
ax2.set_yticklabels(locs, fontsize=11)
|
| 1163 |
+
for ti, tj in true_test_points:
|
| 1164 |
+
rect = plt.Rectangle((tj - 0.5, ti - 0.5), 1, 1,
|
| 1165 |
+
linewidth=2.5, edgecolor='red', facecolor='none')
|
| 1166 |
+
ax2.add_patch(rect)
|
| 1167 |
+
|
| 1168 |
+
# Softmax probability
|
| 1169 |
+
prob_max = gt_probs_grid.max().item()
|
| 1170 |
+
im3 = ax3.imshow(gt_probs_grid.detach().numpy(), cmap=CMAP_SEQUENTIAL,
|
| 1171 |
+
vmin=0, vmax=prob_max, aspect='equal')
|
| 1172 |
+
ax3.set_title('Softmax Weight at Ground-Truth', fontsize=21)
|
| 1173 |
+
ax3.set_xlabel('Second Input', fontsize=18)
|
| 1174 |
+
ax3.set_xticks(locs)
|
| 1175 |
+
ax3.set_yticks(locs)
|
| 1176 |
+
ax3.set_xticklabels(locs, fontsize=11)
|
| 1177 |
+
ax3.set_yticklabels(locs, fontsize=11)
|
| 1178 |
+
for ti, tj in true_test_points:
|
| 1179 |
+
rect = plt.Rectangle((tj - 0.5, ti - 0.5), 1, 1,
|
| 1180 |
+
linewidth=2.5, edgecolor='red', facecolor='none')
|
| 1181 |
+
ax3.add_patch(rect)
|
| 1182 |
+
cbar3 = fig.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04)
|
| 1183 |
+
cbar3.ax.tick_params(labelsize=12)
|
| 1184 |
+
plt.tight_layout()
|
| 1185 |
+
_save_fig(fig, self._out('grokk_memorization_accuracy.png'))
|
| 1186 |
+
|
| 1187 |
+
# ---- Memorization common-to-rare (4-panel) ----
|
| 1188 |
+
if train_data is not None:
|
| 1189 |
+
train_set = set([(int(i), int(j)) for i, j in train_data])
|
| 1190 |
+
asymmetric_train_points = []
|
| 1191 |
+
train_mask_dist = torch.zeros(p, p)
|
| 1192 |
+
for i in range(p):
|
| 1193 |
+
for j in range(p):
|
| 1194 |
+
if (i, j) in train_set and (j, i) in train_set:
|
| 1195 |
+
train_mask_dist[i, j] = 1.0
|
| 1196 |
+
elif (i, j) in train_set and (j, i) not in train_set:
|
| 1197 |
+
train_mask_dist[i, j] = 0.5
|
| 1198 |
+
asymmetric_train_points.append((i, j))
|
| 1199 |
+
else:
|
| 1200 |
+
train_mask_dist[i, j] = 0.0
|
| 1201 |
+
|
| 1202 |
+
# Pick 3 epochs: 0, ~stage1/2, ~stage1
|
| 1203 |
+
selected_epochs = [0]
|
| 1204 |
+
mid_epoch = min(epochs, key=lambda e: abs(e - stage1_end // 2))
|
| 1205 |
+
end_epoch = min(epochs, key=lambda e: abs(e - stage1_end))
|
| 1206 |
+
if mid_epoch not in selected_epochs:
|
| 1207 |
+
selected_epochs.append(mid_epoch)
|
| 1208 |
+
if end_epoch not in selected_epochs:
|
| 1209 |
+
selected_epochs.append(end_epoch)
|
| 1210 |
+
# Ensure we have exactly 3 + distribution = 4 panels
|
| 1211 |
+
while len(selected_epochs) < 3:
|
| 1212 |
+
selected_epochs.append(epochs[min(len(epochs) - 1, 2)])
|
| 1213 |
+
|
| 1214 |
+
fig = plt.figure(figsize=(26, 6))
|
| 1215 |
+
gs = fig.add_gridspec(
|
| 1216 |
+
1, 4, width_ratios=[1, 1, 1, 1.1], wspace=0.15
|
| 1217 |
+
)
|
| 1218 |
+
|
| 1219 |
+
# Panel 1: training data distribution
|
| 1220 |
+
ax_d = fig.add_subplot(gs[0])
|
| 1221 |
+
ax_d.imshow(train_mask_dist.numpy(), cmap=CMAP_SEQUENTIAL,
|
| 1222 |
+
vmin=0, vmax=1, aspect='equal')
|
| 1223 |
+
ax_d.set_title('Training Data Distribution', fontsize=21)
|
| 1224 |
+
ax_d.set_ylabel('First Input', fontsize=18)
|
| 1225 |
+
ax_d.set_xlabel('Second Input', fontsize=18)
|
| 1226 |
+
locs = np.arange(p)
|
| 1227 |
+
ax_d.set_xticks(locs)
|
| 1228 |
+
ax_d.set_yticks(locs)
|
| 1229 |
+
ax_d.set_xticklabels(locs, fontsize=11)
|
| 1230 |
+
ax_d.set_yticklabels(locs, fontsize=11)
|
| 1231 |
+
for ti, tj in asymmetric_train_points:
|
| 1232 |
+
rect = plt.Rectangle((tj - 0.5, ti - 0.5), 1, 1,
|
| 1233 |
+
linewidth=2.0, edgecolor='red', facecolor='none')
|
| 1234 |
+
ax_d.add_patch(rect)
|
| 1235 |
+
|
| 1236 |
+
# Panels 2-4: accuracy at selected epochs
|
| 1237 |
+
for panel_idx, sel_ep in enumerate(selected_epochs):
|
| 1238 |
+
ax_p = fig.add_subplot(gs[panel_idx + 1])
|
| 1239 |
+
model_p = EmbedMLP(
|
| 1240 |
+
d_vocab=self.d_vocab, d_model=self.d_model,
|
| 1241 |
+
d_mlp=d_mlp, act_type=act_type, use_cache=False
|
| 1242 |
+
).to(self.device)
|
| 1243 |
+
model_p.load_state_dict(checkpoints[sel_ep])
|
| 1244 |
+
model_p.eval()
|
| 1245 |
+
with torch.no_grad():
|
| 1246 |
+
logits_p = model_p(self.all_data).squeeze(1)
|
| 1247 |
+
pred_p = torch.argmax(logits_p, dim=1).view(p, p)
|
| 1248 |
+
acc_p = (pred_p == gt_grid).float()
|
| 1249 |
+
|
| 1250 |
+
ax_p.imshow(acc_p.numpy(), cmap=CMAP_SEQUENTIAL,
|
| 1251 |
+
vmin=0, vmax=1, aspect='equal')
|
| 1252 |
+
ax_p.set_title(f'Accuracy at Step {sel_ep}', fontsize=21)
|
| 1253 |
+
ax_p.set_xlabel('Second Input', fontsize=18)
|
| 1254 |
+
ax_p.set_xticks(locs)
|
| 1255 |
+
ax_p.set_yticks(locs)
|
| 1256 |
+
ax_p.set_xticklabels(locs, fontsize=11)
|
| 1257 |
+
ax_p.set_yticklabels(locs, fontsize=11)
|
| 1258 |
+
for ti, tj in asymmetric_train_points:
|
| 1259 |
+
rect = plt.Rectangle((tj - 0.5, ti - 0.5), 1, 1,
|
| 1260 |
+
linewidth=2.0, edgecolor='red', facecolor='none')
|
| 1261 |
+
ax_p.add_patch(rect)
|
| 1262 |
+
|
| 1263 |
+
plt.tight_layout()
|
| 1264 |
+
_save_fig(fig, self._out('grokk_memorization_common_to_rare.png'))
|
| 1265 |
+
|
| 1266 |
+
# ---- Decoded weights dynamic (3 timepoints) ----
|
| 1267 |
+
# Pick 3 representative epochs: 0, stage1, stage2
|
| 1268 |
+
key_epochs = [0]
|
| 1269 |
+
ep_s1 = min(epochs, key=lambda e: abs(e - stage1_end))
|
| 1270 |
+
ep_s2 = min(epochs, key=lambda e: abs(e - stage2_end))
|
| 1271 |
+
if ep_s1 not in key_epochs:
|
| 1272 |
+
key_epochs.append(ep_s1)
|
| 1273 |
+
if ep_s2 not in key_epochs:
|
| 1274 |
+
key_epochs.append(ep_s2)
|
| 1275 |
+
while len(key_epochs) < 3:
|
| 1276 |
+
key_epochs.append(epochs[-1])
|
| 1277 |
+
|
| 1278 |
+
num_components = min(20, d_mlp)
|
| 1279 |
+
n = len(key_epochs)
|
| 1280 |
+
fig, axes = plt.subplots(
|
| 1281 |
+
2, n, figsize=(18, 3.3 * n),
|
| 1282 |
+
gridspec_kw={"hspace": 0.05}, constrained_layout=True
|
| 1283 |
+
)
|
| 1284 |
+
if n == 1:
|
| 1285 |
+
axes = axes.reshape(2, 1)
|
| 1286 |
+
|
| 1287 |
+
x_locs = np.arange(len(self.fourier_basis_names))
|
| 1288 |
+
y_locs = np.arange(num_components)
|
| 1289 |
+
|
| 1290 |
+
for col, key in enumerate(key_epochs):
|
| 1291 |
+
W_in = checkpoints[key]['mlp.W_in']
|
| 1292 |
+
W_out = checkpoints[key]['mlp.W_out']
|
| 1293 |
+
|
| 1294 |
+
data_in = (W_in @ self.fourier_basis.T)[:num_components]
|
| 1295 |
+
data_in_np = data_in.detach().cpu().numpy()
|
| 1296 |
+
abs_max_in = np.abs(data_in_np).max()
|
| 1297 |
+
ax_in = axes[0, col]
|
| 1298 |
+
im_in = ax_in.imshow(
|
| 1299 |
+
data_in_np, cmap=CMAP_DIVERGING,
|
| 1300 |
+
vmin=-abs_max_in, vmax=abs_max_in, aspect='auto'
|
| 1301 |
+
)
|
| 1302 |
+
ax_in.set_title(rf'Step {key}, $\theta_m$ after DFT', fontsize=18)
|
| 1303 |
+
ax_in.set_xticks(x_locs)
|
| 1304 |
+
ax_in.set_xticklabels(self.fourier_basis_names, rotation=90, fontsize=11)
|
| 1305 |
+
ax_in.set_yticks(y_locs)
|
| 1306 |
+
ax_in.set_yticklabels(y_locs)
|
| 1307 |
+
if col == 0:
|
| 1308 |
+
ax_in.set_ylabel('Neuron #', fontsize=16)
|
| 1309 |
+
fig.colorbar(im_in, ax=ax_in)
|
| 1310 |
+
|
| 1311 |
+
data_out = (W_out.T @ self.fourier_basis.T)[:num_components]
|
| 1312 |
+
data_out_np = data_out.detach().cpu().numpy()
|
| 1313 |
+
abs_max_out = np.abs(data_out_np).max() * 0.85
|
| 1314 |
+
ax_out = axes[1, col]
|
| 1315 |
+
im_out = ax_out.imshow(
|
| 1316 |
+
data_out_np, cmap=CMAP_DIVERGING,
|
| 1317 |
+
vmin=-abs_max_out, vmax=abs_max_out, aspect='auto'
|
| 1318 |
+
)
|
| 1319 |
+
ax_out.set_title(rf'Step {key}, $\xi_m$ after DFT', fontsize=18)
|
| 1320 |
+
ax_out.set_xticks(x_locs)
|
| 1321 |
+
ax_out.set_xticklabels(self.fourier_basis_names, rotation=90, fontsize=11)
|
| 1322 |
+
ax_out.set_yticks(y_locs)
|
| 1323 |
+
ax_out.set_yticklabels(y_locs)
|
| 1324 |
+
if col == 0:
|
| 1325 |
+
ax_out.set_ylabel('Neuron #', fontsize=16)
|
| 1326 |
+
fig.colorbar(im_out, ax=ax_out)
|
| 1327 |
+
|
| 1328 |
+
_save_fig(fig, self._out('grokk_decoded_weights_dynamic.png'))
|
| 1329 |
+
|
| 1330 |
+
print(" Saved grokk_loss.json, grokk_loss.png, grokk_acc.json, grokk_acc.png, "
|
| 1331 |
+
"grokk_abs_phase_diff.png, grokk_avg_ipr.png, "
|
| 1332 |
+
"grokk_memorization_accuracy.png, "
|
| 1333 |
+
"grokk_memorization_common_to_rare.png, grokk_decoded_weights_dynamic.png")
|
| 1334 |
+
|
| 1335 |
+
# ------------------------------------------------------------------
|
| 1336 |
+
# Tab 6: Lottery Mechanism
|
| 1337 |
+
# ------------------------------------------------------------------
|
| 1338 |
+
|
| 1339 |
+
def generate_tab6(self):
|
| 1340 |
+
"""Generate lottery mechanism plots."""
|
| 1341 |
+
print(f" [Tab 6] Lottery Mechanism for p={self.p}")
|
| 1342 |
+
run_dir = self._run_dir('quad_random')
|
| 1343 |
+
if run_dir is None:
|
| 1344 |
+
print(" SKIP: quad_random run directory not found")
|
| 1345 |
+
return
|
| 1346 |
+
|
| 1347 |
+
checkpoints = _load_checkpoints(run_dir, self.device)
|
| 1348 |
+
if not checkpoints:
|
| 1349 |
+
print(" SKIP: no quad_random checkpoints")
|
| 1350 |
+
return
|
| 1351 |
+
|
| 1352 |
+
final_data = _load_final(run_dir, self.device)
|
| 1353 |
+
if final_data is None:
|
| 1354 |
+
print(" SKIP: no final quad_random checkpoint")
|
| 1355 |
+
return
|
| 1356 |
+
model_load_final = final_data['model']
|
| 1357 |
+
|
| 1358 |
+
# Select best neuron
|
| 1359 |
+
neuron_id = select_lottery_neuron(
|
| 1360 |
+
model_load_final, self.fourier_basis, decode_scales_phis
|
| 1361 |
+
)
|
| 1362 |
+
|
| 1363 |
+
epochs = sorted(checkpoints.keys())
|
| 1364 |
+
p = self.p
|
| 1365 |
+
|
| 1366 |
+
# Collect per-checkpoint scales and phase diffs for the selected neuron
|
| 1367 |
+
scales_list = []
|
| 1368 |
+
diff_list = []
|
| 1369 |
+
for ep in epochs:
|
| 1370 |
+
scales, phis, psis = decode_scales_phis(
|
| 1371 |
+
checkpoints[ep], self.fourier_basis
|
| 1372 |
+
)
|
| 1373 |
+
scales_list.append(scales[neuron_id])
|
| 1374 |
+
diff_list.append(normalize_to_pi(
|
| 1375 |
+
psis[neuron_id] - 2 * phis[neuron_id]
|
| 1376 |
+
))
|
| 1377 |
+
|
| 1378 |
+
# Stack: [num_checkpoints, K+1], skip DC
|
| 1379 |
+
scales_all = torch.stack(scales_list, dim=0)[:, 1:]
|
| 1380 |
+
diff_all = torch.stack(diff_list, dim=0)[:, 1:]
|
| 1381 |
+
|
| 1382 |
+
# Determine which frequency this neuron specializes in
|
| 1383 |
+
_, _, max_freq_ls = decode_weights(model_load_final, self.fourier_basis)
|
| 1384 |
+
max_freq = max_freq_ls[neuron_id] - 1 # 0-indexed into scales_all
|
| 1385 |
+
|
| 1386 |
+
scales_np = scales_all.cpu().numpy()
|
| 1387 |
+
diff_np = diff_all.cpu().numpy()
|
| 1388 |
+
num_models, num_freqs = scales_np.shape
|
| 1389 |
+
n_plot = min(num_models, 160)
|
| 1390 |
+
scales_np = scales_np[:n_plot]
|
| 1391 |
+
diff_np = diff_np[:n_plot]
|
| 1392 |
+
x_idx = np.arange(n_plot)
|
| 1393 |
+
|
| 1394 |
+
# Color gradient for non-highlighted frequencies
|
| 1395 |
+
base_rgb = np.array(mcolors.to_rgb(COLORS[0]))
|
| 1396 |
+
gray_rgb = np.array(mcolors.to_rgb('white'))
|
| 1397 |
+
highlight_color = COLORS[3]
|
| 1398 |
+
|
| 1399 |
+
nonmax = [f for f in range(num_freqs) if f != max_freq]
|
| 1400 |
+
final_scales = scales_np[-1]
|
| 1401 |
+
sorted_nonmax = sorted(nonmax, key=lambda f: final_scales[f])
|
| 1402 |
+
M = len(sorted_nonmax)
|
| 1403 |
+
|
| 1404 |
+
# Compute save_every for x-axis formatter
|
| 1405 |
+
save_every = epochs[1] - epochs[0] if len(epochs) > 1 else 200
|
| 1406 |
+
|
| 1407 |
+
# ---- Magnitude plot ----
|
| 1408 |
+
fig, ax = plt.subplots(figsize=(4, 4))
|
| 1409 |
+
for idx, f in enumerate(sorted_nonmax):
|
| 1410 |
+
blend = idx / (M - 1) if M > 1 else 0.0
|
| 1411 |
+
col_rgb = (1 - blend - 0.05) * gray_rgb + (blend + 0.05) * base_rgb
|
| 1412 |
+
ax.plot(x_idx, scales_np[:, f], color=col_rgb, linestyle=':',
|
| 1413 |
+
marker='x', linewidth=3.5, markersize=1.5,
|
| 1414 |
+
label=f"Freq. {f + 1}")
|
| 1415 |
+
|
| 1416 |
+
ax.plot(x_idx, scales_np[:, max_freq], color=highlight_color,
|
| 1417 |
+
linestyle=':', marker='x', linewidth=3.5, markersize=1.5,
|
| 1418 |
+
label=f"Freq. {max_freq + 1}")
|
| 1419 |
+
|
| 1420 |
+
ax.xaxis.set_major_formatter(
|
| 1421 |
+
FuncFormatter(lambda val, pos: f"{int(val * save_every)}")
|
| 1422 |
+
)
|
| 1423 |
+
ax.legend(loc='upper left', bbox_to_anchor=(1.02, 1),
|
| 1424 |
+
borderaxespad=0.2, frameon=False, fontsize=13)
|
| 1425 |
+
ax.set_xlabel("Step", fontsize=16)
|
| 1426 |
+
ax.set_ylabel("Magnitude", fontsize=16)
|
| 1427 |
+
ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
|
| 1428 |
+
_save_fig(fig, self._out('lottery_mech_magnitude.png'))
|
| 1429 |
+
|
| 1430 |
+
# ---- Phase misalignment plot ----
|
| 1431 |
+
fig, ax = plt.subplots(figsize=(4, 4))
|
| 1432 |
+
ax.axhline(y=0, color='black', linewidth=1, linestyle='dotted')
|
| 1433 |
+
for idx, f in enumerate(sorted_nonmax):
|
| 1434 |
+
blend = idx / (M - 1) if M > 1 else 0.0
|
| 1435 |
+
col_rgb = (1 - blend - 0.05) * gray_rgb + (blend + 0.05) * base_rgb
|
| 1436 |
+
ax.plot(x_idx, diff_np[:, f], linestyle=':', marker='x',
|
| 1437 |
+
linewidth=3.5, markersize=1.5, color=col_rgb,
|
| 1438 |
+
label=f"Freq. {f}")
|
| 1439 |
+
|
| 1440 |
+
ax.plot(x_idx, diff_np[:, max_freq], linestyle=':', marker='x',
|
| 1441 |
+
linewidth=3.5, markersize=1.5, color=highlight_color,
|
| 1442 |
+
label=f"Freq. {max_freq}")
|
| 1443 |
+
|
| 1444 |
+
ax.xaxis.set_major_formatter(
|
| 1445 |
+
FuncFormatter(lambda val, pos: f"{int(val * save_every)}")
|
| 1446 |
+
)
|
| 1447 |
+
ax.set_xlabel("Step", fontsize=16)
|
| 1448 |
+
ax.set_ylabel("Misalignment", fontsize=16)
|
| 1449 |
+
ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
|
| 1450 |
+
_save_fig(fig, self._out('lottery_mech_phase.png'))
|
| 1451 |
+
|
| 1452 |
+
# ---- Beta contour: simulate gradient flow ----
|
| 1453 |
+
self._generate_lottery_contour()
|
| 1454 |
+
|
| 1455 |
+
print(" Saved lottery_mech_magnitude.png, lottery_mech_phase.png, "
|
| 1456 |
+
"lottery_beta_contour.png")
|
| 1457 |
+
|
| 1458 |
+
def _generate_lottery_contour(self):
|
| 1459 |
+
"""Simulate gradient flow for a grid of (init_magnitude, init_phase_diff)."""
|
| 1460 |
+
p = self.p
|
| 1461 |
+
device = self.device
|
| 1462 |
+
init_k = 1
|
| 1463 |
+
init_psi = 0.0
|
| 1464 |
+
num_steps = 100
|
| 1465 |
+
learning_rate = 0.01
|
| 1466 |
+
|
| 1467 |
+
fourier_basis, _ = get_fourier_basis(p, device)
|
| 1468 |
+
fourier_basis = fourier_basis.to(torch.get_default_dtype())
|
| 1469 |
+
|
| 1470 |
+
initial_scales = np.linspace(0.01, 0.02, num=30)
|
| 1471 |
+
phi0_vals = np.linspace(0, np.pi, num=30)
|
| 1472 |
+
|
| 1473 |
+
results = []
|
| 1474 |
+
for scale in initial_scales:
|
| 1475 |
+
for phi0 in phi0_vals:
|
| 1476 |
+
w_k = 2 * np.pi * init_k / p
|
| 1477 |
+
theta = scale * torch.tensor(
|
| 1478 |
+
[np.cos(w_k * j + phi0) for j in range(p)],
|
| 1479 |
+
device=device
|
| 1480 |
+
)
|
| 1481 |
+
xi = scale * torch.tensor(
|
| 1482 |
+
[np.cos(w_k * j + init_psi) for j in range(p)],
|
| 1483 |
+
device=device
|
| 1484 |
+
)
|
| 1485 |
+
|
| 1486 |
+
# Run gradient flow simulation
|
| 1487 |
+
for _ in range(num_steps):
|
| 1488 |
+
theta, xi = self._gradient_flow_step(
|
| 1489 |
+
theta, xi, init_k, p, learning_rate, fourier_basis
|
| 1490 |
+
)
|
| 1491 |
+
|
| 1492 |
+
# Compute final beta
|
| 1493 |
+
coeffs_xi = fourier_basis.to(xi.dtype) @ xi
|
| 1494 |
+
idx = [init_k * 2 - 1, init_k * 2]
|
| 1495 |
+
xi_n = coeffs_xi[idx]
|
| 1496 |
+
beta_f = torch.norm(xi_n).item() * np.sqrt(2 / p)
|
| 1497 |
+
|
| 1498 |
+
results.append({
|
| 1499 |
+
"init_scale": scale,
|
| 1500 |
+
"init_diff": 2 * phi0,
|
| 1501 |
+
"beta_f": beta_f,
|
| 1502 |
+
})
|
| 1503 |
+
|
| 1504 |
+
# Pivot into grid
|
| 1505 |
+
n_scales = len(initial_scales)
|
| 1506 |
+
n_phis = len(phi0_vals)
|
| 1507 |
+
Z = np.zeros((n_phis, n_scales))
|
| 1508 |
+
for i, r in enumerate(results):
|
| 1509 |
+
row = i % n_phis
|
| 1510 |
+
col = i // n_phis
|
| 1511 |
+
Z[row, col] = r['beta_f']
|
| 1512 |
+
|
| 1513 |
+
X, Y = np.meshgrid(initial_scales, 2 * phi0_vals)
|
| 1514 |
+
|
| 1515 |
+
fig = plt.figure(figsize=(4.5, 4))
|
| 1516 |
+
cf = plt.contourf(X, Y, Z, levels=12, cmap=CMAP_DIVERGING, extend='both')
|
| 1517 |
+
plt.axhline(y=np.pi, color='white', linewidth=1, linestyle=':')
|
| 1518 |
+
plt.xlabel("Initial Magnitude", fontsize=16)
|
| 1519 |
+
plt.ylabel("Initial Phase Difference", fontsize=16)
|
| 1520 |
+
plt.title("Contour of Final Magnitude", fontsize=16)
|
| 1521 |
+
plt.colorbar(cf)
|
| 1522 |
+
plt.tight_layout()
|
| 1523 |
+
_save_fig(fig, self._out('lottery_beta_contour.png'))
|
| 1524 |
+
|
| 1525 |
+
@staticmethod
|
| 1526 |
+
def _gradient_flow_step(theta, xi, init_k, p, lr, fourier_basis):
|
| 1527 |
+
"""One step of analytical gradient flow."""
|
| 1528 |
+
fb = fourier_basis.to(theta.dtype)
|
| 1529 |
+
theta_coeff = fb @ theta
|
| 1530 |
+
xi_coeff = fb @ xi
|
| 1531 |
+
|
| 1532 |
+
neuron_coeff_theta = theta_coeff[[init_k * 2 - 1, init_k * 2]]
|
| 1533 |
+
alpha = np.sqrt(2 / p) * torch.sqrt(
|
| 1534 |
+
torch.sum(neuron_coeff_theta.pow(2))
|
| 1535 |
+
).item()
|
| 1536 |
+
phi = np.arctan2(
|
| 1537 |
+
-neuron_coeff_theta[1].item(), neuron_coeff_theta[0].item()
|
| 1538 |
+
)
|
| 1539 |
+
|
| 1540 |
+
neuron_coeff_xi = xi_coeff[[init_k * 2 - 1, init_k * 2]]
|
| 1541 |
+
beta = np.sqrt(2 / p) * torch.sqrt(
|
| 1542 |
+
torch.sum(neuron_coeff_xi.pow(2))
|
| 1543 |
+
).item()
|
| 1544 |
+
psi = np.arctan2(
|
| 1545 |
+
-neuron_coeff_xi[1].item(), neuron_coeff_xi[0].item()
|
| 1546 |
+
)
|
| 1547 |
+
|
| 1548 |
+
w_k = 2 * np.pi * init_k / p
|
| 1549 |
+
grad_theta = torch.tensor(
|
| 1550 |
+
[2 * p * alpha * beta * np.cos(w_k * j + psi - phi)
|
| 1551 |
+
for j in range(p)],
|
| 1552 |
+
device=theta.device
|
| 1553 |
+
)
|
| 1554 |
+
grad_xi = torch.tensor(
|
| 1555 |
+
[p * alpha ** 2 * np.cos(w_k * j + 2 * phi)
|
| 1556 |
+
for j in range(p)],
|
| 1557 |
+
device=theta.device
|
| 1558 |
+
)
|
| 1559 |
+
|
| 1560 |
+
theta = theta + lr * grad_theta
|
| 1561 |
+
xi = xi + lr * grad_xi
|
| 1562 |
+
return theta, xi
|
| 1563 |
+
|
| 1564 |
+
# ------------------------------------------------------------------
|
| 1565 |
+
# Tab 7: Gradient Dynamics
|
| 1566 |
+
# ------------------------------------------------------------------
|
| 1567 |
+
|
| 1568 |
+
def generate_tab7(self):
|
| 1569 |
+
"""Generate gradient dynamics plots for quad_single_freq and relu_single_freq."""
|
| 1570 |
+
print(f" [Tab 7] Gradient Dynamics for p={self.p}")
|
| 1571 |
+
|
| 1572 |
+
for run_name, act_name, prefix in [
|
| 1573 |
+
('quad_single_freq', 'Quad', 'quad'),
|
| 1574 |
+
('relu_single_freq', 'ReLU', 'relu'),
|
| 1575 |
+
]:
|
| 1576 |
+
run_dir = self._run_dir(run_name)
|
| 1577 |
+
if run_dir is None:
|
| 1578 |
+
print(f" SKIP: {run_name} run directory not found")
|
| 1579 |
+
continue
|
| 1580 |
+
|
| 1581 |
+
checkpoints = _load_checkpoints(run_dir, self.device)
|
| 1582 |
+
if not checkpoints:
|
| 1583 |
+
print(f" SKIP: no {run_name} checkpoints")
|
| 1584 |
+
continue
|
| 1585 |
+
|
| 1586 |
+
epochs = sorted(checkpoints.keys())
|
| 1587 |
+
d_mlp = self.d_mlp
|
| 1588 |
+
|
| 1589 |
+
# Build all neuron records across epochs
|
| 1590 |
+
all_neuron_records = []
|
| 1591 |
+
for ep in epochs:
|
| 1592 |
+
model_sd = checkpoints[ep]
|
| 1593 |
+
W_in_d, W_out_d, mfl = decode_weights(model_sd, self.fourier_basis)
|
| 1594 |
+
for neuron in range(W_in_d.shape[0]):
|
| 1595 |
+
s_in, phi_in = compute_neuron(neuron, mfl, W_in_d)
|
| 1596 |
+
s_out, phi_out = compute_neuron(neuron, mfl, W_out_d)
|
| 1597 |
+
all_neuron_records.append({
|
| 1598 |
+
'epoch': ep,
|
| 1599 |
+
'neuron': neuron,
|
| 1600 |
+
'scale_in': s_in,
|
| 1601 |
+
'phi_in': phi_in,
|
| 1602 |
+
'scale_out': s_out,
|
| 1603 |
+
'phi_out': phi_out,
|
| 1604 |
+
})
|
| 1605 |
+
|
| 1606 |
+
# Select a neuron that shows clear phase alignment
|
| 1607 |
+
# Pick neuron with largest final scale
|
| 1608 |
+
final_records = [r for r in all_neuron_records if r['epoch'] == epochs[-1]]
|
| 1609 |
+
if not final_records:
|
| 1610 |
+
continue
|
| 1611 |
+
best_neuron = max(final_records, key=lambda r: r['scale_in'])['neuron']
|
| 1612 |
+
|
| 1613 |
+
# Extract trajectory for this neuron
|
| 1614 |
+
neuron_records = [r for r in all_neuron_records if r['neuron'] == best_neuron]
|
| 1615 |
+
# Remove last few points if noisy (as notebooks do)
|
| 1616 |
+
trim = max(0, len(neuron_records) - 4) if prefix == 'relu' else max(0, len(neuron_records) - 14)
|
| 1617 |
+
neuron_records = neuron_records[:trim] if trim > 0 else neuron_records
|
| 1618 |
+
|
| 1619 |
+
phi_in_raw = [r['phi_in'] for r in neuron_records]
|
| 1620 |
+
phi_out_raw = [r['phi_out'] for r in neuron_records]
|
| 1621 |
+
scale_in_list = [r['scale_in'] for r in neuron_records]
|
| 1622 |
+
scale_out_list = [r['scale_out'] for r in neuron_records]
|
| 1623 |
+
|
| 1624 |
+
# Phase wrapping fix: normalize 2*phi to [-pi, pi], then adjust
|
| 1625 |
+
# psi to stay within pi of 2*phi (same fix as Tab 3 scatter).
|
| 1626 |
+
phi2_in_list = [normalize_to_pi(2 * v) for v in phi_in_raw]
|
| 1627 |
+
phi_out_list = []
|
| 1628 |
+
for two_phi, psi in zip(phi2_in_list, phi_out_raw):
|
| 1629 |
+
psi_n = normalize_to_pi(psi)
|
| 1630 |
+
if psi_n - two_phi > np.pi:
|
| 1631 |
+
psi_n -= 2 * np.pi
|
| 1632 |
+
elif psi_n - two_phi < -np.pi:
|
| 1633 |
+
psi_n += 2 * np.pi
|
| 1634 |
+
phi_out_list.append(psi_n)
|
| 1635 |
+
|
| 1636 |
+
# Unwrap time series to remove remaining jumps at +-pi boundary
|
| 1637 |
+
phi_in_list = list(np.unwrap(phi_in_raw))
|
| 1638 |
+
phi2_in_list = list(np.unwrap(phi2_in_list))
|
| 1639 |
+
phi_out_list = list(np.unwrap(phi_out_list))
|
| 1640 |
+
|
| 1641 |
+
x = np.arange(len(phi_in_list)) * (epochs[1] - epochs[0] if len(epochs) > 1 else 200)
|
| 1642 |
+
|
| 1643 |
+
# ---- Phase alignment + magnitude plot ----
|
| 1644 |
+
fig_width = 8 if prefix == 'quad' else 5
|
| 1645 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(fig_width, 6), sharex=True)
|
| 1646 |
+
|
| 1647 |
+
ax1.plot(x, phi_in_list, marker='o', markersize=4,
|
| 1648 |
+
color=COLORS[1], label=r"$\phi_m^\star$")
|
| 1649 |
+
ax1.plot(x, phi_out_list, marker='x', markersize=4,
|
| 1650 |
+
color=COLORS[3], label=r"$\psi_m^\star$")
|
| 1651 |
+
ax1.plot(x, phi2_in_list, marker='^', markersize=4,
|
| 1652 |
+
color=COLORS[0], label=r"$2\phi_m^\star$")
|
| 1653 |
+
ax1.set_title('Phase Alignment of Neuron $m$', fontsize=16)
|
| 1654 |
+
ax1.legend(fontsize=18, loc="upper right")
|
| 1655 |
+
ax1.grid(True)
|
| 1656 |
+
|
| 1657 |
+
ax2.plot(x, scale_in_list, marker='o', markersize=4,
|
| 1658 |
+
color=COLORS[0], label=r"$\alpha_m^\star$")
|
| 1659 |
+
ax2.plot(x, scale_out_list, marker='x', markersize=4,
|
| 1660 |
+
color=COLORS[3], label=r"$\beta_m^\star$")
|
| 1661 |
+
ax2.set_title('Magnitude Growth of Neuron $m$', fontsize=16)
|
| 1662 |
+
ax2.set_xlabel('Step', fontsize=16)
|
| 1663 |
+
ax2.legend(fontsize=18, loc="upper left")
|
| 1664 |
+
ax2.grid(True)
|
| 1665 |
+
|
| 1666 |
+
plt.tight_layout()
|
| 1667 |
+
_save_fig(fig, self._out(f'phase_align_{prefix}.png'))
|
| 1668 |
+
|
| 1669 |
+
# ---- Decoded weights at timepoints ----
|
| 1670 |
+
if prefix == 'quad':
|
| 1671 |
+
keys = [0]
|
| 1672 |
+
mid = min(epochs, key=lambda e: abs(e - 1000))
|
| 1673 |
+
end = epochs[-1]
|
| 1674 |
+
if mid not in keys:
|
| 1675 |
+
keys.append(mid)
|
| 1676 |
+
if end not in keys:
|
| 1677 |
+
keys.append(end)
|
| 1678 |
+
else:
|
| 1679 |
+
keys = [0, epochs[-1]]
|
| 1680 |
+
|
| 1681 |
+
num_components = min(20, d_mlp)
|
| 1682 |
+
n = len(keys)
|
| 1683 |
+
fig, axes = plt.subplots(
|
| 1684 |
+
2, n, figsize=(12 if n <= 2 else 18, 4 * n if n <= 2 else 3.3 * n),
|
| 1685 |
+
gridspec_kw={"hspace": 0.05}, constrained_layout=True
|
| 1686 |
+
)
|
| 1687 |
+
if n == 1:
|
| 1688 |
+
axes = axes.reshape(2, 1)
|
| 1689 |
+
|
| 1690 |
+
x_locs = np.arange(len(self.fourier_basis_names))
|
| 1691 |
+
y_locs = np.arange(num_components)
|
| 1692 |
+
|
| 1693 |
+
for col, key in enumerate(keys):
|
| 1694 |
+
if key not in checkpoints:
|
| 1695 |
+
key = min(checkpoints.keys(), key=lambda e: abs(e - key))
|
| 1696 |
+
W_in = checkpoints[key]['mlp.W_in']
|
| 1697 |
+
W_out = checkpoints[key]['mlp.W_out']
|
| 1698 |
+
|
| 1699 |
+
data_in = (W_in @ self.fourier_basis.T)[:num_components]
|
| 1700 |
+
data_in_np = data_in.detach().cpu().numpy()
|
| 1701 |
+
abs_max_in = np.abs(data_in_np).max()
|
| 1702 |
+
ax_in = axes[0, col]
|
| 1703 |
+
im_in = ax_in.imshow(
|
| 1704 |
+
data_in_np, cmap=CMAP_DIVERGING,
|
| 1705 |
+
vmin=-abs_max_in, vmax=abs_max_in, aspect='auto'
|
| 1706 |
+
)
|
| 1707 |
+
ax_in.set_title(rf'Step {key}, $\theta_m$ after DFT', fontsize=18)
|
| 1708 |
+
ax_in.set_xticks(x_locs)
|
| 1709 |
+
ax_in.set_xticklabels(self.fourier_basis_names, rotation=90, fontsize=11)
|
| 1710 |
+
ax_in.set_yticks(y_locs)
|
| 1711 |
+
ax_in.set_yticklabels(y_locs)
|
| 1712 |
+
if col == 0:
|
| 1713 |
+
ax_in.set_ylabel('Neuron #', fontsize=16)
|
| 1714 |
+
fig.colorbar(im_in, ax=ax_in)
|
| 1715 |
+
|
| 1716 |
+
data_out = (W_out.T @ self.fourier_basis.T)[:num_components]
|
| 1717 |
+
data_out_np = data_out.detach().cpu().numpy()
|
| 1718 |
+
abs_max_out = np.abs(data_out_np).max()
|
| 1719 |
+
ax_out = axes[1, col]
|
| 1720 |
+
im_out = ax_out.imshow(
|
| 1721 |
+
data_out_np, cmap=CMAP_DIVERGING,
|
| 1722 |
+
vmin=-abs_max_out, vmax=abs_max_out, aspect='auto'
|
| 1723 |
+
)
|
| 1724 |
+
ax_out.set_title(rf'Step {key}, $\xi_m$ after DFT', fontsize=18)
|
| 1725 |
+
ax_out.set_xticks(x_locs)
|
| 1726 |
+
ax_out.set_xticklabels(self.fourier_basis_names, rotation=90, fontsize=11)
|
| 1727 |
+
ax_out.set_yticks(y_locs)
|
| 1728 |
+
ax_out.set_yticklabels(y_locs)
|
| 1729 |
+
if col == 0:
|
| 1730 |
+
ax_out.set_ylabel('Neuron #', fontsize=16)
|
| 1731 |
+
fig.colorbar(im_out, ax=ax_out)
|
| 1732 |
+
|
| 1733 |
+
_save_fig(fig, self._out(f'single_freq_{prefix}.png'))
|
| 1734 |
+
|
| 1735 |
+
print(f" Saved phase_align_{prefix}.png, single_freq_{prefix}.png")
|
| 1736 |
+
|
| 1737 |
+
# ------------------------------------------------------------------
|
| 1738 |
+
# Metadata JSON
|
| 1739 |
+
# ------------------------------------------------------------------
|
| 1740 |
+
|
| 1741 |
+
def _save_metadata(self):
|
| 1742 |
+
"""Save a metadata JSON summarizing config and final metrics."""
|
| 1743 |
+
print(f" [Meta] Saving metadata for p={self.p}")
|
| 1744 |
+
meta = {
|
| 1745 |
+
'prime': self.p,
|
| 1746 |
+
'd_mlp': self.d_mlp,
|
| 1747 |
+
'training_runs': {},
|
| 1748 |
+
'final_metrics': {},
|
| 1749 |
+
}
|
| 1750 |
+
for run_name, params in TRAINING_RUNS.items():
|
| 1751 |
+
meta['training_runs'][run_name] = {
|
| 1752 |
+
'act_type': params['act_type'],
|
| 1753 |
+
'lr': params['lr'],
|
| 1754 |
+
'weight_decay': params['weight_decay'],
|
| 1755 |
+
'num_epochs': params['num_epochs'],
|
| 1756 |
+
'frac_train': params['frac_train'],
|
| 1757 |
+
'init_type': params['init_type'],
|
| 1758 |
+
'init_scale': params['init_scale'],
|
| 1759 |
+
'optimizer': params['optimizer'],
|
| 1760 |
+
}
|
| 1761 |
+
curves = _load_training_curves(self._run_type_dir(run_name))
|
| 1762 |
+
if curves:
|
| 1763 |
+
metrics = {}
|
| 1764 |
+
if 'train_accs' in curves and curves['train_accs']:
|
| 1765 |
+
metrics['train_acc'] = curves['train_accs'][-1]
|
| 1766 |
+
if 'test_accs' in curves and curves['test_accs']:
|
| 1767 |
+
metrics['test_acc'] = curves['test_accs'][-1]
|
| 1768 |
+
if 'train_losses' in curves and curves['train_losses']:
|
| 1769 |
+
metrics['train_loss'] = curves['train_losses'][-1]
|
| 1770 |
+
if 'test_losses' in curves and curves['test_losses']:
|
| 1771 |
+
metrics['test_loss'] = curves['test_losses'][-1]
|
| 1772 |
+
if metrics:
|
| 1773 |
+
meta['final_metrics'][run_name] = metrics
|
| 1774 |
+
|
| 1775 |
+
with open(self._out('metadata.json'), 'w') as f:
|
| 1776 |
+
json.dump(meta, f, indent=2)
|
| 1777 |
+
print(" Saved metadata.json")
|
| 1778 |
+
|
| 1779 |
+
# ------------------------------------------------------------------
|
| 1780 |
+
# Interactive JSON precomputation
|
| 1781 |
+
# ------------------------------------------------------------------
|
| 1782 |
+
|
| 1783 |
+
def _precompute_neuron_spectra(self):
|
| 1784 |
+
"""Precompute per-neuron Fourier magnitude spectra for top-20 neurons."""
|
| 1785 |
+
print(f" [Interactive] Neuron spectra for p={self.p}")
|
| 1786 |
+
run_dir = self._run_dir('standard')
|
| 1787 |
+
if run_dir is None:
|
| 1788 |
+
print(" SKIP: standard run directory not found")
|
| 1789 |
+
return
|
| 1790 |
+
|
| 1791 |
+
final_data = _load_final(run_dir, self.device)
|
| 1792 |
+
if final_data is None:
|
| 1793 |
+
print(" SKIP: no final checkpoint")
|
| 1794 |
+
return
|
| 1795 |
+
model_load = final_data['model']
|
| 1796 |
+
|
| 1797 |
+
W_in_decode, W_out_decode, max_freq_ls = decode_weights(
|
| 1798 |
+
model_load, self.fourier_basis
|
| 1799 |
+
)
|
| 1800 |
+
d_mlp = W_in_decode.shape[0]
|
| 1801 |
+
num_neurons = min(20, d_mlp)
|
| 1802 |
+
|
| 1803 |
+
sorted_indices = select_top_neurons_by_frequency(
|
| 1804 |
+
max_freq_ls, W_in_decode, n=num_neurons
|
| 1805 |
+
)
|
| 1806 |
+
|
| 1807 |
+
fb_names = self.fourier_basis_names
|
| 1808 |
+
spectra = {}
|
| 1809 |
+
for rank, neuron_idx in enumerate(sorted_indices):
|
| 1810 |
+
# Fourier magnitudes for W_in
|
| 1811 |
+
magnitudes_in = W_in_decode[neuron_idx].abs().cpu().tolist()
|
| 1812 |
+
magnitudes_out = W_out_decode[neuron_idx].abs().cpu().tolist()
|
| 1813 |
+
spectra[f"neuron_{rank}"] = {
|
| 1814 |
+
'global_index': int(neuron_idx),
|
| 1815 |
+
'dominant_freq': int(max_freq_ls[neuron_idx]),
|
| 1816 |
+
'fourier_magnitudes_in': magnitudes_in,
|
| 1817 |
+
'fourier_magnitudes_out': magnitudes_out,
|
| 1818 |
+
}
|
| 1819 |
+
|
| 1820 |
+
payload = {
|
| 1821 |
+
'fourier_basis_names': fb_names,
|
| 1822 |
+
'neurons': spectra,
|
| 1823 |
+
}
|
| 1824 |
+
with open(self._out('neuron_spectra.json'), 'w') as f:
|
| 1825 |
+
json.dump(payload, f)
|
| 1826 |
+
print(" Saved neuron_spectra.json")
|
| 1827 |
+
|
| 1828 |
+
def _precompute_logit_explorer(self):
|
| 1829 |
+
"""Precompute logits for representative (a,b) pairs."""
|
| 1830 |
+
print(f" [Interactive] Logit explorer for p={self.p}")
|
| 1831 |
+
run_dir = self._run_dir('standard')
|
| 1832 |
+
if run_dir is None:
|
| 1833 |
+
print(" SKIP: standard run directory not found")
|
| 1834 |
+
return
|
| 1835 |
+
|
| 1836 |
+
final_data = _load_final(run_dir, self.device)
|
| 1837 |
+
if final_data is None:
|
| 1838 |
+
print(" SKIP: no final checkpoint")
|
| 1839 |
+
return
|
| 1840 |
+
model_load = final_data['model']
|
| 1841 |
+
|
| 1842 |
+
p = self.p
|
| 1843 |
+
act_type = TRAINING_RUNS['standard']['act_type']
|
| 1844 |
+
model = EmbedMLP(
|
| 1845 |
+
d_vocab=self.d_vocab, d_model=self.d_model,
|
| 1846 |
+
d_mlp=self.d_mlp, act_type=act_type, use_cache=False
|
| 1847 |
+
)
|
| 1848 |
+
model.to(self.device)
|
| 1849 |
+
model.load_state_dict(model_load)
|
| 1850 |
+
model.eval()
|
| 1851 |
+
|
| 1852 |
+
# Select p representative pairs: (0,0), (1,2), (3,5), ... spread across inputs
|
| 1853 |
+
pairs = []
|
| 1854 |
+
step = max(1, (p * p) // p)
|
| 1855 |
+
for idx in range(0, p * p, step):
|
| 1856 |
+
a = idx // p
|
| 1857 |
+
b = idx % p
|
| 1858 |
+
pairs.append((a, b))
|
| 1859 |
+
if len(pairs) >= p:
|
| 1860 |
+
break
|
| 1861 |
+
|
| 1862 |
+
pair_tensor = torch.tensor(pairs, dtype=torch.long, device=self.device)
|
| 1863 |
+
with torch.no_grad():
|
| 1864 |
+
logits = model(pair_tensor).squeeze(1) # [n_pairs, p]
|
| 1865 |
+
|
| 1866 |
+
payload = {
|
| 1867 |
+
'pairs': pairs,
|
| 1868 |
+
'correct_answers': [(a + b) % p for a, b in pairs],
|
| 1869 |
+
'logits': logits.cpu().tolist(),
|
| 1870 |
+
'output_classes': list(range(p)),
|
| 1871 |
+
}
|
| 1872 |
+
with open(self._out('logits_interactive.json'), 'w') as f:
|
| 1873 |
+
json.dump(payload, f)
|
| 1874 |
+
print(" Saved logits_interactive.json")
|
| 1875 |
+
|
| 1876 |
+
def _precompute_grokk_slider(self):
|
| 1877 |
+
"""Precompute accuracy grids at ~10 grokking checkpoints for epoch slider."""
|
| 1878 |
+
print(f" [Interactive] Grokking epoch slider for p={self.p}")
|
| 1879 |
+
if self.p < MIN_P_GROKKING:
|
| 1880 |
+
print(f" SKIP: p={self.p} < {MIN_P_GROKKING}")
|
| 1881 |
+
return
|
| 1882 |
+
run_dir = self._run_dir('grokking')
|
| 1883 |
+
if run_dir is None:
|
| 1884 |
+
print(" SKIP: grokking run directory not found")
|
| 1885 |
+
return
|
| 1886 |
+
|
| 1887 |
+
checkpoints = _load_checkpoints(run_dir, self.device)
|
| 1888 |
+
if not checkpoints:
|
| 1889 |
+
print(" SKIP: no grokking checkpoints")
|
| 1890 |
+
return
|
| 1891 |
+
|
| 1892 |
+
epochs = sorted(checkpoints.keys())
|
| 1893 |
+
p = self.p
|
| 1894 |
+
d_mlp = self.d_mlp
|
| 1895 |
+
act_type = TRAINING_RUNS['grokking']['act_type']
|
| 1896 |
+
gt_grid = self.all_labels.view(p, p)
|
| 1897 |
+
|
| 1898 |
+
# Subsample ~10 epochs evenly
|
| 1899 |
+
n_snapshots = min(10, len(epochs))
|
| 1900 |
+
indices = np.linspace(0, len(epochs) - 1, n_snapshots, dtype=int)
|
| 1901 |
+
selected_epochs = [epochs[i] for i in indices]
|
| 1902 |
+
|
| 1903 |
+
epoch_data = []
|
| 1904 |
+
for ep in selected_epochs:
|
| 1905 |
+
model = EmbedMLP(
|
| 1906 |
+
d_vocab=self.d_vocab, d_model=self.d_model,
|
| 1907 |
+
d_mlp=d_mlp, act_type=act_type, use_cache=False
|
| 1908 |
+
).to(self.device)
|
| 1909 |
+
model.load_state_dict(checkpoints[ep])
|
| 1910 |
+
model.eval()
|
| 1911 |
+
with torch.no_grad():
|
| 1912 |
+
logits = model(self.all_data).squeeze(1)
|
| 1913 |
+
predicted = torch.argmax(logits, dim=1).view(p, p)
|
| 1914 |
+
accuracy_grid = (predicted == gt_grid).float().cpu().tolist()
|
| 1915 |
+
epoch_data.append({
|
| 1916 |
+
'epoch': int(ep),
|
| 1917 |
+
'accuracy_grid': accuracy_grid,
|
| 1918 |
+
})
|
| 1919 |
+
|
| 1920 |
+
payload = {
|
| 1921 |
+
'prime': p,
|
| 1922 |
+
'epochs': [d['epoch'] for d in epoch_data],
|
| 1923 |
+
'grids': [d['accuracy_grid'] for d in epoch_data],
|
| 1924 |
+
}
|
| 1925 |
+
with open(self._out('grokk_epoch_data.json'), 'w') as f:
|
| 1926 |
+
json.dump(payload, f)
|
| 1927 |
+
print(" Saved grokk_epoch_data.json")
|
| 1928 |
+
|
| 1929 |
+
# ------------------------------------------------------------------
|
| 1930 |
+
# Training Log consolidation
|
| 1931 |
+
# ------------------------------------------------------------------
|
| 1932 |
+
|
| 1933 |
+
def _save_training_log(self):
|
| 1934 |
+
"""Consolidate training logs from all runs into a precomputed JSON.
|
| 1935 |
+
|
| 1936 |
+
For each run, includes:
|
| 1937 |
+
- config: hyperparameters
|
| 1938 |
+
- log_text: human-readable formatted log
|
| 1939 |
+
- table: subsampled per-epoch metrics for display
|
| 1940 |
+
"""
|
| 1941 |
+
print(f" [Log] Saving training log for p={self.p}")
|
| 1942 |
+
all_runs = {}
|
| 1943 |
+
|
| 1944 |
+
for run_name, params in TRAINING_RUNS.items():
|
| 1945 |
+
run_type_dir = self._run_type_dir(run_name)
|
| 1946 |
+
curves = _load_training_curves(run_type_dir)
|
| 1947 |
+
if curves is None:
|
| 1948 |
+
continue
|
| 1949 |
+
|
| 1950 |
+
# Also check for a pre-saved training_log.txt
|
| 1951 |
+
log_text_path = os.path.join(run_type_dir, "training_log.txt")
|
| 1952 |
+
if os.path.exists(log_text_path):
|
| 1953 |
+
with open(log_text_path) as f:
|
| 1954 |
+
log_text = f.read()
|
| 1955 |
+
else:
|
| 1956 |
+
# Reconstruct from curves data
|
| 1957 |
+
log_text = self._reconstruct_log_text(
|
| 1958 |
+
run_name, params, curves
|
| 1959 |
+
)
|
| 1960 |
+
|
| 1961 |
+
# Build a subsampled table (~100 rows max)
|
| 1962 |
+
n_epochs = len(curves.get('train_losses', []))
|
| 1963 |
+
step = max(1, n_epochs // 100)
|
| 1964 |
+
indices = list(range(0, n_epochs, step))
|
| 1965 |
+
if n_epochs > 0 and (n_epochs - 1) not in indices:
|
| 1966 |
+
indices.append(n_epochs - 1)
|
| 1967 |
+
|
| 1968 |
+
table = []
|
| 1969 |
+
for i in indices:
|
| 1970 |
+
row = {'epoch': i}
|
| 1971 |
+
for key in ('train_losses', 'test_losses', 'train_accs',
|
| 1972 |
+
'test_accs', 'grad_norms', 'param_norms'):
|
| 1973 |
+
vals = curves.get(key, [])
|
| 1974 |
+
row[key.replace('_', '_')] = (
|
| 1975 |
+
round(vals[i], 6) if i < len(vals) else None
|
| 1976 |
+
)
|
| 1977 |
+
table.append(row)
|
| 1978 |
+
|
| 1979 |
+
all_runs[run_name] = {
|
| 1980 |
+
'config': {
|
| 1981 |
+
'prime': self.p,
|
| 1982 |
+
'd_mlp': self.d_mlp,
|
| 1983 |
+
'act_type': params['act_type'],
|
| 1984 |
+
'init_type': params['init_type'],
|
| 1985 |
+
'init_scale': params['init_scale'],
|
| 1986 |
+
'optimizer': params['optimizer'],
|
| 1987 |
+
'lr': params['lr'],
|
| 1988 |
+
'weight_decay': params['weight_decay'],
|
| 1989 |
+
'frac_train': params['frac_train'],
|
| 1990 |
+
'num_epochs': params['num_epochs'],
|
| 1991 |
+
'seed': params['seed'],
|
| 1992 |
+
},
|
| 1993 |
+
'log_text': log_text,
|
| 1994 |
+
'table': table,
|
| 1995 |
+
'total_epochs': n_epochs,
|
| 1996 |
+
}
|
| 1997 |
+
|
| 1998 |
+
if all_runs:
|
| 1999 |
+
with open(self._out('training_log.json'), 'w') as f:
|
| 2000 |
+
json.dump(all_runs, f)
|
| 2001 |
+
print(f" Saved training_log.json ({len(all_runs)} runs)")
|
| 2002 |
+
else:
|
| 2003 |
+
print(" SKIP: no training curves found")
|
| 2004 |
+
|
| 2005 |
+
def _reconstruct_log_text(self, run_name, params, curves):
|
| 2006 |
+
"""Reconstruct a human-readable training log from curves data."""
|
| 2007 |
+
lines = []
|
| 2008 |
+
lines.append(f"{'=' * 70}")
|
| 2009 |
+
lines.append(f"Training Log: p={self.p}, run={run_name}")
|
| 2010 |
+
lines.append(f"{'=' * 70}")
|
| 2011 |
+
lines.append("")
|
| 2012 |
+
lines.append("Configuration:")
|
| 2013 |
+
lines.append(f" prime (p) = {self.p}")
|
| 2014 |
+
lines.append(f" d_mlp = {self.d_mlp}")
|
| 2015 |
+
lines.append(f" activation = {params['act_type']}")
|
| 2016 |
+
lines.append(f" init_type = {params['init_type']}")
|
| 2017 |
+
lines.append(f" init_scale = {params['init_scale']}")
|
| 2018 |
+
lines.append(f" optimizer = {params['optimizer']}")
|
| 2019 |
+
lines.append(f" learning_rate = {params['lr']}")
|
| 2020 |
+
lines.append(f" weight_decay = {params['weight_decay']}")
|
| 2021 |
+
lines.append(f" frac_train = {params['frac_train']}")
|
| 2022 |
+
lines.append(f" num_epochs = {params['num_epochs']}")
|
| 2023 |
+
lines.append(f" seed = {params['seed']}")
|
| 2024 |
+
lines.append("")
|
| 2025 |
+
lines.append(f"{'─' * 70}")
|
| 2026 |
+
lines.append(
|
| 2027 |
+
f"{'Epoch':>8s} {'Train Loss':>12s} {'Test Loss':>12s} "
|
| 2028 |
+
f"{'Train Acc':>10s} {'Test Acc':>10s} "
|
| 2029 |
+
f"{'Grad Norm':>10s} {'Param Norm':>11s}"
|
| 2030 |
+
)
|
| 2031 |
+
lines.append(f"{'─' * 70}")
|
| 2032 |
+
|
| 2033 |
+
train_losses = curves.get('train_losses', [])
|
| 2034 |
+
test_losses = curves.get('test_losses', [])
|
| 2035 |
+
train_accs = curves.get('train_accs', [])
|
| 2036 |
+
test_accs = curves.get('test_accs', [])
|
| 2037 |
+
grad_norms = curves.get('grad_norms', [])
|
| 2038 |
+
param_norms = curves.get('param_norms', [])
|
| 2039 |
+
n_epochs = len(train_losses)
|
| 2040 |
+
|
| 2041 |
+
step = max(1, n_epochs // 100)
|
| 2042 |
+
indices = list(range(0, n_epochs, step))
|
| 2043 |
+
if n_epochs > 0 and (n_epochs - 1) not in indices:
|
| 2044 |
+
indices.append(n_epochs - 1)
|
| 2045 |
+
|
| 2046 |
+
for i in indices:
|
| 2047 |
+
tl = f"{train_losses[i]:.6f}" if i < len(train_losses) else "N/A"
|
| 2048 |
+
tel = f"{test_losses[i]:.6f}" if i < len(test_losses) else "N/A"
|
| 2049 |
+
ta = f"{train_accs[i]:.4f}" if i < len(train_accs) else "N/A"
|
| 2050 |
+
tea = f"{test_accs[i]:.4f}" if i < len(test_accs) else "N/A"
|
| 2051 |
+
gn = f"{grad_norms[i]:.4f}" if i < len(grad_norms) else "N/A"
|
| 2052 |
+
pn = f"{param_norms[i]:.4f}" if i < len(param_norms) else "N/A"
|
| 2053 |
+
lines.append(
|
| 2054 |
+
f"{i:>8d} {tl:>12s} {tel:>12s} "
|
| 2055 |
+
f"{ta:>10s} {tea:>10s} "
|
| 2056 |
+
f"{gn:>10s} {pn:>11s}"
|
| 2057 |
+
)
|
| 2058 |
+
|
| 2059 |
+
lines.append(f"{'─' * 70}")
|
| 2060 |
+
lines.append("")
|
| 2061 |
+
lines.append("Final Results:")
|
| 2062 |
+
if train_losses:
|
| 2063 |
+
lines.append(f" Train Loss = {train_losses[-1]:.6f}")
|
| 2064 |
+
if test_losses:
|
| 2065 |
+
lines.append(f" Test Loss = {test_losses[-1]:.6f}")
|
| 2066 |
+
if train_accs:
|
| 2067 |
+
lines.append(f" Train Acc = {train_accs[-1]:.4f}")
|
| 2068 |
+
if test_accs:
|
| 2069 |
+
lines.append(f" Test Acc = {test_accs[-1]:.4f}")
|
| 2070 |
+
if param_norms:
|
| 2071 |
+
lines.append(f" Param Norm = {param_norms[-1]:.4f}")
|
| 2072 |
+
lines.append(f"\nTotal epochs trained: {n_epochs}")
|
| 2073 |
+
return "\n".join(lines)
|
| 2074 |
+
|
| 2075 |
+
# ------------------------------------------------------------------
|
| 2076 |
+
# Generate all
|
| 2077 |
+
# ------------------------------------------------------------------
|
| 2078 |
+
|
| 2079 |
+
def generate_all(self):
|
| 2080 |
+
"""Generate all tab plots with error handling."""
|
| 2081 |
+
print(f"\n{'=' * 60}")
|
| 2082 |
+
print(f"Generating plots for p={self.p}")
|
| 2083 |
+
print(f" Input: {self.input_dir}")
|
| 2084 |
+
print(f" Output: {self.output_dir}")
|
| 2085 |
+
print(f"{'=' * 60}")
|
| 2086 |
+
|
| 2087 |
+
# Save metadata and training logs first
|
| 2088 |
+
try:
|
| 2089 |
+
self._save_metadata()
|
| 2090 |
+
except Exception as e:
|
| 2091 |
+
print(f" [ERROR] metadata failed: {e}")
|
| 2092 |
+
traceback.print_exc()
|
| 2093 |
+
|
| 2094 |
+
try:
|
| 2095 |
+
self._save_training_log()
|
| 2096 |
+
except Exception as e:
|
| 2097 |
+
print(f" [ERROR] training log failed: {e}")
|
| 2098 |
+
traceback.print_exc()
|
| 2099 |
+
|
| 2100 |
+
generators = [
|
| 2101 |
+
('Tab 1', self.generate_tab1),
|
| 2102 |
+
('Tab 2', self.generate_tab2),
|
| 2103 |
+
('Tab 3', self.generate_tab3),
|
| 2104 |
+
('Tab 4', self.generate_tab4),
|
| 2105 |
+
('Tab 5', self.generate_tab5),
|
| 2106 |
+
('Tab 6', self.generate_tab6),
|
| 2107 |
+
('Tab 7', self.generate_tab7),
|
| 2108 |
+
]
|
| 2109 |
+
|
| 2110 |
+
for name, gen_fn in generators:
|
| 2111 |
+
try:
|
| 2112 |
+
gen_fn()
|
| 2113 |
+
except Exception as e:
|
| 2114 |
+
print(f" [ERROR] {name} failed: {e}")
|
| 2115 |
+
traceback.print_exc()
|
| 2116 |
+
|
| 2117 |
+
# Precompute interactive JSON data
|
| 2118 |
+
interactive = [
|
| 2119 |
+
('Neuron Spectra', self._precompute_neuron_spectra),
|
| 2120 |
+
('Logit Explorer', self._precompute_logit_explorer),
|
| 2121 |
+
('Grokking Slider', self._precompute_grokk_slider),
|
| 2122 |
+
]
|
| 2123 |
+
for name, fn in interactive:
|
| 2124 |
+
try:
|
| 2125 |
+
fn()
|
| 2126 |
+
except Exception as e:
|
| 2127 |
+
print(f" [ERROR] {name} failed: {e}")
|
| 2128 |
+
traceback.print_exc()
|
| 2129 |
+
|
| 2130 |
+
print(f"\nDone generating plots for p={self.p}")
|
| 2131 |
+
|
| 2132 |
+
|
| 2133 |
+
# ======================================================================
|
| 2134 |
+
# CLI
|
| 2135 |
+
# ======================================================================
|
| 2136 |
+
|
| 2137 |
+
def main():
|
| 2138 |
+
parser = argparse.ArgumentParser(
|
| 2139 |
+
description='Generate all model-dependent plots for the HF app.'
|
| 2140 |
+
)
|
| 2141 |
+
parser.add_argument('--all', action='store_true',
|
| 2142 |
+
help='Generate plots for all p found in input dir')
|
| 2143 |
+
parser.add_argument('--p', type=int,
|
| 2144 |
+
help='Generate plots for a specific p')
|
| 2145 |
+
parser.add_argument('--input', type=str, default='./trained_models',
|
| 2146 |
+
help='Base input directory containing p_PPP subdirs')
|
| 2147 |
+
parser.add_argument('--output', type=str,
|
| 2148 |
+
default='./precomputed_results',
|
| 2149 |
+
help='Base output directory for precomputed results')
|
| 2150 |
+
args = parser.parse_args()
|
| 2151 |
+
|
| 2152 |
+
if not args.all and args.p is None:
|
| 2153 |
+
parser.error("Specify --all or --p P")
|
| 2154 |
+
|
| 2155 |
+
if args.p:
|
| 2156 |
+
moduli = [args.p]
|
| 2157 |
+
else:
|
| 2158 |
+
# Discover moduli from input directory
|
| 2159 |
+
moduli = []
|
| 2160 |
+
if os.path.isdir(args.input):
|
| 2161 |
+
for d in sorted(os.listdir(args.input)):
|
| 2162 |
+
if d.startswith('p_'):
|
| 2163 |
+
try:
|
| 2164 |
+
p = int(d.split('_')[1])
|
| 2165 |
+
moduli.append(p)
|
| 2166 |
+
except (ValueError, IndexError):
|
| 2167 |
+
pass
|
| 2168 |
+
if not moduli:
|
| 2169 |
+
print(f"No p_PPP directories found in {args.input}")
|
| 2170 |
+
sys.exit(1)
|
| 2171 |
+
|
| 2172 |
+
total = len(moduli)
|
| 2173 |
+
for i, p in enumerate(moduli):
|
| 2174 |
+
print(f"\n[{i + 1}/{total}] Processing p={p}")
|
| 2175 |
+
# Handle both p_23 and p_023 naming conventions
|
| 2176 |
+
input_dir = os.path.join(args.input, f'p_{p:03d}')
|
| 2177 |
+
if not os.path.isdir(input_dir):
|
| 2178 |
+
input_dir = os.path.join(args.input, f'p_{p}')
|
| 2179 |
+
if not os.path.isdir(input_dir):
|
| 2180 |
+
print(f" Input directory not found: {input_dir}")
|
| 2181 |
+
continue
|
| 2182 |
+
|
| 2183 |
+
output_dir = os.path.join(args.output, f'p_{p:03d}')
|
| 2184 |
+
|
| 2185 |
+
gen = PlotGenerator(p=p, input_dir=input_dir, output_dir=output_dir)
|
| 2186 |
+
gen.generate_all()
|
| 2187 |
+
|
| 2188 |
+
print(f"\nAll done. Processed {total} prime(s).")
|
| 2189 |
+
|
| 2190 |
+
|
| 2191 |
+
if __name__ == '__main__':
|
| 2192 |
+
main()
|
precompute/grokking_stage_detector.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Automatic detection of grokking stage boundaries from training curves.
|
| 3 |
+
|
| 4 |
+
Three stages:
|
| 5 |
+
Stage 1 (Memorization): Train accuracy rises, test accuracy stays low
|
| 6 |
+
Stage 2 (Transition): Test accuracy starts climbing
|
| 7 |
+
Stage 3 (Generalization): Test accuracy near 1.0
|
| 8 |
+
|
| 9 |
+
Returns (stage1_end, stage2_end) as epoch indices, or (None, None) if
|
| 10 |
+
grokking is not detected.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def detect_grokking_stages(train_losses, test_losses, train_accs=None, test_accs=None):
|
| 15 |
+
"""
|
| 16 |
+
Detect memorization -> transition -> generalization boundaries.
|
| 17 |
+
|
| 18 |
+
Heuristic:
|
| 19 |
+
- stage1_end: first epoch where train accuracy >= 0.95 (memorization complete)
|
| 20 |
+
- stage2_end: first epoch where test accuracy >= 0.95 (generalization reached)
|
| 21 |
+
|
| 22 |
+
Fallback (if accuracy curves not available):
|
| 23 |
+
- stage1_end: first epoch where train loss < 0.1
|
| 24 |
+
- stage2_end: first epoch where test loss < 0.1
|
| 25 |
+
"""
|
| 26 |
+
if train_accs is not None and test_accs is not None:
|
| 27 |
+
stage1_end = None
|
| 28 |
+
for i, a in enumerate(train_accs):
|
| 29 |
+
if a >= 0.95:
|
| 30 |
+
stage1_end = i
|
| 31 |
+
break
|
| 32 |
+
|
| 33 |
+
stage2_end = None
|
| 34 |
+
for i, a in enumerate(test_accs):
|
| 35 |
+
if a >= 0.95:
|
| 36 |
+
stage2_end = i
|
| 37 |
+
break
|
| 38 |
+
|
| 39 |
+
# Sanity: stage1 should come before stage2
|
| 40 |
+
if stage1_end is not None and stage2_end is not None and stage1_end >= stage2_end:
|
| 41 |
+
stage1_end = stage2_end // 3
|
| 42 |
+
else:
|
| 43 |
+
stage1_end = None
|
| 44 |
+
for i, loss in enumerate(train_losses):
|
| 45 |
+
if loss < 0.1:
|
| 46 |
+
stage1_end = i
|
| 47 |
+
break
|
| 48 |
+
|
| 49 |
+
stage2_end = None
|
| 50 |
+
for i, loss in enumerate(test_losses):
|
| 51 |
+
if loss < 0.1:
|
| 52 |
+
stage2_end = i
|
| 53 |
+
break
|
| 54 |
+
|
| 55 |
+
return stage1_end, stage2_end
|
precompute/neuron_selector.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Automated neuron selection strategies for all primes.
|
| 3 |
+
Replaces hard-coded neuron indices from the analysis notebooks.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from collections import Counter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def select_top_neurons_by_frequency(max_freq_ls, W_in_decode, n=20):
|
| 11 |
+
"""
|
| 12 |
+
Select top N neurons covering all frequencies (round-robin).
|
| 13 |
+
Used for heatmap plots (Tab 2).
|
| 14 |
+
|
| 15 |
+
Picks the highest-magnitude neuron from each frequency in turn,
|
| 16 |
+
cycling through frequencies until n neurons are selected. This ensures
|
| 17 |
+
the heatmap shows diversification across all frequencies, matching
|
| 18 |
+
the blog's Figure 2.
|
| 19 |
+
|
| 20 |
+
Returns list of neuron indices into the original d_mlp-sized arrays.
|
| 21 |
+
"""
|
| 22 |
+
d_mlp = W_in_decode.shape[0]
|
| 23 |
+
magnitudes = W_in_decode.abs().max(dim=1).values
|
| 24 |
+
|
| 25 |
+
# Group neurons by their dominant frequency, sorted by magnitude (descending)
|
| 26 |
+
from collections import defaultdict
|
| 27 |
+
freq_groups = defaultdict(list)
|
| 28 |
+
for i in range(d_mlp):
|
| 29 |
+
f = max_freq_ls[i]
|
| 30 |
+
if f > 0: # skip DC neurons
|
| 31 |
+
freq_groups[f].append((magnitudes[i].item(), i))
|
| 32 |
+
|
| 33 |
+
# Sort each group by magnitude descending
|
| 34 |
+
for f in freq_groups:
|
| 35 |
+
freq_groups[f].sort(key=lambda x: -x[0])
|
| 36 |
+
|
| 37 |
+
# Round-robin across frequencies (ascending order)
|
| 38 |
+
freqs_sorted = sorted(freq_groups.keys())
|
| 39 |
+
selected = []
|
| 40 |
+
pointers = {f: 0 for f in freqs_sorted}
|
| 41 |
+
|
| 42 |
+
while len(selected) < min(n, d_mlp) and freqs_sorted:
|
| 43 |
+
exhausted = []
|
| 44 |
+
for f in freqs_sorted:
|
| 45 |
+
if len(selected) >= n:
|
| 46 |
+
break
|
| 47 |
+
if pointers[f] < len(freq_groups[f]):
|
| 48 |
+
_, idx = freq_groups[f][pointers[f]]
|
| 49 |
+
selected.append(idx)
|
| 50 |
+
pointers[f] += 1
|
| 51 |
+
else:
|
| 52 |
+
exhausted.append(f)
|
| 53 |
+
for f in exhausted:
|
| 54 |
+
freqs_sorted.remove(f)
|
| 55 |
+
|
| 56 |
+
return selected
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def select_lineplot_neurons(sorted_indices, n=3):
|
| 60 |
+
"""
|
| 61 |
+
Select first N neurons from the frequency-sorted set for line plots (Tab 2).
|
| 62 |
+
Picks neurons evenly spaced through the sorted list to show diverse frequencies.
|
| 63 |
+
"""
|
| 64 |
+
if len(sorted_indices) <= n:
|
| 65 |
+
return list(range(len(sorted_indices)))
|
| 66 |
+
step = len(sorted_indices) // n
|
| 67 |
+
return [i * step for i in range(n)]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def select_phase_frequency(max_freq_ls, p):
|
| 71 |
+
"""
|
| 72 |
+
Choose the frequency for phase distribution analysis (Tab 3).
|
| 73 |
+
Picks the frequency with the most neurons assigned to it (mode),
|
| 74 |
+
excluding frequency 0 (DC component).
|
| 75 |
+
"""
|
| 76 |
+
freq_counts = Counter(f for f in max_freq_ls if f > 0)
|
| 77 |
+
if not freq_counts:
|
| 78 |
+
return 1
|
| 79 |
+
return freq_counts.most_common(1)[0][0]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def select_lottery_neuron(model_load, fourier_basis, decode_scales_phis_fn):
|
| 83 |
+
"""
|
| 84 |
+
Find the neuron with the clearest frequency specialization (Tab 6).
|
| 85 |
+
Picks the neuron with the highest ratio of dominant frequency scale
|
| 86 |
+
to second-highest frequency scale.
|
| 87 |
+
"""
|
| 88 |
+
scales, _, _ = decode_scales_phis_fn(model_load, fourier_basis)
|
| 89 |
+
# scales: [n_neurons, K+1], skip DC at index 0
|
| 90 |
+
scales_no_dc = scales[:, 1:]
|
| 91 |
+
|
| 92 |
+
if scales_no_dc.shape[1] < 2:
|
| 93 |
+
return 0
|
| 94 |
+
|
| 95 |
+
sorted_scales, _ = torch.sort(scales_no_dc, dim=1, descending=True)
|
| 96 |
+
ratio = sorted_scales[:, 0] / (sorted_scales[:, 1] + 1e-10)
|
| 97 |
+
|
| 98 |
+
return ratio.argmax().item()
|
precompute/prime_config.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for all moduli and training runs.
|
| 3 |
+
Defines the d_mlp sizing formula and the 5 training run configurations.
|
| 4 |
+
p can be any odd number >= 3 (not restricted to primes).
|
| 5 |
+
"""
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_moduli(low=3, high=199):
|
| 10 |
+
"""Return all odd numbers in [low, high]."""
|
| 11 |
+
moduli = []
|
| 12 |
+
for n in range(low, high + 1):
|
| 13 |
+
if n >= 3 and n % 2 == 1:
|
| 14 |
+
moduli.append(n)
|
| 15 |
+
return moduli
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Keep old name as alias for backward compatibility
|
| 19 |
+
get_primes = get_moduli
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def compute_d_mlp(p: int) -> int:
|
| 23 |
+
"""
|
| 24 |
+
Compute d_mlp maintaining the ratio from p=23, d_mlp=512.
|
| 25 |
+
Formula: d_mlp = max(512, ceil(512/529 * p^2))
|
| 26 |
+
Can have more neurons but not less than the ratio dictates.
|
| 27 |
+
"""
|
| 28 |
+
ratio = 512 / (23 ** 2) # 512/529 ≈ 0.9679
|
| 29 |
+
return max(512, math.ceil(ratio * p * p))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Minimum p overall (p=2 has 0 non-DC frequencies, making Fourier analysis degenerate)
|
| 33 |
+
MIN_P = 3
|
| 34 |
+
|
| 35 |
+
# Minimum p for grokking experiments (need enough test data for meaningful split)
|
| 36 |
+
MIN_P_GROKKING = 19
|
| 37 |
+
|
| 38 |
+
# Backward-compatible aliases
|
| 39 |
+
MIN_PRIME = MIN_P
|
| 40 |
+
MIN_PRIME_GROKKING = MIN_P_GROKKING
|
| 41 |
+
|
| 42 |
+
# 5 training run configurations per p
|
| 43 |
+
TRAINING_RUNS = {
|
| 44 |
+
"standard": {
|
| 45 |
+
"embed_type": "one_hot",
|
| 46 |
+
"init_type": "random",
|
| 47 |
+
"optimizer": "AdamW",
|
| 48 |
+
"act_type": "ReLU",
|
| 49 |
+
"lr": 5e-5,
|
| 50 |
+
"weight_decay": 0,
|
| 51 |
+
"frac_train": 1.0,
|
| 52 |
+
"num_epochs": 5000,
|
| 53 |
+
"save_every": 200,
|
| 54 |
+
"init_scale": 0.1,
|
| 55 |
+
"save_models": True,
|
| 56 |
+
"batch_style": "full",
|
| 57 |
+
"seed": 42,
|
| 58 |
+
},
|
| 59 |
+
"grokking": {
|
| 60 |
+
"embed_type": "one_hot",
|
| 61 |
+
"init_type": "random",
|
| 62 |
+
"optimizer": "AdamW",
|
| 63 |
+
"act_type": "ReLU",
|
| 64 |
+
"lr": 1e-4,
|
| 65 |
+
"weight_decay": 2.0,
|
| 66 |
+
"frac_train": 0.75,
|
| 67 |
+
"num_epochs": 50000,
|
| 68 |
+
"save_every": 200,
|
| 69 |
+
"init_scale": 0.1,
|
| 70 |
+
"save_models": True,
|
| 71 |
+
"batch_style": "full",
|
| 72 |
+
"seed": 42,
|
| 73 |
+
},
|
| 74 |
+
"quad_random": {
|
| 75 |
+
"embed_type": "one_hot",
|
| 76 |
+
"init_type": "random",
|
| 77 |
+
"optimizer": "AdamW",
|
| 78 |
+
"act_type": "Quad",
|
| 79 |
+
"lr": 5e-5,
|
| 80 |
+
"weight_decay": 0,
|
| 81 |
+
"frac_train": 1.0,
|
| 82 |
+
"num_epochs": 5000,
|
| 83 |
+
"save_every": 200,
|
| 84 |
+
"init_scale": 0.1,
|
| 85 |
+
"save_models": True,
|
| 86 |
+
"batch_style": "full",
|
| 87 |
+
"seed": 42,
|
| 88 |
+
},
|
| 89 |
+
"quad_single_freq": {
|
| 90 |
+
"embed_type": "one_hot",
|
| 91 |
+
"init_type": "single-freq",
|
| 92 |
+
"optimizer": "SGD",
|
| 93 |
+
"act_type": "Quad",
|
| 94 |
+
"lr": 0.1,
|
| 95 |
+
"weight_decay": 0,
|
| 96 |
+
"frac_train": 1.0,
|
| 97 |
+
"num_epochs": 5000,
|
| 98 |
+
"save_every": 200,
|
| 99 |
+
"init_scale": 0.02,
|
| 100 |
+
"save_models": True,
|
| 101 |
+
"batch_style": "full",
|
| 102 |
+
"seed": 42,
|
| 103 |
+
},
|
| 104 |
+
"relu_single_freq": {
|
| 105 |
+
"embed_type": "one_hot",
|
| 106 |
+
"init_type": "single-freq",
|
| 107 |
+
"optimizer": "SGD",
|
| 108 |
+
"act_type": "ReLU",
|
| 109 |
+
"lr": 0.01,
|
| 110 |
+
"weight_decay": 0,
|
| 111 |
+
"frac_train": 1.0,
|
| 112 |
+
"num_epochs": 5000,
|
| 113 |
+
"save_every": 200,
|
| 114 |
+
"init_scale": 0.002,
|
| 115 |
+
"save_models": True,
|
| 116 |
+
"batch_style": "full",
|
| 117 |
+
"seed": 42,
|
| 118 |
+
},
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
# Analytical computation configs (no training needed)
|
| 122 |
+
ANALYTICAL_CONFIGS = {
|
| 123 |
+
"decouple_dynamics": {
|
| 124 |
+
"init_k": 2,
|
| 125 |
+
"num_steps_case1": 1400,
|
| 126 |
+
"learning_rate_case1": 1,
|
| 127 |
+
"init_phi_case1": 1.5,
|
| 128 |
+
"init_psi_case1": 0.18,
|
| 129 |
+
"num_steps_case2": 700,
|
| 130 |
+
"learning_rate_case2": 1,
|
| 131 |
+
"init_phi_case2": -0.72,
|
| 132 |
+
"init_psi_case2": -2.91,
|
| 133 |
+
"amplitude": 0.02,
|
| 134 |
+
},
|
| 135 |
+
}
|
precompute/run_all.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Pre-compute results for all odd p in [3, MAX_P].
|
| 3 |
+
# Deletes checkpoints after each p to save disk space.
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
# bash precompute/run_all.sh # p = 3, 5, 7, ..., 99
|
| 7 |
+
# MAX_P=199 bash precompute/run_all.sh # p = 3, 5, 7, ..., 199
|
| 8 |
+
#
|
| 9 |
+
# Run from the project root directory.
|
| 10 |
+
|
| 11 |
+
MAX_P=${MAX_P:-99}
|
| 12 |
+
|
| 13 |
+
set -e
|
| 14 |
+
echo "=== Pre-computing all odd p in [3, $MAX_P] ==="
|
| 15 |
+
|
| 16 |
+
COMPLETED=0
|
| 17 |
+
FAILED=0
|
| 18 |
+
|
| 19 |
+
for P in $(seq 3 2 "$MAX_P"); do
|
| 20 |
+
echo ""
|
| 21 |
+
echo "========================================"
|
| 22 |
+
echo " Processing p=$P"
|
| 23 |
+
echo "========================================"
|
| 24 |
+
if CLEANUP=1 bash precompute/run_pipeline.sh "$P"; then
|
| 25 |
+
COMPLETED=$((COMPLETED + 1))
|
| 26 |
+
else
|
| 27 |
+
echo "[FAIL] p=$P failed"
|
| 28 |
+
FAILED=$((FAILED + 1))
|
| 29 |
+
fi
|
| 30 |
+
done
|
| 31 |
+
|
| 32 |
+
echo ""
|
| 33 |
+
echo "=== All done. Completed: $COMPLETED, Failed: $FAILED ==="
|
| 34 |
+
echo "=== Precomputed results size: ==="
|
| 35 |
+
du -sh precomputed_results/
|
precompute/run_pipeline.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Full pre-computation pipeline for a single modulus p (any odd number >= 3).
|
| 3 |
+
#
|
| 4 |
+
# Usage:
|
| 5 |
+
# bash precompute/run_pipeline.sh 23
|
| 6 |
+
# bash precompute/run_pipeline.sh 9 --d_mlp 128
|
| 7 |
+
# P=23 bash precompute/run_pipeline.sh
|
| 8 |
+
#
|
| 9 |
+
# # Delete checkpoints after generating plots (saves disk space):
|
| 10 |
+
# CLEANUP=1 bash precompute/run_pipeline.sh 97
|
| 11 |
+
#
|
| 12 |
+
# Run from the project root directory.
|
| 13 |
+
|
| 14 |
+
P=${1:-${P:-23}}
|
| 15 |
+
shift 2>/dev/null || true # consume the p arg
|
| 16 |
+
|
| 17 |
+
# CLEANUP=1 to delete model checkpoints after plot generation
|
| 18 |
+
CLEANUP=${CLEANUP:-0}
|
| 19 |
+
|
| 20 |
+
# Collect remaining args (e.g. --d_mlp 128) to pass to train_all.py
|
| 21 |
+
EXTRA_ARGS="$@"
|
| 22 |
+
|
| 23 |
+
set -e
|
| 24 |
+
echo "=== Running full pipeline for p=$P $EXTRA_ARGS ==="
|
| 25 |
+
|
| 26 |
+
# Step 1: Train all 5 configurations
|
| 27 |
+
echo ""
|
| 28 |
+
echo "--- Step 1/4: Training ---"
|
| 29 |
+
python precompute/train_all.py --p "$P" --output ./trained_models --resume $EXTRA_ARGS
|
| 30 |
+
|
| 31 |
+
# Step 2: Generate model-based plots (d_mlp inferred from checkpoint)
|
| 32 |
+
echo ""
|
| 33 |
+
echo "--- Step 2/4: Generating model-based plots ---"
|
| 34 |
+
python precompute/generate_plots.py --p "$P" --input ./trained_models --output ./precomputed_results
|
| 35 |
+
|
| 36 |
+
# Step 3: Generate analytical simulation plots
|
| 37 |
+
echo ""
|
| 38 |
+
echo "--- Step 3/4: Generating analytical plots ---"
|
| 39 |
+
python precompute/generate_analytical.py --p "$P" --output ./precomputed_results
|
| 40 |
+
|
| 41 |
+
# Step 4: Cleanup checkpoints if requested
|
| 42 |
+
PADDED=$(printf '%03d' "$P")
|
| 43 |
+
MODEL_DIR="trained_models/p_${PADDED}"
|
| 44 |
+
if [ "$CLEANUP" = "1" ] && [ -d "$MODEL_DIR" ]; then
|
| 45 |
+
echo ""
|
| 46 |
+
echo "--- Cleanup: Deleting checkpoints for p=$P ---"
|
| 47 |
+
SIZE=$(du -sh "$MODEL_DIR" | cut -f1)
|
| 48 |
+
rm -rf "$MODEL_DIR"
|
| 49 |
+
echo " Freed $SIZE from $MODEL_DIR"
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
# Step 5: Verify
|
| 53 |
+
echo ""
|
| 54 |
+
echo "--- Verification ---"
|
| 55 |
+
RESULT_DIR="precomputed_results/p_${PADDED}"
|
| 56 |
+
echo "=== Results in ${RESULT_DIR}/ ==="
|
| 57 |
+
ls -la "${RESULT_DIR}/"
|
| 58 |
+
FILE_COUNT=$(ls -1 "${RESULT_DIR}/" | wc -l | tr -d ' ')
|
| 59 |
+
echo "=== Total files: ${FILE_COUNT} ==="
|
| 60 |
+
echo "=== Pipeline complete for p=$P ==="
|
precompute/train_all.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Batch training script for all odd moduli p in [3, 199].
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
# Train all runs for all odd p
|
| 7 |
+
python train_all.py --all
|
| 8 |
+
|
| 9 |
+
# Train specific p
|
| 10 |
+
python train_all.py --p 23
|
| 11 |
+
|
| 12 |
+
# Train specific run type for a p
|
| 13 |
+
python train_all.py --p 23 --run standard
|
| 14 |
+
|
| 15 |
+
# Resume (skips completed runs)
|
| 16 |
+
python train_all.py --all --resume
|
| 17 |
+
|
| 18 |
+
# Custom output directory
|
| 19 |
+
python train_all.py --all --output ./my_models
|
| 20 |
+
"""
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
# Add src to path
|
| 27 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from prime_config import get_moduli, compute_d_mlp, TRAINING_RUNS, MIN_P, MIN_P_GROKKING
|
| 31 |
+
from utils import Config
|
| 32 |
+
from nnTrainer import Trainer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def build_config_dict(p, run_params, d_mlp_override=None):
|
| 36 |
+
"""Build a nested config dict compatible with the Config class."""
|
| 37 |
+
d_mlp = d_mlp_override if d_mlp_override is not None else compute_d_mlp(p)
|
| 38 |
+
return {
|
| 39 |
+
'data': {
|
| 40 |
+
'p': p,
|
| 41 |
+
'd_vocab': None,
|
| 42 |
+
'fn_name': 'add',
|
| 43 |
+
'frac_train': run_params['frac_train'],
|
| 44 |
+
'batch_style': run_params['batch_style'],
|
| 45 |
+
},
|
| 46 |
+
'model': {
|
| 47 |
+
'd_model': None,
|
| 48 |
+
'd_mlp': d_mlp,
|
| 49 |
+
'act_type': run_params['act_type'],
|
| 50 |
+
'embed_type': run_params['embed_type'],
|
| 51 |
+
'init_type': run_params['init_type'],
|
| 52 |
+
'init_scale': run_params['init_scale'],
|
| 53 |
+
},
|
| 54 |
+
'training': {
|
| 55 |
+
'num_epochs': run_params['num_epochs'],
|
| 56 |
+
'lr': run_params['lr'],
|
| 57 |
+
'weight_decay': run_params['weight_decay'],
|
| 58 |
+
'optimizer': run_params['optimizer'],
|
| 59 |
+
'stopping_thresh': -1,
|
| 60 |
+
'save_models': run_params['save_models'],
|
| 61 |
+
'save_every': run_params['save_every'],
|
| 62 |
+
'seed': run_params['seed'],
|
| 63 |
+
},
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _save_training_log(output_dir, p, run_name, run_params, d_mlp, curves):
|
| 68 |
+
"""Save a human-readable training_log.txt summarizing the run."""
|
| 69 |
+
log_path = os.path.join(output_dir, "training_log.txt")
|
| 70 |
+
n_epochs = len(curves.get('train_losses', []))
|
| 71 |
+
with open(log_path, 'w') as f:
|
| 72 |
+
f.write(f"{'=' * 70}\n")
|
| 73 |
+
f.write(f"Training Log: p={p}, run={run_name}\n")
|
| 74 |
+
f.write(f"{'=' * 70}\n\n")
|
| 75 |
+
f.write(f"Configuration:\n")
|
| 76 |
+
f.write(f" prime (p) = {p}\n")
|
| 77 |
+
f.write(f" d_mlp = {d_mlp}\n")
|
| 78 |
+
f.write(f" activation = {run_params['act_type']}\n")
|
| 79 |
+
f.write(f" init_type = {run_params['init_type']}\n")
|
| 80 |
+
f.write(f" init_scale = {run_params['init_scale']}\n")
|
| 81 |
+
f.write(f" optimizer = {run_params['optimizer']}\n")
|
| 82 |
+
f.write(f" learning_rate = {run_params['lr']}\n")
|
| 83 |
+
f.write(f" weight_decay = {run_params['weight_decay']}\n")
|
| 84 |
+
f.write(f" frac_train = {run_params['frac_train']}\n")
|
| 85 |
+
f.write(f" num_epochs = {run_params['num_epochs']}\n")
|
| 86 |
+
f.write(f" batch_style = {run_params['batch_style']}\n")
|
| 87 |
+
f.write(f" seed = {run_params['seed']}\n")
|
| 88 |
+
f.write(f"\n{'─' * 70}\n")
|
| 89 |
+
f.write(f"{'Epoch':>8s} {'Train Loss':>12s} {'Test Loss':>12s} "
|
| 90 |
+
f"{'Train Acc':>10s} {'Test Acc':>10s} "
|
| 91 |
+
f"{'Grad Norm':>10s} {'Param Norm':>11s}\n")
|
| 92 |
+
f.write(f"{'─' * 70}\n")
|
| 93 |
+
|
| 94 |
+
# Print every 100 epochs + the last epoch
|
| 95 |
+
train_losses = curves.get('train_losses', [])
|
| 96 |
+
test_losses = curves.get('test_losses', [])
|
| 97 |
+
train_accs = curves.get('train_accs', [])
|
| 98 |
+
test_accs = curves.get('test_accs', [])
|
| 99 |
+
grad_norms = curves.get('grad_norms', [])
|
| 100 |
+
param_norms = curves.get('param_norms', [])
|
| 101 |
+
|
| 102 |
+
step = max(1, n_epochs // 100) # ~100 lines
|
| 103 |
+
indices = list(range(0, n_epochs, step))
|
| 104 |
+
if n_epochs > 0 and (n_epochs - 1) not in indices:
|
| 105 |
+
indices.append(n_epochs - 1)
|
| 106 |
+
|
| 107 |
+
for i in indices:
|
| 108 |
+
tl = f"{train_losses[i]:.6f}" if i < len(train_losses) else "N/A"
|
| 109 |
+
tel = f"{test_losses[i]:.6f}" if i < len(test_losses) else "N/A"
|
| 110 |
+
ta = f"{train_accs[i]:.4f}" if i < len(train_accs) else "N/A"
|
| 111 |
+
tea = f"{test_accs[i]:.4f}" if i < len(test_accs) else "N/A"
|
| 112 |
+
gn = f"{grad_norms[i]:.4f}" if i < len(grad_norms) else "N/A"
|
| 113 |
+
pn = f"{param_norms[i]:.4f}" if i < len(param_norms) else "N/A"
|
| 114 |
+
f.write(f"{i:>8d} {tl:>12s} {tel:>12s} "
|
| 115 |
+
f"{ta:>10s} {tea:>10s} "
|
| 116 |
+
f"{gn:>10s} {pn:>11s}\n")
|
| 117 |
+
|
| 118 |
+
f.write(f"{'─' * 70}\n\n")
|
| 119 |
+
f.write(f"Final Results:\n")
|
| 120 |
+
if train_losses:
|
| 121 |
+
f.write(f" Train Loss = {train_losses[-1]:.6f}\n")
|
| 122 |
+
if test_losses:
|
| 123 |
+
f.write(f" Test Loss = {test_losses[-1]:.6f}\n")
|
| 124 |
+
if train_accs:
|
| 125 |
+
f.write(f" Train Acc = {train_accs[-1]:.4f}\n")
|
| 126 |
+
if test_accs:
|
| 127 |
+
f.write(f" Test Acc = {test_accs[-1]:.4f}\n")
|
| 128 |
+
if param_norms:
|
| 129 |
+
f.write(f" Param Norm = {param_norms[-1]:.4f}\n")
|
| 130 |
+
f.write(f"\nTotal epochs trained: {n_epochs}\n")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def run_training(p, run_name, output_base, d_mlp_override=None):
|
| 134 |
+
"""Train a single run for a single prime."""
|
| 135 |
+
if p < MIN_P:
|
| 136 |
+
print(f"[SKIP] p={p}, run={run_name}: p < {MIN_P} (too few Fourier frequencies)")
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
# Single-freq init needs at least 1 non-DC frequency: (p-1)//2 >= 1 → p >= 3
|
| 140 |
+
if run_name in ('quad_single_freq', 'relu_single_freq') and (p - 1) // 2 < 1:
|
| 141 |
+
print(f"[SKIP] p={p}, run={run_name}: no non-DC frequencies for single-freq init")
|
| 142 |
+
return
|
| 143 |
+
|
| 144 |
+
if run_name == 'grokking' and p < MIN_P_GROKKING:
|
| 145 |
+
print(f"[SKIP] p={p}, run={run_name}: p < {MIN_P_GROKKING} (too few test points)")
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
run_params = TRAINING_RUNS[run_name]
|
| 149 |
+
config_dict = build_config_dict(p, run_params, d_mlp_override)
|
| 150 |
+
d_mlp = d_mlp_override if d_mlp_override is not None else compute_d_mlp(p)
|
| 151 |
+
|
| 152 |
+
output_dir = os.path.join(output_base, f"p_{p:03d}", run_name)
|
| 153 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 154 |
+
|
| 155 |
+
# Check if already completed
|
| 156 |
+
marker = os.path.join(output_dir, "DONE")
|
| 157 |
+
if os.path.exists(marker):
|
| 158 |
+
print(f"[SKIP] p={p}, run={run_name} already completed")
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
print(f"[TRAIN] p={p}, d_mlp={d_mlp}, run={run_name}, "
|
| 162 |
+
f"epochs={run_params['num_epochs']}")
|
| 163 |
+
|
| 164 |
+
config = Config(config_dict)
|
| 165 |
+
trainer = Trainer(config=config, use_wandb=False)
|
| 166 |
+
|
| 167 |
+
# Override save directory so checkpoints go into our output structure
|
| 168 |
+
trainer.save_dir = output_dir
|
| 169 |
+
run_subdir = os.path.join(output_dir, trainer.run_name)
|
| 170 |
+
os.makedirs(run_subdir, exist_ok=True)
|
| 171 |
+
|
| 172 |
+
# Re-save train/test data to the overridden location so generate_plots.py
|
| 173 |
+
# can find them (Trainer.__init__ saves to the original save_dir)
|
| 174 |
+
torch.save(trainer.train, os.path.join(run_subdir, 'train_data.pth'))
|
| 175 |
+
torch.save(trainer.test, os.path.join(run_subdir, 'test_data.pth'))
|
| 176 |
+
|
| 177 |
+
trainer.initial_save_if_appropriate()
|
| 178 |
+
|
| 179 |
+
# Plateau early-stopping for grokking: after 10K epochs, if curves
|
| 180 |
+
# haven't changed in the last 1000 epochs, stop training.
|
| 181 |
+
plateau_check = (run_name == 'grokking')
|
| 182 |
+
plateau_min_epoch = 10000
|
| 183 |
+
plateau_window = 1000
|
| 184 |
+
plateau_loss_tol = 1e-3 # absolute change in loss
|
| 185 |
+
plateau_acc_tol = 0.005 # absolute change in accuracy
|
| 186 |
+
|
| 187 |
+
for epoch in range(config.num_epochs):
|
| 188 |
+
train_loss, test_loss = trainer.do_a_training_step(epoch)
|
| 189 |
+
|
| 190 |
+
if test_loss.item() < config.stopping_thresh:
|
| 191 |
+
print(f" Early stopping at epoch {epoch}: "
|
| 192 |
+
f"test loss {test_loss.item():.6f}")
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
# Plateau detection for grokking
|
| 196 |
+
if (plateau_check and epoch >= plateau_min_epoch
|
| 197 |
+
and epoch % plateau_window == 0):
|
| 198 |
+
tl = trainer.train_losses
|
| 199 |
+
tel = trainer.test_losses
|
| 200 |
+
ta = trainer.train_accs
|
| 201 |
+
tea = trainer.test_accs
|
| 202 |
+
w = plateau_window
|
| 203 |
+
if len(tl) >= w and len(tel) >= w:
|
| 204 |
+
tl_flat = (max(tl[-w:]) - min(tl[-w:])) < plateau_loss_tol
|
| 205 |
+
tel_flat = (max(tel[-w:]) - min(tel[-w:])) < plateau_loss_tol
|
| 206 |
+
ta_flat = (not ta) or (max(ta[-w:]) - min(ta[-w:])) < plateau_acc_tol
|
| 207 |
+
tea_flat = (not tea) or (max(tea[-w:]) - min(tea[-w:])) < plateau_acc_tol
|
| 208 |
+
if tl_flat and tel_flat and ta_flat and tea_flat:
|
| 209 |
+
print(f" Plateau early stopping at epoch {epoch}: "
|
| 210 |
+
f"no change in last {w} epochs")
|
| 211 |
+
break
|
| 212 |
+
|
| 213 |
+
if config.is_it_time_to_save(epoch=epoch):
|
| 214 |
+
trainer.save_epoch(epoch=epoch, save_to_wandb=False, local_save=True)
|
| 215 |
+
|
| 216 |
+
trainer.post_training_save(
|
| 217 |
+
save_optimizer_and_scheduler=False, log_to_wandb=False
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Save training curves as JSON for plot generation
|
| 221 |
+
curves = {
|
| 222 |
+
'train_losses': trainer.train_losses,
|
| 223 |
+
'test_losses': trainer.test_losses,
|
| 224 |
+
'train_accs': trainer.train_accs,
|
| 225 |
+
'test_accs': trainer.test_accs,
|
| 226 |
+
'grad_norms': trainer.grad_norms,
|
| 227 |
+
'param_norms': trainer.param_norms,
|
| 228 |
+
}
|
| 229 |
+
curves_path = os.path.join(output_dir, "training_curves.json")
|
| 230 |
+
with open(curves_path, 'w') as f:
|
| 231 |
+
json.dump(curves, f)
|
| 232 |
+
|
| 233 |
+
# Save a human-readable training log
|
| 234 |
+
_save_training_log(output_dir, p, run_name, run_params, d_mlp, curves)
|
| 235 |
+
|
| 236 |
+
# Write completion marker
|
| 237 |
+
with open(marker, 'w') as f:
|
| 238 |
+
f.write(f"p={p} run={run_name} completed\n")
|
| 239 |
+
|
| 240 |
+
print(f"[DONE] p={p}, run={run_name}, "
|
| 241 |
+
f"train_acc={trainer.train_accs[-1]:.4f}, "
|
| 242 |
+
f"test_acc={trainer.test_accs[-1]:.4f}")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def main():
|
| 246 |
+
parser = argparse.ArgumentParser(
|
| 247 |
+
description='Batch training for modular addition experiments'
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument('--all', action='store_true',
|
| 250 |
+
help='Train all odd p in [3, 199]')
|
| 251 |
+
parser.add_argument('--p', type=int,
|
| 252 |
+
help='Train a specific odd modulus p')
|
| 253 |
+
parser.add_argument('--run', type=str, choices=list(TRAINING_RUNS.keys()),
|
| 254 |
+
help='Train a specific run type')
|
| 255 |
+
parser.add_argument('--output', type=str, default='./trained_models',
|
| 256 |
+
help='Output directory for trained models')
|
| 257 |
+
parser.add_argument('--d_mlp', type=int, default=None,
|
| 258 |
+
help='Override d_mlp (number of hidden neurons). '
|
| 259 |
+
'Default: auto-computed from p.')
|
| 260 |
+
parser.add_argument('--resume', action='store_true',
|
| 261 |
+
help='Skip already-completed runs (checks DONE marker)')
|
| 262 |
+
args = parser.parse_args()
|
| 263 |
+
|
| 264 |
+
if not args.all and args.p is None:
|
| 265 |
+
parser.error("Specify --all or --p P")
|
| 266 |
+
|
| 267 |
+
moduli = [args.p] if args.p else get_moduli()
|
| 268 |
+
runs = [args.run] if args.run else list(TRAINING_RUNS.keys())
|
| 269 |
+
|
| 270 |
+
total = len(moduli) * len(runs)
|
| 271 |
+
completed = 0
|
| 272 |
+
|
| 273 |
+
for p in moduli:
|
| 274 |
+
for run_name in runs:
|
| 275 |
+
completed += 1
|
| 276 |
+
print(f"\n{'='*60}")
|
| 277 |
+
print(f"[{completed}/{total}] p={p}, run={run_name}")
|
| 278 |
+
print(f"{'='*60}")
|
| 279 |
+
try:
|
| 280 |
+
run_training(p, run_name, args.output, d_mlp_override=args.d_mlp)
|
| 281 |
+
except Exception as e:
|
| 282 |
+
print(f"[FAIL] p={p}, run={run_name}: {e}")
|
| 283 |
+
import traceback
|
| 284 |
+
traceback.print_exc()
|
| 285 |
+
|
| 286 |
+
print(f"\nAll done. {completed} runs processed.")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
main()
|
precomputed_results/p_015/p015_full_training_para_origin.png
ADDED
|
precomputed_results/p_015/p015_lineplot_in.png
ADDED
|
Git LFS Details
|
precomputed_results/p_015/p015_lineplot_out.png
ADDED
|
Git LFS Details
|
precomputed_results/p_015/p015_logits_interactive.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"pairs": [[0, 0], [1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0], [7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0], [13, 0], [14, 0]], "correct_answers": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "logits": [[5.845241069793701, -0.2678382396697998, -0.27407053112983704, 0.5236045718193054, 0.1727389544248581, -0.5161969065666199, 0.32797300815582275, 0.03217152878642082, 0.9118141531944275, 0.1262667030096054, -0.48561030626296997, -0.08425326645374298, 0.7506414651870728, 0.26417604088783264, 0.06685394793748856], [1.2039813995361328, 6.88271427154541, 1.5816727876663208, -1.3607698678970337, 0.2838805019855499, -0.1305035948753357, -0.1361088901758194, 0.1741233468055725, -0.20070961117744446, -0.19441378116607666, -0.07290667295455933, -0.036422792822122574, -0.045937348157167435, -0.4817635118961334, -1.3201030492782593], [1.021496057510376, 0.460868775844574, 6.665925979614258, -0.31677037477493286, 0.5722809433937073, -0.4090900421142578, -1.4602270126342773, -1.284806728363037, 0.1630801558494568, -0.14396820962429047, 0.21148079633712769, 0.009713075123727322, -0.13466662168502808, -1.4001739025115967, -0.48616650700569153], [2.0029213428497314, -1.0664080381393433, 0.1110871359705925, 6.9900641441345215, -0.6577256321907043, 0.14176632463932037, 0.5043318271636963, -0.9338439702987671, -0.5924315452575684, 0.0853288546204567, -0.8137574791908264, 0.2345130294561386, -0.747736930847168, -0.7098338603973389, -0.09835103899240494], [1.1424025297164917, -0.11194394528865814, -0.35566678643226624, 0.07304318249225616, 6.601017951965332, -0.7027722001075745, -0.2666034996509552, -0.44442567229270935, 0.8163065314292908, -0.5371125340461731, 0.3994847536087036, -2.494434118270874, -0.7756778001785278, -0.05254651606082916, -0.19730210304260254], [-0.2894173860549927, -0.4359486699104309, -0.5648083686828613, 0.2765766680240631, -0.8157410025596619, 6.83429479598999, -0.35287508368492126, -0.13654330372810364, 0.4015010595321655, -0.5329406261444092, -0.8750499486923218, 0.27872195839881897, 0.017333246767520905, 0.40312278270721436, -0.36508798599243164], [0.8411731719970703, -0.7100308537483215, -0.03279941901564598, -1.725080132484436, 0.041413623839616776, -0.020579706877470016, 7.065398216247559, -0.6480932235717773, 0.3441147208213806, -0.6311541795730591, -0.5848420858383179, 0.23020878434181213, 1.6761562824249268, -0.4346846342086792, 0.1869385540485382], [1.0707398653030396, 0.15244746208190918, -0.8544912338256836, -0.15931977331638336, -0.30050942301750183, -0.14030054211616516, -0.8752976059913635, 6.867440223693848, -0.9738104343414307, -0.016967639327049255, 0.06942860782146454, -0.36363551020622253, -0.5302596092224121, -0.04578635096549988, 1.362797498703003], [1.3280404806137085, 0.42291590571403503, 0.2617959976196289, -0.14120015501976013, 0.16784554719924927, -0.3041268587112427, -0.9258386492729187, -1.557655930519104, 6.762355327606201, -1.1574827432632446, 0.6234251260757446, 0.20528843998908997, -0.12094193696975708, -0.20512300729751587, 0.5443016290664673], [1.1472804546356201, -0.6102383732795715, -0.025639446452260017, 1.0203014612197876, -0.7307616472244263, -0.12430916726589203, -1.7754124402999878, -0.18994221091270447, -0.2811968922615051, 7.137627601623535, -0.20736677944660187, -0.2558075189590454, -0.8490561842918396, -0.11468800902366638, -0.2236211597919464], [-0.4398433268070221, -0.21957175433635712, 0.3112000524997711, -0.31722167134284973, -0.045922309160232544, -0.8279529213905334, 0.09667133539915085, -0.1345088630914688, -0.4463891386985779, -0.33588770031929016, 6.491363525390625, -0.12330206483602524, 0.10725130885839462, 0.42743009328842163, -0.8439421057701111], [0.8614866733551025, -0.09104573726654053, -0.26044338941574097, -1.4104074239730835, -1.4556188583374023, -0.1587509959936142, -0.8687411546707153, 0.6269921064376831, -0.11933901160955429, -0.3620198369026184, 0.2740001678466797, 6.834388256072998, 0.9378089308738708, -0.8011277914047241, 0.10253019630908966], [0.8134461045265198, -0.5108746886253357, -0.11389578878879547, -0.7348212003707886, -0.3866046965122223, -0.30147287249565125, -1.0970127582550049, -0.5308148264884949, -0.25204434990882874, 0.6024379134178162, -0.5250549912452698, 0.2944880723953247, 6.6044793128967285, -0.06822022795677185, 0.009033107198774815], [1.3036129474639893, 0.3358013927936554, -1.516667366027832, -0.7404187321662903, 0.37273597717285156, 0.4566952586174011, -0.5649716854095459, -0.16146983206272125, -0.5403905510902405, -0.6627720594406128, 0.10609856992959976, 1.1525702476501465, 0.3066888749599457, 6.983869552612305, 0.6811028122901917], [1.5695332288742065, -1.1690753698349, 0.29273056983947754, -0.6167480945587158, -0.481355756521225, 0.3711283802986145, -0.3477175235748291, 0.06404877454042435, -0.34069737792015076, -0.5002283453941345, -0.22661544382572174, 0.29657235741615295, -1.5196651220321655, 1.143660306930542, 6.862580299377441]], "output_classes": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]}
|
precomputed_results/p_015/p015_lottery_beta_contour.png
ADDED
|
precomputed_results/p_015/p015_lottery_mech_magnitude.png
ADDED
|
precomputed_results/p_015/p015_lottery_mech_phase.png
ADDED
|
precomputed_results/p_015/p015_magnitude_distribution.png
ADDED
|
precomputed_results/p_015/p015_metadata.json
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prime": 15,
|
| 3 |
+
"d_mlp": 512,
|
| 4 |
+
"training_runs": {
|
| 5 |
+
"standard": {
|
| 6 |
+
"act_type": "ReLU",
|
| 7 |
+
"lr": 5e-05,
|
| 8 |
+
"weight_decay": 0,
|
| 9 |
+
"num_epochs": 5000,
|
| 10 |
+
"frac_train": 1.0,
|
| 11 |
+
"init_type": "random",
|
| 12 |
+
"init_scale": 0.1,
|
| 13 |
+
"optimizer": "AdamW"
|
| 14 |
+
},
|
| 15 |
+
"grokking": {
|
| 16 |
+
"act_type": "ReLU",
|
| 17 |
+
"lr": 0.0001,
|
| 18 |
+
"weight_decay": 2.0,
|
| 19 |
+
"num_epochs": 50000,
|
| 20 |
+
"frac_train": 0.75,
|
| 21 |
+
"init_type": "random",
|
| 22 |
+
"init_scale": 0.1,
|
| 23 |
+
"optimizer": "AdamW"
|
| 24 |
+
},
|
| 25 |
+
"quad_random": {
|
| 26 |
+
"act_type": "Quad",
|
| 27 |
+
"lr": 5e-05,
|
| 28 |
+
"weight_decay": 0,
|
| 29 |
+
"num_epochs": 5000,
|
| 30 |
+
"frac_train": 1.0,
|
| 31 |
+
"init_type": "random",
|
| 32 |
+
"init_scale": 0.1,
|
| 33 |
+
"optimizer": "AdamW"
|
| 34 |
+
},
|
| 35 |
+
"quad_single_freq": {
|
| 36 |
+
"act_type": "Quad",
|
| 37 |
+
"lr": 0.1,
|
| 38 |
+
"weight_decay": 0,
|
| 39 |
+
"num_epochs": 5000,
|
| 40 |
+
"frac_train": 1.0,
|
| 41 |
+
"init_type": "single-freq",
|
| 42 |
+
"init_scale": 0.02,
|
| 43 |
+
"optimizer": "SGD"
|
| 44 |
+
},
|
| 45 |
+
"relu_single_freq": {
|
| 46 |
+
"act_type": "ReLU",
|
| 47 |
+
"lr": 0.01,
|
| 48 |
+
"weight_decay": 0,
|
| 49 |
+
"num_epochs": 5000,
|
| 50 |
+
"frac_train": 1.0,
|
| 51 |
+
"init_type": "single-freq",
|
| 52 |
+
"init_scale": 0.002,
|
| 53 |
+
"optimizer": "SGD"
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
"final_metrics": {
|
| 57 |
+
"standard": {
|
| 58 |
+
"train_acc": 1.0,
|
| 59 |
+
"test_acc": 1.0,
|
| 60 |
+
"train_loss": 0.020928841084241867,
|
| 61 |
+
"test_loss": 0.020928841084241867
|
| 62 |
+
},
|
| 63 |
+
"quad_random": {
|
| 64 |
+
"train_acc": 1.0,
|
| 65 |
+
"test_acc": 1.0,
|
| 66 |
+
"train_loss": 0.0036203155759721994,
|
| 67 |
+
"test_loss": 0.0036203155759721994
|
| 68 |
+
},
|
| 69 |
+
"quad_single_freq": {
|
| 70 |
+
"train_acc": 1.0,
|
| 71 |
+
"test_acc": 1.0,
|
| 72 |
+
"train_loss": 0.04876862093806267,
|
| 73 |
+
"test_loss": 0.04876862093806267
|
| 74 |
+
},
|
| 75 |
+
"relu_single_freq": {
|
| 76 |
+
"train_acc": 1.0,
|
| 77 |
+
"test_acc": 1.0,
|
| 78 |
+
"train_loss": 2.7064406871795654,
|
| 79 |
+
"test_loss": 2.7064406871795654
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
}
|
precomputed_results/p_015/p015_neuron_spectra.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"fourier_basis_names": ["Const", "cos 1", "sin 1", "cos 2", "sin 2", "cos 3", "sin 3", "cos 4", "sin 4", "cos 5", "sin 5", "cos 6", "sin 6", "cos 7", "sin 7"], "neurons": {"neuron_0": {"global_index": 298, "dominant_freq": 1, "fourier_magnitudes_in": [0.009785229340195656, 0.835197389125824, 0.06538262963294983, 0.023014001548290253, 0.11825679987668991, 0.13374952971935272, 0.06052215397357941, 0.05648910254240036, 0.19347421824932098, 0.022295288741588593, 0.025461282581090927, 0.10164802521467209, 0.13591913878917694, 0.05420650914311409, 0.015495105646550655], "fourier_magnitudes_out": [0.05270613357424736, 0.8947794437408447, 0.023624544963240623, 0.0090867318212986, 0.013868873007595539, 0.23038825392723083, 0.01721891760826111, 0.03962008282542229, 0.026063552126288414, 0.13086184859275818, 0.030048014596104622, 0.0664755254983902, 0.05337228253483772, 0.13469304144382477, 0.026413951069116592]}, "neuron_1": {"global_index": 322, "dominant_freq": 2, "fourier_magnitudes_in": [0.0048967949114739895, 0.05831458047032356, 0.047817476093769073, 0.8282294273376465, 0.08354193717241287, 0.07255098968744278, 0.11165464669466019, 0.025725897401571274, 0.1043122410774231, 0.06207004189491272, 0.03182118758559227, 0.16142725944519043, 0.05974670872092247, 0.07021868228912354, 0.14918409287929535], "fourier_magnitudes_out": [0.0682767927646637, 0.04459874704480171, 0.048241592943668365, 0.8725450038909912, 0.12229588627815247, 0.00944548286497593, 0.04323046654462814, 0.04066399484872818, 0.016419490799307823, 0.08050032705068588, 0.10128692537546158, 0.2789347767829895, 0.11171453446149826, 0.0022370540536940098, 0.05760588496923447]}, "neuron_2": {"global_index": 506, "dominant_freq": 3, "fourier_magnitudes_in": [0.16119761765003204, 0.007301056291908026, 0.015491282567381859, 0.008759702555835247, 0.012521358206868172, 0.8759171366691589, 0.027130309492349625, 0.011930187232792377, 0.008996764197945595, 0.017064398154616356, 0.0032833898440003395, 0.2937963902950287, 0.03241246938705444, 0.01809331774711609, 0.01297019049525261], "fourier_magnitudes_out": [0.1367007941007614, 0.08054599165916443, 0.009201026521623135, 0.04534965381026268, 0.08026888966560364, 0.8633893728256226, 0.050462786108255386, 0.0044480981305241585, 0.03184027969837189, 0.05210161209106445, 0.04297168180346489, 0.32101351022720337, 0.06568200141191483, 0.035220254212617874, 0.014398200437426567]}, "neuron_3": {"global_index": 163, "dominant_freq": 4, "fourier_magnitudes_in": [0.031200144439935684, 0.04270334914326668, 0.08881780505180359, 0.030023256316781044, 0.08304275572299957, 0.04510613903403282, 0.1826746016740799, 0.03483400121331215, 0.8419885039329529, 0.013886661268770695, 0.08348363637924194, 0.050276018679142, 0.10324658453464508, 0.007152870763093233, 0.037329111248254776], "fourier_magnitudes_out": [0.0033031131606549025, 0.04313803091645241, 0.04859255626797676, 0.03949907422065735, 0.0202469564974308, 0.20968425273895264, 0.08351283520460129, 0.8787142038345337, 0.06666962057352066, 0.12211224436759949, 0.0748063325881958, 0.03903647139668465, 0.004882670473307371, 0.03963864594697952, 0.00997573509812355]}, "neuron_4": {"global_index": 56, "dominant_freq": 5, "fourier_magnitudes_in": [0.3008343279361725, 0.00417996384203434, 0.0018625137163326144, 0.0030638582538813353, 0.0034040315076708794, 0.0014134111115708947, 0.004357232246547937, 0.0004787891812156886, 0.004558315966278315, 0.8440173864364624, 0.003970965277403593, 0.0037021928001195192, 0.00269315205514431, 0.004479836206883192, 0.0009506550850346684], "fourier_magnitudes_out": [0.34606269001960754, 0.027366885915398598, 0.036413464695215225, 0.017639687284827232, 0.027893830090761185, 0.009557241573929787, 0.03348158299922943, 0.02666180580854416, 0.006313585676252842, 0.9367755651473999, 0.021462714299559593, 0.004096114542335272, 0.023737004026770592, 0.012917638756334782, 0.005492078140377998]}, "neuron_5": {"global_index": 382, "dominant_freq": 6, "fourier_magnitudes_in": [0.17134937644004822, 0.030019262805581093, 0.003407673677429557, 0.009750651195645332, 0.017143065109848976, 0.29516395926475525, 0.04455732926726341, 0.0013349098153412342, 0.022812439128756523, 0.007311187218874693, 0.0026531207840889692, 0.8911311626434326, 0.017707478255033493, 0.012044227682054043, 0.007314047310501337], "fourier_magnitudes_out": [0.15681979060173035, 0.02663942240178585, 0.015795163810253143, 0.03183992952108383, 0.019026663154363632, 0.34929507970809937, 0.021825529634952545, 0.0012088950024917722, 0.013926293700933456, 0.004897939506918192, 0.010765696875751019, 0.9581749439239502, 0.012587963603436947, 0.02056090161204338, 0.01620820164680481]}, "neuron_6": {"global_index": 44, "dominant_freq": 7, "fourier_magnitudes_in": [0.016063887625932693, 0.005506421905010939, 0.06481847167015076, 0.050152119249105453, 0.08404979109764099, 0.04145295172929764, 0.0641774982213974, 0.06344524770975113, 0.07214409857988358, 0.09170015156269073, 0.048928599804639816, 0.22082282602787018, 0.04532390087842941, 0.8467198014259338, 0.08364292979240417], "fourier_magnitudes_out": [0.015495719388127327, 0.0706552267074585, 0.03188098222017288, 0.022803593426942825, 0.05851001664996147, 0.006458980031311512, 0.07394170761108398, 0.06014903634786606, 0.08831571042537689, 0.10563234984874725, 0.11405274271965027, 0.26272422075271606, 0.13614341616630554, 0.9203148484230042, 0.14940021932125092]}, "neuron_7": {"global_index": 157, "dominant_freq": 1, "fourier_magnitudes_in": [0.023283498361706734, 0.14338412880897522, 0.8202316761016846, 0.05399405211210251, 0.058775369077920914, 0.12734520435333252, 0.1968384087085724, 0.070158950984478, 0.07580292969942093, 0.14696937799453735, 0.0658537819981575, 0.08757402002811432, 0.050403349101543427, 0.07286650687456131, 0.030624864622950554], "fourier_magnitudes_out": [0.10438423603773117, 0.8233321905136108, 0.2895442247390747, 0.11436242610216141, 0.043502938002347946, 0.08488017320632935, 0.21032141149044037, 0.03133478760719299, 0.147483691573143, 0.09128359705209732, 0.11333037912845612, 0.042169392108917236, 0.14409606158733368, 0.06334586441516876, 0.010641958564519882]}, "neuron_8": {"global_index": 436, "dominant_freq": 2, "fourier_magnitudes_in": [0.02816132642328739, 0.0366082526743412, 0.09432870894670486, 0.82112056016922, 0.2999240458011627, 0.07823537290096283, 0.07104597240686417, 0.010237633250653744, 0.05316127836704254, 0.01780366338789463, 0.10813448578119278, 0.10922082513570786, 0.2154001146554947, 0.06505948305130005, 0.005772633943706751], "fourier_magnitudes_out": [0.07935837656259537, 0.05681774765253067, 0.11203812062740326, 0.7207208275794983, 0.5456924438476562, 0.08229123800992966, 0.0757966861128807, 0.025318337604403496, 0.12979470193386078, 0.16904646158218384, 0.01633336953818798, 0.10492201149463654, 0.25266337394714355, 0.07896621525287628, 0.056506332010030746]}, "neuron_9": {"global_index": 339, "dominant_freq": 3, "fourier_magnitudes_in": [0.00793448369950056, 0.23576614260673523, 0.05173733830451965, 0.0808015987277031, 0.08904809504747391, 0.8467172980308533, 0.07009700685739517, 0.2346189022064209, 0.06041780859231949, 0.22861061990261078, 0.005940192844718695, 0.08250848948955536, 0.05041716247797012, 0.09760993719100952, 0.08870971947908401], "fourier_magnitudes_out": [0.07995986938476562, 0.058010708540678024, 0.0238045621663332, 0.044154077768325806, 0.06384548544883728, 0.8888619542121887, 0.1083059012889862, 0.05018500238656998, 0.06562003493309021, 0.032845836132764816, 0.06583794951438904, 0.19849520921707153, 0.08598171174526215, 0.048292793333530426, 0.056917209178209305]}, "neuron_10": {"global_index": 405, "dominant_freq": 4, "fourier_magnitudes_in": [0.04605866223573685, 0.10036957263946533, 0.1019270122051239, 0.08908756822347641, 0.05084078758955002, 0.09276082366704941, 0.1580628901720047, 0.11540821939706802, 0.8267641663551331, 0.07836127281188965, 0.007629718631505966, 0.10522190481424332, 0.08720945566892624, 0.07787368446588516, 0.07161819189786911], "fourier_magnitudes_out": [0.0646987333893776, 0.06891193985939026, 0.1397920846939087, 0.05930565670132637, 0.08342889696359634, 0.09649658203125, 0.19548238813877106, 0.8096501231193542, 0.276018351316452, 0.04780340939760208, 0.11156941205263138, 0.02261027880012989, 0.12174452841281891, 0.13799422979354858, 0.09665590524673462]}, "neuron_11": {"global_index": 363, "dominant_freq": 5, "fourier_magnitudes_in": [0.26700034737586975, 0.027114098891615868, 0.030112946406006813, 0.004235479515045881, 0.04029411822557449, 0.03277970105409622, 0.02381693758070469, 0.03963545709848404, 0.008424329571425915, 0.8159627914428711, 0.03508993238210678, 0.012520099990069866, 0.03853161633014679, 0.037013012915849686, 0.016481192782521248], "fourier_magnitudes_out": [0.3175621032714844, 0.022698314860463142, 0.029232148081064224, 0.009378736838698387, 0.04299170523881912, 0.021195683628320694, 0.006493568420410156, 0.02105909213423729, 0.005779783241450787, 0.9192318916320801, 0.025057973340153694, 0.015149646438658237, 0.004153982736170292, 0.01740305870771408, 0.028322339057922363]}, "neuron_12": {"global_index": 195, "dominant_freq": 6, "fourier_magnitudes_in": [0.09348957985639572, 0.09145021438598633, 0.019759872928261757, 0.14044030010700226, 0.010322250425815582, 0.09715981036424637, 0.15762653946876526, 0.08852937072515488, 0.012697099708020687, 0.10808465629816055, 0.005568717140704393, 0.11048246920108795, 0.8844196796417236, 0.06171249970793724, 0.021244505420327187], "fourier_magnitudes_out": [0.16980455815792084, 0.02732899598777294, 0.036924730986356735, 0.05185381695628166, 0.045569196343421936, 0.26525792479515076, 0.05154341459274292, 0.0016775119584053755, 0.005609582178294659, 0.06192351505160332, 0.05204838141798973, 0.878484845161438, 0.06501992046833038, 0.029970666393637657, 0.007208859547972679]}, "neuron_13": {"global_index": 400, "dominant_freq": 7, "fourier_magnitudes_in": [0.03170298784971237, 0.01636633276939392, 0.02570384368300438, 0.059365708380937576, 0.038365013897418976, 0.04416408762335777, 0.03580497205257416, 0.08423422276973724, 0.05347143113613129, 0.14569905400276184, 0.11170309782028198, 0.19443102180957794, 0.1770995706319809, 0.16263896226882935, 0.8427256941795349], "fourier_magnitudes_out": [0.001597732538357377, 0.03749779239296913, 0.021893106400966644, 0.021266808733344078, 0.003741121618077159, 0.06272387504577637, 0.031196942552924156, 0.06795763224363327, 0.09264159947633743, 0.009625964798033237, 0.16665540635585785, 0.20691494643688202, 0.2753947973251343, 0.8446058630943298, 0.2736137807369232]}, "neuron_14": {"global_index": 204, "dominant_freq": 1, "fourier_magnitudes_in": [0.009332116693258286, 0.18445806205272675, 0.8107064962387085, 0.0019745242316275835, 0.021862220019102097, 0.164852112531662, 0.19152098894119263, 0.0416061170399189, 0.006671736016869545, 0.12151989340782166, 0.0737357810139656, 0.07512296736240387, 0.00437825545668602, 0.09026234596967697, 0.010851189494132996], "fourier_magnitudes_out": [0.00363162811845541, 0.7924743294715881, 0.3171338140964508, 0.004241126589477062, 0.01878344640135765, 0.12165644019842148, 0.20536403357982635, 0.018597787246108055, 0.0810178741812706, 0.016301296651363373, 0.17405299842357635, 0.0980185940861702, 0.04579637944698334, 0.14084625244140625, 0.08989735692739487]}, "neuron_15": {"global_index": 200, "dominant_freq": 2, "fourier_magnitudes_in": [0.18793319165706635, 0.07831189781427383, 0.0546899177134037, 0.26695817708969116, 0.8170905709266663, 0.05296998843550682, 0.1943284422159195, 0.23756185173988342, 0.15034066140651703, 0.017131371423602104, 0.0028046099469065666, 0.11494778096675873, 0.10373453050851822, 0.12329491227865219, 0.18243563175201416], "fourier_magnitudes_out": [0.09242420643568039, 0.0712248757481575, 0.12614493072032928, 0.7048966288566589, 0.5352984666824341, 0.1033194437623024, 0.05411561205983162, 0.020737502723932266, 0.1372404545545578, 0.13828690350055695, 0.03525715693831444, 0.11022429913282394, 0.2024645209312439, 0.14993123710155487, 0.08065018057823181]}, "neuron_16": {"global_index": 29, "dominant_freq": 3, "fourier_magnitudes_in": [0.06590171158313751, 0.1875513195991516, 0.13940751552581787, 0.05203652381896973, 0.12810246646404266, 0.8321968913078308, 0.14579933881759644, 0.1637534499168396, 0.11946561932563782, 0.1717541664838791, 0.013540252111852169, 0.15768876671791077, 0.04509717971086502, 0.05775374174118042, 0.17158927023410797], "fourier_magnitudes_out": [0.004155661445111036, 0.03632631525397301, 0.11120419949293137, 0.07094360142946243, 0.0025094610173255205, 0.8083617091178894, 0.2703225910663605, 0.04779119789600372, 0.01505393534898758, 0.019837403669953346, 0.020111212506890297, 0.13071304559707642, 0.1915966272354126, 0.012666973285377026, 0.03241927549242973]}, "neuron_17": {"global_index": 263, "dominant_freq": 4, "fourier_magnitudes_in": [0.04006792977452278, 0.04470885172486305, 0.03644431754946709, 0.09875550121068954, 0.0844721719622612, 0.2842579483985901, 0.07786386460065842, 0.8181069493293762, 0.07545538246631622, 0.1041831225156784, 0.08087007701396942, 0.013009922578930855, 0.05374658852815628, 0.010971452109515667, 0.05345306172966957], "fourier_magnitudes_out": [0.0032453967723995447, 0.04532613605260849, 0.050931766629219055, 0.06243853271007538, 0.1471986621618271, 0.07246130704879761, 0.0618220679461956, 0.793653130531311, 0.21675735712051392, 0.10524935275316238, 0.002365024061873555, 0.05830618739128113, 0.12600253522396088, 0.015407783910632133, 0.04248036816716194]}, "neuron_18": {"global_index": 468, "dominant_freq": 5, "fourier_magnitudes_in": [0.27209052443504333, 0.01535376999527216, 0.024236787110567093, 0.014399196021258831, 0.026386309415102005, 0.03005525842308998, 0.0007289683562703431, 0.013318601995706558, 0.025410467758774757, 0.8113285303115845, 0.024080565199255943, 0.028674796223640442, 0.0011739643523469567, 0.015659669414162636, 0.025659453123807907], "fourier_magnitudes_out": [0.2273026406764984, 0.010362203232944012, 0.08463546633720398, 0.048502951860427856, 0.006748202722519636, 0.008361948654055595, 0.017721673473715782, 0.055468566715717316, 0.07133140414953232, 0.8983375430107117, 0.06657372415065765, 0.07770014554262161, 0.05767229199409485, 0.02166377194225788, 0.007756791543215513]}, "neuron_19": {"global_index": 502, "dominant_freq": 6, "fourier_magnitudes_in": [0.161783829331398, 0.011249735951423645, 0.011750025674700737, 0.006551133934408426, 0.005556976888328791, 0.3025686740875244, 0.004988331813365221, 0.01731167919933796, 0.009423403069376945, 0.001427029725164175, 0.0038674750830978155, 0.8558500409126282, 0.014558211900293827, 0.0111812399700284, 0.0006309517193585634], "fourier_magnitudes_out": [0.17063283920288086, 0.016314541921019554, 0.05021467059850693, 0.016806060448288918, 0.04298456758260727, 0.2806430757045746, 0.014211905188858509, 0.004449731670320034, 0.02185676619410515, 0.006607384420931339, 0.005355009343475103, 0.875554621219635, 0.016872091218829155, 0.022274665534496307, 0.03286416828632355]}}}
|
precomputed_results/p_015/p015_output_logits.png
ADDED
|
precomputed_results/p_015/p015_overview.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"std_epochs": [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800], "std_ipr": [0.2509222626686096, 0.26378345489501953, 0.3110557198524475, 0.36821821331977844, 0.4191873371601105, 0.46056264638900757, 0.4932430386543274, 0.518873393535614, 0.5395100712776184, 0.5563467741012573, 0.5705686807632446, 0.5829041600227356, 0.5938962697982788, 0.6040970087051392, 0.6136700510978699, 0.6225748658180237, 0.6310203671455383, 0.6389821171760559, 0.6463097333908081, 0.6528617739677429, 0.6585475206375122, 0.6634621620178223, 0.6678526401519775, 0.6718972325325012, 0.6756421327590942], "std_train_loss": [2.7085084915161133, 2.677509307861328, 2.6341302394866943, 2.5730724334716797, 2.492614269256592, 2.3927817344665527, 2.2741761207580566, 2.137855052947998, 1.9853696823120117, 1.8187763690948486, 1.6410351991653442, 1.4555929899215698, 1.2665268182754517, 1.0787031650543213, 0.8973091840744019, 0.7275450825691223, 0.5740776062011719, 0.44004037976264954, 0.32742777466773987, 0.23663438856601715, 0.1664789468050003, 0.1142977625131607, 0.0767836645245552, 0.05059399828314781, 0.032766345888376236]}
|
precomputed_results/p_015/p015_overview_loss_ipr.png
ADDED
|
precomputed_results/p_015/p015_overview_phase_scatter.png
ADDED
|
precomputed_results/p_015/p015_phase_align_approx1.png
ADDED
|
Git LFS Details
|
precomputed_results/p_015/p015_phase_align_approx2.png
ADDED
|
Git LFS Details
|
precomputed_results/p_015/p015_phase_align_quad.png
ADDED
|
precomputed_results/p_015/p015_phase_align_relu.png
ADDED
|
precomputed_results/p_015/p015_phase_distribution.png
ADDED
|
precomputed_results/p_015/p015_phase_relationship.png
ADDED
|
precomputed_results/p_015/p015_single_freq_quad.png
ADDED
|
Git LFS Details
|
precomputed_results/p_015/p015_single_freq_relu.png
ADDED
|
Git LFS Details
|
precomputed_results/p_015/p015_training_log.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
precomputed_results/p_023/p023_full_training_para_origin.png
ADDED
|
Git LFS Details
|
precomputed_results/p_023/p023_grokk_abs_phase_diff.png
ADDED
|
precomputed_results/p_023/p023_grokk_acc.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"epochs": [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000, 10200, 10400, 10600, 10800, 11000, 11200, 11400, 11600, 11800, 12000, 12200, 12400, 12600, 12800, 13000, 13200, 13400, 13600, 13800, 14000, 14200, 14400, 14600, 14800, 15000, 15200, 15400, 15600, 15800, 16000, 16200, 16400, 16600, 16800, 17000, 17200, 17400, 17600, 17800, 18000, 18200, 18400, 18600, 18800, 19000, 19200, 19400, 19600, 19800, 20000, 20200, 20400, 20600, 20800, 21000, 21200, 21400, 21600, 21800, 22000, 22200, 22400, 22600, 22800, 23000, 23200, 23400, 23600, 23800], "train_accs": [0.045454545454545456, 0.23737373737373738, 0.5404040404040404, 0.7272727272727273, 0.7348484848484849, 0.7424242424242424, 0.76010101010101, 0.8055555555555556, 0.8838383838383839, 0.9671717171717171, 0.9924242424242424, 0.9974747474747475, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "test_accs": [0.03759398496240601, 0.007518796992481203, 0.007518796992481203, 0.007518796992481203, 0.015037593984962405, 0.03007518796992481, 0.06766917293233082, 0.18045112781954886, 0.39849624060150374, 0.6466165413533834, 0.7142857142857143, 0.7293233082706767, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7368421052631579, 0.7518796992481203, 0.7669172932330827, 0.7819548872180451, 0.7819548872180451, 0.7819548872180451, 0.7819548872180451, 0.7819548872180451, 0.7819548872180451, 0.7819548872180451, 0.7969924812030075, 0.8270676691729323, 0.8270676691729323, 0.8421052631578947, 0.9022556390977443, 0.9022556390977443, 0.9022556390977443, 0.9022556390977443, 0.9172932330827067, 0.9323308270676691, 0.9323308270676691, 0.9323308270676691, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939, 0.9624060150375939], "stage1_end": 1744, "stage2_end": 9431}
|
precomputed_results/p_023/p023_grokk_acc.png
ADDED
|
precomputed_results/p_023/p023_grokk_avg_ipr.png
ADDED
|
precomputed_results/p_023/p023_grokk_decoded_weights_dynamic.png
ADDED
|
Git LFS Details
|
precomputed_results/p_023/p023_grokk_epoch_data.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"prime": 23, "epochs": [0, 2600, 5200, 7800, 10400, 13200, 15800, 18400, 21000, 23800], "grids": [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], [[1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]]}
|
precomputed_results/p_023/p023_grokk_loss.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
precomputed_results/p_023/p023_grokk_loss.png
ADDED
|
precomputed_results/p_023/p023_grokk_memorization_accuracy.png
ADDED
|
precomputed_results/p_023/p023_grokk_memorization_common_to_rare.png
ADDED
|
Git LFS Details
|
precomputed_results/p_023/p023_lineplot_in.png
ADDED
|
Git LFS Details
|