File size: 4,876 Bytes
9c63689 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | #!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import random
import string
from pathlib import Path
SUBJECTS = [
"compiler pass",
"proof lemma",
"satellite sensor",
"database shard",
"medical note",
"legal clause",
"robot actuator",
"training run",
"network trace",
"chemical sample",
]
RELATIONS = [
"checksum",
"owner",
"dependency",
"calibration",
"exception",
"threshold",
"signature",
"revision",
]
FILLER = (
"The surrounding material is intentionally plausible but irrelevant. "
"A correct model must preserve the named key and ignore nearby distractors. "
"The text may contain repeated structure, altered order, and misleading values. "
)
def code(rng: random.Random, prefix: str, size: int = 8) -> str:
alphabet = string.ascii_uppercase + string.digits
return prefix + "-" + "".join(rng.choice(alphabet) for _ in range(size))
def filler_words(rng: random.Random, target_words: int) -> str:
pieces = []
while len(" ".join(pieces).split()) < target_words:
pieces.append(FILLER)
pieces.append(f"Noise marker {code(rng, 'NOISE', 5)} is not the answer.")
return " ".join(pieces)
def make_needle_example(rng: random.Random, idx: int, slots: int, distractor_words: int) -> dict:
facts = []
keys = []
for _ in range(slots):
key = code(rng, "KEY")
value = code(rng, "VAL")
subject = rng.choice(SUBJECTS)
relation = rng.choice(RELATIONS)
keys.append((key, value, subject, relation))
facts.append(f"Record {key}: subject={subject}; relation={relation}; value={value}.")
rng.shuffle(facts)
target_key, target_value, target_subject, target_relation = rng.choice(keys)
prompt = [
f"Long-context memory drill {idx}.",
"Store each record exactly. Later, answer from the records, not from nearby guesses.",
filler_words(rng, distractor_words // 2),
"\n".join(facts),
filler_words(rng, distractor_words // 2),
f"Question: For {target_key}, what value is attached to relation {target_relation} for {target_subject}?",
f"Answer: {target_value}",
]
return {
"mode": "needle",
"target_key": target_key,
"target_value": target_value,
"text": "\n\n".join(prompt),
}
def make_multihop_example(rng: random.Random, idx: int, chain_len: int, distractor_words: int) -> dict:
nodes = [code(rng, "NODE") for _ in range(chain_len + 1)]
final_value = code(rng, "FINAL")
edges = [f"Link {nodes[i]} -> {nodes[i + 1]}." for i in range(chain_len)]
edges.append(f"Terminal {nodes[-1]} has value {final_value}.")
distractors = [
f"Link {code(rng, 'NODE')} -> {code(rng, 'NODE')}."
for _ in range(max(4, chain_len * 2))
]
all_lines = edges + distractors
rng.shuffle(all_lines)
prompt = [
f"Long-context chain drill {idx}.",
"Follow only exact links. Distractor links may look similar but are unrelated.",
filler_words(rng, distractor_words // 2),
"\n".join(all_lines),
filler_words(rng, distractor_words // 2),
f"Question: Starting from {nodes[0]}, follow the links to the terminal. What terminal value is reached?",
f"Answer: {final_value}",
]
return {
"mode": "multihop",
"start": nodes[0],
"target_value": final_value,
"text": "\n\n".join(prompt),
}
def main() -> int:
parser = argparse.ArgumentParser(description="Generate AGILLM-4 long-context recall curriculum JSONL")
parser.add_argument("--out", required=True)
parser.add_argument("--examples", type=int, default=128)
parser.add_argument("--mode", choices=["needle", "multihop", "mixed"], default="mixed")
parser.add_argument("--slots", type=int, default=64)
parser.add_argument("--chain_len", type=int, default=8)
parser.add_argument("--distractor_words", type=int, default=1200)
parser.add_argument("--seed", type=int, default=401)
args = parser.parse_args()
rng = random.Random(args.seed)
out = Path(args.out)
out.parent.mkdir(parents=True, exist_ok=True)
with out.open("w", encoding="utf-8") as handle:
for idx in range(args.examples):
mode = args.mode
if mode == "mixed":
mode = "needle" if idx % 2 == 0 else "multihop"
if mode == "needle":
row = make_needle_example(rng, idx, args.slots, args.distractor_words)
else:
row = make_multihop_example(rng, idx, args.chain_len, args.distractor_words)
handle.write(json.dumps(row, ensure_ascii=True) + "\n")
print(f"wrote {args.examples} examples to {out}")
return 0
if __name__ == "__main__":
raise SystemExit(main())
|