ModerRAS commited on
Commit
a61b883
·
1 Parent(s): 95246c7

Add Rust cached synthetic training runner

Browse files
AGENTS.md CHANGED
@@ -117,29 +117,31 @@ cargo run --release --manifest-path tools\schema_v2_synthetic_augment\Cargo.toml
117
  ```
118
 
119
  Preferred synthetic follow-up training is a second stage from the best repaired
120
- hard-focus checkpoint, not a replacement for hard-focus. Do not combine
121
- `--encoded-cache-dir` with `--extra-data-file`; use the raw hard-focus JSONL
122
- when mixing synthetic augmentation, or rebuild a combined Rust encoded cache.
123
- Use native Windows Python from `.venv` after confirming CUDA works:
 
 
 
 
124
 
125
  ```powershell
126
- .\.venv\Scripts\python.exe -m anifilebert.train --tokenizer char `
127
- --data-file data\schema_v2_hard_focus_char_seed63.jsonl `
128
- --extra-data-file data\schema_v2_synthetic_aug.jsonl `
129
- --extra-data-repeat 3 `
130
- --vocab-file datasets\AnimeName\vocab.char.json `
131
- --save-dir checkpoints\schema-v2-best-hardfocus-synth-pathleaf `
132
- --init-model-dir checkpoints\ablation-schema-v2-hardfocus-cache-repaired-from-baseline-seed62-10epoch-rerun\final `
133
- --epochs 2 --batch-size 512 --learning-rate 0.00004 --warmup-steps 120 `
134
- --max-seq-length 128 --train-split 0.995 --num-workers 0 `
135
- --checkpoint-steps 1000 --save-total-limit 3 --no-periodic-eval `
136
- --bf16 --auto-find-batch-size `
137
- --parse-eval-limit 2048 `
138
- --case-eval-file data\parser_regression_cases.json `
139
- --case-eval-output reports\schema_v2_best_hardfocus_synth_pathleaf_case_metrics.json `
140
- --seed 63 --experiment-name schema-v2-best-hardfocus-synth-pathleaf
141
  ```
142
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  Export for Android:
144
 
145
  ```bash
 
117
  ```
118
 
119
  Preferred synthetic follow-up training is a second stage from the best repaired
120
+ hard-focus checkpoint, not a replacement for hard-focus. Keep this path
121
+ Rust-cache-first: build one combined encoded cache from hard-focus JSONL plus
122
+ synthetic JSONL, then train from that cache. Do not pass `--extra-data-file` to
123
+ `anifilebert.train` together with `--encoded-cache-dir`.
124
+
125
+ Use the local wrapper, which calls Rust `tools/encoded_dataset_cache` with
126
+ multiple `--input` values and then launches `anifilebert.train` against the
127
+ combined cache:
128
 
129
  ```powershell
130
+ .\.venv\Scripts\python.exe -m tools.train_schema_v2_synthetic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  ```
132
 
133
+ The wrapper defaults to:
134
+
135
+ - primary data: `data\schema_v2_hard_focus_char_seed63.jsonl`
136
+ - synthetic data: `data\schema_v2_synthetic_aug.jsonl`
137
+ - synthetic repeat: `3`
138
+ - encoded cache: `data\encoded_cache\schema_v2_hard_focus_seed63_synth_pathleaf_repeat3`
139
+ - init checkpoint: `checkpoints\ablation-schema-v2-hardfocus-cache-repaired-from-baseline-seed62-10epoch-rerun\final`
140
+ - output checkpoint: `checkpoints\schema-v2-best-hardfocus-synth-pathleaf-cache`
141
+
142
+ Use `--force-cache` to rebuild the combined cache after changing either JSONL,
143
+ vocab, label schema, max length, split ratio, seed, or repeat count.
144
+
145
  Export for Android:
146
 
147
  ```bash
tools/encoded_dataset_cache/README.md CHANGED
@@ -21,6 +21,26 @@ cargo run --release --manifest-path tools\encoded_dataset_cache\Cargo.toml -- `
21
  --threads 16
22
  ```
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  Use the cache in training:
25
 
26
  ```powershell
 
21
  --threads 16
22
  ```
23
 
24
+ Multiple JSONL inputs can be encoded into one deterministic train/eval split.
25
+ Pass `--input-repeat` once per `--input` when an augmentation source should be
26
+ upweighted:
27
+
28
+ ```powershell
29
+ cargo run --release --manifest-path tools\encoded_dataset_cache\Cargo.toml -- `
30
+ --input data\schema_v2_hard_focus_char_seed63.jsonl `
31
+ --input data\schema_v2_synthetic_aug.jsonl `
32
+ --input-repeat 1 `
33
+ --input-repeat 3 `
34
+ --vocab-file datasets\AnimeName\vocab.char.json `
35
+ --label-schema-file label_schema.json `
36
+ --output-dir data\encoded_cache\schema_v2_hard_focus_seed63_synth_pathleaf_repeat3 `
37
+ --max-length 128 `
38
+ --train-split 0.995 `
39
+ --seed 63 `
40
+ --shard-size 25000 `
41
+ --threads 16
42
+ ```
43
+
44
  Use the cache in training:
45
 
46
  ```powershell
tools/encoded_dataset_cache/src/main.rs CHANGED
@@ -1,3 +1,5 @@
 
 
1
  use anyhow::{bail, Context, Result};
2
  use clap::Parser;
3
  use fancy_regex::Regex as FancyRegex;
@@ -78,13 +80,16 @@ const SEPARATOR_CHARS: &[char] = &[' ', '\t', '-', '_', '.', '|', '~', '~'];
78
  )]
79
  struct Args {
80
  #[arg(long)]
81
- input: PathBuf,
 
 
 
82
 
83
  #[arg(long)]
84
- vocab_file: PathBuf,
85
 
86
  #[arg(long)]
87
- output_dir: PathBuf,
88
 
89
  #[arg(long, default_value = "label_schema.json")]
90
  label_schema_file: PathBuf,
@@ -109,6 +114,15 @@ struct Args {
109
 
110
  #[arg(long, default_value_t = 0)]
111
  threads: usize,
 
 
 
 
 
 
 
 
 
112
  }
113
 
114
  #[derive(Debug, Deserialize)]
@@ -160,6 +174,24 @@ struct SplitSummary {
160
 
161
  fn main() -> Result<()> {
162
  let args = Args::parse();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  if args.max_length < 4 {
164
  bail!("--max-length must be at least 4");
165
  }
@@ -177,9 +209,10 @@ fn main() -> Result<()> {
177
  }
178
 
179
  let started = Instant::now();
180
- let vocab = load_vocab(&args.vocab_file)?;
181
  let label_ids = load_label_ids(&args.label_schema_file)?;
182
- let mut rows = load_rows(&args.input, args.limit_rows)?;
 
183
  if rows.len() < 2 {
184
  bail!("need at least two rows to build train/eval cache");
185
  }
@@ -192,10 +225,10 @@ fn main() -> Result<()> {
192
  let split_idx = split_idx.max(1).min(rows.len() - 1);
193
  let (train_rows, eval_rows) = rows.split_at(split_idx);
194
 
195
- fs::create_dir_all(&args.output_dir).with_context(|| {
196
  format!(
197
  "failed to create output directory {}",
198
- args.output_dir.display()
199
  )
200
  })?;
201
 
@@ -207,25 +240,26 @@ fn main() -> Result<()> {
207
  let train_summary = write_split(
208
  "train",
209
  train_rows,
210
- &args.output_dir,
211
  &context,
212
  args.shard_size,
213
  )?;
214
  let eval_summary = write_split(
215
  "eval",
216
  eval_rows,
217
- &args.output_dir,
218
  &context,
219
  args.shard_size,
220
  )?;
221
- write_eval_records(eval_rows, &args.output_dir.join("eval_records.jsonl"))?;
222
 
223
  let manifest = json!({
224
  "format": "anifilebert.encoded_dataset_cache.v1",
225
- "input": args.input,
226
- "vocab_file": args.vocab_file,
 
227
  "label_schema_file": args.label_schema_file,
228
- "output_dir": args.output_dir,
229
  "max_length": args.max_length,
230
  "shard_size": args.shard_size,
231
  "limit_rows": args.limit_rows,
@@ -238,7 +272,7 @@ fn main() -> Result<()> {
238
  "eval_records": "eval_records.jsonl",
239
  "elapsed_seconds": started.elapsed().as_secs_f64(),
240
  });
241
- let manifest_path = args.output_dir.join("manifest.json");
242
  fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
243
  .with_context(|| format!("failed to write {}", manifest_path.display()))?;
244
  println!("{}", serde_json::to_string_pretty(&manifest)?);
@@ -293,14 +327,67 @@ fn load_label_ids(path: &Path) -> Result<HashMap<String, i16>> {
293
  .collect())
294
  }
295
 
296
- fn load_rows(path: &Path, limit_rows: usize) -> Result<Vec<SourceRow>> {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?;
298
  let reader = BufReader::new(file);
299
  let mut rows = Vec::new();
300
  for (idx, line) in reader.lines().enumerate() {
301
- if limit_rows > 0 && rows.len() >= limit_rows {
302
- break;
303
- }
304
  let raw_line = line.with_context(|| format!("failed reading line {}", idx + 1))?;
305
  if raw_line.trim().is_empty() {
306
  continue;
 
1
+ mod regex_benchmark;
2
+
3
  use anyhow::{bail, Context, Result};
4
  use clap::Parser;
5
  use fancy_regex::Regex as FancyRegex;
 
80
  )]
81
  struct Args {
82
  #[arg(long)]
83
+ input: Vec<PathBuf>,
84
+
85
+ #[arg(long, value_name = "N")]
86
+ input_repeat: Vec<usize>,
87
 
88
  #[arg(long)]
89
+ vocab_file: Option<PathBuf>,
90
 
91
  #[arg(long)]
92
+ output_dir: Option<PathBuf>,
93
 
94
  #[arg(long, default_value = "label_schema.json")]
95
  label_schema_file: PathBuf,
 
114
 
115
  #[arg(long, default_value_t = 0)]
116
  threads: usize,
117
+
118
+ #[arg(long)]
119
+ regex_benchmark_input: Option<PathBuf>,
120
+
121
+ #[arg(long, default_value_t = 0)]
122
+ regex_benchmark_limit_rows: usize,
123
+
124
+ #[arg(long, default_value_t = 3)]
125
+ regex_benchmark_repeat: usize,
126
  }
127
 
128
  #[derive(Debug, Deserialize)]
 
174
 
175
  fn main() -> Result<()> {
176
  let args = Args::parse();
177
+ if let Some(input) = &args.regex_benchmark_input {
178
+ return regex_benchmark::run(
179
+ input,
180
+ args.regex_benchmark_limit_rows,
181
+ args.regex_benchmark_repeat,
182
+ );
183
+ }
184
+ if args.input.is_empty() {
185
+ bail!("at least one --input is required");
186
+ }
187
+ let vocab_file = args
188
+ .vocab_file
189
+ .as_ref()
190
+ .context("--vocab-file is required when building an encoded cache")?;
191
+ let output_dir = args
192
+ .output_dir
193
+ .as_ref()
194
+ .context("--output-dir is required when building an encoded cache")?;
195
  if args.max_length < 4 {
196
  bail!("--max-length must be at least 4");
197
  }
 
209
  }
210
 
211
  let started = Instant::now();
212
+ let vocab = load_vocab(vocab_file)?;
213
  let label_ids = load_label_ids(&args.label_schema_file)?;
214
+ let input_repeats = resolve_input_repeats(&args.input, &args.input_repeat)?;
215
+ let (mut rows, input_summaries) = load_input_rows(&args.input, &input_repeats, args.limit_rows)?;
216
  if rows.len() < 2 {
217
  bail!("need at least two rows to build train/eval cache");
218
  }
 
225
  let split_idx = split_idx.max(1).min(rows.len() - 1);
226
  let (train_rows, eval_rows) = rows.split_at(split_idx);
227
 
228
+ fs::create_dir_all(output_dir).with_context(|| {
229
  format!(
230
  "failed to create output directory {}",
231
+ output_dir.display()
232
  )
233
  })?;
234
 
 
240
  let train_summary = write_split(
241
  "train",
242
  train_rows,
243
+ output_dir,
244
  &context,
245
  args.shard_size,
246
  )?;
247
  let eval_summary = write_split(
248
  "eval",
249
  eval_rows,
250
+ output_dir,
251
  &context,
252
  args.shard_size,
253
  )?;
254
+ write_eval_records(eval_rows, &output_dir.join("eval_records.jsonl"))?;
255
 
256
  let manifest = json!({
257
  "format": "anifilebert.encoded_dataset_cache.v1",
258
+ "input": args.input.first(),
259
+ "inputs": input_summaries,
260
+ "vocab_file": vocab_file,
261
  "label_schema_file": args.label_schema_file,
262
+ "output_dir": output_dir,
263
  "max_length": args.max_length,
264
  "shard_size": args.shard_size,
265
  "limit_rows": args.limit_rows,
 
272
  "eval_records": "eval_records.jsonl",
273
  "elapsed_seconds": started.elapsed().as_secs_f64(),
274
  });
275
+ let manifest_path = output_dir.join("manifest.json");
276
  fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
277
  .with_context(|| format!("failed to write {}", manifest_path.display()))?;
278
  println!("{}", serde_json::to_string_pretty(&manifest)?);
 
327
  .collect())
328
  }
329
 
330
+ fn resolve_input_repeats(inputs: &[PathBuf], repeats: &[usize]) -> Result<Vec<usize>> {
331
+ if repeats.is_empty() {
332
+ return Ok(vec![1; inputs.len()]);
333
+ }
334
+ if repeats.len() == 1 {
335
+ return Ok(vec![repeats[0].max(1); inputs.len()]);
336
+ }
337
+ if repeats.len() != inputs.len() {
338
+ bail!(
339
+ "--input-repeat must be omitted, passed once for all inputs, or passed once per --input ({} inputs, {} repeats)",
340
+ inputs.len(),
341
+ repeats.len()
342
+ );
343
+ }
344
+ Ok(repeats.iter().map(|repeat| (*repeat).max(1)).collect())
345
+ }
346
+
347
+ fn load_input_rows(
348
+ inputs: &[PathBuf],
349
+ repeats: &[usize],
350
+ limit_rows: usize,
351
+ ) -> Result<(Vec<SourceRow>, Vec<Value>)> {
352
+ let mut combined = Vec::new();
353
+ let mut summaries = Vec::new();
354
+ for (path, repeat) in inputs.iter().zip(repeats.iter()) {
355
+ let rows = load_rows(path)?;
356
+ let samples = rows.len();
357
+ let mut written = 0usize;
358
+ for _ in 0..*repeat {
359
+ for row in &rows {
360
+ if limit_rows > 0 && combined.len() >= limit_rows {
361
+ break;
362
+ }
363
+ let mut row = row.clone();
364
+ row.row_index = combined.len();
365
+ combined.push(row);
366
+ written += 1;
367
+ }
368
+ if limit_rows > 0 && combined.len() >= limit_rows {
369
+ break;
370
+ }
371
+ }
372
+ summaries.push(json!({
373
+ "path": path,
374
+ "samples": samples,
375
+ "repeat": repeat,
376
+ "effective_samples": samples * repeat,
377
+ "written_rows": written,
378
+ }));
379
+ if limit_rows > 0 && combined.len() >= limit_rows {
380
+ break;
381
+ }
382
+ }
383
+ Ok((combined, summaries))
384
+ }
385
+
386
+ fn load_rows(path: &Path) -> Result<Vec<SourceRow>> {
387
  let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?;
388
  let reader = BufReader::new(file);
389
  let mut rows = Vec::new();
390
  for (idx, line) in reader.lines().enumerate() {
 
 
 
391
  let raw_line = line.with_context(|| format!("failed reading line {}", idx + 1))?;
392
  if raw_line.trim().is_empty() {
393
  continue;
tools/encoded_dataset_cache/src/{bin/regex_benchmark.rs → regex_benchmark.rs} RENAMED
@@ -1,5 +1,4 @@
1
  use anyhow::{ensure, Context, Result};
2
- use clap::Parser;
3
  use fancy_regex::Regex as FancyRegex;
4
  use regex::Regex;
5
  use serde_json::Value;
@@ -24,37 +23,21 @@ const CJK_MARKER_PATTERN: &str = r"(?:[一二三四五六七八九十兩两貳
24
  const SPECIAL_CONTEXT_PREFIX_PATTERN: &str =
25
  r"(?i)^(?:[\[\(【《]\s*(?:menu|menus|bdmenu|ncop|nced|op|ed|ova|oad|sp)\s*[\]\)】》]\s*){0,2}";
26
 
27
- #[derive(Parser, Debug)]
28
- #[command(
29
- about = "Compare regex vs fancy-regex workload costs for AniFileBERT cache preprocessing"
30
- )]
31
- struct Args {
32
- #[arg(long)]
33
- input: PathBuf,
34
 
35
- #[arg(long, default_value_t = 0)]
36
- limit_rows: usize,
37
-
38
- #[arg(long, default_value_t = 3)]
39
- repeat: usize,
40
- }
41
-
42
- fn main() -> Result<()> {
43
- let args = Args::parse();
44
- ensure!(args.repeat > 0, "--repeat must be greater than 0");
45
-
46
- let filenames = load_filenames(&args.input, args.limit_rows)?;
47
  if filenames.is_empty() {
48
- anyhow::bail!("no filenames loaded from {}", args.input.display());
49
  }
50
 
51
  let selective = SelectivePatterns::new()?;
52
  let fancy_all = FancyAllPatterns::new()?;
53
 
54
  let (selective_seconds, selective_count) =
55
- time_repeated(args.repeat, || run_selective(&filenames, &selective))?;
56
  let (fancy_seconds, fancy_count) =
57
- time_repeated(args.repeat, || run_fancy_all(&filenames, &fancy_all))?;
58
  ensure!(
59
  selective_count == fancy_count,
60
  "selective and fancy-all match counts differ: selective={}, fancy_all={}",
@@ -71,7 +54,7 @@ fn main() -> Result<()> {
71
  "{}",
72
  serde_json::json!({
73
  "rows": filenames.len(),
74
- "repeat": args.repeat,
75
  "selective_seconds": selective_seconds,
76
  "fancy_all_seconds": fancy_seconds,
77
  "ratio": ratio,
 
1
  use anyhow::{ensure, Context, Result};
 
2
  use fancy_regex::Regex as FancyRegex;
3
  use regex::Regex;
4
  use serde_json::Value;
 
23
  const SPECIAL_CONTEXT_PREFIX_PATTERN: &str =
24
  r"(?i)^(?:[\[\(【《]\s*(?:menu|menus|bdmenu|ncop|nced|op|ed|ova|oad|sp)\s*[\]\)】》]\s*){0,2}";
25
 
26
+ pub fn run(input: &PathBuf, limit_rows: usize, repeat: usize) -> Result<()> {
27
+ ensure!(repeat > 0, "--regex-benchmark-repeat must be greater than 0");
 
 
 
 
 
28
 
29
+ let filenames = load_filenames(input, limit_rows)?;
 
 
 
 
 
 
 
 
 
 
 
30
  if filenames.is_empty() {
31
+ anyhow::bail!("no filenames loaded from {}", input.display());
32
  }
33
 
34
  let selective = SelectivePatterns::new()?;
35
  let fancy_all = FancyAllPatterns::new()?;
36
 
37
  let (selective_seconds, selective_count) =
38
+ time_repeated(repeat, || run_selective(&filenames, &selective))?;
39
  let (fancy_seconds, fancy_count) =
40
+ time_repeated(repeat, || run_fancy_all(&filenames, &fancy_all))?;
41
  ensure!(
42
  selective_count == fancy_count,
43
  "selective and fancy-all match counts differ: selective={}, fancy_all={}",
 
54
  "{}",
55
  serde_json::json!({
56
  "rows": filenames.len(),
57
+ "repeat": repeat,
58
  "selective_seconds": selective_seconds,
59
  "fancy_all_seconds": fancy_seconds,
60
  "ratio": ratio,
tools/train_schema_v2_synthetic.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ r"""Local schema v2 synthetic-augmentation training runner.
3
+
4
+ This wrapper keeps the training structure reproducible:
5
+
6
+ 1. Build a combined Rust encoded cache from the hard-focus JSONL plus synthetic
7
+ augmentation JSONL.
8
+ 2. Train with ``anifilebert.train --encoded-cache-dir`` so Python training never
9
+ has to re-split raw mixed JSONL in a non-comparable way.
10
+
11
+ Typical usage from the repo root on the local Windows GPU machine:
12
+
13
+ .\.venv\Scripts\python.exe -m tools.train_schema_v2_synthetic
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import datetime as dt
20
+ import json
21
+ from pathlib import Path
22
+ import shlex
23
+ import shutil
24
+ import subprocess
25
+ import sys
26
+ from typing import Any, Sequence
27
+
28
+
29
+ def utc_now() -> str:
30
+ return dt.datetime.now(dt.timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
31
+
32
+
33
+ def command_text(args: Sequence[Any]) -> str:
34
+ return " ".join(shlex.quote(str(arg)) for arg in args)
35
+
36
+
37
+ def run(args: Sequence[Any], *, dry_run: bool, command_log: list[dict[str, Any]]) -> None:
38
+ entry: dict[str, Any] = {
39
+ "cmd": command_text(args),
40
+ "started_at": utc_now(),
41
+ "dry_run": dry_run,
42
+ }
43
+ command_log.append(entry)
44
+ print(f"\n$ {entry['cmd']}")
45
+ if dry_run:
46
+ entry["returncode"] = 0
47
+ entry["finished_at"] = utc_now()
48
+ return
49
+ proc = subprocess.Popen(
50
+ list(map(str, args)),
51
+ stdout=subprocess.PIPE,
52
+ stderr=subprocess.STDOUT,
53
+ text=True,
54
+ encoding="utf-8",
55
+ errors="replace",
56
+ bufsize=1,
57
+ )
58
+ assert proc.stdout is not None
59
+ for line in proc.stdout:
60
+ print(line, end="")
61
+ proc.wait()
62
+ entry["returncode"] = proc.returncode
63
+ entry["finished_at"] = utc_now()
64
+ if proc.returncode != 0:
65
+ raise RuntimeError(f"Command failed with exit code {proc.returncode}: {entry['cmd']}")
66
+
67
+
68
+ def parse_args() -> argparse.Namespace:
69
+ parser = argparse.ArgumentParser(description="Train schema v2 hard-focus + synthetic augmentation with Rust cache")
70
+ parser.add_argument("--primary-data-file", default="data/schema_v2_hard_focus_char_seed63.jsonl")
71
+ parser.add_argument("--synthetic-data-file", default="data/schema_v2_synthetic_aug.jsonl")
72
+ parser.add_argument("--synthetic-repeat", type=int, default=3)
73
+ parser.add_argument("--vocab-file", default="datasets/AnimeName/vocab.char.json")
74
+ parser.add_argument("--label-schema-file", default="label_schema.json")
75
+ parser.add_argument("--encoded-cache-dir", default="data/encoded_cache/schema_v2_hard_focus_seed63_synth_pathleaf_repeat3")
76
+ parser.add_argument("--save-dir", default="checkpoints/schema-v2-best-hardfocus-synth-pathleaf-cache")
77
+ parser.add_argument("--init-model-dir", default="checkpoints/ablation-schema-v2-hardfocus-cache-repaired-from-baseline-seed62-10epoch-rerun/final")
78
+ parser.add_argument("--case-eval-output", default="reports/schema_v2_best_hardfocus_synth_pathleaf_cache_case_metrics.json")
79
+ parser.add_argument("--experiment-name", default="schema-v2-best-hardfocus-synth-pathleaf-cache")
80
+ parser.add_argument("--max-length", type=int, default=128)
81
+ parser.add_argument("--train-split", type=float, default=0.995)
82
+ parser.add_argument("--seed", type=int, default=63)
83
+ parser.add_argument("--shard-size", type=int, default=25000)
84
+ parser.add_argument("--threads", type=int, default=16)
85
+ parser.add_argument("--epochs", type=float, default=2)
86
+ parser.add_argument("--batch-size", type=int, default=512)
87
+ parser.add_argument("--learning-rate", type=float, default=0.00004)
88
+ parser.add_argument("--warmup-steps", type=int, default=120)
89
+ parser.add_argument("--checkpoint-steps", type=int, default=1000)
90
+ parser.add_argument("--save-total-limit", type=int, default=3)
91
+ parser.add_argument("--parse-eval-limit", type=int, default=2048)
92
+ parser.add_argument("--case-eval-file", default="data/parser_regression_cases.json")
93
+ parser.add_argument("--force-cache", action="store_true", help="Delete and rebuild the encoded cache even if manifest exists")
94
+ parser.add_argument("--skip-cache", action="store_true", help="Reuse the existing encoded cache")
95
+ parser.add_argument("--dry-run", action="store_true")
96
+ return parser.parse_args()
97
+
98
+
99
+ def main() -> None:
100
+ args = parse_args()
101
+ command_log: list[dict[str, Any]] = []
102
+ cache_dir = Path(args.encoded_cache_dir)
103
+ manifest_path = cache_dir / "manifest.json"
104
+
105
+ if args.force_cache and cache_dir.exists():
106
+ print(f"Removing existing cache: {cache_dir}")
107
+ if not args.dry_run:
108
+ shutil.rmtree(cache_dir)
109
+
110
+ if not args.skip_cache and not manifest_path.exists():
111
+ cache_cmd = [
112
+ "cargo", "run", "--release",
113
+ "--manifest-path", "tools/encoded_dataset_cache/Cargo.toml",
114
+ "--",
115
+ "--input", args.primary_data_file,
116
+ "--input", args.synthetic_data_file,
117
+ "--input-repeat", "1",
118
+ "--input-repeat", str(max(1, args.synthetic_repeat)),
119
+ "--vocab-file", args.vocab_file,
120
+ "--label-schema-file", args.label_schema_file,
121
+ "--output-dir", args.encoded_cache_dir,
122
+ "--max-length", str(args.max_length),
123
+ "--train-split", str(args.train_split),
124
+ "--seed", str(args.seed),
125
+ "--shard-size", str(args.shard_size),
126
+ "--threads", str(args.threads),
127
+ ]
128
+ run(cache_cmd, dry_run=args.dry_run, command_log=command_log)
129
+ else:
130
+ print(f"Using existing encoded cache: {cache_dir}")
131
+
132
+ train_cmd = [
133
+ sys.executable, "-m", "anifilebert.train",
134
+ "--tokenizer", "char",
135
+ "--data-file", args.primary_data_file,
136
+ "--vocab-file", args.vocab_file,
137
+ "--encoded-cache-dir", args.encoded_cache_dir,
138
+ "--save-dir", args.save_dir,
139
+ "--init-model-dir", args.init_model_dir,
140
+ "--epochs", str(args.epochs),
141
+ "--batch-size", str(args.batch_size),
142
+ "--learning-rate", str(args.learning_rate),
143
+ "--warmup-steps", str(args.warmup_steps),
144
+ "--max-seq-length", str(args.max_length),
145
+ "--train-split", str(args.train_split),
146
+ "--num-workers", "0",
147
+ "--checkpoint-steps", str(args.checkpoint_steps),
148
+ "--save-total-limit", str(args.save_total_limit),
149
+ "--no-periodic-eval",
150
+ "--bf16",
151
+ "--auto-find-batch-size",
152
+ "--parse-eval-limit", str(args.parse_eval_limit),
153
+ "--case-eval-file", args.case_eval_file,
154
+ "--case-eval-output", args.case_eval_output,
155
+ "--seed", str(args.seed),
156
+ "--experiment-name", args.experiment_name,
157
+ ]
158
+ run(train_cmd, dry_run=args.dry_run, command_log=command_log)
159
+
160
+ run_manifest = {
161
+ "name": args.experiment_name,
162
+ "started_at": command_log[0]["started_at"] if command_log else utc_now(),
163
+ "finished_at": utc_now(),
164
+ "primary_data_file": args.primary_data_file,
165
+ "synthetic_data_file": args.synthetic_data_file,
166
+ "synthetic_repeat": args.synthetic_repeat,
167
+ "encoded_cache_dir": args.encoded_cache_dir,
168
+ "save_dir": args.save_dir,
169
+ "init_model_dir": args.init_model_dir,
170
+ "commands": command_log,
171
+ }
172
+ manifest_output = Path(args.save_dir) / "schema_v2_synthetic_train_manifest.json"
173
+ print(f"Writing run manifest: {manifest_output}")
174
+ if not args.dry_run:
175
+ manifest_output.parent.mkdir(parents=True, exist_ok=True)
176
+ manifest_output.write_text(json.dumps(run_manifest, ensure_ascii=False, indent=2), encoding="utf-8")
177
+
178
+
179
+ if __name__ == "__main__":
180
+ main()