from __future__ import annotations import math from dataclasses import dataclass from typing import Iterable import torch from torch import nn from addition.config import ExperimentConfig from addition.data import ( AdditionProblem, EvaluationSuite, build_batch, carry_density, count_carry_chain, exact_sum_matches, maybe_trim_examples, ) from addition.model import AdditionTransformer @dataclass class LengthMetrics: digit_accuracy: float final_carry_accuracy: float exact_match: float avg_carry_chain: float avg_carry_density: float example_count: int per_position_digit_accuracy: list[float] def _chunked(sequence: list[AdditionProblem], chunk_size: int) -> Iterable[list[AdditionProblem]]: for start in range(0, len(sequence), chunk_size): yield sequence[start : start + chunk_size] @torch.no_grad() def evaluate_problem_set( model: AdditionTransformer, config: ExperimentConfig, problems: list[AdditionProblem], active_digits: int, *, device: str, return_attention: bool = False, ) -> tuple[LengthMetrics, dict[str, float] | None]: model.eval() latent_steps = config.latent_steps_for_stage(active_digits) num_examples = len(problems) if num_examples == 0: empty = LengthMetrics( digit_accuracy=0.0, final_carry_accuracy=0.0, exact_match=0.0, avg_carry_chain=0.0, avg_carry_density=0.0, example_count=0, per_position_digit_accuracy=[0.0] * active_digits, ) return empty, None predicted_digits = torch.zeros(num_examples, active_digits, dtype=torch.long) predicted_final_carry = torch.zeros(num_examples, dtype=torch.long) truth_digits = torch.tensor([[problem.sum_digits[position] for position in range(active_digits)] for problem in problems], dtype=torch.long) truth_final_carry = torch.tensor([problem.carry_out[active_digits - 1] for problem in problems], dtype=torch.long) attention_stats: dict[str, float] | None = None offset = 0 for problem_chunk in _chunked(problems, config.eval_batch_size): batch = build_batch( problems=problem_chunk, radix=config.radix, device=device, ) outputs = model(batch.input_ids, latent_steps=latent_steps, return_attention=return_attention) chunk_size = len(problem_chunk) predicted_digits[offset : offset + chunk_size] = outputs.digit_logits.argmax(dim=-1)[:, :active_digits].cpu() predicted_final_carry[offset : offset + chunk_size] = outputs.final_carry_logits.argmax(dim=-1).cpu() if return_attention and attention_stats is None: attention_stats = summarize_attention( attention_weights=outputs.attention_weights, active_digits=active_digits, input_sequence_length=batch.input_ids.shape[1], output_sequence_length=outputs.output_hidden.shape[1], ) offset += chunk_size exact_matches = [] for example_index, problem in enumerate(problems): exact_matches.append( exact_sum_matches( predicted_digits=predicted_digits[example_index].tolist(), predicted_final_carry=int(predicted_final_carry[example_index].item()), truth_digits=problem.sum_digits[:active_digits], truth_final_carry=problem.carry_out[active_digits - 1], ) ) per_position_digit = (predicted_digits == truth_digits).float().mean(dim=0).tolist() metrics = LengthMetrics( digit_accuracy=float((predicted_digits == truth_digits).float().mean().item()), final_carry_accuracy=float((predicted_final_carry == truth_final_carry).float().mean().item()), exact_match=float(torch.tensor(exact_matches, dtype=torch.float32).mean().item()), avg_carry_chain=float(sum(count_carry_chain(problem) for problem in problems) / len(problems)), avg_carry_density=float(sum(carry_density(problem) for problem in problems) / len(problems)), example_count=len(problems), per_position_digit_accuracy=[float(value) for value in per_position_digit], ) return metrics, attention_stats def summarize_attention( attention_weights: torch.Tensor | None, *, active_digits: int, input_sequence_length: int, output_sequence_length: int, ) -> dict[str, float]: if attention_weights is None: return {} # Shape: [batch, heads, target_len, source_len] final_attention = attention_weights[:, :, -1, :] attention_mean = final_attention.mean(dim=(0, 1)) active_last_a_index = active_digits active_last_b_index = input_sequence_length // 2 + active_digits latent_slice = attention_mean[input_sequence_length : -output_sequence_length] output_slice = attention_mean[-output_sequence_length:-1] entropy = -torch.sum(attention_mean * torch.log(attention_mean.clamp_min(1e-9))).item() summary = { "lsd_a_attention": float(attention_mean[1].item()), "msd_a_attention": float(attention_mean[active_last_a_index].item()), "lsd_b_attention": float(attention_mean[(input_sequence_length // 2) + 1].item()), "msd_b_attention": float(attention_mean[active_last_b_index].item()), "attention_entropy": float(entropy), "all_latent_attention": float(latent_slice.sum().item()) if latent_slice.numel() else 0.0, "previous_output_attention": float(output_slice.sum().item()) if output_slice.numel() else 0.0, } return summary @torch.no_grad() def evaluate_length_dict( model: AdditionTransformer, config: ExperimentConfig, problems_by_length: dict[int, list[AdditionProblem]], *, device: str, attention_length: int | None = None, ) -> dict[str, dict]: structured: dict[str, dict] = {} for length, problems in sorted(problems_by_length.items()): length_metrics, attention = evaluate_problem_set( model=model, config=config, problems=problems, active_digits=length, device=device, return_attention=attention_length is not None and attention_length == length, ) structured[str(length)] = { "digit_accuracy": length_metrics.digit_accuracy, "final_carry_accuracy": length_metrics.final_carry_accuracy, "exact_match": length_metrics.exact_match, "avg_carry_chain": length_metrics.avg_carry_chain, "avg_carry_density": length_metrics.avg_carry_density, "example_count": length_metrics.example_count, "per_position_digit_accuracy": length_metrics.per_position_digit_accuracy, } if attention is not None: structured[str(length)]["attention_summary"] = attention return structured def collect_hidden_dataset( model: AdditionTransformer, config: ExperimentConfig, problems: list[AdditionProblem], *, active_digits: int, device: str, limit_examples: int, ) -> tuple[torch.Tensor, torch.Tensor]: model.eval() latent_steps = config.latent_steps_for_stage(active_digits) selected = maybe_trim_examples(problems, limit_examples) hidden_states: list[torch.Tensor] = [] carry_targets: list[torch.Tensor] = [] with torch.no_grad(): for problem_chunk in _chunked(selected, config.eval_batch_size): batch = build_batch( problems=problem_chunk, radix=config.radix, device=device, ) outputs = model(batch.input_ids, latent_steps=latent_steps, return_attention=False) slot_hidden = outputs.output_hidden[:, :active_digits, :] slot_mask = batch.target_digit_mask hidden_states.append(slot_hidden[slot_mask].detach().cpu()) carry_targets.append(batch.target_carry[slot_mask].detach().cpu()) return torch.cat(hidden_states, dim=0), torch.cat(carry_targets, dim=0) def fit_linear_probe( hidden_states: torch.Tensor, carry_targets: torch.Tensor, *, epochs: int, learning_rate: float, ) -> dict[str, float]: if hidden_states.numel() == 0: return {"probe_accuracy": 0.0} indices = torch.randperm(hidden_states.shape[0]) hidden_states = hidden_states[indices] carry_targets = carry_targets[indices] split_index = max(1, int(0.8 * hidden_states.shape[0])) train_hidden = hidden_states[:split_index] train_targets = carry_targets[:split_index] test_hidden = hidden_states[split_index:] test_targets = carry_targets[split_index:] if test_hidden.numel() == 0: test_hidden = train_hidden test_targets = train_targets probe = nn.Linear(hidden_states.shape[-1], 2) optimizer = torch.optim.AdamW(probe.parameters(), lr=learning_rate) loss_fn = nn.CrossEntropyLoss() for _ in range(epochs): logits = probe(train_hidden) loss = loss_fn(logits, train_targets) optimizer.zero_grad() loss.backward() optimizer.step() with torch.no_grad(): predictions = probe(test_hidden).argmax(dim=-1) accuracy = float((predictions == test_targets).float().mean().item()) return {"probe_accuracy": accuracy} def evaluate_suite( model: AdditionTransformer, config: ExperimentConfig, suite: EvaluationSuite, *, device: str, ) -> dict[str, dict]: id_lengths = list(range(1, config.train_max_digits + 1)) ood_lengths = list(config.ood_lengths) max_attention_length = max(ood_lengths) if ood_lengths else config.train_max_digits validation = evaluate_length_dict( model=model, config=config, problems_by_length={length: suite.validation_uniform[length] for length in id_lengths}, device=device, ) uniform_all = evaluate_length_dict( model=model, config=config, problems_by_length={length: suite.test_uniform[length] for length in sorted(set(id_lengths + ood_lengths))}, device=device, attention_length=max_attention_length, ) carry_heavy_all = evaluate_length_dict( model=model, config=config, problems_by_length={length: suite.test_carry_heavy[length] for length in sorted(set(id_lengths + ood_lengths))}, device=device, attention_length=max_attention_length, ) probe_hidden, probe_targets = collect_hidden_dataset( model=model, config=config, problems=suite.test_carry_heavy[max_attention_length], active_digits=max_attention_length, device=device, limit_examples=config.attention_probe_examples, ) diagnostics = fit_linear_probe( hidden_states=probe_hidden, carry_targets=probe_targets, epochs=config.linear_probe_epochs, learning_rate=config.linear_probe_lr, ) diagnostics["attention_uniform"] = uniform_all[str(max_attention_length)].get("attention_summary", {}) diagnostics["attention_carry_heavy"] = carry_heavy_all[str(max_attention_length)].get("attention_summary", {}) return { "validation_uniform": validation, "test_uniform": uniform_all, "test_carry_heavy": carry_heavy_all, "diagnostics": diagnostics, } def stage_validation_metric(results: dict[str, dict], stage: int) -> float: stage_metrics = results["validation_uniform"][str(stage)] return float(stage_metrics["digit_accuracy"]) def flatten_nested_metrics(prefix: str, nested: dict[str, dict]) -> dict[str, float]: flat: dict[str, float] = {} for split_name, split_metrics in nested.items(): if split_name == "diagnostics": for key, value in split_metrics.items(): if isinstance(value, dict): for inner_key, inner_value in value.items(): flat[f"{prefix}{split_name}/{key}/{inner_key}"] = float(inner_value) else: flat[f"{prefix}{split_name}/{key}"] = float(value) continue for length, length_metrics in split_metrics.items(): if not isinstance(length_metrics, dict): continue for metric_name, metric_value in length_metrics.items(): if isinstance(metric_value, list): if metric_value: flat[f"{prefix}{split_name}/length_{length}/{metric_name}_mean"] = float(sum(metric_value) / len(metric_value)) continue if isinstance(metric_value, dict): for inner_key, inner_value in metric_value.items(): flat[f"{prefix}{split_name}/length_{length}/{metric_name}/{inner_key}"] = float(inner_value) continue flat[f"{prefix}{split_name}/length_{length}/{metric_name}"] = float(metric_value) return flat