zhuoranyang's picture
Fix Tab 7: 10K epochs, better neuron selection, fixed timepoints
e518ead verified
"""
Configuration for all moduli and training runs.
Defines the d_mlp sizing formula and the 5 training run configurations.
p can be any odd number >= 3 (not restricted to primes).
"""
import math
def get_moduli(low=3, high=199):
"""Return all odd numbers in [low, high]."""
moduli = []
for n in range(low, high + 1):
if n >= 3 and n % 2 == 1:
moduli.append(n)
return moduli
# Keep old name as alias for backward compatibility
get_primes = get_moduli
def compute_d_mlp(p: int) -> int:
"""
Compute d_mlp maintaining the ratio from p=23, d_mlp=512.
Formula: d_mlp = max(512, ceil(512/529 * p^2))
Can have more neurons but not less than the ratio dictates.
"""
ratio = 512 / (23 ** 2) # 512/529 ≈ 0.9679
return max(512, math.ceil(ratio * p * p))
# Minimum p overall (p=2 has 0 non-DC frequencies, making Fourier analysis degenerate)
MIN_P = 3
# Minimum p for grokking experiments (need enough test data for meaningful split)
MIN_P_GROKKING = 19
# Backward-compatible aliases
MIN_PRIME = MIN_P
MIN_PRIME_GROKKING = MIN_P_GROKKING
# 5 training run configurations per p
TRAINING_RUNS = {
"standard": {
"embed_type": "one_hot",
"init_type": "random",
"optimizer": "AdamW",
"act_type": "ReLU",
"lr": 5e-5,
"weight_decay": 0,
"frac_train": 1.0,
"num_epochs": 5000,
"save_every": 200,
"init_scale": 0.1,
"save_models": True,
"batch_style": "full",
"seed": 42,
},
"grokking": {
"embed_type": "one_hot",
"init_type": "random",
"optimizer": "AdamW",
"act_type": "ReLU",
"lr": 1e-4,
"weight_decay": 2.0,
"frac_train": 0.75,
"num_epochs": 50000,
"save_every": 200,
"init_scale": 0.1,
"save_models": True,
"batch_style": "full",
"seed": 42,
},
"quad_random": {
"embed_type": "one_hot",
"init_type": "random",
"optimizer": "AdamW",
"act_type": "Quad",
"lr": 5e-5,
"weight_decay": 0,
"frac_train": 1.0,
"num_epochs": 5000,
"save_every": 200,
"init_scale": 0.1,
"save_models": True,
"batch_style": "full",
"seed": 42,
},
"quad_single_freq": {
"embed_type": "one_hot",
"init_type": "single-freq",
"optimizer": "SGD",
"act_type": "Quad",
"lr": 0.1,
"weight_decay": 0,
"frac_train": 1.0,
"num_epochs": 10000,
"save_every": 200,
"init_scale": 0.02,
"save_models": True,
"batch_style": "full",
"seed": 42,
},
"relu_single_freq": {
"embed_type": "one_hot",
"init_type": "single-freq",
"optimizer": "SGD",
"act_type": "ReLU",
"lr": 0.01,
"weight_decay": 0,
"frac_train": 1.0,
"num_epochs": 10000,
"save_every": 200,
"init_scale": 0.002,
"save_models": True,
"batch_style": "full",
"seed": 42,
},
}
# Analytical computation configs (no training needed)
ANALYTICAL_CONFIGS = {
"decouple_dynamics": {
"init_k": 2,
"num_steps_case1": 1400,
"learning_rate_case1": 1,
"init_phi_case1": 1.5,
"init_psi_case1": 0.18,
"num_steps_case2": 700,
"learning_rate_case2": 1,
"init_phi_case2": -0.72,
"init_psi_case2": -2.91,
"amplitude": 0.02,
},
}