dcostenco commited on
Commit
0dfe3d5
·
verified ·
1 Parent(s): 7d5ab03

Add training/build_4b_v43_corpus.py

Browse files
Files changed (1) hide show
  1. training/build_4b_v43_corpus.py +238 -0
training/build_4b_v43_corpus.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ build_4b_v43_corpus.py — Corpus builder for Prism Coder 4B v43.
4
+
5
+ Same source mix as 14B v44 (topfive_v2, combined_aac_full, layer3, grounded_recall).
6
+ Outputs two files for mlx_lm.lora: train.jsonl + valid.jsonl in a target directory.
7
+
8
+ Required sources:
9
+ - topfive_v2.train.7b.jsonl v2 mix (40% AAC / 12% abstention / 12% safety / 36% tool-use)
10
+ - combined_aac_full.jsonl 57k clinical AAC corpus, subsampled to 7000
11
+ - layer3_corpus.jsonl 45 rows × 5 oversample — MANDATORY for 3-layer AAC arch
12
+ - grounded_recall_corpus.jsonl 40 rows × 5 oversample — cascade/verifier compatibility
13
+
14
+ Usage:
15
+ python3 build_4b_v43_corpus.py [--out-dir DIR] [--valid-frac 0.05] [--seed 42]
16
+
17
+ Hard-gate audit runs before writing. Exits 1 on any failure.
18
+ """
19
+ import json
20
+ import random
21
+ import sys
22
+ from pathlib import Path
23
+ import argparse
24
+
25
+ BASE_DIR = Path("/Users/admin/synalux-private/prism-training/data/topfive")
26
+ PRISM_DATA = Path("/Users/admin/prism/training/data")
27
+
28
+ SOURCES = {
29
+ "v2_base": BASE_DIR / "topfive_v2.train.7b.jsonl",
30
+ "combined_aac": BASE_DIR / "combined_aac_full.jsonl",
31
+ "layer3": BASE_DIR / "layer3_corpus.jsonl",
32
+ "grounded_recall": PRISM_DATA / "grounded_recall_corpus.jsonl",
33
+ }
34
+
35
+ AAC_FULL_SUBSAMPLE = 7000
36
+ LAYER3_OVERSAMPLE = 5
37
+ GROUNDED_OVERSAMPLE = 5
38
+
39
+ GATE = {
40
+ "min_total": 20_000,
41
+ "min_aac_frac": 0.35,
42
+ "min_layer3": 40,
43
+ "min_grounded": 80,
44
+ "min_tool_calls": 5_000,
45
+ "min_safety": 500,
46
+ }
47
+
48
+
49
+ def messages_to_chatml(messages: list) -> str:
50
+ parts = []
51
+ for m in messages:
52
+ role = m.get("role", "user")
53
+ content = m.get("content", "")
54
+ parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
55
+ return "\n".join(parts)
56
+
57
+
58
+ def normalize(row: dict) -> dict:
59
+ if "text" in row:
60
+ return {"text": row["text"], "_bucket": row.get("_bucket", ""), "source": row.get("source", "")}
61
+ if "messages" in row:
62
+ return {
63
+ "text": messages_to_chatml(row["messages"]),
64
+ "_bucket": row.get("_bucket", ""),
65
+ "source": row.get("source", ""),
66
+ }
67
+ return row
68
+
69
+
70
+ def load_jsonl(path: Path) -> list[dict]:
71
+ if not path.exists():
72
+ print(f"FATAL: {path} missing — aborting", file=sys.stderr)
73
+ sys.exit(1)
74
+ rows = []
75
+ with path.open() as f:
76
+ for line in f:
77
+ line = line.strip()
78
+ if line:
79
+ try:
80
+ rows.append(json.loads(line))
81
+ except json.JSONDecodeError:
82
+ pass
83
+ return rows
84
+
85
+
86
+ def oversample(rows: list[dict], factor: int, rng: random.Random) -> list[dict]:
87
+ out = []
88
+ for _ in range(factor):
89
+ cycle = rows.copy()
90
+ rng.shuffle(cycle)
91
+ out.extend(cycle)
92
+ return out
93
+
94
+
95
+ def main():
96
+ p = argparse.ArgumentParser()
97
+ p.add_argument("--out-dir", type=Path, default=Path("/tmp/4b_v43_data"),
98
+ help="Output directory — will contain train.jsonl and valid.jsonl")
99
+ p.add_argument("--valid-frac", type=float, default=0.05,
100
+ help="Fraction of data held out for validation (default 0.05)")
101
+ p.add_argument("--seed", type=int, default=42)
102
+ args = p.parse_args()
103
+
104
+ rng = random.Random(args.seed)
105
+
106
+ print("=== Prism Coder 4B v43 Corpus Builder ===\n")
107
+ print("Checking source files...")
108
+ for name, path in SOURCES.items():
109
+ if not path.exists():
110
+ print(f" FATAL: {name} missing: {path}", file=sys.stderr)
111
+ sys.exit(1)
112
+ print(f" OK: {name} ({path.name})")
113
+ print()
114
+
115
+ # 1. v2 base mix
116
+ v2_rows = [normalize(r) for r in load_jsonl(SOURCES["v2_base"])]
117
+ print(f"v2 base mix: {len(v2_rows):>6} rows")
118
+
119
+ # 2. combined_aac_full subsample
120
+ aac_full = load_jsonl(SOURCES["combined_aac"])
121
+ rng.shuffle(aac_full)
122
+ aac_sample = [normalize(r) for r in aac_full[:AAC_FULL_SUBSAMPLE]]
123
+ for r in aac_sample:
124
+ r["_bucket"] = "aac"
125
+ r.setdefault("source", "combined_aac_full")
126
+ print(f"combined_aac subsamp: {len(aac_sample):>6} rows (from {len(aac_full)} available)")
127
+
128
+ # 3. layer3 oversample (MANDATORY)
129
+ layer3_base = load_jsonl(SOURCES["layer3"])
130
+ layer3_rows = oversample([normalize(r) for r in layer3_base], LAYER3_OVERSAMPLE, rng)
131
+ for r in layer3_rows:
132
+ r["_bucket"] = "layer3"
133
+ r.setdefault("source", "layer3_corpus")
134
+ print(f"layer3 (×{LAYER3_OVERSAMPLE}): {len(layer3_rows):>6} rows")
135
+
136
+ # 4. grounded_recall oversample
137
+ gr_base = load_jsonl(SOURCES["grounded_recall"])
138
+ gr_rows = oversample([normalize(r) for r in gr_base], GROUNDED_OVERSAMPLE, rng)
139
+ for r in gr_rows:
140
+ r["_bucket"] = "grounded_recall"
141
+ r.setdefault("source", "grounded_recall_corpus")
142
+ print(f"grounded_recall (×{GROUNDED_OVERSAMPLE}): {len(gr_rows):>6} rows")
143
+
144
+ # 5. Compose + shuffle
145
+ all_rows = v2_rows + aac_sample + layer3_rows + gr_rows
146
+ rng.shuffle(all_rows)
147
+ print(f"\nTotal before filter: {len(all_rows):>6} rows")
148
+
149
+ # 6. Audit (on all_rows while bucket tags still present)
150
+ print("\n=== CORPUS AUDIT (mandatory — failures are fatal) ===")
151
+ failed = False
152
+ total = len(all_rows)
153
+
154
+ if total < GATE["min_total"]:
155
+ print(f" FAIL: total {total} < {GATE['min_total']}")
156
+ failed = True
157
+ else:
158
+ print(f" OK: total {total} >= {GATE['min_total']}")
159
+
160
+ aac_count = sum(1 for r in all_rows if r.get("_bucket") == "aac")
161
+ aac_frac = aac_count / total
162
+ print(f" AAC fraction: {aac_frac:.1%} ({aac_count} rows)")
163
+ if aac_frac < GATE["min_aac_frac"]:
164
+ print(f" FAIL: AAC fraction {aac_frac:.1%} < {GATE['min_aac_frac']:.1%}")
165
+ failed = True
166
+ else:
167
+ print(f" OK: AAC fraction >= {GATE['min_aac_frac']:.1%}")
168
+
169
+ layer3_count = sum(1 for r in all_rows if "[LAYER3" in r.get("text", ""))
170
+ print(f" Layer3 examples: {layer3_count}")
171
+ if layer3_count < GATE["min_layer3"]:
172
+ print(f" FAIL: layer3 {layer3_count} < {GATE['min_layer3']}")
173
+ failed = True
174
+ else:
175
+ print(f" OK: layer3 >= {GATE['min_layer3']}")
176
+
177
+ gr_count = sum(1 for r in all_rows if r.get("_bucket") == "grounded_recall")
178
+ print(f" Grounded recall rows: {gr_count}")
179
+ if gr_count < GATE["min_grounded"]:
180
+ print(f" FAIL: grounded_recall {gr_count} < {GATE['min_grounded']}")
181
+ failed = True
182
+ else:
183
+ print(f" OK: grounded_recall >= {GATE['min_grounded']}")
184
+
185
+ tool_count = sum(1 for r in all_rows if "<tool_call>" in r.get("text", "") or "tool_call" in r.get("text", ""))
186
+ print(f" Tool-call SFT rows: {tool_count}")
187
+ if tool_count < GATE["min_tool_calls"]:
188
+ print(f" FAIL: tool_calls {tool_count} < {GATE['min_tool_calls']}")
189
+ failed = True
190
+ else:
191
+ print(f" OK: tool_calls >= {GATE['min_tool_calls']}")
192
+
193
+ safety_count = sum(1 for r in all_rows if any(
194
+ t in r.get("text", "") for t in ["abstain", "cannot", "refuse", "I should not", "harmful"]
195
+ ))
196
+ print(f" Safety/abstention rows (approx): {safety_count}")
197
+ if safety_count < GATE["min_safety"]:
198
+ print(f" WARN: safety {safety_count} < {GATE['min_safety']}")
199
+ else:
200
+ print(f" OK: safety/abstention present")
201
+
202
+ print("\n=== Composition summary ===")
203
+ for bucket in ["aac", "abstention", "safety", "tool_use", "layer3", "grounded_recall"]:
204
+ n = sum(1 for r in all_rows if r.get("_bucket") == bucket)
205
+ pct = n / total * 100 if total else 0
206
+ print(f" {bucket:>20}: {n:>6} ({pct:5.1f}%)")
207
+
208
+ if failed:
209
+ print("\nFATAL: Corpus audit failed — DO NOT use this corpus for training")
210
+ sys.exit(1)
211
+
212
+ # 7. Strip metadata, write train/valid split
213
+ final = [{"text": r["text"]} for r in all_rows if r.get("text", "").strip()]
214
+ rng.shuffle(final)
215
+
216
+ n_valid = max(1, int(len(final) * args.valid_frac))
217
+ valid_rows = final[:n_valid]
218
+ train_rows = final[n_valid:]
219
+
220
+ args.out_dir.mkdir(parents=True, exist_ok=True)
221
+ train_path = args.out_dir / "train.jsonl"
222
+ valid_path = args.out_dir / "valid.jsonl"
223
+
224
+ with train_path.open("w") as f:
225
+ for row in train_rows:
226
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
227
+ with valid_path.open("w") as f:
228
+ for row in valid_rows:
229
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
230
+
231
+ print(f"\n✅ All gates passed — corpus ready")
232
+ print(f" Train: {len(train_rows):>6} rows → {train_path}")
233
+ print(f" Valid: {len(valid_rows):>6} rows → {valid_path}")
234
+ print(f"\nNext: bash /Users/admin/synalux-private/prism-training/train_4b_v43_local.sh")
235
+
236
+
237
+ if __name__ == "__main__":
238
+ main()