File size: 4,876 Bytes
9c63689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import random
import string
from pathlib import Path


SUBJECTS = [
    "compiler pass",
    "proof lemma",
    "satellite sensor",
    "database shard",
    "medical note",
    "legal clause",
    "robot actuator",
    "training run",
    "network trace",
    "chemical sample",
]

RELATIONS = [
    "checksum",
    "owner",
    "dependency",
    "calibration",
    "exception",
    "threshold",
    "signature",
    "revision",
]

FILLER = (
    "The surrounding material is intentionally plausible but irrelevant. "
    "A correct model must preserve the named key and ignore nearby distractors. "
    "The text may contain repeated structure, altered order, and misleading values. "
)


def code(rng: random.Random, prefix: str, size: int = 8) -> str:
    alphabet = string.ascii_uppercase + string.digits
    return prefix + "-" + "".join(rng.choice(alphabet) for _ in range(size))


def filler_words(rng: random.Random, target_words: int) -> str:
    pieces = []
    while len(" ".join(pieces).split()) < target_words:
        pieces.append(FILLER)
        pieces.append(f"Noise marker {code(rng, 'NOISE', 5)} is not the answer.")
    return " ".join(pieces)


def make_needle_example(rng: random.Random, idx: int, slots: int, distractor_words: int) -> dict:
    facts = []
    keys = []
    for _ in range(slots):
        key = code(rng, "KEY")
        value = code(rng, "VAL")
        subject = rng.choice(SUBJECTS)
        relation = rng.choice(RELATIONS)
        keys.append((key, value, subject, relation))
        facts.append(f"Record {key}: subject={subject}; relation={relation}; value={value}.")
    rng.shuffle(facts)
    target_key, target_value, target_subject, target_relation = rng.choice(keys)
    prompt = [
        f"Long-context memory drill {idx}.",
        "Store each record exactly. Later, answer from the records, not from nearby guesses.",
        filler_words(rng, distractor_words // 2),
        "\n".join(facts),
        filler_words(rng, distractor_words // 2),
        f"Question: For {target_key}, what value is attached to relation {target_relation} for {target_subject}?",
        f"Answer: {target_value}",
    ]
    return {
        "mode": "needle",
        "target_key": target_key,
        "target_value": target_value,
        "text": "\n\n".join(prompt),
    }


def make_multihop_example(rng: random.Random, idx: int, chain_len: int, distractor_words: int) -> dict:
    nodes = [code(rng, "NODE") for _ in range(chain_len + 1)]
    final_value = code(rng, "FINAL")
    edges = [f"Link {nodes[i]} -> {nodes[i + 1]}." for i in range(chain_len)]
    edges.append(f"Terminal {nodes[-1]} has value {final_value}.")
    distractors = [
        f"Link {code(rng, 'NODE')} -> {code(rng, 'NODE')}."
        for _ in range(max(4, chain_len * 2))
    ]
    all_lines = edges + distractors
    rng.shuffle(all_lines)
    prompt = [
        f"Long-context chain drill {idx}.",
        "Follow only exact links. Distractor links may look similar but are unrelated.",
        filler_words(rng, distractor_words // 2),
        "\n".join(all_lines),
        filler_words(rng, distractor_words // 2),
        f"Question: Starting from {nodes[0]}, follow the links to the terminal. What terminal value is reached?",
        f"Answer: {final_value}",
    ]
    return {
        "mode": "multihop",
        "start": nodes[0],
        "target_value": final_value,
        "text": "\n\n".join(prompt),
    }


def main() -> int:
    parser = argparse.ArgumentParser(description="Generate AGILLM-4 long-context recall curriculum JSONL")
    parser.add_argument("--out", required=True)
    parser.add_argument("--examples", type=int, default=128)
    parser.add_argument("--mode", choices=["needle", "multihop", "mixed"], default="mixed")
    parser.add_argument("--slots", type=int, default=64)
    parser.add_argument("--chain_len", type=int, default=8)
    parser.add_argument("--distractor_words", type=int, default=1200)
    parser.add_argument("--seed", type=int, default=401)
    args = parser.parse_args()

    rng = random.Random(args.seed)
    out = Path(args.out)
    out.parent.mkdir(parents=True, exist_ok=True)
    with out.open("w", encoding="utf-8") as handle:
        for idx in range(args.examples):
            mode = args.mode
            if mode == "mixed":
                mode = "needle" if idx % 2 == 0 else "multihop"
            if mode == "needle":
                row = make_needle_example(rng, idx, args.slots, args.distractor_words)
            else:
                row = make_multihop_example(rng, idx, args.chain_len, args.distractor_words)
            handle.write(json.dumps(row, ensure_ascii=True) + "\n")
    print(f"wrote {args.examples} examples to {out}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())