| """Score a fusion GPT checkpoint on ArithMark 2.0.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from collections import Counter |
| from contextlib import nullcontext |
| import json |
| from pathlib import Path |
| import re |
| import urllib.request |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| DATA_URL = ( |
| "https://huggingface.co/datasets/AxiomicLabs/Arithmark-2.0/" |
| "resolve/main/arithmark_2.0.jsonl" |
| ) |
|
|
|
|
| def ensure_data(path: Path) -> Path: |
| if path.exists(): |
| return path |
| path.parent.mkdir(parents=True, exist_ok=True) |
| urllib.request.urlretrieve(DATA_URL, path) |
| return path |
|
|
|
|
| def load_examples(path: Path, *, max_examples: int = 0) -> list[dict]: |
| examples = [] |
| with path.open("r", encoding="utf-8") as handle: |
| for line in handle: |
| if not line.strip(): |
| continue |
| examples.append(json.loads(line)) |
| if max_examples > 0 and len(examples) >= max_examples: |
| break |
| return examples |
|
|
|
|
| def _encoded_choice( |
| tokenizer, |
| context: str, |
| ending: str, |
| ) -> tuple[list[int], int]: |
| context_ids = tokenizer(context, add_special_tokens=False).input_ids |
| full_ids = tokenizer(context + ending, add_special_tokens=False).input_ids |
| continuation_length = len(full_ids) - len(context_ids) |
| return full_ids, continuation_length |
|
|
|
|
| @torch.inference_mode() |
| def evaluate( |
| model, |
| tokenizer, |
| examples: list[dict], |
| *, |
| device: torch.device, |
| batch_size: int, |
| dump_failures: bool = False, |
| failure_operator_count: int | None = None, |
| max_failures: int = 100, |
| ) -> dict: |
| correct = 0 |
| total = 0 |
| by_operator_count: dict[str, list[int]] = {} |
| by_topic: dict[str, list[int]] = {} |
| failures: list[dict] = [] |
| failure_summary: Counter[tuple[str, str, str]] = Counter() |
| model.eval() |
| pad_id = tokenizer.pad_token_id |
| if pad_id is None: |
| pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0 |
|
|
| for start in range(0, len(examples), batch_size): |
| batch_examples = examples[start : start + batch_size] |
| encoded = [] |
| offsets = [] |
| for example in batch_examples: |
| flat_start = len(encoded) |
| encoded.extend( |
| _encoded_choice(tokenizer, example["ctx"], ending) |
| for ending in example["endings"] |
| ) |
| offsets.append((flat_start, len(example["endings"]))) |
|
|
| max_length = max(len(item[0]) for item in encoded) |
| input_ids = torch.full( |
| (len(encoded), max_length), |
| int(pad_id), |
| dtype=torch.long, |
| device=device, |
| ) |
| attention_mask = torch.zeros_like(input_ids, dtype=torch.bool) |
| lengths = [] |
| continuation_lengths = [] |
| for row, (ids, continuation_length) in enumerate(encoded): |
| length = len(ids) |
| input_ids[row, :length] = torch.tensor(ids, device=device) |
| attention_mask[row, :length] = True |
| lengths.append(length) |
| continuation_lengths.append(continuation_length) |
|
|
| autocast = ( |
| torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
| if device.type == "cuda" |
| else nullcontext() |
| ) |
| with autocast: |
| logits = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ).logits |
| log_probs = F.log_softmax(logits.float(), dim=-1) |
|
|
| for example_index, example in enumerate(batch_examples): |
| flat_start, choice_count = offsets[example_index] |
| likelihoods = [] |
| for choice_index in range(choice_count): |
| row = flat_start + choice_index |
| length = lengths[row] |
| continuation_length = continuation_lengths[row] |
| continuation_start = length - continuation_length |
| likelihood = 0.0 |
| for position in range(continuation_start, length): |
| likelihood += float( |
| log_probs[row, position - 1, input_ids[row, position]].item() |
| ) |
| likelihoods.append(likelihood) |
|
|
| prediction = max(range(choice_count), key=likelihoods.__getitem__) |
| label = int(example["label"]) |
| matched = prediction == label |
| correct += int(matched) |
| total += 1 |
| metadata = example.get("metadata", {}) |
| operator_count = str(metadata.get("operator_count", "unknown")) |
| topic = str(metadata.get("topic", "unknown")) |
| for grouped, key in ( |
| (by_operator_count, operator_count), |
| (by_topic, topic), |
| ): |
| group = grouped.setdefault(key, [0, 0]) |
| group[0] += int(matched) |
| group[1] += 1 |
|
|
| if not matched and dump_failures: |
| op_count_int = None |
| try: |
| op_count_int = int(operator_count) |
| except ValueError: |
| pass |
| if failure_operator_count is None or op_count_int == failure_operator_count: |
| context = str(example["ctx"]).strip() |
| expression = context[:-1].strip() if context.endswith("=") else context |
| operands = [int(value) for value in re.findall(r"\d+", expression)] |
| operator = "".join(re.findall(r"[+\-*/]", expression)) |
| predicted_answer = str(example["endings"][prediction]).strip() |
| correct_answer = str(example["endings"][label]).strip() |
| width = max((len(str(value)) for value in operands), default=0) |
| failure_summary[(topic, operator, f"width={width}")] += 1 |
| if len(failures) < max_failures: |
| failures.append( |
| { |
| "ctx": context, |
| "topic": topic, |
| "operator_count": operator_count, |
| "operator": operator, |
| "operands": operands, |
| "max_operand_digits": width, |
| "correct_answer": correct_answer, |
| "predicted_answer": predicted_answer, |
| "choices": [str(value).strip() for value in example["endings"]], |
| "choice_scores": [round(value, 4) for value in likelihoods], |
| "score_margin_correct_minus_predicted": round( |
| likelihoods[label] - likelihoods[prediction], |
| 4, |
| ), |
| } |
| ) |
|
|
| results = { |
| "benchmark": "arithmark_2.0", |
| "model_type": "fusion_gpt", |
| "accuracy": correct / max(total, 1), |
| "correct": correct, |
| "total": total, |
| "by_operator_count": { |
| key: { |
| "accuracy": values[0] / max(values[1], 1), |
| "correct": values[0], |
| "total": values[1], |
| } |
| for key, values in sorted(by_operator_count.items()) |
| }, |
| "by_topic": { |
| key: { |
| "accuracy": values[0] / max(values[1], 1), |
| "correct": values[0], |
| "total": values[1], |
| } |
| for key, values in sorted(by_topic.items()) |
| }, |
| } |
| if dump_failures: |
| results["failure_summary"] = { |
| "|".join(key): value |
| for key, value in failure_summary.most_common() |
| } |
| results["failures"] = failures |
| return results |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--checkpoint", type=Path, default=Path("outputs/fusion_run/final_model")) |
| parser.add_argument("--data-path", type=Path, default=Path("arithmark_2.0.jsonl")) |
| parser.add_argument("--batch-size", type=int, default=64) |
| parser.add_argument("--device", default="auto") |
| parser.add_argument("--dtype", default="auto", choices=("auto", "float32", "bfloat16", "float16")) |
| parser.add_argument("--output", type=Path) |
| parser.add_argument( |
| "--max-examples", |
| type=int, |
| default=0, |
| help="Evaluate only the first N examples. Default evaluates all examples.", |
| ) |
| parser.add_argument( |
| "--dump-failures", |
| action="store_true", |
| help="Include incorrectly scored examples and grouped failure summary.", |
| ) |
| parser.add_argument( |
| "--failure-operator-count", |
| type=int, |
| default=None, |
| help="Only dump failures with this operator count, e.g. 1 for easy examples.", |
| ) |
| parser.add_argument("--max-failures", type=int, default=100) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if args.device == "auto": |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| device = torch.device(args.device) |
|
|
| data_path = ensure_data(args.data_path) |
| examples = load_examples(data_path, max_examples=args.max_examples) |
| dtype = None |
| if args.dtype == "float32": |
| dtype = torch.float32 |
| elif args.dtype == "bfloat16": |
| dtype = torch.bfloat16 |
| elif args.dtype == "float16": |
| dtype = torch.float16 |
| model = AutoModelForCausalLM.from_pretrained( |
| args.checkpoint, |
| dtype=dtype, |
| trust_remote_code=True, |
| ).to(device) |
| tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True) |
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| results = evaluate( |
| model, |
| tokenizer, |
| examples, |
| device=device, |
| batch_size=args.batch_size, |
| dump_failures=args.dump_failures, |
| failure_operator_count=args.failure_operator_count, |
| max_failures=args.max_failures, |
| ) |
| print(json.dumps(results, indent=2, sort_keys=True)) |
| if args.output: |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| args.output.write_text( |
| json.dumps(results, indent=2, sort_keys=True) + "\n", |
| encoding="utf-8", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|