AGILLM-4 / long_context_curriculum.py
OpenTransformer's picture
Add AGILLM-4 training scaffold
9c63689 verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import random
import string
from pathlib import Path
SUBJECTS = [
"compiler pass",
"proof lemma",
"satellite sensor",
"database shard",
"medical note",
"legal clause",
"robot actuator",
"training run",
"network trace",
"chemical sample",
]
RELATIONS = [
"checksum",
"owner",
"dependency",
"calibration",
"exception",
"threshold",
"signature",
"revision",
]
FILLER = (
"The surrounding material is intentionally plausible but irrelevant. "
"A correct model must preserve the named key and ignore nearby distractors. "
"The text may contain repeated structure, altered order, and misleading values. "
)
def code(rng: random.Random, prefix: str, size: int = 8) -> str:
alphabet = string.ascii_uppercase + string.digits
return prefix + "-" + "".join(rng.choice(alphabet) for _ in range(size))
def filler_words(rng: random.Random, target_words: int) -> str:
pieces = []
while len(" ".join(pieces).split()) < target_words:
pieces.append(FILLER)
pieces.append(f"Noise marker {code(rng, 'NOISE', 5)} is not the answer.")
return " ".join(pieces)
def make_needle_example(rng: random.Random, idx: int, slots: int, distractor_words: int) -> dict:
facts = []
keys = []
for _ in range(slots):
key = code(rng, "KEY")
value = code(rng, "VAL")
subject = rng.choice(SUBJECTS)
relation = rng.choice(RELATIONS)
keys.append((key, value, subject, relation))
facts.append(f"Record {key}: subject={subject}; relation={relation}; value={value}.")
rng.shuffle(facts)
target_key, target_value, target_subject, target_relation = rng.choice(keys)
prompt = [
f"Long-context memory drill {idx}.",
"Store each record exactly. Later, answer from the records, not from nearby guesses.",
filler_words(rng, distractor_words // 2),
"\n".join(facts),
filler_words(rng, distractor_words // 2),
f"Question: For {target_key}, what value is attached to relation {target_relation} for {target_subject}?",
f"Answer: {target_value}",
]
return {
"mode": "needle",
"target_key": target_key,
"target_value": target_value,
"text": "\n\n".join(prompt),
}
def make_multihop_example(rng: random.Random, idx: int, chain_len: int, distractor_words: int) -> dict:
nodes = [code(rng, "NODE") for _ in range(chain_len + 1)]
final_value = code(rng, "FINAL")
edges = [f"Link {nodes[i]} -> {nodes[i + 1]}." for i in range(chain_len)]
edges.append(f"Terminal {nodes[-1]} has value {final_value}.")
distractors = [
f"Link {code(rng, 'NODE')} -> {code(rng, 'NODE')}."
for _ in range(max(4, chain_len * 2))
]
all_lines = edges + distractors
rng.shuffle(all_lines)
prompt = [
f"Long-context chain drill {idx}.",
"Follow only exact links. Distractor links may look similar but are unrelated.",
filler_words(rng, distractor_words // 2),
"\n".join(all_lines),
filler_words(rng, distractor_words // 2),
f"Question: Starting from {nodes[0]}, follow the links to the terminal. What terminal value is reached?",
f"Answer: {final_value}",
]
return {
"mode": "multihop",
"start": nodes[0],
"target_value": final_value,
"text": "\n\n".join(prompt),
}
def main() -> int:
parser = argparse.ArgumentParser(description="Generate AGILLM-4 long-context recall curriculum JSONL")
parser.add_argument("--out", required=True)
parser.add_argument("--examples", type=int, default=128)
parser.add_argument("--mode", choices=["needle", "multihop", "mixed"], default="mixed")
parser.add_argument("--slots", type=int, default=64)
parser.add_argument("--chain_len", type=int, default=8)
parser.add_argument("--distractor_words", type=int, default=1200)
parser.add_argument("--seed", type=int, default=401)
args = parser.parse_args()
rng = random.Random(args.seed)
out = Path(args.out)
out.parent.mkdir(parents=True, exist_ok=True)
with out.open("w", encoding="utf-8") as handle:
for idx in range(args.examples):
mode = args.mode
if mode == "mixed":
mode = "needle" if idx % 2 == 0 else "multihop"
if mode == "needle":
row = make_needle_example(rng, idx, args.slots, args.distractor_words)
else:
row = make_multihop_example(rng, idx, args.chain_len, args.distractor_words)
handle.write(json.dumps(row, ensure_ascii=True) + "\n")
print(f"wrote {args.examples} examples to {out}")
return 0
if __name__ == "__main__":
raise SystemExit(main())