from __future__ import annotations import json import time from pathlib import Path from typing import Any import torch from torch import nn from addition.config import ExperimentConfig, ensure_output_dirs, parse_config, save_config from addition.data import build_batch, build_evaluation_suite, digits_to_string, exact_sum_matches, sample_training_batch, seed_everything from addition.eval import evaluate_problem_set, evaluate_suite, flatten_nested_metrics from addition.model import build_model, describe_model from addition.plots import plot_single_run_results def _maybe_init_wandb(config: ExperimentConfig, output_dir: Path): if not config.use_wandb or config.wandb_mode == "disabled": return None try: import wandb except ImportError: print("wandb is not installed; continuing with local logging only.") return None run = wandb.init( project=config.wandb_project, entity=config.wandb_entity or None, name=config.effective_run_name, mode=config.wandb_mode, config={"experiment": config.__dict__}, dir=str(output_dir), reinit=True, ) return run def _save_json(path: Path, payload: dict[str, Any]) -> None: with path.open("w", encoding="utf-8") as handle: json.dump(payload, handle, indent=2, sort_keys=True) def _save_checkpoint(path: Path, model: nn.Module, optimizer: torch.optim.Optimizer, metadata: dict[str, Any]) -> None: torch.save( { "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "metadata": metadata, }, path, ) def _stage_checkpoint_path(stage_directory: Path, stage: int) -> Path: return stage_directory / f"stage_{stage:02d}_passed.pt" def _evaluate_current_stage( model: nn.Module, config: ExperimentConfig, suite, stage: int, device: str, ) -> dict[str, float]: stage_metrics, _ = evaluate_problem_set( model=model, config=config, problems=suite.validation_uniform[stage], active_digits=stage, device=device, return_attention=False, ) return { "digit_accuracy": stage_metrics.digit_accuracy, "final_carry_accuracy": stage_metrics.final_carry_accuracy, "exact_match": stage_metrics.exact_match, } def _masked_digit_loss( logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor, loss_fn: nn.Module, ) -> torch.Tensor: masked_logits = logits[mask] masked_targets = targets[mask] if masked_logits.numel() == 0: return logits.new_zeros(()) return loss_fn(masked_logits, masked_targets) @torch.no_grad() def _print_model_debug_format( model: nn.Module, config: ExperimentConfig, *, stage: int, rng, device: str, ) -> None: debug_batch = sample_training_batch(config=config, stage=stage, rng=rng, device=device) outputs = model(debug_batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False) print("[addition debug] model_architecture", flush=True) print(model, flush=True) print( "[addition debug] batch_format " f"stage={stage} input_shape={tuple(debug_batch.input_ids.shape)} " f"target_digits_shape={tuple(debug_batch.target_digits.shape)} " f"target_mask_shape={tuple(debug_batch.target_digit_mask.shape)} " f"target_final_carry_shape={tuple(debug_batch.target_final_carry.shape)} " f"digit_logits_shape={tuple(outputs.digit_logits.shape)} " f"final_carry_logits_shape={tuple(outputs.final_carry_logits.shape)} " f"output_hidden_shape={tuple(outputs.output_hidden.shape)}", flush=True, ) @torch.no_grad() def _print_validation_samples( model: nn.Module, config: ExperimentConfig, problems, *, stage: int, device: str, limit: int = 3, ) -> None: sample_problems = list(problems[:limit]) if not sample_problems: return batch = build_batch(problems=sample_problems, radix=config.radix, device=device) outputs = model(batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False) predicted_digits = outputs.digit_logits.argmax(dim=-1).cpu().tolist() predicted_final_carry = outputs.final_carry_logits.argmax(dim=-1).cpu().tolist() for example_index, problem in enumerate(sample_problems): truth_digits = problem.sum_digits[:stage] truth_final_carry = problem.carry_out[stage - 1] pred_digits = predicted_digits[example_index][:stage] pred_final_carry = int(predicted_final_carry[example_index]) exact = exact_sum_matches( predicted_digits=pred_digits, predicted_final_carry=pred_final_carry, truth_digits=truth_digits, truth_final_carry=truth_final_carry, ) a_text = digits_to_string(problem.a_digits[:stage], final_carry=0, radix=config.radix) b_text = digits_to_string(problem.b_digits[:stage], final_carry=0, radix=config.radix) pred_text = digits_to_string(pred_digits, final_carry=pred_final_carry, radix=config.radix) truth_text = digits_to_string(truth_digits, final_carry=truth_final_carry, radix=config.radix) print( f"[addition sample] stage={stage} idx={example_index} " f"a={a_text} b={b_text} pred={pred_text} true={truth_text} " f"pred_digits={pred_digits} pred_carry={pred_final_carry} " f"true_digits={truth_digits} true_carry={truth_final_carry} exact={int(exact)}", flush=True, ) def run_experiment(config: ExperimentConfig) -> dict[str, Any]: directories = ensure_output_dirs(config) save_config(config, directories["root"]) seed_everything(config.seed) device = config.device model = build_model(config, device=device) optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) digit_loss_fn = nn.CrossEntropyLoss() final_carry_loss_fn = nn.CrossEntropyLoss() suite = build_evaluation_suite(config) rng = __import__("random").Random(config.seed + 12345) history: list[dict[str, Any]] = [] best_validation = -1.0 best_checkpoint_path = directories["checkpoints"] / "best.pt" last_checkpoint_path = directories["checkpoints"] / "last.pt" stage = config.initial_stage if config.uses_curriculum else config.train_max_digits stage_steps = 0 global_step = 0 stop_reason = "train_steps_exhausted" wandb_run = _maybe_init_wandb(config, directories["root"]) started_at = time.time() param_counts = describe_model(config) print( f"[addition train] model={config.model} seed={config.seed} device={device} " f"params={param_counts['total_params']} stage={stage}", flush=True, ) _print_model_debug_format(model=model, config=config, stage=stage, rng=rng, device=device) while global_step < config.train_steps: model.train() batch = sample_training_batch(config=config, stage=stage, rng=rng, device=device) optimizer.zero_grad(set_to_none=True) outputs = model(batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False) digit_loss = _masked_digit_loss( logits=outputs.digit_logits, targets=batch.target_digits, mask=batch.target_digit_mask, loss_fn=digit_loss_fn, ) final_carry_loss = final_carry_loss_fn(outputs.final_carry_logits, batch.target_final_carry) loss = digit_loss + final_carry_loss loss.backward() if config.grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm) optimizer.step() global_step += 1 stage_steps += 1 if global_step % max(1, config.validation_interval // 2) == 0: train_message = ( f"[addition train] step={global_step} stage={stage} " f"loss={loss.item():.4f} digit_loss={digit_loss.item():.4f} " f"final_carry_loss={final_carry_loss.item():.4f}" ) print(train_message, flush=True) should_validate = ( global_step % config.validation_interval == 0 or global_step == config.train_steps or (config.uses_curriculum and stage_steps == config.max_steps_per_stage) ) if not should_validate: continue validation = _evaluate_current_stage(model=model, config=config, suite=suite, stage=stage, device=device) history_entry = { "global_step": global_step, "stage": stage, "stage_steps": stage_steps, "loss": float(loss.item()), "digit_loss": float(digit_loss.item()), "final_carry_loss": float(final_carry_loss.item()), "validation_digit_accuracy": validation["digit_accuracy"], "validation_final_carry_accuracy": validation["final_carry_accuracy"], "validation_exact_match": validation["exact_match"], "latent_steps": config.latent_steps_for_stage(stage), } history.append(history_entry) print( f"[addition val] step={global_step} stage={stage} " f"digit_acc={validation['digit_accuracy']:.4f} final_carry_acc={validation['final_carry_accuracy']:.4f} " f"exact={validation['exact_match']:.4f}", flush=True, ) _print_validation_samples( model=model, config=config, problems=suite.validation_uniform[stage], stage=stage, device=device, ) if wandb_run is not None: payload = { "train/loss": float(loss.item()), "train/digit_loss": float(digit_loss.item()), "train/final_carry_loss": float(final_carry_loss.item()), "train/stage": stage, "train/latent_steps": config.latent_steps_for_stage(stage), "validation/digit_accuracy": validation["digit_accuracy"], "validation/final_carry_accuracy": validation["final_carry_accuracy"], "validation/exact_match": validation["exact_match"], "step": global_step, } wandb_run.log(payload) if validation["exact_match"] >= best_validation: best_validation = validation["exact_match"] _save_checkpoint( best_checkpoint_path, model, optimizer, metadata={ "global_step": global_step, "stage": stage, "best_validation_exact_match": best_validation, }, ) reached_threshold = validation["exact_match"] >= config.stage_accuracy_threshold reached_cap = stage_steps >= config.max_steps_per_stage if config.uses_curriculum: if stage < config.train_max_digits and reached_threshold: _save_checkpoint( _stage_checkpoint_path(directories["stage_checkpoints"], stage), model, optimizer, metadata={ "global_step": global_step, "stage": stage, "validation_exact_match": validation["exact_match"], "validation_digit_accuracy": validation["digit_accuracy"], "validation_final_carry_accuracy": validation["final_carry_accuracy"], }, ) print( f"[addition curriculum] advance {stage} -> {stage + 1} " f"(exact_match={validation['exact_match']:.4f})", flush=True, ) stage += 1 stage_steps = 0 continue if reached_cap and not reached_threshold: print( f"[addition curriculum] hold stage={stage} after {stage_steps} steps " f"(exact_match={validation['exact_match']:.4f} < threshold={config.stage_accuracy_threshold:.2f})", flush=True, ) if stage == config.train_max_digits and reached_threshold: stop_reason = "final_stage_threshold" break _save_checkpoint( last_checkpoint_path, model, optimizer, metadata={ "global_step": global_step, "stage": stage, "stop_reason": stop_reason, }, ) best_payload = torch.load(best_checkpoint_path, map_location=device) model.load_state_dict(best_payload["model_state"]) final_results = evaluate_suite(model=model, config=config, suite=suite, device=device) flat_final_metrics = flatten_nested_metrics("", final_results) summary = { "config": config.__dict__, "param_counts": param_counts, "best_validation_exact_match": best_validation, "global_step": global_step, "final_stage": stage, "stop_reason": stop_reason, "elapsed_seconds": time.time() - started_at, "history": history, "final_results": final_results, "flat_final_metrics": flat_final_metrics, } _save_json(directories["artifacts"] / "summary.json", summary) with (directories["artifacts"] / "history.jsonl").open("w", encoding="utf-8") as handle: for entry in history: handle.write(json.dumps(entry, sort_keys=True) + "\n") plot_single_run_results(summary, directories["plots"]) if wandb_run is not None: wandb_run.log(flat_final_metrics | {"step": global_step}) wandb_run.summary.update( { "best_validation_exact_match": best_validation, "final_stage": stage, "stop_reason": stop_reason, } ) wandb_run.finish() return summary def main() -> None: config = parse_config("Train the addition carry experiment.") run_experiment(config) if __name__ == "__main__": main()