File size: 11,891 Bytes
76de008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
from __future__ import annotations

import argparse
import dataclasses
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable

import torch


VALID_MODELS = ("nocurr_nocot", "curr_nocot", "curr_cot")
VALID_PRESETS = ("default", "smoke")


@dataclass
class ExperimentConfig:
    model: str = "nocurr_nocot"
    output_dir: str = "addition_runs/default"
    seed: int = 0
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    preset: str = "default"
    run_name: str = ""
    notes: str = ""
    use_wandb: bool = True
    wandb_project: str = "addition-carry"
    wandb_entity: str = ""
    wandb_mode: str = "online"
    radix: int = 10
    train_max_digits: int = 12
    eval_max_digits: int = 20
    ood_lengths: tuple[int, ...] = (14, 16, 20)
    train_batch_size: int = 256
    eval_batch_size: int = 512
    learning_rate: float = 3e-4
    weight_decay: float = 1e-2
    grad_clip_norm: float = 1.0
    carry_loss_weight: float = 0.0
    train_steps: int = 3600
    max_steps_per_stage: int = 300
    validation_interval: int = 100
    stage_accuracy_threshold: float = 0.99
    initial_stage: int = 1
    eval_examples_per_length: int = 256
    carry_heavy_examples_per_length: int = 256
    train_carry_heavy_prob: float = 0.15
    d_model: int = 512
    n_heads: int = 1
    ff_dim: int = 2048
    dropout: float = 0.0
    max_latent_steps: int = 12
    attention_probe_examples: int = 256
    linear_probe_epochs: int = 150
    linear_probe_lr: float = 1e-2
    comparison_num_seeds: int = 5

    def __post_init__(self) -> None:
        if self.model not in VALID_MODELS:
            raise ValueError(f"Unsupported model: {self.model}")
        if self.preset not in VALID_PRESETS:
            raise ValueError(f"Unsupported preset: {self.preset}")
        if self.train_max_digits > self.eval_max_digits:
            raise ValueError("train_max_digits must be <= eval_max_digits")
        if self.max_latent_steps < 0:
            raise ValueError("max_latent_steps must be non-negative")
        if self.radix < 2 or self.radix > 16:
            raise ValueError("radix must be between 2 and 16")
        if self.initial_stage < 1 or self.initial_stage > self.train_max_digits:
            raise ValueError("initial_stage must be between 1 and train_max_digits")
        self.ood_lengths = tuple(int(v) for v in self.ood_lengths if int(v) > self.train_max_digits)
        if not self.ood_lengths:
            self.ood_lengths = (self.eval_max_digits,)

    @property
    def uses_curriculum(self) -> bool:
        return self.model in {"curr_nocot", "curr_cot"}

    @property
    def uses_latent_cot(self) -> bool:
        return self.model == "curr_cot"

    @property
    def discrete_vocab_size(self) -> int:
        return self.radix + 2

    @property
    def digit_vocab_size(self) -> int:
        return self.radix

    @property
    def input_sequence_length(self) -> int:
        return self.input_sequence_length_for_digits(self.eval_max_digits)

    @property
    def output_sequence_length(self) -> int:
        return self.output_sequence_length_for_digits(self.eval_max_digits)

    @property
    def base_sequence_length(self) -> int:
        return self.base_sequence_length_for_digits(self.eval_max_digits)

    @property
    def max_sequence_length(self) -> int:
        return self.base_sequence_length + self.max_latent_steps

    @property
    def effective_run_name(self) -> str:
        if self.run_name:
            return self.run_name
        return f"{self.model}_base{self.radix}_seed{self.seed}"

    def input_sequence_length_for_digits(self, active_digits: int) -> int:
        return (int(active_digits) * 2) + 2

    def output_sequence_length_for_digits(self, active_digits: int) -> int:
        return int(active_digits) + 1

    def base_sequence_length_for_digits(self, active_digits: int) -> int:
        return self.input_sequence_length_for_digits(active_digits) + self.output_sequence_length_for_digits(active_digits)

    def latent_steps_for_stage(self, stage: int) -> int:
        if not self.uses_latent_cot:
            return 0
        return max(0, min(int(stage), int(self.max_latent_steps)))


def default_output_root() -> Path:
    return Path("addition_runs")


def apply_preset(config: ExperimentConfig) -> ExperimentConfig:
    config = dataclasses.replace(config)
    if config.preset == "smoke":
        config.output_dir = config.output_dir or str(default_output_root() / "smoke")
        config.train_batch_size = 64
        config.eval_batch_size = 128
        config.d_model = 128
        config.n_heads = 1
        config.ff_dim = 512
        config.train_steps = 180
        config.max_steps_per_stage = 40
        config.validation_interval = 20
        config.eval_examples_per_length = 64
        config.carry_heavy_examples_per_length = 64
        config.attention_probe_examples = 64
        config.linear_probe_epochs = 60
        config.comparison_num_seeds = 2
    return config


def config_to_dict(config: ExperimentConfig) -> dict:
    data = dataclasses.asdict(config)
    data["ood_lengths"] = list(config.ood_lengths)
    data["uses_curriculum"] = config.uses_curriculum
    data["uses_latent_cot"] = config.uses_latent_cot
    data["discrete_vocab_size"] = config.discrete_vocab_size
    data["input_sequence_length"] = config.input_sequence_length
    data["output_sequence_length"] = config.output_sequence_length
    data["base_sequence_length"] = config.base_sequence_length
    data["max_sequence_length"] = config.max_sequence_length
    data["effective_run_name"] = config.effective_run_name
    return data


def save_config(config: ExperimentConfig, output_dir: Path) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    with (output_dir / "config.json").open("w", encoding="utf-8") as handle:
        json.dump(config_to_dict(config), handle, indent=2, sort_keys=True)


def add_config_arguments(parser: argparse.ArgumentParser) -> None:
    parser.add_argument("--model", choices=VALID_MODELS, default="nocurr_nocot")
    parser.add_argument("--output_dir", type=str, default="")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--preset", choices=VALID_PRESETS, default="default")
    parser.add_argument("--run_name", type=str, default="")
    parser.add_argument("--notes", type=str, default="")
    parser.add_argument("--use_wandb", action="store_true")
    parser.add_argument("--no_wandb", action="store_true")
    parser.add_argument("--wandb_project", type=str, default="addition-carry")
    parser.add_argument("--wandb_entity", type=str, default="")
    parser.add_argument("--wandb_mode", type=str, default="online", choices=("online", "offline", "disabled"))
    parser.add_argument("--radix", type=int, default=10)
    parser.add_argument("--train_max_digits", type=int, default=12)
    parser.add_argument("--eval_max_digits", type=int, default=20)
    parser.add_argument("--ood_lengths", type=int, nargs="*", default=[14, 16, 20])
    parser.add_argument("--train_batch_size", type=int, default=256)
    parser.add_argument("--eval_batch_size", type=int, default=512)
    parser.add_argument("--learning_rate", type=float, default=3e-4)
    parser.add_argument("--weight_decay", type=float, default=1e-2)
    parser.add_argument("--grad_clip_norm", type=float, default=1.0)
    parser.add_argument("--carry_loss_weight", type=float, default=0.0)
    parser.add_argument("--train_steps", type=int, default=3600)
    parser.add_argument("--max_steps_per_stage", type=int, default=300)
    parser.add_argument("--validation_interval", type=int, default=100)
    parser.add_argument("--stage_accuracy_threshold", type=float, default=0.99)
    parser.add_argument("--initial_stage", type=int, default=1)
    parser.add_argument("--eval_examples_per_length", type=int, default=256)
    parser.add_argument("--carry_heavy_examples_per_length", type=int, default=256)
    parser.add_argument("--train_carry_heavy_prob", type=float, default=0.15)
    parser.add_argument("--d_model", type=int, default=512)
    parser.add_argument("--n_heads", type=int, default=1)
    parser.add_argument("--ff_dim", type=int, default=2048)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--max_latent_steps", type=int, default=12)
    parser.add_argument("--attention_probe_examples", type=int, default=256)
    parser.add_argument("--linear_probe_epochs", type=int, default=150)
    parser.add_argument("--linear_probe_lr", type=float, default=1e-2)
    parser.add_argument("--comparison_num_seeds", type=int, default=5)


def build_config_from_args(args: argparse.Namespace) -> ExperimentConfig:
    use_wandb = bool(args.use_wandb or not args.no_wandb)
    if args.wandb_mode == "disabled":
        use_wandb = False
    output_dir = args.output_dir or str(default_output_root() / f"{args.model}_base{args.radix}_seed{args.seed}")
    config = ExperimentConfig(
        model=args.model,
        output_dir=output_dir,
        seed=args.seed,
        device=args.device,
        preset=args.preset,
        run_name=args.run_name,
        notes=args.notes,
        use_wandb=use_wandb,
        wandb_project=args.wandb_project,
        wandb_entity=args.wandb_entity,
        wandb_mode=args.wandb_mode,
        radix=args.radix,
        train_max_digits=args.train_max_digits,
        eval_max_digits=args.eval_max_digits,
        ood_lengths=tuple(args.ood_lengths),
        train_batch_size=args.train_batch_size,
        eval_batch_size=args.eval_batch_size,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        grad_clip_norm=args.grad_clip_norm,
        carry_loss_weight=args.carry_loss_weight,
        train_steps=args.train_steps,
        max_steps_per_stage=args.max_steps_per_stage,
        validation_interval=args.validation_interval,
        stage_accuracy_threshold=args.stage_accuracy_threshold,
        initial_stage=args.initial_stage,
        eval_examples_per_length=args.eval_examples_per_length,
        carry_heavy_examples_per_length=args.carry_heavy_examples_per_length,
        train_carry_heavy_prob=args.train_carry_heavy_prob,
        d_model=args.d_model,
        n_heads=args.n_heads,
        ff_dim=args.ff_dim,
        dropout=args.dropout,
        max_latent_steps=args.max_latent_steps,
        attention_probe_examples=args.attention_probe_examples,
        linear_probe_epochs=args.linear_probe_epochs,
        linear_probe_lr=args.linear_probe_lr,
        comparison_num_seeds=args.comparison_num_seeds,
    )
    return apply_preset(config)


def build_arg_parser(description: str) -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description=description)
    add_config_arguments(parser)
    return parser


def parse_config(description: str) -> ExperimentConfig:
    parser = build_arg_parser(description)
    args = parser.parse_args()
    return build_config_from_args(args)


def ensure_output_dirs(config: ExperimentConfig) -> dict[str, Path]:
    root = Path(config.output_dir)
    directories = {
        "root": root,
        "checkpoints": root / "checkpoints",
        "stage_checkpoints": root / "checkpoints" / "stages",
        "plots": root / "plots",
        "artifacts": root / "artifacts",
    }
    for directory in directories.values():
        directory.mkdir(parents=True, exist_ok=True)
    return directories


def flatten_metric_dict(prefix: str, metrics: dict[str, float | int | str]) -> dict[str, float | int | str]:
    return {f"{prefix}{key}": value for key, value in metrics.items()}


def iter_stage_lengths(config: ExperimentConfig) -> Iterable[int]:
    for stage in range(1, config.train_max_digits + 1):
        yield stage