"""Evaluate NextTerm on OEIS Eval Neo with MLX-LM BatchGenerator.""" import argparse import gc import inspect import json import time from pathlib import Path import mlx.core as mx from mlx_lm import load from mlx_lm.generate import BatchGenerator from tqdm import tqdm SCRIPT_DIR = Path(__file__).resolve().parent def default_model_path() -> str: if (SCRIPT_DIR / "model.safetensors").exists(): return str(SCRIPT_DIR) local_model = SCRIPT_DIR / "NextTerm-440M" if local_model.exists(): return str(local_model) return "N8Programs/NextTerm-440M" DATA_PATH = SCRIPT_DIR / "oeis_val_neo.jsonl" MODEL_NAME = default_model_path() MAX_NEW_TOKENS = 196 MAX_CONTEXT_TOKENS = 4096 BATCH_SIZE = 64 OUTPUT_PATH = Path("oeis_eval_results/oeis_eval_mlx_neo_per_doc.jsonl") SUMMARY_PATH = Path("oeis_eval_results/oeis_eval_mlx_neo_summary.json") PARSE_ERROR_PRINT_LIMIT = 25 parse_error_print_count = 0 def parse_generated(text: str) -> int | None: global parse_error_print_count if "," in text: text = text.split(",")[0] try: return int(text) except ValueError: if parse_error_print_count < PARSE_ERROR_PRINT_LIMIT: print(f"Could not parse generated text: {text!r}") parse_error_print_count += 1 return None def load_sequences(path: Path): sequences = [] answers = [] with path.open() as f: for line in f: record = json.loads(line, parse_int=str) seq = record.get("seq", []) if len(seq) < 2: continue sequences.append(seq[:-1]) answers.append(str(seq[-1])) return sequences, answers def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--data-path", type=Path, default=DATA_PATH) parser.add_argument("--model", default=MODEL_NAME) parser.add_argument("--max-new-tokens", type=int, default=MAX_NEW_TOKENS) parser.add_argument("--max-context-tokens", type=int, default=MAX_CONTEXT_TOKENS) parser.add_argument("--batch-size", type=int, default=BATCH_SIZE) parser.add_argument("--max-examples", type=int, default=0) parser.add_argument("--output", type=Path, default=OUTPUT_PATH) parser.add_argument("--summary-output", type=Path, default=SUMMARY_PATH) parser.add_argument("--overwrite", action="store_true") parser.add_argument("--restrict-digit-comma-eos", action="store_true") parser.add_argument("--restrict-integer-comma-eos", action="store_true") return parser.parse_args() def load_completed(path: Path) -> dict[int, dict]: completed = {} if not path.exists(): return completed with path.open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue record = json.loads(line) completed[int(record["row_index"])] = record return completed def encode_no_special(tokenizer, text: str) -> list[int]: try: return tokenizer.encode(text, add_special_tokens=False) except TypeError: return tokenizer.encode(text) def normalize_stop_tokens_for_batch_generator(stop_tokens: list[list[int]]): annotation = inspect.signature(BatchGenerator.__init__).parameters[ "stop_tokens" ].annotation if "set" in str(annotation): return {seq[0] for seq in stop_tokens if len(seq) == 1} return stop_tokens def split_batch_generator_responses(responses): if isinstance(responses, tuple) and len(responses) == 2: prompt_responses, generation_responses = responses return prompt_responses, generation_responses if isinstance(responses, list): return [], responses raise RuntimeError( "Unexpected mlx_lm BatchGenerator.next() API. Update this script for " f"{type(responses).__name__}: {responses!r}" ) def make_integer_comma_eos_processor(tokenizer): allowed = set() for text in [str(i) for i in range(10)] + ["-", ","]: tokens = encode_no_special(tokenizer, text) if len(tokens) == 1: allowed.add(int(tokens[0])) else: print(f"Skipping multi-token allowed text {text!r}: {tokens}") if tokenizer.eos_token_id is not None: allowed.add(int(tokenizer.eos_token_id)) allowed_ids = sorted(allowed) mask_cache = {} def processor(_tokens, logits): vocab_size = logits.shape[-1] mask = mask_cache.get(vocab_size) if mask is None: values = [-1e9] * vocab_size for token_id in allowed_ids: if 0 <= token_id < vocab_size: values[token_id] = 0.0 mask = mx.array(values, dtype=logits.dtype) mask_cache[vocab_size] = mask return logits + mask[None, :] return processor, allowed_ids def run_generation_queue( *, model, tokenizer, prompts, answers, row_indices, stop_tokens, max_new_tokens: int, batch_size: int, output_file, progress, logits_processors=None, ) -> None: gen = BatchGenerator( model, stop_tokens=normalize_stop_tokens_for_batch_generator(stop_tokens), logits_processors=logits_processors, completion_batch_size=batch_size, prefill_batch_size=batch_size, ) uids = gen.insert(prompts, [max_new_tokens] * len(prompts)) uid_to_pos = {uid: pos for pos, uid in enumerate(uids)} generated_tokens = {uid: [] for uid in uids} finished = set() try: while True: responses = gen.next() prompt_responses, generation_responses = split_batch_generator_responses( responses ) if not prompt_responses and not generation_responses: break if not generation_responses: continue for response in generation_responses: uid = response.uid if response.finish_reason != "stop": generated_tokens[uid].append(int(response.token)) if response.finish_reason is None or uid in finished: continue finished.add(uid) pos = uid_to_pos[uid] text = tokenizer.decode(generated_tokens[uid]) prediction = parse_generated(text) answer = answers[pos] answer_int = int(answer) record = { "row_index": row_indices[pos], "answer": answer, "prediction": prediction, "correct": prediction == answer_int, "parsed": prediction is not None, "generated_text": text, "generated_tokens": generated_tokens[uid], "finish_reason": response.finish_reason, } output_file.write(json.dumps(record) + "\n") progress.update(1) finally: gen.close() mx.clear_cache() gc.collect() if len(finished) != len(uids): raise RuntimeError(f"Chunk finished {len(finished)}/{len(uids)} rows") def main(): args = parse_args() started = time.perf_counter() sequences, answers = load_sequences(args.data_path) if args.max_examples > 0: sequences = sequences[: args.max_examples] answers = answers[: args.max_examples] print(f"Loaded {len(answers)} sequences from {args.data_path}") model, tokenizer = load(args.model) sep_tokens = encode_no_special(tokenizer, ",") if not sep_tokens: sep_tokens = encode_no_special(tokenizer, "1,")[-1:] prompts = [",".join(str(x) for x in seq) + "," for seq in sequences] prompts = [tokenizer.encode(p) for p in prompts] eval_indices = [ i for i, prompt in enumerate(prompts) if args.max_context_tokens <= 0 or len(prompt) < args.max_context_tokens ] skipped_long = len(prompts) - len(eval_indices) eval_indices = sorted(eval_indices, key=lambda i: len(prompts[i])) print( f"Evaluating {len(eval_indices)} rows; skipped_long={skipped_long}; " f"sep_tokens={sep_tokens}; eos_token={tokenizer.eos_token_id}; " f"max_new_tokens={args.max_new_tokens}; batch_size={args.batch_size}" ) stop_tokens = [sep_tokens] if tokenizer.eos_token_id is not None: stop_tokens.append([tokenizer.eos_token_id]) logits_processors = None allowed_token_ids = None restrict_integer = args.restrict_digit_comma_eos or args.restrict_integer_comma_eos if restrict_integer: processor, allowed_token_ids = make_integer_comma_eos_processor(tokenizer) logits_processors = [processor] print(f"Restricting logits to integer/comma/EOS token ids: {allowed_token_ids}") args.output.parent.mkdir(parents=True, exist_ok=True) args.summary_output.parent.mkdir(parents=True, exist_ok=True) if args.overwrite and args.output.exists(): args.output.unlink() completed = load_completed(args.output) if completed: print(f"Resuming from {args.output}: {len(completed)} rows already done") todo_indices = [idx for idx in eval_indices if idx not in completed] with args.output.open("a", encoding="utf-8") as output_file: with tqdm(total=len(todo_indices), desc="Generating") as progress: run_generation_queue( model=model, tokenizer=tokenizer, prompts=[prompts[i] for i in todo_indices], answers=[answers[i] for i in todo_indices], row_indices=todo_indices, stop_tokens=stop_tokens, max_new_tokens=args.max_new_tokens, batch_size=args.batch_size, output_file=output_file, progress=progress, logits_processors=logits_processors, ) output_file.flush() records = load_completed(args.output) eval_set = set(eval_indices) records = {idx: record for idx, record in records.items() if idx in eval_set} correct = sum(1 for record in records.values() if record["correct"]) parsed = sum(1 for record in records.values() if record["parsed"]) total = len(eval_indices) evaluated = len(records) elapsed = time.perf_counter() - started print(f"Documents: {len(answers)}") print(f"Evaluated: {evaluated}/{total}") print(f"Skipped long: {skipped_long}") print(f"Parsed predictions: {parsed}/{evaluated}") print(f"Accuracy: {correct}/{evaluated} = {correct / evaluated:.4f}") summary = { "data_path": str(args.data_path), "model": args.model, "output": str(args.output), "documents": len(answers), "evaluated": evaluated, "expected_evaluated": total, "skipped_long": skipped_long, "parsed": parsed, "correct": correct, "accuracy": correct / evaluated if evaluated else 0.0, "max_new_tokens": args.max_new_tokens, "max_context_tokens": args.max_context_tokens, "batch_size": args.batch_size, "restrict_digit_comma_eos": args.restrict_digit_comma_eos, "restrict_integer_comma_eos": restrict_integer, "allowed_token_ids": allowed_token_ids, "seconds": elapsed, } args.summary_output.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8") print(f"Wrote {args.output}") print(f"Wrote {args.summary_output}") if __name__ == "__main__": main()