| from __future__ import annotations |
|
|
| import argparse |
| import dataclasses |
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Iterable |
|
|
| import torch |
|
|
|
|
| VALID_MODELS = ("nocurr_nocot", "curr_nocot", "curr_cot") |
| VALID_PRESETS = ("default", "smoke") |
|
|
|
|
| @dataclass |
| class ExperimentConfig: |
| model: str = "nocurr_nocot" |
| output_dir: str = "addition_runs/default" |
| seed: int = 0 |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| preset: str = "default" |
| run_name: str = "" |
| notes: str = "" |
| use_wandb: bool = True |
| wandb_project: str = "addition-carry" |
| wandb_entity: str = "" |
| wandb_mode: str = "online" |
| radix: int = 10 |
| train_max_digits: int = 12 |
| eval_max_digits: int = 20 |
| ood_lengths: tuple[int, ...] = (14, 16, 20) |
| train_batch_size: int = 256 |
| eval_batch_size: int = 512 |
| learning_rate: float = 3e-4 |
| weight_decay: float = 1e-2 |
| grad_clip_norm: float = 1.0 |
| carry_loss_weight: float = 0.0 |
| train_steps: int = 3600 |
| max_steps_per_stage: int = 300 |
| validation_interval: int = 100 |
| stage_accuracy_threshold: float = 0.99 |
| initial_stage: int = 1 |
| eval_examples_per_length: int = 256 |
| carry_heavy_examples_per_length: int = 256 |
| train_carry_heavy_prob: float = 0.15 |
| d_model: int = 512 |
| n_heads: int = 1 |
| ff_dim: int = 2048 |
| dropout: float = 0.0 |
| max_latent_steps: int = 12 |
| attention_probe_examples: int = 256 |
| linear_probe_epochs: int = 150 |
| linear_probe_lr: float = 1e-2 |
| comparison_num_seeds: int = 5 |
|
|
| def __post_init__(self) -> None: |
| if self.model not in VALID_MODELS: |
| raise ValueError(f"Unsupported model: {self.model}") |
| if self.preset not in VALID_PRESETS: |
| raise ValueError(f"Unsupported preset: {self.preset}") |
| if self.train_max_digits > self.eval_max_digits: |
| raise ValueError("train_max_digits must be <= eval_max_digits") |
| if self.max_latent_steps < 0: |
| raise ValueError("max_latent_steps must be non-negative") |
| if self.radix < 2 or self.radix > 16: |
| raise ValueError("radix must be between 2 and 16") |
| if self.initial_stage < 1 or self.initial_stage > self.train_max_digits: |
| raise ValueError("initial_stage must be between 1 and train_max_digits") |
| self.ood_lengths = tuple(int(v) for v in self.ood_lengths if int(v) > self.train_max_digits) |
| if not self.ood_lengths: |
| self.ood_lengths = (self.eval_max_digits,) |
|
|
| @property |
| def uses_curriculum(self) -> bool: |
| return self.model in {"curr_nocot", "curr_cot"} |
|
|
| @property |
| def uses_latent_cot(self) -> bool: |
| return self.model == "curr_cot" |
|
|
| @property |
| def discrete_vocab_size(self) -> int: |
| return self.radix + 2 |
|
|
| @property |
| def digit_vocab_size(self) -> int: |
| return self.radix |
|
|
| @property |
| def input_sequence_length(self) -> int: |
| return self.input_sequence_length_for_digits(self.eval_max_digits) |
|
|
| @property |
| def output_sequence_length(self) -> int: |
| return self.output_sequence_length_for_digits(self.eval_max_digits) |
|
|
| @property |
| def base_sequence_length(self) -> int: |
| return self.base_sequence_length_for_digits(self.eval_max_digits) |
|
|
| @property |
| def max_sequence_length(self) -> int: |
| return self.base_sequence_length + self.max_latent_steps |
|
|
| @property |
| def effective_run_name(self) -> str: |
| if self.run_name: |
| return self.run_name |
| return f"{self.model}_base{self.radix}_seed{self.seed}" |
|
|
| def input_sequence_length_for_digits(self, active_digits: int) -> int: |
| return (int(active_digits) * 2) + 2 |
|
|
| def output_sequence_length_for_digits(self, active_digits: int) -> int: |
| return int(active_digits) + 1 |
|
|
| def base_sequence_length_for_digits(self, active_digits: int) -> int: |
| return self.input_sequence_length_for_digits(active_digits) + self.output_sequence_length_for_digits(active_digits) |
|
|
| def latent_steps_for_stage(self, stage: int) -> int: |
| if not self.uses_latent_cot: |
| return 0 |
| return max(0, min(int(stage), int(self.max_latent_steps))) |
|
|
|
|
| def default_output_root() -> Path: |
| return Path("addition_runs") |
|
|
|
|
| def apply_preset(config: ExperimentConfig) -> ExperimentConfig: |
| config = dataclasses.replace(config) |
| if config.preset == "smoke": |
| config.output_dir = config.output_dir or str(default_output_root() / "smoke") |
| config.train_batch_size = 64 |
| config.eval_batch_size = 128 |
| config.d_model = 128 |
| config.n_heads = 1 |
| config.ff_dim = 512 |
| config.train_steps = 180 |
| config.max_steps_per_stage = 40 |
| config.validation_interval = 20 |
| config.eval_examples_per_length = 64 |
| config.carry_heavy_examples_per_length = 64 |
| config.attention_probe_examples = 64 |
| config.linear_probe_epochs = 60 |
| config.comparison_num_seeds = 2 |
| return config |
|
|
|
|
| def config_to_dict(config: ExperimentConfig) -> dict: |
| data = dataclasses.asdict(config) |
| data["ood_lengths"] = list(config.ood_lengths) |
| data["uses_curriculum"] = config.uses_curriculum |
| data["uses_latent_cot"] = config.uses_latent_cot |
| data["discrete_vocab_size"] = config.discrete_vocab_size |
| data["input_sequence_length"] = config.input_sequence_length |
| data["output_sequence_length"] = config.output_sequence_length |
| data["base_sequence_length"] = config.base_sequence_length |
| data["max_sequence_length"] = config.max_sequence_length |
| data["effective_run_name"] = config.effective_run_name |
| return data |
|
|
|
|
| def save_config(config: ExperimentConfig, output_dir: Path) -> None: |
| output_dir.mkdir(parents=True, exist_ok=True) |
| with (output_dir / "config.json").open("w", encoding="utf-8") as handle: |
| json.dump(config_to_dict(config), handle, indent=2, sort_keys=True) |
|
|
|
|
| def add_config_arguments(parser: argparse.ArgumentParser) -> None: |
| parser.add_argument("--model", choices=VALID_MODELS, default="nocurr_nocot") |
| parser.add_argument("--output_dir", type=str, default="") |
| parser.add_argument("--seed", type=int, default=0) |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--preset", choices=VALID_PRESETS, default="default") |
| parser.add_argument("--run_name", type=str, default="") |
| parser.add_argument("--notes", type=str, default="") |
| parser.add_argument("--use_wandb", action="store_true") |
| parser.add_argument("--no_wandb", action="store_true") |
| parser.add_argument("--wandb_project", type=str, default="addition-carry") |
| parser.add_argument("--wandb_entity", type=str, default="") |
| parser.add_argument("--wandb_mode", type=str, default="online", choices=("online", "offline", "disabled")) |
| parser.add_argument("--radix", type=int, default=10) |
| parser.add_argument("--train_max_digits", type=int, default=12) |
| parser.add_argument("--eval_max_digits", type=int, default=20) |
| parser.add_argument("--ood_lengths", type=int, nargs="*", default=[14, 16, 20]) |
| parser.add_argument("--train_batch_size", type=int, default=256) |
| parser.add_argument("--eval_batch_size", type=int, default=512) |
| parser.add_argument("--learning_rate", type=float, default=3e-4) |
| parser.add_argument("--weight_decay", type=float, default=1e-2) |
| parser.add_argument("--grad_clip_norm", type=float, default=1.0) |
| parser.add_argument("--carry_loss_weight", type=float, default=0.0) |
| parser.add_argument("--train_steps", type=int, default=3600) |
| parser.add_argument("--max_steps_per_stage", type=int, default=300) |
| parser.add_argument("--validation_interval", type=int, default=100) |
| parser.add_argument("--stage_accuracy_threshold", type=float, default=0.99) |
| parser.add_argument("--initial_stage", type=int, default=1) |
| parser.add_argument("--eval_examples_per_length", type=int, default=256) |
| parser.add_argument("--carry_heavy_examples_per_length", type=int, default=256) |
| parser.add_argument("--train_carry_heavy_prob", type=float, default=0.15) |
| parser.add_argument("--d_model", type=int, default=512) |
| parser.add_argument("--n_heads", type=int, default=1) |
| parser.add_argument("--ff_dim", type=int, default=2048) |
| parser.add_argument("--dropout", type=float, default=0.0) |
| parser.add_argument("--max_latent_steps", type=int, default=12) |
| parser.add_argument("--attention_probe_examples", type=int, default=256) |
| parser.add_argument("--linear_probe_epochs", type=int, default=150) |
| parser.add_argument("--linear_probe_lr", type=float, default=1e-2) |
| parser.add_argument("--comparison_num_seeds", type=int, default=5) |
|
|
|
|
| def build_config_from_args(args: argparse.Namespace) -> ExperimentConfig: |
| use_wandb = bool(args.use_wandb or not args.no_wandb) |
| if args.wandb_mode == "disabled": |
| use_wandb = False |
| output_dir = args.output_dir or str(default_output_root() / f"{args.model}_base{args.radix}_seed{args.seed}") |
| config = ExperimentConfig( |
| model=args.model, |
| output_dir=output_dir, |
| seed=args.seed, |
| device=args.device, |
| preset=args.preset, |
| run_name=args.run_name, |
| notes=args.notes, |
| use_wandb=use_wandb, |
| wandb_project=args.wandb_project, |
| wandb_entity=args.wandb_entity, |
| wandb_mode=args.wandb_mode, |
| radix=args.radix, |
| train_max_digits=args.train_max_digits, |
| eval_max_digits=args.eval_max_digits, |
| ood_lengths=tuple(args.ood_lengths), |
| train_batch_size=args.train_batch_size, |
| eval_batch_size=args.eval_batch_size, |
| learning_rate=args.learning_rate, |
| weight_decay=args.weight_decay, |
| grad_clip_norm=args.grad_clip_norm, |
| carry_loss_weight=args.carry_loss_weight, |
| train_steps=args.train_steps, |
| max_steps_per_stage=args.max_steps_per_stage, |
| validation_interval=args.validation_interval, |
| stage_accuracy_threshold=args.stage_accuracy_threshold, |
| initial_stage=args.initial_stage, |
| eval_examples_per_length=args.eval_examples_per_length, |
| carry_heavy_examples_per_length=args.carry_heavy_examples_per_length, |
| train_carry_heavy_prob=args.train_carry_heavy_prob, |
| d_model=args.d_model, |
| n_heads=args.n_heads, |
| ff_dim=args.ff_dim, |
| dropout=args.dropout, |
| max_latent_steps=args.max_latent_steps, |
| attention_probe_examples=args.attention_probe_examples, |
| linear_probe_epochs=args.linear_probe_epochs, |
| linear_probe_lr=args.linear_probe_lr, |
| comparison_num_seeds=args.comparison_num_seeds, |
| ) |
| return apply_preset(config) |
|
|
|
|
| def build_arg_parser(description: str) -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description=description) |
| add_config_arguments(parser) |
| return parser |
|
|
|
|
| def parse_config(description: str) -> ExperimentConfig: |
| parser = build_arg_parser(description) |
| args = parser.parse_args() |
| return build_config_from_args(args) |
|
|
|
|
| def ensure_output_dirs(config: ExperimentConfig) -> dict[str, Path]: |
| root = Path(config.output_dir) |
| directories = { |
| "root": root, |
| "checkpoints": root / "checkpoints", |
| "stage_checkpoints": root / "checkpoints" / "stages", |
| "plots": root / "plots", |
| "artifacts": root / "artifacts", |
| } |
| for directory in directories.values(): |
| directory.mkdir(parents=True, exist_ok=True) |
| return directories |
|
|
|
|
| def flatten_metric_dict(prefix: str, metrics: dict[str, float | int | str]) -> dict[str, float | int | str]: |
| return {f"{prefix}{key}": value for key, value in metrics.items()} |
|
|
|
|
| def iter_stage_lengths(config: ExperimentConfig) -> Iterable[int]: |
| for stage in range(1, config.train_max_digits + 1): |
| yield stage |
|
|