File size: 3,787 Bytes
b14638e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Paraphrase benchmark for the GENERALIZATION push — disjoint train/test templates.

The field's open metric where Yaz (and all GRACE-class lookup editors) lose is
*generalization*: an edit made in one phrasing should hold under other phrasings.
We build a ZsRE-style split:
  - TRAIN templates (8): the model trains on these.
  - HELD-OUT TEST templates (5): DISJOINT, never seen in training. The edit-transfer
    test probes only these — so any transfer is real generalization, not memorization.

All templates END at the answer (causal LM: the capital is the next token after the prompt),
so routing supervision lands on the same answer position regardless of phrasing.

Country->capital pairs are taken from the existing facts_50.jsonl (50 facts), deduped.

Outputs:
  data/facts_para_train.jsonl       — 50 facts x 8 train templates (text=prefix+capital+".")
  data/probes_para_indist.jsonl     — reliability probes, TRAIN template #0
  data/probes_para_heldout.jsonl    — generalization probes, the 5 TEST templates (250 rows)
"""
from __future__ import annotations
import json
from pathlib import Path

ROOT = Path(__file__).resolve().parent.parent
SRC = ROOT / "data" / "facts_50.jsonl"

# All prefixes end with a space; text = prefix + capital + "."
TRAIN_TEMPLATES = [
    "The capital of {C} is ",
    "{C}'s capital is ",
    "The capital city of {C} is ",
    "Capital of {C}: ",
    "In {C}, the capital is ",
    "{C} has its capital at ",
    "The country {C} has its capital, which is ",
    "Q: What is the capital of {C}? A: ",
]
# DISJOINT held-out phrasings — never trained on.
TEST_TEMPLATES = [
    "{C} — capital: ",
    "The seat of government of {C} is located in ",
    "If you visit {C}, the capital you arrive in is ",
    "The administrative capital of {C} is ",
    "Name the capital of {C}: ",
]


def pairs():
    seen, out = set(), []
    for l in SRC.read_text().splitlines():
        if not l:
            continue
        r = json.loads(l)
        if r["country"] in seen:
            continue
        seen.add(r["country"])
        out.append((r["country"], r["capital"]))
    return out


def main():
    ps = pairs()
    # training facts: 8 phrasings per fact, tagged with template_id (0..7)
    train_rows = []
    for c, cap in ps:
        for tid, tmpl in enumerate(TRAIN_TEMPLATES):
            train_rows.append({"country": c, "capital": cap, "template_id": tid,
                               "text": tmpl.format(C=c) + cap + "."})
    (ROOT / "data" / "facts_para_train.jsonl").write_text(
        "\n".join(json.dumps(r) for r in train_rows) + "\n")

    # reliability probes: in-distribution (train template #0)
    indist = [{"country": c, "capital": cap,
               "prompt": TRAIN_TEMPLATES[0].format(C=c), "expected_first_byte": cap[0]}
              for c, cap in ps]
    (ROOT / "data" / "probes_para_indist.jsonl").write_text(
        "\n".join(json.dumps(r) for r in indist) + "\n")

    # generalization probes: held-out templates (one row per country x test-template)
    held = []
    for c, cap in ps:
        for tid, tmpl in enumerate(TEST_TEMPLATES):
            held.append({"country": c, "capital": cap, "test_template_id": tid,
                         "prompt": tmpl.format(C=c), "expected_first_byte": cap[0]})
    (ROOT / "data" / "probes_para_heldout.jsonl").write_text(
        "\n".join(json.dumps(r) for r in held) + "\n")

    print(f"facts: {len(ps)}  train_rows: {len(train_rows)} ({len(TRAIN_TEMPLATES)} tmpl/fact)")
    print(f"indist probes: {len(indist)}  heldout probes: {len(held)} "
          f"({len(TEST_TEMPLATES)} tmpl/fact)")
    print("train/test templates are DISJOINT:",
          set(TRAIN_TEMPLATES).isdisjoint(set(TEST_TEMPLATES)))


if __name__ == "__main__":
    main()