"""Score a fusion GPT checkpoint on ArithMark 2.0.""" from __future__ import annotations import argparse from collections import Counter from contextlib import nullcontext import json from pathlib import Path import re import urllib.request import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer DATA_URL = ( "https://huggingface.co/datasets/AxiomicLabs/Arithmark-2.0/" "resolve/main/arithmark_2.0.jsonl" ) def ensure_data(path: Path) -> Path: if path.exists(): return path path.parent.mkdir(parents=True, exist_ok=True) urllib.request.urlretrieve(DATA_URL, path) return path def load_examples(path: Path, *, max_examples: int = 0) -> list[dict]: examples = [] with path.open("r", encoding="utf-8") as handle: for line in handle: if not line.strip(): continue examples.append(json.loads(line)) if max_examples > 0 and len(examples) >= max_examples: break return examples def _encoded_choice( tokenizer, context: str, ending: str, ) -> tuple[list[int], int]: context_ids = tokenizer(context, add_special_tokens=False).input_ids full_ids = tokenizer(context + ending, add_special_tokens=False).input_ids continuation_length = len(full_ids) - len(context_ids) return full_ids, continuation_length @torch.inference_mode() def evaluate( model, tokenizer, examples: list[dict], *, device: torch.device, batch_size: int, dump_failures: bool = False, failure_operator_count: int | None = None, max_failures: int = 100, ) -> dict: correct = 0 total = 0 by_operator_count: dict[str, list[int]] = {} by_topic: dict[str, list[int]] = {} failures: list[dict] = [] failure_summary: Counter[tuple[str, str, str]] = Counter() model.eval() pad_id = tokenizer.pad_token_id if pad_id is None: pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0 for start in range(0, len(examples), batch_size): batch_examples = examples[start : start + batch_size] encoded = [] offsets = [] for example in batch_examples: flat_start = len(encoded) encoded.extend( _encoded_choice(tokenizer, example["ctx"], ending) for ending in example["endings"] ) offsets.append((flat_start, len(example["endings"]))) max_length = max(len(item[0]) for item in encoded) input_ids = torch.full( (len(encoded), max_length), int(pad_id), dtype=torch.long, device=device, ) attention_mask = torch.zeros_like(input_ids, dtype=torch.bool) lengths = [] continuation_lengths = [] for row, (ids, continuation_length) in enumerate(encoded): length = len(ids) input_ids[row, :length] = torch.tensor(ids, device=device) attention_mask[row, :length] = True lengths.append(length) continuation_lengths.append(continuation_length) autocast = ( torch.autocast(device_type="cuda", dtype=torch.bfloat16) if device.type == "cuda" else nullcontext() ) with autocast: logits = model( input_ids=input_ids, attention_mask=attention_mask, ).logits log_probs = F.log_softmax(logits.float(), dim=-1) for example_index, example in enumerate(batch_examples): flat_start, choice_count = offsets[example_index] likelihoods = [] for choice_index in range(choice_count): row = flat_start + choice_index length = lengths[row] continuation_length = continuation_lengths[row] continuation_start = length - continuation_length likelihood = 0.0 for position in range(continuation_start, length): likelihood += float( log_probs[row, position - 1, input_ids[row, position]].item() ) likelihoods.append(likelihood) prediction = max(range(choice_count), key=likelihoods.__getitem__) label = int(example["label"]) matched = prediction == label correct += int(matched) total += 1 metadata = example.get("metadata", {}) operator_count = str(metadata.get("operator_count", "unknown")) topic = str(metadata.get("topic", "unknown")) for grouped, key in ( (by_operator_count, operator_count), (by_topic, topic), ): group = grouped.setdefault(key, [0, 0]) group[0] += int(matched) group[1] += 1 if not matched and dump_failures: op_count_int = None try: op_count_int = int(operator_count) except ValueError: pass if failure_operator_count is None or op_count_int == failure_operator_count: context = str(example["ctx"]).strip() expression = context[:-1].strip() if context.endswith("=") else context operands = [int(value) for value in re.findall(r"\d+", expression)] operator = "".join(re.findall(r"[+\-*/]", expression)) predicted_answer = str(example["endings"][prediction]).strip() correct_answer = str(example["endings"][label]).strip() width = max((len(str(value)) for value in operands), default=0) failure_summary[(topic, operator, f"width={width}")] += 1 if len(failures) < max_failures: failures.append( { "ctx": context, "topic": topic, "operator_count": operator_count, "operator": operator, "operands": operands, "max_operand_digits": width, "correct_answer": correct_answer, "predicted_answer": predicted_answer, "choices": [str(value).strip() for value in example["endings"]], "choice_scores": [round(value, 4) for value in likelihoods], "score_margin_correct_minus_predicted": round( likelihoods[label] - likelihoods[prediction], 4, ), } ) results = { "benchmark": "arithmark_2.0", "model_type": "fusion_gpt", "accuracy": correct / max(total, 1), "correct": correct, "total": total, "by_operator_count": { key: { "accuracy": values[0] / max(values[1], 1), "correct": values[0], "total": values[1], } for key, values in sorted(by_operator_count.items()) }, "by_topic": { key: { "accuracy": values[0] / max(values[1], 1), "correct": values[0], "total": values[1], } for key, values in sorted(by_topic.items()) }, } if dump_failures: results["failure_summary"] = { "|".join(key): value for key, value in failure_summary.most_common() } results["failures"] = failures return results def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--checkpoint", type=Path, default=Path("outputs/fusion_run/final_model")) parser.add_argument("--data-path", type=Path, default=Path("arithmark_2.0.jsonl")) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--device", default="auto") parser.add_argument("--dtype", default="auto", choices=("auto", "float32", "bfloat16", "float16")) parser.add_argument("--output", type=Path) parser.add_argument( "--max-examples", type=int, default=0, help="Evaluate only the first N examples. Default evaluates all examples.", ) parser.add_argument( "--dump-failures", action="store_true", help="Include incorrectly scored examples and grouped failure summary.", ) parser.add_argument( "--failure-operator-count", type=int, default=None, help="Only dump failures with this operator count, e.g. 1 for easy examples.", ) parser.add_argument("--max-failures", type=int, default=100) return parser.parse_args() def main() -> None: args = parse_args() if args.device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) data_path = ensure_data(args.data_path) examples = load_examples(data_path, max_examples=args.max_examples) dtype = None if args.dtype == "float32": dtype = torch.float32 elif args.dtype == "bfloat16": dtype = torch.bfloat16 elif args.dtype == "float16": dtype = torch.float16 model = AutoModelForCausalLM.from_pretrained( args.checkpoint, dtype=dtype, trust_remote_code=True, ).to(device) tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token results = evaluate( model, tokenizer, examples, device=device, batch_size=args.batch_size, dump_failures=args.dump_failures, failure_operator_count=args.failure_operator_count, max_failures=args.max_failures, ) print(json.dumps(results, indent=2, sort_keys=True)) if args.output: args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text( json.dumps(results, indent=2, sort_keys=True) + "\n", encoding="utf-8", ) if __name__ == "__main__": main()