| from __future__ import annotations |
|
|
| import json |
| import time |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| from torch import nn |
|
|
| from addition.config import ExperimentConfig, ensure_output_dirs, parse_config, save_config |
| from addition.data import build_batch, build_evaluation_suite, digits_to_string, exact_sum_matches, sample_training_batch, seed_everything |
| from addition.eval import evaluate_problem_set, evaluate_suite, flatten_nested_metrics |
| from addition.model import build_model, describe_model |
| from addition.plots import plot_single_run_results |
|
|
|
|
| def _maybe_init_wandb(config: ExperimentConfig, output_dir: Path): |
| if not config.use_wandb or config.wandb_mode == "disabled": |
| return None |
| try: |
| import wandb |
| except ImportError: |
| print("wandb is not installed; continuing with local logging only.") |
| return None |
| run = wandb.init( |
| project=config.wandb_project, |
| entity=config.wandb_entity or None, |
| name=config.effective_run_name, |
| mode=config.wandb_mode, |
| config={"experiment": config.__dict__}, |
| dir=str(output_dir), |
| reinit=True, |
| ) |
| return run |
|
|
|
|
| def _save_json(path: Path, payload: dict[str, Any]) -> None: |
| with path.open("w", encoding="utf-8") as handle: |
| json.dump(payload, handle, indent=2, sort_keys=True) |
|
|
|
|
| def _save_checkpoint(path: Path, model: nn.Module, optimizer: torch.optim.Optimizer, metadata: dict[str, Any]) -> None: |
| torch.save( |
| { |
| "model_state": model.state_dict(), |
| "optimizer_state": optimizer.state_dict(), |
| "metadata": metadata, |
| }, |
| path, |
| ) |
|
|
|
|
| def _stage_checkpoint_path(stage_directory: Path, stage: int) -> Path: |
| return stage_directory / f"stage_{stage:02d}_passed.pt" |
|
|
|
|
| def _evaluate_current_stage( |
| model: nn.Module, |
| config: ExperimentConfig, |
| suite, |
| stage: int, |
| device: str, |
| ) -> dict[str, float]: |
| stage_metrics, _ = evaluate_problem_set( |
| model=model, |
| config=config, |
| problems=suite.validation_uniform[stage], |
| active_digits=stage, |
| device=device, |
| return_attention=False, |
| ) |
| return { |
| "digit_accuracy": stage_metrics.digit_accuracy, |
| "final_carry_accuracy": stage_metrics.final_carry_accuracy, |
| "exact_match": stage_metrics.exact_match, |
| } |
|
|
|
|
| def _masked_digit_loss( |
| logits: torch.Tensor, |
| targets: torch.Tensor, |
| mask: torch.Tensor, |
| loss_fn: nn.Module, |
| ) -> torch.Tensor: |
| masked_logits = logits[mask] |
| masked_targets = targets[mask] |
| if masked_logits.numel() == 0: |
| return logits.new_zeros(()) |
| return loss_fn(masked_logits, masked_targets) |
|
|
|
|
| @torch.no_grad() |
| def _print_model_debug_format( |
| model: nn.Module, |
| config: ExperimentConfig, |
| *, |
| stage: int, |
| rng, |
| device: str, |
| ) -> None: |
| debug_batch = sample_training_batch(config=config, stage=stage, rng=rng, device=device) |
| outputs = model(debug_batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False) |
| print("[addition debug] model_architecture", flush=True) |
| print(model, flush=True) |
| print( |
| "[addition debug] batch_format " |
| f"stage={stage} input_shape={tuple(debug_batch.input_ids.shape)} " |
| f"target_digits_shape={tuple(debug_batch.target_digits.shape)} " |
| f"target_mask_shape={tuple(debug_batch.target_digit_mask.shape)} " |
| f"target_final_carry_shape={tuple(debug_batch.target_final_carry.shape)} " |
| f"digit_logits_shape={tuple(outputs.digit_logits.shape)} " |
| f"final_carry_logits_shape={tuple(outputs.final_carry_logits.shape)} " |
| f"output_hidden_shape={tuple(outputs.output_hidden.shape)}", |
| flush=True, |
| ) |
|
|
|
|
| @torch.no_grad() |
| def _print_validation_samples( |
| model: nn.Module, |
| config: ExperimentConfig, |
| problems, |
| *, |
| stage: int, |
| device: str, |
| limit: int = 3, |
| ) -> None: |
| sample_problems = list(problems[:limit]) |
| if not sample_problems: |
| return |
| batch = build_batch(problems=sample_problems, radix=config.radix, device=device) |
| outputs = model(batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False) |
| predicted_digits = outputs.digit_logits.argmax(dim=-1).cpu().tolist() |
| predicted_final_carry = outputs.final_carry_logits.argmax(dim=-1).cpu().tolist() |
|
|
| for example_index, problem in enumerate(sample_problems): |
| truth_digits = problem.sum_digits[:stage] |
| truth_final_carry = problem.carry_out[stage - 1] |
| pred_digits = predicted_digits[example_index][:stage] |
| pred_final_carry = int(predicted_final_carry[example_index]) |
| exact = exact_sum_matches( |
| predicted_digits=pred_digits, |
| predicted_final_carry=pred_final_carry, |
| truth_digits=truth_digits, |
| truth_final_carry=truth_final_carry, |
| ) |
| a_text = digits_to_string(problem.a_digits[:stage], final_carry=0, radix=config.radix) |
| b_text = digits_to_string(problem.b_digits[:stage], final_carry=0, radix=config.radix) |
| pred_text = digits_to_string(pred_digits, final_carry=pred_final_carry, radix=config.radix) |
| truth_text = digits_to_string(truth_digits, final_carry=truth_final_carry, radix=config.radix) |
| print( |
| f"[addition sample] stage={stage} idx={example_index} " |
| f"a={a_text} b={b_text} pred={pred_text} true={truth_text} " |
| f"pred_digits={pred_digits} pred_carry={pred_final_carry} " |
| f"true_digits={truth_digits} true_carry={truth_final_carry} exact={int(exact)}", |
| flush=True, |
| ) |
|
|
|
|
| def run_experiment(config: ExperimentConfig) -> dict[str, Any]: |
| directories = ensure_output_dirs(config) |
| save_config(config, directories["root"]) |
| seed_everything(config.seed) |
| device = config.device |
| model = build_model(config, device=device) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) |
| digit_loss_fn = nn.CrossEntropyLoss() |
| final_carry_loss_fn = nn.CrossEntropyLoss() |
| suite = build_evaluation_suite(config) |
| rng = __import__("random").Random(config.seed + 12345) |
| history: list[dict[str, Any]] = [] |
| best_validation = -1.0 |
| best_checkpoint_path = directories["checkpoints"] / "best.pt" |
| last_checkpoint_path = directories["checkpoints"] / "last.pt" |
| stage = config.initial_stage if config.uses_curriculum else config.train_max_digits |
| stage_steps = 0 |
| global_step = 0 |
| stop_reason = "train_steps_exhausted" |
| wandb_run = _maybe_init_wandb(config, directories["root"]) |
| started_at = time.time() |
| param_counts = describe_model(config) |
| print( |
| f"[addition train] model={config.model} seed={config.seed} device={device} " |
| f"params={param_counts['total_params']} stage={stage}", |
| flush=True, |
| ) |
| _print_model_debug_format(model=model, config=config, stage=stage, rng=rng, device=device) |
|
|
| while global_step < config.train_steps: |
| model.train() |
| batch = sample_training_batch(config=config, stage=stage, rng=rng, device=device) |
| optimizer.zero_grad(set_to_none=True) |
| outputs = model(batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False) |
| digit_loss = _masked_digit_loss( |
| logits=outputs.digit_logits, |
| targets=batch.target_digits, |
| mask=batch.target_digit_mask, |
| loss_fn=digit_loss_fn, |
| ) |
| final_carry_loss = final_carry_loss_fn(outputs.final_carry_logits, batch.target_final_carry) |
| loss = digit_loss + final_carry_loss |
| loss.backward() |
| if config.grad_clip_norm > 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm) |
| optimizer.step() |
|
|
| global_step += 1 |
| stage_steps += 1 |
|
|
| if global_step % max(1, config.validation_interval // 2) == 0: |
| train_message = ( |
| f"[addition train] step={global_step} stage={stage} " |
| f"loss={loss.item():.4f} digit_loss={digit_loss.item():.4f} " |
| f"final_carry_loss={final_carry_loss.item():.4f}" |
| ) |
| print(train_message, flush=True) |
|
|
| should_validate = ( |
| global_step % config.validation_interval == 0 |
| or global_step == config.train_steps |
| or (config.uses_curriculum and stage_steps == config.max_steps_per_stage) |
| ) |
| if not should_validate: |
| continue |
|
|
| validation = _evaluate_current_stage(model=model, config=config, suite=suite, stage=stage, device=device) |
| history_entry = { |
| "global_step": global_step, |
| "stage": stage, |
| "stage_steps": stage_steps, |
| "loss": float(loss.item()), |
| "digit_loss": float(digit_loss.item()), |
| "final_carry_loss": float(final_carry_loss.item()), |
| "validation_digit_accuracy": validation["digit_accuracy"], |
| "validation_final_carry_accuracy": validation["final_carry_accuracy"], |
| "validation_exact_match": validation["exact_match"], |
| "latent_steps": config.latent_steps_for_stage(stage), |
| } |
| history.append(history_entry) |
| print( |
| f"[addition val] step={global_step} stage={stage} " |
| f"digit_acc={validation['digit_accuracy']:.4f} final_carry_acc={validation['final_carry_accuracy']:.4f} " |
| f"exact={validation['exact_match']:.4f}", |
| flush=True, |
| ) |
| _print_validation_samples( |
| model=model, |
| config=config, |
| problems=suite.validation_uniform[stage], |
| stage=stage, |
| device=device, |
| ) |
| if wandb_run is not None: |
| payload = { |
| "train/loss": float(loss.item()), |
| "train/digit_loss": float(digit_loss.item()), |
| "train/final_carry_loss": float(final_carry_loss.item()), |
| "train/stage": stage, |
| "train/latent_steps": config.latent_steps_for_stage(stage), |
| "validation/digit_accuracy": validation["digit_accuracy"], |
| "validation/final_carry_accuracy": validation["final_carry_accuracy"], |
| "validation/exact_match": validation["exact_match"], |
| "step": global_step, |
| } |
| wandb_run.log(payload) |
|
|
| if validation["exact_match"] >= best_validation: |
| best_validation = validation["exact_match"] |
| _save_checkpoint( |
| best_checkpoint_path, |
| model, |
| optimizer, |
| metadata={ |
| "global_step": global_step, |
| "stage": stage, |
| "best_validation_exact_match": best_validation, |
| }, |
| ) |
|
|
| reached_threshold = validation["exact_match"] >= config.stage_accuracy_threshold |
| reached_cap = stage_steps >= config.max_steps_per_stage |
|
|
| if config.uses_curriculum: |
| if stage < config.train_max_digits and reached_threshold: |
| _save_checkpoint( |
| _stage_checkpoint_path(directories["stage_checkpoints"], stage), |
| model, |
| optimizer, |
| metadata={ |
| "global_step": global_step, |
| "stage": stage, |
| "validation_exact_match": validation["exact_match"], |
| "validation_digit_accuracy": validation["digit_accuracy"], |
| "validation_final_carry_accuracy": validation["final_carry_accuracy"], |
| }, |
| ) |
| print( |
| f"[addition curriculum] advance {stage} -> {stage + 1} " |
| f"(exact_match={validation['exact_match']:.4f})", |
| flush=True, |
| ) |
| stage += 1 |
| stage_steps = 0 |
| continue |
| if reached_cap and not reached_threshold: |
| print( |
| f"[addition curriculum] hold stage={stage} after {stage_steps} steps " |
| f"(exact_match={validation['exact_match']:.4f} < threshold={config.stage_accuracy_threshold:.2f})", |
| flush=True, |
| ) |
| if stage == config.train_max_digits and reached_threshold: |
| stop_reason = "final_stage_threshold" |
| break |
|
|
| _save_checkpoint( |
| last_checkpoint_path, |
| model, |
| optimizer, |
| metadata={ |
| "global_step": global_step, |
| "stage": stage, |
| "stop_reason": stop_reason, |
| }, |
| ) |
|
|
| best_payload = torch.load(best_checkpoint_path, map_location=device) |
| model.load_state_dict(best_payload["model_state"]) |
| final_results = evaluate_suite(model=model, config=config, suite=suite, device=device) |
| flat_final_metrics = flatten_nested_metrics("", final_results) |
| summary = { |
| "config": config.__dict__, |
| "param_counts": param_counts, |
| "best_validation_exact_match": best_validation, |
| "global_step": global_step, |
| "final_stage": stage, |
| "stop_reason": stop_reason, |
| "elapsed_seconds": time.time() - started_at, |
| "history": history, |
| "final_results": final_results, |
| "flat_final_metrics": flat_final_metrics, |
| } |
| _save_json(directories["artifacts"] / "summary.json", summary) |
| with (directories["artifacts"] / "history.jsonl").open("w", encoding="utf-8") as handle: |
| for entry in history: |
| handle.write(json.dumps(entry, sort_keys=True) + "\n") |
| plot_single_run_results(summary, directories["plots"]) |
|
|
| if wandb_run is not None: |
| wandb_run.log(flat_final_metrics | {"step": global_step}) |
| wandb_run.summary.update( |
| { |
| "best_validation_exact_match": best_validation, |
| "final_stage": stage, |
| "stop_reason": stop_reason, |
| } |
| ) |
| wandb_run.finish() |
|
|
| return summary |
|
|
|
|
| def main() -> None: |
| config = parse_config("Train the addition carry experiment.") |
| run_experiment(config) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|