zhuoranyang's picture
Improve HF result commit auth and increase training log heartbeat
144d5cc
#!/usr/bin/env python3
"""
Batch training script for all odd moduli p in [3, 199].
Usage:
# Train all runs for all odd p
python train_all.py --all
# Train specific p
python train_all.py --p 23
# Train specific run type for a p
python train_all.py --p 23 --run standard
# Resume (skips completed runs)
python train_all.py --all --resume
# Custom output directory
python train_all.py --all --output ./my_models
"""
import argparse
import json
import os
import sys
import time
# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
import torch
from prime_config import get_moduli, compute_d_mlp, TRAINING_RUNS, MIN_P, MIN_P_GROKKING
from utils import Config
from nnTrainer import Trainer
def build_config_dict(p, run_params, d_mlp_override=None):
"""Build a nested config dict compatible with the Config class."""
d_mlp = d_mlp_override if d_mlp_override is not None else compute_d_mlp(p)
return {
'data': {
'p': p,
'd_vocab': None,
'fn_name': 'add',
'frac_train': run_params['frac_train'],
'batch_style': run_params['batch_style'],
},
'model': {
'd_model': None,
'd_mlp': d_mlp,
'act_type': run_params['act_type'],
'embed_type': run_params['embed_type'],
'init_type': run_params['init_type'],
'init_scale': run_params['init_scale'],
},
'training': {
'num_epochs': run_params['num_epochs'],
'lr': run_params['lr'],
'weight_decay': run_params['weight_decay'],
'optimizer': run_params['optimizer'],
'stopping_thresh': -1,
'save_models': run_params['save_models'],
'save_every': run_params['save_every'],
'seed': run_params['seed'],
},
}
def _save_training_log(output_dir, p, run_name, run_params, d_mlp, curves):
"""Save a human-readable training_log.txt summarizing the run."""
log_path = os.path.join(output_dir, "training_log.txt")
n_epochs = len(curves.get('train_losses', []))
with open(log_path, 'w') as f:
f.write(f"{'=' * 70}\n")
f.write(f"Training Log: p={p}, run={run_name}\n")
f.write(f"{'=' * 70}\n\n")
f.write(f"Configuration:\n")
f.write(f" prime (p) = {p}\n")
f.write(f" d_mlp = {d_mlp}\n")
f.write(f" activation = {run_params['act_type']}\n")
f.write(f" init_type = {run_params['init_type']}\n")
f.write(f" init_scale = {run_params['init_scale']}\n")
f.write(f" optimizer = {run_params['optimizer']}\n")
f.write(f" learning_rate = {run_params['lr']}\n")
f.write(f" weight_decay = {run_params['weight_decay']}\n")
f.write(f" frac_train = {run_params['frac_train']}\n")
f.write(f" num_epochs = {run_params['num_epochs']}\n")
f.write(f" batch_style = {run_params['batch_style']}\n")
f.write(f" seed = {run_params['seed']}\n")
f.write(f"\n{'─' * 70}\n")
f.write(f"{'Epoch':>8s} {'Train Loss':>12s} {'Test Loss':>12s} "
f"{'Train Acc':>10s} {'Test Acc':>10s} "
f"{'Grad Norm':>10s} {'Param Norm':>11s}\n")
f.write(f"{'─' * 70}\n")
# Print every 100 epochs + the last epoch
train_losses = curves.get('train_losses', [])
test_losses = curves.get('test_losses', [])
train_accs = curves.get('train_accs', [])
test_accs = curves.get('test_accs', [])
grad_norms = curves.get('grad_norms', [])
param_norms = curves.get('param_norms', [])
step = max(1, n_epochs // 100) # ~100 lines
indices = list(range(0, n_epochs, step))
if n_epochs > 0 and (n_epochs - 1) not in indices:
indices.append(n_epochs - 1)
for i in indices:
tl = f"{train_losses[i]:.6f}" if i < len(train_losses) else "N/A"
tel = f"{test_losses[i]:.6f}" if i < len(test_losses) else "N/A"
ta = f"{train_accs[i]:.4f}" if i < len(train_accs) else "N/A"
tea = f"{test_accs[i]:.4f}" if i < len(test_accs) else "N/A"
gn = f"{grad_norms[i]:.4f}" if i < len(grad_norms) else "N/A"
pn = f"{param_norms[i]:.4f}" if i < len(param_norms) else "N/A"
f.write(f"{i:>8d} {tl:>12s} {tel:>12s} "
f"{ta:>10s} {tea:>10s} "
f"{gn:>10s} {pn:>11s}\n")
f.write(f"{'─' * 70}\n\n")
f.write(f"Final Results:\n")
if train_losses:
f.write(f" Train Loss = {train_losses[-1]:.6f}\n")
if test_losses:
f.write(f" Test Loss = {test_losses[-1]:.6f}\n")
if train_accs:
f.write(f" Train Acc = {train_accs[-1]:.4f}\n")
if test_accs:
f.write(f" Test Acc = {test_accs[-1]:.4f}\n")
if param_norms:
f.write(f" Param Norm = {param_norms[-1]:.4f}\n")
f.write(f"\nTotal epochs trained: {n_epochs}\n")
def run_training(p, run_name, output_base, d_mlp_override=None):
"""Train a single run for a single prime."""
if p < MIN_P:
print(f"[SKIP] p={p}, run={run_name}: p < {MIN_P} (too few Fourier frequencies)")
return
# Single-freq init needs at least 1 non-DC frequency: (p-1)//2 >= 1 → p >= 3
if run_name in ('quad_single_freq', 'relu_single_freq') and (p - 1) // 2 < 1:
print(f"[SKIP] p={p}, run={run_name}: no non-DC frequencies for single-freq init")
return
if run_name == 'grokking' and p < MIN_P_GROKKING:
print(f"[SKIP] p={p}, run={run_name}: p < {MIN_P_GROKKING} (too few test points)")
return
run_params = TRAINING_RUNS[run_name]
config_dict = build_config_dict(p, run_params, d_mlp_override)
d_mlp = d_mlp_override if d_mlp_override is not None else compute_d_mlp(p)
output_dir = os.path.join(output_base, f"p_{p:03d}", run_name)
os.makedirs(output_dir, exist_ok=True)
# Check if already completed
marker = os.path.join(output_dir, "DONE")
if os.path.exists(marker):
print(f"[SKIP] p={p}, run={run_name} already completed")
return
num_epochs = run_params['num_epochs']
print(f"[TRAIN] p={p}, d_mlp={d_mlp}, run={run_name}, "
f"epochs={num_epochs}")
config = Config(config_dict)
trainer = Trainer(config=config, use_wandb=False)
# Progress logging:
# - keep epoch-based logs reasonably frequent
# - also enforce a wall-clock heartbeat so streaming UIs stay active
log_interval = min(max(1, num_epochs // 20), 100)
max_silence_sec = 20
last_log_time = time.time()
# Override save directory so checkpoints go into our output structure
trainer.save_dir = output_dir
run_subdir = os.path.join(output_dir, trainer.run_name)
os.makedirs(run_subdir, exist_ok=True)
# Re-save train/test data to the overridden location so generate_plots.py
# can find them (Trainer.__init__ saves to the original save_dir)
torch.save(trainer.train, os.path.join(run_subdir, 'train_data.pth'))
torch.save(trainer.test, os.path.join(run_subdir, 'test_data.pth'))
trainer.initial_save_if_appropriate()
# Plateau early-stopping for grokking: after 10K epochs, if curves
# haven't changed in the last 1000 epochs, stop training.
plateau_check = (run_name == 'grokking')
plateau_min_epoch = 10000
plateau_window = 1000
plateau_loss_tol = 1e-3 # absolute change in loss
plateau_acc_tol = 0.005 # absolute change in accuracy
for epoch in range(config.num_epochs):
train_loss, test_loss = trainer.do_a_training_step(epoch)
# Progress logging
now = time.time()
if (
epoch % log_interval == 0
or epoch == config.num_epochs - 1
or (now - last_log_time) >= max_silence_sec
):
pct = 100 * (epoch + 1) / config.num_epochs
train_acc = trainer.train_accs[-1] if trainer.train_accs else 0
test_acc = trainer.test_accs[-1] if trainer.test_accs else 0
print(f" [{run_name}] Epoch {epoch:>6d}/{config.num_epochs}"
f" ({pct:5.1f}%)"
f" train_loss={train_loss.item():.4f}"
f" test_loss={test_loss.item():.4f}"
f" train_acc={train_acc:.4f}"
f" test_acc={test_acc:.4f}",
flush=True)
last_log_time = now
if test_loss.item() < config.stopping_thresh:
print(f" Early stopping at epoch {epoch}: "
f"test loss {test_loss.item():.6f}")
break
# Plateau detection for grokking
if (plateau_check and epoch >= plateau_min_epoch
and epoch % plateau_window == 0):
tl = trainer.train_losses
tel = trainer.test_losses
ta = trainer.train_accs
tea = trainer.test_accs
w = plateau_window
if len(tl) >= w and len(tel) >= w:
tl_flat = (max(tl[-w:]) - min(tl[-w:])) < plateau_loss_tol
tel_flat = (max(tel[-w:]) - min(tel[-w:])) < plateau_loss_tol
ta_flat = (not ta) or (max(ta[-w:]) - min(ta[-w:])) < plateau_acc_tol
tea_flat = (not tea) or (max(tea[-w:]) - min(tea[-w:])) < plateau_acc_tol
if tl_flat and tel_flat and ta_flat and tea_flat:
print(f" Plateau early stopping at epoch {epoch}: "
f"no change in last {w} epochs")
break
if config.is_it_time_to_save(epoch=epoch):
trainer.save_epoch(epoch=epoch, save_to_wandb=False, local_save=True)
trainer.post_training_save(
save_optimizer_and_scheduler=False, log_to_wandb=False
)
# Save training curves as JSON for plot generation
curves = {
'train_losses': trainer.train_losses,
'test_losses': trainer.test_losses,
'train_accs': trainer.train_accs,
'test_accs': trainer.test_accs,
'grad_norms': trainer.grad_norms,
'param_norms': trainer.param_norms,
}
curves_path = os.path.join(output_dir, "training_curves.json")
with open(curves_path, 'w') as f:
json.dump(curves, f)
# Save a human-readable training log
_save_training_log(output_dir, p, run_name, run_params, d_mlp, curves)
# Write completion marker
with open(marker, 'w') as f:
f.write(f"p={p} run={run_name} completed\n")
print(f"[DONE] p={p}, run={run_name}, "
f"train_acc={trainer.train_accs[-1]:.4f}, "
f"test_acc={trainer.test_accs[-1]:.4f}")
def main():
parser = argparse.ArgumentParser(
description='Batch training for modular addition experiments'
)
parser.add_argument('--all', action='store_true',
help='Train all odd p in [3, 199]')
parser.add_argument('--p', type=int,
help='Train a specific odd modulus p')
parser.add_argument('--run', type=str, choices=list(TRAINING_RUNS.keys()),
help='Train a specific run type')
parser.add_argument('--output', type=str, default='./trained_models',
help='Output directory for trained models')
parser.add_argument('--d_mlp', type=int, default=None,
help='Override d_mlp (number of hidden neurons). '
'Default: auto-computed from p.')
parser.add_argument('--resume', action='store_true',
help='Skip already-completed runs (checks DONE marker)')
args = parser.parse_args()
if not args.all and args.p is None:
parser.error("Specify --all or --p P")
moduli = [args.p] if args.p else get_moduli()
runs = [args.run] if args.run else list(TRAINING_RUNS.keys())
total = len(moduli) * len(runs)
completed = 0
for p in moduli:
for run_name in runs:
completed += 1
print(f"\n{'='*60}")
print(f"[{completed}/{total}] p={p}, run={run_name}")
print(f"{'='*60}")
try:
run_training(p, run_name, args.output, d_mlp_override=args.d_mlp)
except Exception as e:
print(f"[FAIL] p={p}, run={run_name}: {e}")
import traceback
traceback.print_exc()
print(f"\nAll done. {completed} runs processed.")
if __name__ == "__main__":
main()