Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
import dataclasses
import math
import random
from dataclasses import dataclass
from typing import Iterable
import torch
from addition.config import ExperimentConfig
DIGIT_OFFSET = 0
DEFAULT_SYMBOLS = "0123456789ABCDEF"
@dataclass
class AdditionProblem:
a_digits: list[int]
b_digits: list[int]
sum_digits: list[int]
carry_out: list[int]
active_digits: int
is_carry_heavy: bool
@dataclass
class Batch:
input_ids: torch.Tensor
target_digits: torch.Tensor
target_digit_mask: torch.Tensor
target_carry: torch.Tensor
target_final_carry: torch.Tensor
active_digits: torch.Tensor
is_carry_heavy: torch.Tensor
@dataclass
class EvaluationSuite:
validation_uniform: dict[int, list[AdditionProblem]]
test_uniform: dict[int, list[AdditionProblem]]
test_carry_heavy: dict[int, list[AdditionProblem]]
def a_token_id(radix: int) -> int:
return radix
def b_token_id(radix: int) -> int:
return radix + 1
def seed_everything(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def compute_sum_and_carry(a_digits: list[int], b_digits: list[int], radix: int) -> tuple[list[int], list[int]]:
sum_digits: list[int] = []
carry_out: list[int] = []
carry = 0
for a_digit, b_digit in zip(a_digits, b_digits):
total = int(a_digit) + int(b_digit) + carry
sum_digits.append(total % radix)
carry = total // radix
carry_out.append(carry)
return sum_digits, carry_out
def sample_uniform_problem(max_digits: int, active_digits: int, radix: int, rng: random.Random) -> AdditionProblem:
a_digits = [0] * max_digits
b_digits = [0] * max_digits
for index in range(active_digits):
a_digits[index] = rng.randint(0, radix - 1)
b_digits[index] = rng.randint(0, radix - 1)
sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=radix)
return AdditionProblem(
a_digits=a_digits,
b_digits=b_digits,
sum_digits=sum_digits,
carry_out=carry_out,
active_digits=active_digits,
is_carry_heavy=False,
)
def sample_carry_heavy_problem(max_digits: int, active_digits: int, radix: int, rng: random.Random) -> AdditionProblem:
a_digits = [0] * max_digits
b_digits = [0] * max_digits
carry = 0
for index in range(active_digits):
high_floor = max(0, radix // 2)
a_digit = rng.randint(high_floor, radix - 1)
if carry == 0:
min_b = max(0, radix - a_digit)
else:
min_b = max(0, (radix - 1) - a_digit)
b_digit = rng.randint(min_b, radix - 1)
a_digits[index] = a_digit
b_digits[index] = b_digit
total = a_digit + b_digit + carry
carry = total // radix
sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=radix)
return AdditionProblem(
a_digits=a_digits,
b_digits=b_digits,
sum_digits=sum_digits,
carry_out=carry_out,
active_digits=active_digits,
is_carry_heavy=True,
)
def sample_problem(
max_digits: int,
active_digits: int,
radix: int,
rng: random.Random,
carry_heavy: bool = False,
) -> AdditionProblem:
if carry_heavy:
return sample_carry_heavy_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng)
return sample_uniform_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng)
def encode_problem_tokens(problem: AdditionProblem, radix: int) -> list[int]:
return (
[a_token_id(radix)]
+ [DIGIT_OFFSET + digit for digit in problem.a_digits[: problem.active_digits]]
+ [b_token_id(radix)]
+ [DIGIT_OFFSET + digit for digit in problem.b_digits[: problem.active_digits]]
)
def build_batch(
problems: list[AdditionProblem],
radix: int,
device: str,
) -> Batch:
active_digits = problems[0].active_digits if problems else 0
input_ids = torch.tensor(
[
encode_problem_tokens(problem=problem, radix=radix)
for problem in problems
],
dtype=torch.long,
device=device,
)
target_digits = torch.tensor(
[problem.sum_digits[:active_digits] for problem in problems],
dtype=torch.long,
device=device,
)
target_digit_mask = torch.tensor(
[[1] * active_digits for _ in problems],
dtype=torch.bool,
device=device,
)
target_carry = torch.tensor(
[problem.carry_out[:active_digits] for problem in problems],
dtype=torch.long,
device=device,
)
target_final_carry = torch.tensor(
[problem.carry_out[problem.active_digits - 1] for problem in problems],
dtype=torch.long,
device=device,
)
return Batch(
input_ids=input_ids,
target_digits=target_digits,
target_digit_mask=target_digit_mask,
target_carry=target_carry,
target_final_carry=target_final_carry,
active_digits=torch.tensor([problem.active_digits for problem in problems], dtype=torch.long, device=device),
is_carry_heavy=torch.tensor([int(problem.is_carry_heavy) for problem in problems], dtype=torch.bool, device=device),
)
def sample_training_batch(
config: ExperimentConfig,
stage: int,
rng: random.Random,
device: str,
) -> Batch:
problems: list[AdditionProblem] = []
for _ in range(config.train_batch_size):
carry_heavy = rng.random() < config.train_carry_heavy_prob
problem = sample_problem(
max_digits=stage,
active_digits=stage,
radix=config.radix,
rng=rng,
carry_heavy=carry_heavy,
)
problems.append(problem)
return build_batch(
problems=problems,
radix=config.radix,
device=device,
)
def build_problem_set(
*,
max_digits: int,
active_digits: int,
radix: int,
count: int,
seed: int,
carry_heavy: bool,
) -> list[AdditionProblem]:
rng = random.Random(seed)
return [
sample_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng, carry_heavy=carry_heavy)
for _ in range(count)
]
def build_evaluation_suite(config: ExperimentConfig) -> EvaluationSuite:
validation_uniform: dict[int, list[AdditionProblem]] = {}
test_uniform: dict[int, list[AdditionProblem]] = {}
test_carry_heavy: dict[int, list[AdditionProblem]] = {}
all_lengths = sorted(set(range(1, config.train_max_digits + 1)).union(config.ood_lengths))
for length in all_lengths:
validation_uniform[length] = build_problem_set(
max_digits=length,
active_digits=length,
radix=config.radix,
count=config.eval_examples_per_length,
seed=10_000 + length,
carry_heavy=False,
)
test_uniform[length] = build_problem_set(
max_digits=length,
active_digits=length,
radix=config.radix,
count=config.eval_examples_per_length,
seed=20_000 + length,
carry_heavy=False,
)
test_carry_heavy[length] = build_problem_set(
max_digits=length,
active_digits=length,
radix=config.radix,
count=config.carry_heavy_examples_per_length,
seed=30_000 + length,
carry_heavy=True,
)
return EvaluationSuite(
validation_uniform=validation_uniform,
test_uniform=test_uniform,
test_carry_heavy=test_carry_heavy,
)
def digits_to_string(digits: Iterable[int], final_carry: int, radix: int) -> str:
digits = list(digits)
significant_digits = list(digits)
if final_carry:
significant_digits.append(final_carry)
while len(significant_digits) > 1 and significant_digits[-1] == 0:
significant_digits.pop()
symbols = DEFAULT_SYMBOLS[:radix]
return "".join(symbols[digit] for digit in reversed(significant_digits))
def value_from_digits(digits: Iterable[int], final_carry: int, radix: int) -> int:
value = 0
place = 1
for digit in digits:
value += int(digit) * place
place *= radix
if final_carry:
value += int(final_carry) * place
return value
def exact_sum_matches(
predicted_digits: list[int],
predicted_final_carry: int,
truth_digits: list[int],
truth_final_carry: int,
) -> bool:
return predicted_digits == truth_digits and int(predicted_final_carry) == int(truth_final_carry)
def summarize_problem(problem: AdditionProblem, radix: int) -> dict[str, int | str]:
final_carry = problem.carry_out[problem.active_digits - 1]
return {
"a": digits_to_string(problem.a_digits[: problem.active_digits], final_carry=0, radix=radix),
"b": digits_to_string(problem.b_digits[: problem.active_digits], final_carry=0, radix=radix),
"sum": digits_to_string(problem.sum_digits[: problem.active_digits], final_carry=final_carry, radix=radix),
"radix": radix,
"active_digits": problem.active_digits,
"carry_heavy": int(problem.is_carry_heavy),
}
def count_carry_chain(problem: AdditionProblem) -> int:
longest = 0
current = 0
for index in range(problem.active_digits):
if problem.carry_out[index]:
current += 1
longest = max(longest, current)
else:
current = 0
return longest
def carry_density(problem: AdditionProblem) -> float:
if problem.active_digits <= 0:
return 0.0
return float(sum(problem.carry_out[: problem.active_digits])) / float(problem.active_digits)
def curriculum_stage_lengths(config: ExperimentConfig) -> list[int]:
if config.uses_curriculum:
return list(range(1, config.train_max_digits + 1))
return [config.train_max_digits]
def infer_eval_lengths(config: ExperimentConfig) -> list[int]:
return sorted(set(range(1, config.train_max_digits + 1)).union(config.ood_lengths))
def estimate_train_tokens_per_step(config: ExperimentConfig, stage: int) -> int:
latent_steps = config.latent_steps_for_stage(stage)
return config.train_batch_size * (config.base_sequence_length_for_digits(stage) + latent_steps)
def stage_fraction(stage: int, max_stage: int) -> float:
if max_stage <= 1:
return 1.0
return float(stage - 1) / float(max_stage - 1)
def maybe_trim_examples(problems: list[AdditionProblem], limit: int) -> list[AdditionProblem]:
if limit <= 0 or len(problems) <= limit:
return list(problems)
return list(problems[:limit])
def stage_display_name(stage: int) -> str:
suffix = "th"
if stage % 10 == 1 and stage % 100 != 11:
suffix = "st"
elif stage % 10 == 2 and stage % 100 != 12:
suffix = "nd"
elif stage % 10 == 3 and stage % 100 != 13:
suffix = "rd"
return f"{stage}{suffix}-digit"
def ideal_carry_chain_examples(config: ExperimentConfig, active_digits: int) -> list[AdditionProblem]:
examples: list[AdditionProblem] = []
for base_digit in (max(0, config.radix - 2), config.radix - 1):
a_digits = [base_digit] * active_digits
b_digits = [1] * active_digits
sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=config.radix)
examples.append(
AdditionProblem(
a_digits=a_digits,
b_digits=b_digits,
sum_digits=sum_digits,
carry_out=carry_out,
active_digits=active_digits,
is_carry_heavy=True,
)
)
return examples
def expected_sum_length(problem: AdditionProblem) -> int:
final_carry = problem.carry_out[problem.active_digits - 1]
return problem.active_digits + int(final_carry > 0)
def average_query_count(config: ExperimentConfig) -> float:
lengths = curriculum_stage_lengths(config)
return sum(lengths) / float(len(lengths))
def token_budget(config: ExperimentConfig) -> int:
avg_stage = int(math.ceil(average_query_count(config)))
return config.base_sequence_length_for_digits(avg_stage) + config.latent_steps_for_stage(avg_stage)