Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
import argparse
import dataclasses
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import torch
VALID_MODELS = ("nocurr_nocot", "curr_nocot", "curr_cot")
VALID_PRESETS = ("default", "smoke")
@dataclass
class ExperimentConfig:
model: str = "nocurr_nocot"
output_dir: str = "addition_runs/default"
seed: int = 0
device: str = "cuda" if torch.cuda.is_available() else "cpu"
preset: str = "default"
run_name: str = ""
notes: str = ""
use_wandb: bool = True
wandb_project: str = "addition-carry"
wandb_entity: str = ""
wandb_mode: str = "online"
radix: int = 10
train_max_digits: int = 12
eval_max_digits: int = 20
ood_lengths: tuple[int, ...] = (14, 16, 20)
train_batch_size: int = 256
eval_batch_size: int = 512
learning_rate: float = 3e-4
weight_decay: float = 1e-2
grad_clip_norm: float = 1.0
carry_loss_weight: float = 0.0
train_steps: int = 3600
max_steps_per_stage: int = 300
validation_interval: int = 100
stage_accuracy_threshold: float = 0.99
initial_stage: int = 1
eval_examples_per_length: int = 256
carry_heavy_examples_per_length: int = 256
train_carry_heavy_prob: float = 0.15
d_model: int = 512
n_heads: int = 1
ff_dim: int = 2048
dropout: float = 0.0
max_latent_steps: int = 12
attention_probe_examples: int = 256
linear_probe_epochs: int = 150
linear_probe_lr: float = 1e-2
comparison_num_seeds: int = 5
def __post_init__(self) -> None:
if self.model not in VALID_MODELS:
raise ValueError(f"Unsupported model: {self.model}")
if self.preset not in VALID_PRESETS:
raise ValueError(f"Unsupported preset: {self.preset}")
if self.train_max_digits > self.eval_max_digits:
raise ValueError("train_max_digits must be <= eval_max_digits")
if self.max_latent_steps < 0:
raise ValueError("max_latent_steps must be non-negative")
if self.radix < 2 or self.radix > 16:
raise ValueError("radix must be between 2 and 16")
if self.initial_stage < 1 or self.initial_stage > self.train_max_digits:
raise ValueError("initial_stage must be between 1 and train_max_digits")
self.ood_lengths = tuple(int(v) for v in self.ood_lengths if int(v) > self.train_max_digits)
if not self.ood_lengths:
self.ood_lengths = (self.eval_max_digits,)
@property
def uses_curriculum(self) -> bool:
return self.model in {"curr_nocot", "curr_cot"}
@property
def uses_latent_cot(self) -> bool:
return self.model == "curr_cot"
@property
def discrete_vocab_size(self) -> int:
return self.radix + 2
@property
def digit_vocab_size(self) -> int:
return self.radix
@property
def input_sequence_length(self) -> int:
return self.input_sequence_length_for_digits(self.eval_max_digits)
@property
def output_sequence_length(self) -> int:
return self.output_sequence_length_for_digits(self.eval_max_digits)
@property
def base_sequence_length(self) -> int:
return self.base_sequence_length_for_digits(self.eval_max_digits)
@property
def max_sequence_length(self) -> int:
return self.base_sequence_length + self.max_latent_steps
@property
def effective_run_name(self) -> str:
if self.run_name:
return self.run_name
return f"{self.model}_base{self.radix}_seed{self.seed}"
def input_sequence_length_for_digits(self, active_digits: int) -> int:
return (int(active_digits) * 2) + 2
def output_sequence_length_for_digits(self, active_digits: int) -> int:
return int(active_digits) + 1
def base_sequence_length_for_digits(self, active_digits: int) -> int:
return self.input_sequence_length_for_digits(active_digits) + self.output_sequence_length_for_digits(active_digits)
def latent_steps_for_stage(self, stage: int) -> int:
if not self.uses_latent_cot:
return 0
return max(0, min(int(stage), int(self.max_latent_steps)))
def default_output_root() -> Path:
return Path("addition_runs")
def apply_preset(config: ExperimentConfig) -> ExperimentConfig:
config = dataclasses.replace(config)
if config.preset == "smoke":
config.output_dir = config.output_dir or str(default_output_root() / "smoke")
config.train_batch_size = 64
config.eval_batch_size = 128
config.d_model = 128
config.n_heads = 1
config.ff_dim = 512
config.train_steps = 180
config.max_steps_per_stage = 40
config.validation_interval = 20
config.eval_examples_per_length = 64
config.carry_heavy_examples_per_length = 64
config.attention_probe_examples = 64
config.linear_probe_epochs = 60
config.comparison_num_seeds = 2
return config
def config_to_dict(config: ExperimentConfig) -> dict:
data = dataclasses.asdict(config)
data["ood_lengths"] = list(config.ood_lengths)
data["uses_curriculum"] = config.uses_curriculum
data["uses_latent_cot"] = config.uses_latent_cot
data["discrete_vocab_size"] = config.discrete_vocab_size
data["input_sequence_length"] = config.input_sequence_length
data["output_sequence_length"] = config.output_sequence_length
data["base_sequence_length"] = config.base_sequence_length
data["max_sequence_length"] = config.max_sequence_length
data["effective_run_name"] = config.effective_run_name
return data
def save_config(config: ExperimentConfig, output_dir: Path) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
with (output_dir / "config.json").open("w", encoding="utf-8") as handle:
json.dump(config_to_dict(config), handle, indent=2, sort_keys=True)
def add_config_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--model", choices=VALID_MODELS, default="nocurr_nocot")
parser.add_argument("--output_dir", type=str, default="")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--preset", choices=VALID_PRESETS, default="default")
parser.add_argument("--run_name", type=str, default="")
parser.add_argument("--notes", type=str, default="")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--no_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="addition-carry")
parser.add_argument("--wandb_entity", type=str, default="")
parser.add_argument("--wandb_mode", type=str, default="online", choices=("online", "offline", "disabled"))
parser.add_argument("--radix", type=int, default=10)
parser.add_argument("--train_max_digits", type=int, default=12)
parser.add_argument("--eval_max_digits", type=int, default=20)
parser.add_argument("--ood_lengths", type=int, nargs="*", default=[14, 16, 20])
parser.add_argument("--train_batch_size", type=int, default=256)
parser.add_argument("--eval_batch_size", type=int, default=512)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--weight_decay", type=float, default=1e-2)
parser.add_argument("--grad_clip_norm", type=float, default=1.0)
parser.add_argument("--carry_loss_weight", type=float, default=0.0)
parser.add_argument("--train_steps", type=int, default=3600)
parser.add_argument("--max_steps_per_stage", type=int, default=300)
parser.add_argument("--validation_interval", type=int, default=100)
parser.add_argument("--stage_accuracy_threshold", type=float, default=0.99)
parser.add_argument("--initial_stage", type=int, default=1)
parser.add_argument("--eval_examples_per_length", type=int, default=256)
parser.add_argument("--carry_heavy_examples_per_length", type=int, default=256)
parser.add_argument("--train_carry_heavy_prob", type=float, default=0.15)
parser.add_argument("--d_model", type=int, default=512)
parser.add_argument("--n_heads", type=int, default=1)
parser.add_argument("--ff_dim", type=int, default=2048)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--max_latent_steps", type=int, default=12)
parser.add_argument("--attention_probe_examples", type=int, default=256)
parser.add_argument("--linear_probe_epochs", type=int, default=150)
parser.add_argument("--linear_probe_lr", type=float, default=1e-2)
parser.add_argument("--comparison_num_seeds", type=int, default=5)
def build_config_from_args(args: argparse.Namespace) -> ExperimentConfig:
use_wandb = bool(args.use_wandb or not args.no_wandb)
if args.wandb_mode == "disabled":
use_wandb = False
output_dir = args.output_dir or str(default_output_root() / f"{args.model}_base{args.radix}_seed{args.seed}")
config = ExperimentConfig(
model=args.model,
output_dir=output_dir,
seed=args.seed,
device=args.device,
preset=args.preset,
run_name=args.run_name,
notes=args.notes,
use_wandb=use_wandb,
wandb_project=args.wandb_project,
wandb_entity=args.wandb_entity,
wandb_mode=args.wandb_mode,
radix=args.radix,
train_max_digits=args.train_max_digits,
eval_max_digits=args.eval_max_digits,
ood_lengths=tuple(args.ood_lengths),
train_batch_size=args.train_batch_size,
eval_batch_size=args.eval_batch_size,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
grad_clip_norm=args.grad_clip_norm,
carry_loss_weight=args.carry_loss_weight,
train_steps=args.train_steps,
max_steps_per_stage=args.max_steps_per_stage,
validation_interval=args.validation_interval,
stage_accuracy_threshold=args.stage_accuracy_threshold,
initial_stage=args.initial_stage,
eval_examples_per_length=args.eval_examples_per_length,
carry_heavy_examples_per_length=args.carry_heavy_examples_per_length,
train_carry_heavy_prob=args.train_carry_heavy_prob,
d_model=args.d_model,
n_heads=args.n_heads,
ff_dim=args.ff_dim,
dropout=args.dropout,
max_latent_steps=args.max_latent_steps,
attention_probe_examples=args.attention_probe_examples,
linear_probe_epochs=args.linear_probe_epochs,
linear_probe_lr=args.linear_probe_lr,
comparison_num_seeds=args.comparison_num_seeds,
)
return apply_preset(config)
def build_arg_parser(description: str) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=description)
add_config_arguments(parser)
return parser
def parse_config(description: str) -> ExperimentConfig:
parser = build_arg_parser(description)
args = parser.parse_args()
return build_config_from_args(args)
def ensure_output_dirs(config: ExperimentConfig) -> dict[str, Path]:
root = Path(config.output_dir)
directories = {
"root": root,
"checkpoints": root / "checkpoints",
"stage_checkpoints": root / "checkpoints" / "stages",
"plots": root / "plots",
"artifacts": root / "artifacts",
}
for directory in directories.values():
directory.mkdir(parents=True, exist_ok=True)
return directories
def flatten_metric_dict(prefix: str, metrics: dict[str, float | int | str]) -> dict[str, float | int | str]:
return {f"{prefix}{key}": value for key, value in metrics.items()}
def iter_stage_lengths(config: ExperimentConfig) -> Iterable[int]:
for stage in range(1, config.train_max_digits + 1):
yield stage