NextTerm-440M / oeis_eval_mlx_neo.py
N8Programs's picture
Bundle evaluation datasets with model card scripts
a46649b verified
Raw
History Blame Contribute Delete
11.6 kB
"""Evaluate NextTerm on OEIS Eval Neo with MLX-LM BatchGenerator."""
import argparse
import gc
import inspect
import json
import time
from pathlib import Path
import mlx.core as mx
from mlx_lm import load
from mlx_lm.generate import BatchGenerator
from tqdm import tqdm
SCRIPT_DIR = Path(__file__).resolve().parent
def default_model_path() -> str:
if (SCRIPT_DIR / "model.safetensors").exists():
return str(SCRIPT_DIR)
local_model = SCRIPT_DIR / "NextTerm-440M"
if local_model.exists():
return str(local_model)
return "N8Programs/NextTerm-440M"
DATA_PATH = SCRIPT_DIR / "oeis_val_neo.jsonl"
MODEL_NAME = default_model_path()
MAX_NEW_TOKENS = 196
MAX_CONTEXT_TOKENS = 4096
BATCH_SIZE = 64
OUTPUT_PATH = Path("oeis_eval_results/oeis_eval_mlx_neo_per_doc.jsonl")
SUMMARY_PATH = Path("oeis_eval_results/oeis_eval_mlx_neo_summary.json")
PARSE_ERROR_PRINT_LIMIT = 25
parse_error_print_count = 0
def parse_generated(text: str) -> int | None:
global parse_error_print_count
if "," in text:
text = text.split(",")[0]
try:
return int(text)
except ValueError:
if parse_error_print_count < PARSE_ERROR_PRINT_LIMIT:
print(f"Could not parse generated text: {text!r}")
parse_error_print_count += 1
return None
def load_sequences(path: Path):
sequences = []
answers = []
with path.open() as f:
for line in f:
record = json.loads(line, parse_int=str)
seq = record.get("seq", [])
if len(seq) < 2:
continue
sequences.append(seq[:-1])
answers.append(str(seq[-1]))
return sequences, answers
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=Path, default=DATA_PATH)
parser.add_argument("--model", default=MODEL_NAME)
parser.add_argument("--max-new-tokens", type=int, default=MAX_NEW_TOKENS)
parser.add_argument("--max-context-tokens", type=int, default=MAX_CONTEXT_TOKENS)
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE)
parser.add_argument("--max-examples", type=int, default=0)
parser.add_argument("--output", type=Path, default=OUTPUT_PATH)
parser.add_argument("--summary-output", type=Path, default=SUMMARY_PATH)
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--restrict-digit-comma-eos", action="store_true")
parser.add_argument("--restrict-integer-comma-eos", action="store_true")
return parser.parse_args()
def load_completed(path: Path) -> dict[int, dict]:
completed = {}
if not path.exists():
return completed
with path.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
record = json.loads(line)
completed[int(record["row_index"])] = record
return completed
def encode_no_special(tokenizer, text: str) -> list[int]:
try:
return tokenizer.encode(text, add_special_tokens=False)
except TypeError:
return tokenizer.encode(text)
def normalize_stop_tokens_for_batch_generator(stop_tokens: list[list[int]]):
annotation = inspect.signature(BatchGenerator.__init__).parameters[
"stop_tokens"
].annotation
if "set" in str(annotation):
return {seq[0] for seq in stop_tokens if len(seq) == 1}
return stop_tokens
def split_batch_generator_responses(responses):
if isinstance(responses, tuple) and len(responses) == 2:
prompt_responses, generation_responses = responses
return prompt_responses, generation_responses
if isinstance(responses, list):
return [], responses
raise RuntimeError(
"Unexpected mlx_lm BatchGenerator.next() API. Update this script for "
f"{type(responses).__name__}: {responses!r}"
)
def make_integer_comma_eos_processor(tokenizer):
allowed = set()
for text in [str(i) for i in range(10)] + ["-", ","]:
tokens = encode_no_special(tokenizer, text)
if len(tokens) == 1:
allowed.add(int(tokens[0]))
else:
print(f"Skipping multi-token allowed text {text!r}: {tokens}")
if tokenizer.eos_token_id is not None:
allowed.add(int(tokenizer.eos_token_id))
allowed_ids = sorted(allowed)
mask_cache = {}
def processor(_tokens, logits):
vocab_size = logits.shape[-1]
mask = mask_cache.get(vocab_size)
if mask is None:
values = [-1e9] * vocab_size
for token_id in allowed_ids:
if 0 <= token_id < vocab_size:
values[token_id] = 0.0
mask = mx.array(values, dtype=logits.dtype)
mask_cache[vocab_size] = mask
return logits + mask[None, :]
return processor, allowed_ids
def run_generation_queue(
*,
model,
tokenizer,
prompts,
answers,
row_indices,
stop_tokens,
max_new_tokens: int,
batch_size: int,
output_file,
progress,
logits_processors=None,
) -> None:
gen = BatchGenerator(
model,
stop_tokens=normalize_stop_tokens_for_batch_generator(stop_tokens),
logits_processors=logits_processors,
completion_batch_size=batch_size,
prefill_batch_size=batch_size,
)
uids = gen.insert(prompts, [max_new_tokens] * len(prompts))
uid_to_pos = {uid: pos for pos, uid in enumerate(uids)}
generated_tokens = {uid: [] for uid in uids}
finished = set()
try:
while True:
responses = gen.next()
prompt_responses, generation_responses = split_batch_generator_responses(
responses
)
if not prompt_responses and not generation_responses:
break
if not generation_responses:
continue
for response in generation_responses:
uid = response.uid
if response.finish_reason != "stop":
generated_tokens[uid].append(int(response.token))
if response.finish_reason is None or uid in finished:
continue
finished.add(uid)
pos = uid_to_pos[uid]
text = tokenizer.decode(generated_tokens[uid])
prediction = parse_generated(text)
answer = answers[pos]
answer_int = int(answer)
record = {
"row_index": row_indices[pos],
"answer": answer,
"prediction": prediction,
"correct": prediction == answer_int,
"parsed": prediction is not None,
"generated_text": text,
"generated_tokens": generated_tokens[uid],
"finish_reason": response.finish_reason,
}
output_file.write(json.dumps(record) + "\n")
progress.update(1)
finally:
gen.close()
mx.clear_cache()
gc.collect()
if len(finished) != len(uids):
raise RuntimeError(f"Chunk finished {len(finished)}/{len(uids)} rows")
def main():
args = parse_args()
started = time.perf_counter()
sequences, answers = load_sequences(args.data_path)
if args.max_examples > 0:
sequences = sequences[: args.max_examples]
answers = answers[: args.max_examples]
print(f"Loaded {len(answers)} sequences from {args.data_path}")
model, tokenizer = load(args.model)
sep_tokens = encode_no_special(tokenizer, ",")
if not sep_tokens:
sep_tokens = encode_no_special(tokenizer, "1,")[-1:]
prompts = [",".join(str(x) for x in seq) + "," for seq in sequences]
prompts = [tokenizer.encode(p) for p in prompts]
eval_indices = [
i
for i, prompt in enumerate(prompts)
if args.max_context_tokens <= 0 or len(prompt) < args.max_context_tokens
]
skipped_long = len(prompts) - len(eval_indices)
eval_indices = sorted(eval_indices, key=lambda i: len(prompts[i]))
print(
f"Evaluating {len(eval_indices)} rows; skipped_long={skipped_long}; "
f"sep_tokens={sep_tokens}; eos_token={tokenizer.eos_token_id}; "
f"max_new_tokens={args.max_new_tokens}; batch_size={args.batch_size}"
)
stop_tokens = [sep_tokens]
if tokenizer.eos_token_id is not None:
stop_tokens.append([tokenizer.eos_token_id])
logits_processors = None
allowed_token_ids = None
restrict_integer = args.restrict_digit_comma_eos or args.restrict_integer_comma_eos
if restrict_integer:
processor, allowed_token_ids = make_integer_comma_eos_processor(tokenizer)
logits_processors = [processor]
print(f"Restricting logits to integer/comma/EOS token ids: {allowed_token_ids}")
args.output.parent.mkdir(parents=True, exist_ok=True)
args.summary_output.parent.mkdir(parents=True, exist_ok=True)
if args.overwrite and args.output.exists():
args.output.unlink()
completed = load_completed(args.output)
if completed:
print(f"Resuming from {args.output}: {len(completed)} rows already done")
todo_indices = [idx for idx in eval_indices if idx not in completed]
with args.output.open("a", encoding="utf-8") as output_file:
with tqdm(total=len(todo_indices), desc="Generating") as progress:
run_generation_queue(
model=model,
tokenizer=tokenizer,
prompts=[prompts[i] for i in todo_indices],
answers=[answers[i] for i in todo_indices],
row_indices=todo_indices,
stop_tokens=stop_tokens,
max_new_tokens=args.max_new_tokens,
batch_size=args.batch_size,
output_file=output_file,
progress=progress,
logits_processors=logits_processors,
)
output_file.flush()
records = load_completed(args.output)
eval_set = set(eval_indices)
records = {idx: record for idx, record in records.items() if idx in eval_set}
correct = sum(1 for record in records.values() if record["correct"])
parsed = sum(1 for record in records.values() if record["parsed"])
total = len(eval_indices)
evaluated = len(records)
elapsed = time.perf_counter() - started
print(f"Documents: {len(answers)}")
print(f"Evaluated: {evaluated}/{total}")
print(f"Skipped long: {skipped_long}")
print(f"Parsed predictions: {parsed}/{evaluated}")
print(f"Accuracy: {correct}/{evaluated} = {correct / evaluated:.4f}")
summary = {
"data_path": str(args.data_path),
"model": args.model,
"output": str(args.output),
"documents": len(answers),
"evaluated": evaluated,
"expected_evaluated": total,
"skipped_long": skipped_long,
"parsed": parsed,
"correct": correct,
"accuracy": correct / evaluated if evaluated else 0.0,
"max_new_tokens": args.max_new_tokens,
"max_context_tokens": args.max_context_tokens,
"batch_size": args.batch_size,
"restrict_digit_comma_eos": args.restrict_digit_comma_eos,
"restrict_integer_comma_eos": restrict_integer,
"allowed_token_ids": allowed_token_ids,
"seconds": elapsed,
}
args.summary_output.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8")
print(f"Wrote {args.output}")
print(f"Wrote {args.summary_output}")
if __name__ == "__main__":
main()