zhuoranyang commited on
Commit
e518ead
·
verified ·
1 Parent(s): 8df851f

Fix Tab 7: 10K epochs, better neuron selection, fixed timepoints

Browse files
README.md CHANGED
@@ -160,8 +160,8 @@ Each modulus produces ~33 files in `precomputed_results/p_XXX/`:
160
  | `standard` | ReLU | AdamW | 5e-5 | 0 | 100% | 5,000 | Tabs 1–4 |
161
  | `grokking` | ReLU | AdamW | 1e-4 | 2.0 | 75% | 50,000 | Tabs 1, 6 |
162
  | `quad_random` | Quad | AdamW | 5e-5 | 0 | 100% | 5,000 | Tab 5 |
163
- | `quad_single_freq` | Quad | SGD | 0.1 | 0 | 100% | 5,000 | Tab 7 |
164
- | `relu_single_freq` | ReLU | SGD | 0.01 | 0 | 100% | 5,000 | Tab 7 |
165
 
166
  ## Running a Single Experiment
167
 
 
160
  | `standard` | ReLU | AdamW | 5e-5 | 0 | 100% | 5,000 | Tabs 1–4 |
161
  | `grokking` | ReLU | AdamW | 1e-4 | 2.0 | 75% | 50,000 | Tabs 1, 6 |
162
  | `quad_random` | Quad | AdamW | 5e-5 | 0 | 100% | 5,000 | Tab 5 |
163
+ | `quad_single_freq` | Quad | SGD | 0.1 | 0 | 100% | 10,000 | Tab 7 |
164
+ | `relu_single_freq` | ReLU | SGD | 0.01 | 0 | 100% | 10,000 | Tab 7 |
165
 
166
  ## Running a Single Experiment
167
 
precompute/generate_plots.py CHANGED
@@ -1603,12 +1603,36 @@ class PlotGenerator:
1603
  'phi_out': phi_out,
1604
  })
1605
 
1606
- # Select a neuron that shows clear phase alignment
1607
- # Pick neuron with largest final scale
 
 
 
1608
  final_records = [r for r in all_neuron_records if r['epoch'] == epochs[-1]]
1609
  if not final_records:
1610
  continue
1611
- best_neuron = max(final_records, key=lambda r: r['scale_in'])['neuron']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1612
 
1613
  # Extract trajectory for this neuron
1614
  neuron_records = [r for r in all_neuron_records if r['neuron'] == best_neuron]
@@ -1667,16 +1691,18 @@ class PlotGenerator:
1667
  _save_fig(fig, self._out(f'phase_align_{prefix}.png'))
1668
 
1669
  # ---- Decoded weights at timepoints ----
 
 
1670
  if prefix == 'quad':
1671
- keys = [0]
1672
- mid = min(epochs, key=lambda e: abs(e - 1000))
1673
- end = epochs[-1]
1674
- if mid not in keys:
1675
- keys.append(mid)
1676
- if end not in keys:
1677
- keys.append(end)
1678
  else:
1679
- keys = [0, epochs[-1]]
 
 
 
 
 
 
1680
 
1681
  num_components = min(20, d_mlp)
1682
  n = len(keys)
 
1603
  'phi_out': phi_out,
1604
  })
1605
 
1606
+ # Select a neuron that shows interesting phase convergence dynamics.
1607
+ # The lottery winner (largest final scale) already has ψ ≈ 2φ from
1608
+ # the start, producing flat boring plots. Instead, pick a neuron that
1609
+ # (a) has significant final scale (top quartile → actually learned),
1610
+ # (b) had the largest initial phase misalignment |ψ₀ - 2φ₀|.
1611
  final_records = [r for r in all_neuron_records if r['epoch'] == epochs[-1]]
1612
  if not final_records:
1613
  continue
1614
+ init_records = [r for r in all_neuron_records if r['epoch'] == epochs[0]]
1615
+ init_by_neuron = {r['neuron']: r for r in init_records}
1616
+
1617
+ # Keep neurons with final scale in top 25%
1618
+ scales = sorted([r['scale_in'] for r in final_records], reverse=True)
1619
+ scale_threshold = scales[max(0, len(scales) // 4 - 1)] if len(scales) >= 4 else scales[-1]
1620
+ strong_neurons = [r for r in final_records if r['scale_in'] >= scale_threshold]
1621
+
1622
+ # Among strong neurons, pick the one with largest initial misalignment
1623
+ best_neuron = None
1624
+ best_misalign = -1.0
1625
+ for r in strong_neurons:
1626
+ n = r['neuron']
1627
+ if n not in init_by_neuron:
1628
+ continue
1629
+ ir = init_by_neuron[n]
1630
+ misalign = abs(normalize_to_pi(ir['phi_out'] - 2 * ir['phi_in']))
1631
+ if misalign > best_misalign:
1632
+ best_misalign = misalign
1633
+ best_neuron = n
1634
+ if best_neuron is None:
1635
+ best_neuron = max(final_records, key=lambda r: r['scale_in'])['neuron']
1636
 
1637
  # Extract trajectory for this neuron
1638
  neuron_records = [r for r in all_neuron_records if r['neuron'] == best_neuron]
 
1691
  _save_fig(fig, self._out(f'phase_align_{prefix}.png'))
1692
 
1693
  # ---- Decoded weights at timepoints ----
1694
+ # Use fixed timepoints matching the notebook figures:
1695
+ # Quad: steps 0, 1000, 5000 ReLU: steps 0, 5000
1696
  if prefix == 'quad':
1697
+ target_keys = [0, 1000, 5000]
 
 
 
 
 
 
1698
  else:
1699
+ target_keys = [0, 5000]
1700
+ # Snap each target to the nearest available checkpoint epoch
1701
+ keys = []
1702
+ for t in target_keys:
1703
+ nearest = min(epochs, key=lambda e: abs(e - t))
1704
+ if nearest not in keys:
1705
+ keys.append(nearest)
1706
 
1707
  num_components = min(20, d_mlp)
1708
  n = len(keys)
precompute/prime_config.py CHANGED
@@ -94,7 +94,7 @@ TRAINING_RUNS = {
94
  "lr": 0.1,
95
  "weight_decay": 0,
96
  "frac_train": 1.0,
97
- "num_epochs": 5000,
98
  "save_every": 200,
99
  "init_scale": 0.02,
100
  "save_models": True,
@@ -109,7 +109,7 @@ TRAINING_RUNS = {
109
  "lr": 0.01,
110
  "weight_decay": 0,
111
  "frac_train": 1.0,
112
- "num_epochs": 5000,
113
  "save_every": 200,
114
  "init_scale": 0.002,
115
  "save_models": True,
 
94
  "lr": 0.1,
95
  "weight_decay": 0,
96
  "frac_train": 1.0,
97
+ "num_epochs": 10000,
98
  "save_every": 200,
99
  "init_scale": 0.02,
100
  "save_models": True,
 
109
  "lr": 0.01,
110
  "weight_decay": 0,
111
  "frac_train": 1.0,
112
+ "num_epochs": 10000,
113
  "save_every": 200,
114
  "init_scale": 0.002,
115
  "save_models": True,
precomputed_results/p_015/p015_phase_align_quad.png CHANGED
precomputed_results/p_015/p015_phase_align_relu.png CHANGED
precomputed_results/p_015/p015_single_freq_quad.png CHANGED

Git LFS Details

  • SHA256: ccbca685fc93991c9965ca3e8fc3fa87ae3fd03e74fb5c40b9f03468e63dda79
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB

Git LFS Details

  • SHA256: e5906064f6dbfbeca87b012e3a4df82884d4c5349469e352ffe1ea2339a1d857
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
precomputed_results/p_015/p015_single_freq_relu.png CHANGED

Git LFS Details

  • SHA256: fa5f0678750033da499f044eb3d50a94cfd0f2ce00b69dbd76692a8ae6be4aa6
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB

Git LFS Details

  • SHA256: d6c7fe07aa3804a74002a27f11d9d5fc0cf59f4878a782ed599182af615b366c
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
precomputed_results/p_023/p023_phase_align_quad.png CHANGED
precomputed_results/p_023/p023_phase_align_relu.png CHANGED
precomputed_results/p_023/p023_single_freq_quad.png CHANGED

Git LFS Details

  • SHA256: 48cecf07f923c6d0757652f68c427b4da1b3cde5f6b3dd8e2503bf60d9058f13
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB

Git LFS Details

  • SHA256: 514323ec0a4055901b250f0fc32289d125dd2e1678d947417ad26ce33e531063
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
precomputed_results/p_023/p023_single_freq_relu.png CHANGED

Git LFS Details

  • SHA256: 441e3592d831b1ce488d7445610f296841302f31b208417f9e942d50adad4129
  • Pointer size: 131 Bytes
  • Size of remote file: 125 kB

Git LFS Details

  • SHA256: 8e30db53c38ab81e903456d954ee0c279d40fc28b8481d710a2d4e41a3c236bc
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
precomputed_results/p_029/p029_phase_align_quad.png CHANGED
precomputed_results/p_029/p029_phase_align_relu.png CHANGED
precomputed_results/p_029/p029_single_freq_quad.png CHANGED

Git LFS Details

  • SHA256: a4bc4fa7afb28ff80d3ffa30a460293e12f604a4ee3238eddddb8402cc13f79a
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB

Git LFS Details

  • SHA256: 71aa792034d88cd9c5a7dbeb8ffa3088b52817fcf7209d0ddf1c120f1439f0f8
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
precomputed_results/p_029/p029_single_freq_relu.png CHANGED

Git LFS Details

  • SHA256: 85dcb7facbf0161bad07679baeb1a2c96b684176b9706e4547b0808ea122adce
  • Pointer size: 131 Bytes
  • Size of remote file: 126 kB

Git LFS Details

  • SHA256: 964d258244a653f430aa88e4aa62abcbba753bef38961e6b8c49797ab22bf4bb
  • Pointer size: 131 Bytes
  • Size of remote file: 125 kB
precomputed_results/p_031/p031_phase_align_quad.png CHANGED
precomputed_results/p_031/p031_phase_align_relu.png CHANGED
precomputed_results/p_031/p031_single_freq_quad.png CHANGED

Git LFS Details

  • SHA256: 00ed32cbc3259510286729ea0d16cac0f7050de0e35c4bca4e65f5d46b0270dc
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB

Git LFS Details

  • SHA256: c5a7ccd341127c493b97e32ea2b02b42a5d3678b7f38db3bc57d8b441abd30d9
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
precomputed_results/p_031/p031_single_freq_relu.png CHANGED

Git LFS Details

  • SHA256: b4873e378e97c8ec6c04d1da29b9df4382d434baf78edbb1e6ca3e6571b646e4
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB

Git LFS Details

  • SHA256: bc893000f749ec845d0cb00c4103a2d68dd4acfbabd55f8f17eb4196624b3848
  • Pointer size: 131 Bytes
  • Size of remote file: 126 kB