from __future__ import annotations import dataclasses import math import random from dataclasses import dataclass from typing import Iterable import torch from addition.config import ExperimentConfig DIGIT_OFFSET = 0 DEFAULT_SYMBOLS = "0123456789ABCDEF" @dataclass class AdditionProblem: a_digits: list[int] b_digits: list[int] sum_digits: list[int] carry_out: list[int] active_digits: int is_carry_heavy: bool @dataclass class Batch: input_ids: torch.Tensor target_digits: torch.Tensor target_digit_mask: torch.Tensor target_carry: torch.Tensor target_final_carry: torch.Tensor active_digits: torch.Tensor is_carry_heavy: torch.Tensor @dataclass class EvaluationSuite: validation_uniform: dict[int, list[AdditionProblem]] test_uniform: dict[int, list[AdditionProblem]] test_carry_heavy: dict[int, list[AdditionProblem]] def a_token_id(radix: int) -> int: return radix def b_token_id(radix: int) -> int: return radix + 1 def seed_everything(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def compute_sum_and_carry(a_digits: list[int], b_digits: list[int], radix: int) -> tuple[list[int], list[int]]: sum_digits: list[int] = [] carry_out: list[int] = [] carry = 0 for a_digit, b_digit in zip(a_digits, b_digits): total = int(a_digit) + int(b_digit) + carry sum_digits.append(total % radix) carry = total // radix carry_out.append(carry) return sum_digits, carry_out def sample_uniform_problem(max_digits: int, active_digits: int, radix: int, rng: random.Random) -> AdditionProblem: a_digits = [0] * max_digits b_digits = [0] * max_digits for index in range(active_digits): a_digits[index] = rng.randint(0, radix - 1) b_digits[index] = rng.randint(0, radix - 1) sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=radix) return AdditionProblem( a_digits=a_digits, b_digits=b_digits, sum_digits=sum_digits, carry_out=carry_out, active_digits=active_digits, is_carry_heavy=False, ) def sample_carry_heavy_problem(max_digits: int, active_digits: int, radix: int, rng: random.Random) -> AdditionProblem: a_digits = [0] * max_digits b_digits = [0] * max_digits carry = 0 for index in range(active_digits): high_floor = max(0, radix // 2) a_digit = rng.randint(high_floor, radix - 1) if carry == 0: min_b = max(0, radix - a_digit) else: min_b = max(0, (radix - 1) - a_digit) b_digit = rng.randint(min_b, radix - 1) a_digits[index] = a_digit b_digits[index] = b_digit total = a_digit + b_digit + carry carry = total // radix sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=radix) return AdditionProblem( a_digits=a_digits, b_digits=b_digits, sum_digits=sum_digits, carry_out=carry_out, active_digits=active_digits, is_carry_heavy=True, ) def sample_problem( max_digits: int, active_digits: int, radix: int, rng: random.Random, carry_heavy: bool = False, ) -> AdditionProblem: if carry_heavy: return sample_carry_heavy_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng) return sample_uniform_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng) def encode_problem_tokens(problem: AdditionProblem, radix: int) -> list[int]: return ( [a_token_id(radix)] + [DIGIT_OFFSET + digit for digit in problem.a_digits[: problem.active_digits]] + [b_token_id(radix)] + [DIGIT_OFFSET + digit for digit in problem.b_digits[: problem.active_digits]] ) def build_batch( problems: list[AdditionProblem], radix: int, device: str, ) -> Batch: active_digits = problems[0].active_digits if problems else 0 input_ids = torch.tensor( [ encode_problem_tokens(problem=problem, radix=radix) for problem in problems ], dtype=torch.long, device=device, ) target_digits = torch.tensor( [problem.sum_digits[:active_digits] for problem in problems], dtype=torch.long, device=device, ) target_digit_mask = torch.tensor( [[1] * active_digits for _ in problems], dtype=torch.bool, device=device, ) target_carry = torch.tensor( [problem.carry_out[:active_digits] for problem in problems], dtype=torch.long, device=device, ) target_final_carry = torch.tensor( [problem.carry_out[problem.active_digits - 1] for problem in problems], dtype=torch.long, device=device, ) return Batch( input_ids=input_ids, target_digits=target_digits, target_digit_mask=target_digit_mask, target_carry=target_carry, target_final_carry=target_final_carry, active_digits=torch.tensor([problem.active_digits for problem in problems], dtype=torch.long, device=device), is_carry_heavy=torch.tensor([int(problem.is_carry_heavy) for problem in problems], dtype=torch.bool, device=device), ) def sample_training_batch( config: ExperimentConfig, stage: int, rng: random.Random, device: str, ) -> Batch: problems: list[AdditionProblem] = [] for _ in range(config.train_batch_size): carry_heavy = rng.random() < config.train_carry_heavy_prob problem = sample_problem( max_digits=stage, active_digits=stage, radix=config.radix, rng=rng, carry_heavy=carry_heavy, ) problems.append(problem) return build_batch( problems=problems, radix=config.radix, device=device, ) def build_problem_set( *, max_digits: int, active_digits: int, radix: int, count: int, seed: int, carry_heavy: bool, ) -> list[AdditionProblem]: rng = random.Random(seed) return [ sample_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng, carry_heavy=carry_heavy) for _ in range(count) ] def build_evaluation_suite(config: ExperimentConfig) -> EvaluationSuite: validation_uniform: dict[int, list[AdditionProblem]] = {} test_uniform: dict[int, list[AdditionProblem]] = {} test_carry_heavy: dict[int, list[AdditionProblem]] = {} all_lengths = sorted(set(range(1, config.train_max_digits + 1)).union(config.ood_lengths)) for length in all_lengths: validation_uniform[length] = build_problem_set( max_digits=length, active_digits=length, radix=config.radix, count=config.eval_examples_per_length, seed=10_000 + length, carry_heavy=False, ) test_uniform[length] = build_problem_set( max_digits=length, active_digits=length, radix=config.radix, count=config.eval_examples_per_length, seed=20_000 + length, carry_heavy=False, ) test_carry_heavy[length] = build_problem_set( max_digits=length, active_digits=length, radix=config.radix, count=config.carry_heavy_examples_per_length, seed=30_000 + length, carry_heavy=True, ) return EvaluationSuite( validation_uniform=validation_uniform, test_uniform=test_uniform, test_carry_heavy=test_carry_heavy, ) def digits_to_string(digits: Iterable[int], final_carry: int, radix: int) -> str: digits = list(digits) significant_digits = list(digits) if final_carry: significant_digits.append(final_carry) while len(significant_digits) > 1 and significant_digits[-1] == 0: significant_digits.pop() symbols = DEFAULT_SYMBOLS[:radix] return "".join(symbols[digit] for digit in reversed(significant_digits)) def value_from_digits(digits: Iterable[int], final_carry: int, radix: int) -> int: value = 0 place = 1 for digit in digits: value += int(digit) * place place *= radix if final_carry: value += int(final_carry) * place return value def exact_sum_matches( predicted_digits: list[int], predicted_final_carry: int, truth_digits: list[int], truth_final_carry: int, ) -> bool: return predicted_digits == truth_digits and int(predicted_final_carry) == int(truth_final_carry) def summarize_problem(problem: AdditionProblem, radix: int) -> dict[str, int | str]: final_carry = problem.carry_out[problem.active_digits - 1] return { "a": digits_to_string(problem.a_digits[: problem.active_digits], final_carry=0, radix=radix), "b": digits_to_string(problem.b_digits[: problem.active_digits], final_carry=0, radix=radix), "sum": digits_to_string(problem.sum_digits[: problem.active_digits], final_carry=final_carry, radix=radix), "radix": radix, "active_digits": problem.active_digits, "carry_heavy": int(problem.is_carry_heavy), } def count_carry_chain(problem: AdditionProblem) -> int: longest = 0 current = 0 for index in range(problem.active_digits): if problem.carry_out[index]: current += 1 longest = max(longest, current) else: current = 0 return longest def carry_density(problem: AdditionProblem) -> float: if problem.active_digits <= 0: return 0.0 return float(sum(problem.carry_out[: problem.active_digits])) / float(problem.active_digits) def curriculum_stage_lengths(config: ExperimentConfig) -> list[int]: if config.uses_curriculum: return list(range(1, config.train_max_digits + 1)) return [config.train_max_digits] def infer_eval_lengths(config: ExperimentConfig) -> list[int]: return sorted(set(range(1, config.train_max_digits + 1)).union(config.ood_lengths)) def estimate_train_tokens_per_step(config: ExperimentConfig, stage: int) -> int: latent_steps = config.latent_steps_for_stage(stage) return config.train_batch_size * (config.base_sequence_length_for_digits(stage) + latent_steps) def stage_fraction(stage: int, max_stage: int) -> float: if max_stage <= 1: return 1.0 return float(stage - 1) / float(max_stage - 1) def maybe_trim_examples(problems: list[AdditionProblem], limit: int) -> list[AdditionProblem]: if limit <= 0 or len(problems) <= limit: return list(problems) return list(problems[:limit]) def stage_display_name(stage: int) -> str: suffix = "th" if stage % 10 == 1 and stage % 100 != 11: suffix = "st" elif stage % 10 == 2 and stage % 100 != 12: suffix = "nd" elif stage % 10 == 3 and stage % 100 != 13: suffix = "rd" return f"{stage}{suffix}-digit" def ideal_carry_chain_examples(config: ExperimentConfig, active_digits: int) -> list[AdditionProblem]: examples: list[AdditionProblem] = [] for base_digit in (max(0, config.radix - 2), config.radix - 1): a_digits = [base_digit] * active_digits b_digits = [1] * active_digits sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=config.radix) examples.append( AdditionProblem( a_digits=a_digits, b_digits=b_digits, sum_digits=sum_digits, carry_out=carry_out, active_digits=active_digits, is_carry_heavy=True, ) ) return examples def expected_sum_length(problem: AdditionProblem) -> int: final_carry = problem.carry_out[problem.active_digits - 1] return problem.active_digits + int(final_carry > 0) def average_query_count(config: ExperimentConfig) -> float: lengths = curriculum_stage_lengths(config) return sum(lengths) / float(len(lengths)) def token_budget(config: ExperimentConfig) -> int: avg_stage = int(math.ceil(average_query_count(config))) return config.base_sequence_length_for_digits(avg_stage) + config.latent_steps_for_stage(avg_stage)