#!/usr/bin/env python3 from __future__ import annotations import argparse import json import random import string from pathlib import Path SUBJECTS = [ "compiler pass", "proof lemma", "satellite sensor", "database shard", "medical note", "legal clause", "robot actuator", "training run", "network trace", "chemical sample", ] RELATIONS = [ "checksum", "owner", "dependency", "calibration", "exception", "threshold", "signature", "revision", ] FILLER = ( "The surrounding material is intentionally plausible but irrelevant. " "A correct model must preserve the named key and ignore nearby distractors. " "The text may contain repeated structure, altered order, and misleading values. " ) def code(rng: random.Random, prefix: str, size: int = 8) -> str: alphabet = string.ascii_uppercase + string.digits return prefix + "-" + "".join(rng.choice(alphabet) for _ in range(size)) def filler_words(rng: random.Random, target_words: int) -> str: pieces = [] while len(" ".join(pieces).split()) < target_words: pieces.append(FILLER) pieces.append(f"Noise marker {code(rng, 'NOISE', 5)} is not the answer.") return " ".join(pieces) def make_needle_example(rng: random.Random, idx: int, slots: int, distractor_words: int) -> dict: facts = [] keys = [] for _ in range(slots): key = code(rng, "KEY") value = code(rng, "VAL") subject = rng.choice(SUBJECTS) relation = rng.choice(RELATIONS) keys.append((key, value, subject, relation)) facts.append(f"Record {key}: subject={subject}; relation={relation}; value={value}.") rng.shuffle(facts) target_key, target_value, target_subject, target_relation = rng.choice(keys) prompt = [ f"Long-context memory drill {idx}.", "Store each record exactly. Later, answer from the records, not from nearby guesses.", filler_words(rng, distractor_words // 2), "\n".join(facts), filler_words(rng, distractor_words // 2), f"Question: For {target_key}, what value is attached to relation {target_relation} for {target_subject}?", f"Answer: {target_value}", ] return { "mode": "needle", "target_key": target_key, "target_value": target_value, "text": "\n\n".join(prompt), } def make_multihop_example(rng: random.Random, idx: int, chain_len: int, distractor_words: int) -> dict: nodes = [code(rng, "NODE") for _ in range(chain_len + 1)] final_value = code(rng, "FINAL") edges = [f"Link {nodes[i]} -> {nodes[i + 1]}." for i in range(chain_len)] edges.append(f"Terminal {nodes[-1]} has value {final_value}.") distractors = [ f"Link {code(rng, 'NODE')} -> {code(rng, 'NODE')}." for _ in range(max(4, chain_len * 2)) ] all_lines = edges + distractors rng.shuffle(all_lines) prompt = [ f"Long-context chain drill {idx}.", "Follow only exact links. Distractor links may look similar but are unrelated.", filler_words(rng, distractor_words // 2), "\n".join(all_lines), filler_words(rng, distractor_words // 2), f"Question: Starting from {nodes[0]}, follow the links to the terminal. What terminal value is reached?", f"Answer: {final_value}", ] return { "mode": "multihop", "start": nodes[0], "target_value": final_value, "text": "\n\n".join(prompt), } def main() -> int: parser = argparse.ArgumentParser(description="Generate AGILLM-4 long-context recall curriculum JSONL") parser.add_argument("--out", required=True) parser.add_argument("--examples", type=int, default=128) parser.add_argument("--mode", choices=["needle", "multihop", "mixed"], default="mixed") parser.add_argument("--slots", type=int, default=64) parser.add_argument("--chain_len", type=int, default=8) parser.add_argument("--distractor_words", type=int, default=1200) parser.add_argument("--seed", type=int, default=401) args = parser.parse_args() rng = random.Random(args.seed) out = Path(args.out) out.parent.mkdir(parents=True, exist_ok=True) with out.open("w", encoding="utf-8") as handle: for idx in range(args.examples): mode = args.mode if mode == "mixed": mode = "needle" if idx % 2 == 0 else "multihop" if mode == "needle": row = make_needle_example(rng, idx, args.slots, args.distractor_words) else: row = make_multihop_example(rng, idx, args.chain_len, args.distractor_words) handle.write(json.dumps(row, ensure_ascii=True) + "\n") print(f"wrote {args.examples} examples to {out}") return 0 if __name__ == "__main__": raise SystemExit(main())