Fix Tab 7: 10K epochs, better neuron selection, fixed timepoints
Browse files- README.md +2 -2
- precompute/generate_plots.py +37 -11
- precompute/prime_config.py +2 -2
- precomputed_results/p_015/p015_phase_align_quad.png +0 -0
- precomputed_results/p_015/p015_phase_align_relu.png +0 -0
- precomputed_results/p_015/p015_single_freq_quad.png +2 -2
- precomputed_results/p_015/p015_single_freq_relu.png +2 -2
- precomputed_results/p_023/p023_phase_align_quad.png +0 -0
- precomputed_results/p_023/p023_phase_align_relu.png +0 -0
- precomputed_results/p_023/p023_single_freq_quad.png +2 -2
- precomputed_results/p_023/p023_single_freq_relu.png +2 -2
- precomputed_results/p_029/p029_phase_align_quad.png +0 -0
- precomputed_results/p_029/p029_phase_align_relu.png +0 -0
- precomputed_results/p_029/p029_single_freq_quad.png +2 -2
- precomputed_results/p_029/p029_single_freq_relu.png +2 -2
- precomputed_results/p_031/p031_phase_align_quad.png +0 -0
- precomputed_results/p_031/p031_phase_align_relu.png +0 -0
- precomputed_results/p_031/p031_single_freq_quad.png +2 -2
- precomputed_results/p_031/p031_single_freq_relu.png +2 -2
README.md
CHANGED
|
@@ -160,8 +160,8 @@ Each modulus produces ~33 files in `precomputed_results/p_XXX/`:
|
|
| 160 |
| `standard` | ReLU | AdamW | 5e-5 | 0 | 100% | 5,000 | Tabs 1–4 |
|
| 161 |
| `grokking` | ReLU | AdamW | 1e-4 | 2.0 | 75% | 50,000 | Tabs 1, 6 |
|
| 162 |
| `quad_random` | Quad | AdamW | 5e-5 | 0 | 100% | 5,000 | Tab 5 |
|
| 163 |
-
| `quad_single_freq` | Quad | SGD | 0.1 | 0 | 100% |
|
| 164 |
-
| `relu_single_freq` | ReLU | SGD | 0.01 | 0 | 100% |
|
| 165 |
|
| 166 |
## Running a Single Experiment
|
| 167 |
|
|
|
|
| 160 |
| `standard` | ReLU | AdamW | 5e-5 | 0 | 100% | 5,000 | Tabs 1–4 |
|
| 161 |
| `grokking` | ReLU | AdamW | 1e-4 | 2.0 | 75% | 50,000 | Tabs 1, 6 |
|
| 162 |
| `quad_random` | Quad | AdamW | 5e-5 | 0 | 100% | 5,000 | Tab 5 |
|
| 163 |
+
| `quad_single_freq` | Quad | SGD | 0.1 | 0 | 100% | 10,000 | Tab 7 |
|
| 164 |
+
| `relu_single_freq` | ReLU | SGD | 0.01 | 0 | 100% | 10,000 | Tab 7 |
|
| 165 |
|
| 166 |
## Running a Single Experiment
|
| 167 |
|
precompute/generate_plots.py
CHANGED
|
@@ -1603,12 +1603,36 @@ class PlotGenerator:
|
|
| 1603 |
'phi_out': phi_out,
|
| 1604 |
})
|
| 1605 |
|
| 1606 |
-
# Select a neuron that shows
|
| 1607 |
-
#
|
|
|
|
|
|
|
|
|
|
| 1608 |
final_records = [r for r in all_neuron_records if r['epoch'] == epochs[-1]]
|
| 1609 |
if not final_records:
|
| 1610 |
continue
|
| 1611 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1612 |
|
| 1613 |
# Extract trajectory for this neuron
|
| 1614 |
neuron_records = [r for r in all_neuron_records if r['neuron'] == best_neuron]
|
|
@@ -1667,16 +1691,18 @@ class PlotGenerator:
|
|
| 1667 |
_save_fig(fig, self._out(f'phase_align_{prefix}.png'))
|
| 1668 |
|
| 1669 |
# ---- Decoded weights at timepoints ----
|
|
|
|
|
|
|
| 1670 |
if prefix == 'quad':
|
| 1671 |
-
|
| 1672 |
-
mid = min(epochs, key=lambda e: abs(e - 1000))
|
| 1673 |
-
end = epochs[-1]
|
| 1674 |
-
if mid not in keys:
|
| 1675 |
-
keys.append(mid)
|
| 1676 |
-
if end not in keys:
|
| 1677 |
-
keys.append(end)
|
| 1678 |
else:
|
| 1679 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1680 |
|
| 1681 |
num_components = min(20, d_mlp)
|
| 1682 |
n = len(keys)
|
|
|
|
| 1603 |
'phi_out': phi_out,
|
| 1604 |
})
|
| 1605 |
|
| 1606 |
+
# Select a neuron that shows interesting phase convergence dynamics.
|
| 1607 |
+
# The lottery winner (largest final scale) already has ψ ≈ 2φ from
|
| 1608 |
+
# the start, producing flat boring plots. Instead, pick a neuron that
|
| 1609 |
+
# (a) has significant final scale (top quartile → actually learned),
|
| 1610 |
+
# (b) had the largest initial phase misalignment |ψ₀ - 2φ₀|.
|
| 1611 |
final_records = [r for r in all_neuron_records if r['epoch'] == epochs[-1]]
|
| 1612 |
if not final_records:
|
| 1613 |
continue
|
| 1614 |
+
init_records = [r for r in all_neuron_records if r['epoch'] == epochs[0]]
|
| 1615 |
+
init_by_neuron = {r['neuron']: r for r in init_records}
|
| 1616 |
+
|
| 1617 |
+
# Keep neurons with final scale in top 25%
|
| 1618 |
+
scales = sorted([r['scale_in'] for r in final_records], reverse=True)
|
| 1619 |
+
scale_threshold = scales[max(0, len(scales) // 4 - 1)] if len(scales) >= 4 else scales[-1]
|
| 1620 |
+
strong_neurons = [r for r in final_records if r['scale_in'] >= scale_threshold]
|
| 1621 |
+
|
| 1622 |
+
# Among strong neurons, pick the one with largest initial misalignment
|
| 1623 |
+
best_neuron = None
|
| 1624 |
+
best_misalign = -1.0
|
| 1625 |
+
for r in strong_neurons:
|
| 1626 |
+
n = r['neuron']
|
| 1627 |
+
if n not in init_by_neuron:
|
| 1628 |
+
continue
|
| 1629 |
+
ir = init_by_neuron[n]
|
| 1630 |
+
misalign = abs(normalize_to_pi(ir['phi_out'] - 2 * ir['phi_in']))
|
| 1631 |
+
if misalign > best_misalign:
|
| 1632 |
+
best_misalign = misalign
|
| 1633 |
+
best_neuron = n
|
| 1634 |
+
if best_neuron is None:
|
| 1635 |
+
best_neuron = max(final_records, key=lambda r: r['scale_in'])['neuron']
|
| 1636 |
|
| 1637 |
# Extract trajectory for this neuron
|
| 1638 |
neuron_records = [r for r in all_neuron_records if r['neuron'] == best_neuron]
|
|
|
|
| 1691 |
_save_fig(fig, self._out(f'phase_align_{prefix}.png'))
|
| 1692 |
|
| 1693 |
# ---- Decoded weights at timepoints ----
|
| 1694 |
+
# Use fixed timepoints matching the notebook figures:
|
| 1695 |
+
# Quad: steps 0, 1000, 5000 ReLU: steps 0, 5000
|
| 1696 |
if prefix == 'quad':
|
| 1697 |
+
target_keys = [0, 1000, 5000]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1698 |
else:
|
| 1699 |
+
target_keys = [0, 5000]
|
| 1700 |
+
# Snap each target to the nearest available checkpoint epoch
|
| 1701 |
+
keys = []
|
| 1702 |
+
for t in target_keys:
|
| 1703 |
+
nearest = min(epochs, key=lambda e: abs(e - t))
|
| 1704 |
+
if nearest not in keys:
|
| 1705 |
+
keys.append(nearest)
|
| 1706 |
|
| 1707 |
num_components = min(20, d_mlp)
|
| 1708 |
n = len(keys)
|
precompute/prime_config.py
CHANGED
|
@@ -94,7 +94,7 @@ TRAINING_RUNS = {
|
|
| 94 |
"lr": 0.1,
|
| 95 |
"weight_decay": 0,
|
| 96 |
"frac_train": 1.0,
|
| 97 |
-
"num_epochs":
|
| 98 |
"save_every": 200,
|
| 99 |
"init_scale": 0.02,
|
| 100 |
"save_models": True,
|
|
@@ -109,7 +109,7 @@ TRAINING_RUNS = {
|
|
| 109 |
"lr": 0.01,
|
| 110 |
"weight_decay": 0,
|
| 111 |
"frac_train": 1.0,
|
| 112 |
-
"num_epochs":
|
| 113 |
"save_every": 200,
|
| 114 |
"init_scale": 0.002,
|
| 115 |
"save_models": True,
|
|
|
|
| 94 |
"lr": 0.1,
|
| 95 |
"weight_decay": 0,
|
| 96 |
"frac_train": 1.0,
|
| 97 |
+
"num_epochs": 10000,
|
| 98 |
"save_every": 200,
|
| 99 |
"init_scale": 0.02,
|
| 100 |
"save_models": True,
|
|
|
|
| 109 |
"lr": 0.01,
|
| 110 |
"weight_decay": 0,
|
| 111 |
"frac_train": 1.0,
|
| 112 |
+
"num_epochs": 10000,
|
| 113 |
"save_every": 200,
|
| 114 |
"init_scale": 0.002,
|
| 115 |
"save_models": True,
|
precomputed_results/p_015/p015_phase_align_quad.png
CHANGED
|
|
precomputed_results/p_015/p015_phase_align_relu.png
CHANGED
|
|
precomputed_results/p_015/p015_single_freq_quad.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
precomputed_results/p_015/p015_single_freq_relu.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
precomputed_results/p_023/p023_phase_align_quad.png
CHANGED
|
|
precomputed_results/p_023/p023_phase_align_relu.png
CHANGED
|
|
precomputed_results/p_023/p023_single_freq_quad.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
precomputed_results/p_023/p023_single_freq_relu.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
precomputed_results/p_029/p029_phase_align_quad.png
CHANGED
|
|
precomputed_results/p_029/p029_phase_align_relu.png
CHANGED
|
|
precomputed_results/p_029/p029_single_freq_quad.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
precomputed_results/p_029/p029_single_freq_relu.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
precomputed_results/p_031/p031_phase_align_quad.png
CHANGED
|
|
precomputed_results/p_031/p031_phase_align_relu.png
CHANGED
|
|
precomputed_results/p_031/p031_single_freq_quad.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
precomputed_results/p_031/p031_single_freq_relu.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|