Atom2.7m / benchmark_fusion_arithmark.py
ucr-max's picture
Update Atom2.7m submission
2fd4f23 verified
Raw
History Blame Contribute Delete
10.5 kB
"""Score a fusion GPT checkpoint on ArithMark 2.0."""
from __future__ import annotations
import argparse
from collections import Counter
from contextlib import nullcontext
import json
from pathlib import Path
import re
import urllib.request
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
DATA_URL = (
"https://huggingface.co/datasets/AxiomicLabs/Arithmark-2.0/"
"resolve/main/arithmark_2.0.jsonl"
)
def ensure_data(path: Path) -> Path:
if path.exists():
return path
path.parent.mkdir(parents=True, exist_ok=True)
urllib.request.urlretrieve(DATA_URL, path)
return path
def load_examples(path: Path, *, max_examples: int = 0) -> list[dict]:
examples = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if not line.strip():
continue
examples.append(json.loads(line))
if max_examples > 0 and len(examples) >= max_examples:
break
return examples
def _encoded_choice(
tokenizer,
context: str,
ending: str,
) -> tuple[list[int], int]:
context_ids = tokenizer(context, add_special_tokens=False).input_ids
full_ids = tokenizer(context + ending, add_special_tokens=False).input_ids
continuation_length = len(full_ids) - len(context_ids)
return full_ids, continuation_length
@torch.inference_mode()
def evaluate(
model,
tokenizer,
examples: list[dict],
*,
device: torch.device,
batch_size: int,
dump_failures: bool = False,
failure_operator_count: int | None = None,
max_failures: int = 100,
) -> dict:
correct = 0
total = 0
by_operator_count: dict[str, list[int]] = {}
by_topic: dict[str, list[int]] = {}
failures: list[dict] = []
failure_summary: Counter[tuple[str, str, str]] = Counter()
model.eval()
pad_id = tokenizer.pad_token_id
if pad_id is None:
pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
for start in range(0, len(examples), batch_size):
batch_examples = examples[start : start + batch_size]
encoded = []
offsets = []
for example in batch_examples:
flat_start = len(encoded)
encoded.extend(
_encoded_choice(tokenizer, example["ctx"], ending)
for ending in example["endings"]
)
offsets.append((flat_start, len(example["endings"])))
max_length = max(len(item[0]) for item in encoded)
input_ids = torch.full(
(len(encoded), max_length),
int(pad_id),
dtype=torch.long,
device=device,
)
attention_mask = torch.zeros_like(input_ids, dtype=torch.bool)
lengths = []
continuation_lengths = []
for row, (ids, continuation_length) in enumerate(encoded):
length = len(ids)
input_ids[row, :length] = torch.tensor(ids, device=device)
attention_mask[row, :length] = True
lengths.append(length)
continuation_lengths.append(continuation_length)
autocast = (
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
if device.type == "cuda"
else nullcontext()
)
with autocast:
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
).logits
log_probs = F.log_softmax(logits.float(), dim=-1)
for example_index, example in enumerate(batch_examples):
flat_start, choice_count = offsets[example_index]
likelihoods = []
for choice_index in range(choice_count):
row = flat_start + choice_index
length = lengths[row]
continuation_length = continuation_lengths[row]
continuation_start = length - continuation_length
likelihood = 0.0
for position in range(continuation_start, length):
likelihood += float(
log_probs[row, position - 1, input_ids[row, position]].item()
)
likelihoods.append(likelihood)
prediction = max(range(choice_count), key=likelihoods.__getitem__)
label = int(example["label"])
matched = prediction == label
correct += int(matched)
total += 1
metadata = example.get("metadata", {})
operator_count = str(metadata.get("operator_count", "unknown"))
topic = str(metadata.get("topic", "unknown"))
for grouped, key in (
(by_operator_count, operator_count),
(by_topic, topic),
):
group = grouped.setdefault(key, [0, 0])
group[0] += int(matched)
group[1] += 1
if not matched and dump_failures:
op_count_int = None
try:
op_count_int = int(operator_count)
except ValueError:
pass
if failure_operator_count is None or op_count_int == failure_operator_count:
context = str(example["ctx"]).strip()
expression = context[:-1].strip() if context.endswith("=") else context
operands = [int(value) for value in re.findall(r"\d+", expression)]
operator = "".join(re.findall(r"[+\-*/]", expression))
predicted_answer = str(example["endings"][prediction]).strip()
correct_answer = str(example["endings"][label]).strip()
width = max((len(str(value)) for value in operands), default=0)
failure_summary[(topic, operator, f"width={width}")] += 1
if len(failures) < max_failures:
failures.append(
{
"ctx": context,
"topic": topic,
"operator_count": operator_count,
"operator": operator,
"operands": operands,
"max_operand_digits": width,
"correct_answer": correct_answer,
"predicted_answer": predicted_answer,
"choices": [str(value).strip() for value in example["endings"]],
"choice_scores": [round(value, 4) for value in likelihoods],
"score_margin_correct_minus_predicted": round(
likelihoods[label] - likelihoods[prediction],
4,
),
}
)
results = {
"benchmark": "arithmark_2.0",
"model_type": "fusion_gpt",
"accuracy": correct / max(total, 1),
"correct": correct,
"total": total,
"by_operator_count": {
key: {
"accuracy": values[0] / max(values[1], 1),
"correct": values[0],
"total": values[1],
}
for key, values in sorted(by_operator_count.items())
},
"by_topic": {
key: {
"accuracy": values[0] / max(values[1], 1),
"correct": values[0],
"total": values[1],
}
for key, values in sorted(by_topic.items())
},
}
if dump_failures:
results["failure_summary"] = {
"|".join(key): value
for key, value in failure_summary.most_common()
}
results["failures"] = failures
return results
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--checkpoint", type=Path, default=Path("outputs/fusion_run/final_model"))
parser.add_argument("--data-path", type=Path, default=Path("arithmark_2.0.jsonl"))
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--device", default="auto")
parser.add_argument("--dtype", default="auto", choices=("auto", "float32", "bfloat16", "float16"))
parser.add_argument("--output", type=Path)
parser.add_argument(
"--max-examples",
type=int,
default=0,
help="Evaluate only the first N examples. Default evaluates all examples.",
)
parser.add_argument(
"--dump-failures",
action="store_true",
help="Include incorrectly scored examples and grouped failure summary.",
)
parser.add_argument(
"--failure-operator-count",
type=int,
default=None,
help="Only dump failures with this operator count, e.g. 1 for easy examples.",
)
parser.add_argument("--max-failures", type=int, default=100)
return parser.parse_args()
def main() -> None:
args = parse_args()
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
data_path = ensure_data(args.data_path)
examples = load_examples(data_path, max_examples=args.max_examples)
dtype = None
if args.dtype == "float32":
dtype = torch.float32
elif args.dtype == "bfloat16":
dtype = torch.bfloat16
elif args.dtype == "float16":
dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint,
dtype=dtype,
trust_remote_code=True,
).to(device)
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
results = evaluate(
model,
tokenizer,
examples,
device=device,
batch_size=args.batch_size,
dump_failures=args.dump_failures,
failure_operator_count=args.failure_operator_count,
max_failures=args.max_failures,
)
print(json.dumps(results, indent=2, sort_keys=True))
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(
json.dumps(results, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
if __name__ == "__main__":
main()