Token Classification
Transformers
ONNX
Safetensors
English
Japanese
Chinese
bert
anime
filename-parsing
Eval Results (legacy)
Instructions to use ModerRAS/AniFileBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ModerRAS/AniFileBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="ModerRAS/AniFileBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("ModerRAS/AniFileBERT") model = AutoModelForTokenClassification.from_pretrained("ModerRAS/AniFileBERT") - Notebooks
- Google Colab
- Kaggle
Add Rust cached synthetic training runner
Browse files
AGENTS.md
CHANGED
|
@@ -117,29 +117,31 @@ cargo run --release --manifest-path tools\schema_v2_synthetic_augment\Cargo.toml
|
|
| 117 |
```
|
| 118 |
|
| 119 |
Preferred synthetic follow-up training is a second stage from the best repaired
|
| 120 |
-
hard-focus checkpoint, not a replacement for hard-focus.
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
```powershell
|
| 126 |
-
.\.venv\Scripts\python.exe -m
|
| 127 |
-
--data-file data\schema_v2_hard_focus_char_seed63.jsonl `
|
| 128 |
-
--extra-data-file data\schema_v2_synthetic_aug.jsonl `
|
| 129 |
-
--extra-data-repeat 3 `
|
| 130 |
-
--vocab-file datasets\AnimeName\vocab.char.json `
|
| 131 |
-
--save-dir checkpoints\schema-v2-best-hardfocus-synth-pathleaf `
|
| 132 |
-
--init-model-dir checkpoints\ablation-schema-v2-hardfocus-cache-repaired-from-baseline-seed62-10epoch-rerun\final `
|
| 133 |
-
--epochs 2 --batch-size 512 --learning-rate 0.00004 --warmup-steps 120 `
|
| 134 |
-
--max-seq-length 128 --train-split 0.995 --num-workers 0 `
|
| 135 |
-
--checkpoint-steps 1000 --save-total-limit 3 --no-periodic-eval `
|
| 136 |
-
--bf16 --auto-find-batch-size `
|
| 137 |
-
--parse-eval-limit 2048 `
|
| 138 |
-
--case-eval-file data\parser_regression_cases.json `
|
| 139 |
-
--case-eval-output reports\schema_v2_best_hardfocus_synth_pathleaf_case_metrics.json `
|
| 140 |
-
--seed 63 --experiment-name schema-v2-best-hardfocus-synth-pathleaf
|
| 141 |
```
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
Export for Android:
|
| 144 |
|
| 145 |
```bash
|
|
|
|
| 117 |
```
|
| 118 |
|
| 119 |
Preferred synthetic follow-up training is a second stage from the best repaired
|
| 120 |
+
hard-focus checkpoint, not a replacement for hard-focus. Keep this path
|
| 121 |
+
Rust-cache-first: build one combined encoded cache from hard-focus JSONL plus
|
| 122 |
+
synthetic JSONL, then train from that cache. Do not pass `--extra-data-file` to
|
| 123 |
+
`anifilebert.train` together with `--encoded-cache-dir`.
|
| 124 |
+
|
| 125 |
+
Use the local wrapper, which calls Rust `tools/encoded_dataset_cache` with
|
| 126 |
+
multiple `--input` values and then launches `anifilebert.train` against the
|
| 127 |
+
combined cache:
|
| 128 |
|
| 129 |
```powershell
|
| 130 |
+
.\.venv\Scripts\python.exe -m tools.train_schema_v2_synthetic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
```
|
| 132 |
|
| 133 |
+
The wrapper defaults to:
|
| 134 |
+
|
| 135 |
+
- primary data: `data\schema_v2_hard_focus_char_seed63.jsonl`
|
| 136 |
+
- synthetic data: `data\schema_v2_synthetic_aug.jsonl`
|
| 137 |
+
- synthetic repeat: `3`
|
| 138 |
+
- encoded cache: `data\encoded_cache\schema_v2_hard_focus_seed63_synth_pathleaf_repeat3`
|
| 139 |
+
- init checkpoint: `checkpoints\ablation-schema-v2-hardfocus-cache-repaired-from-baseline-seed62-10epoch-rerun\final`
|
| 140 |
+
- output checkpoint: `checkpoints\schema-v2-best-hardfocus-synth-pathleaf-cache`
|
| 141 |
+
|
| 142 |
+
Use `--force-cache` to rebuild the combined cache after changing either JSONL,
|
| 143 |
+
vocab, label schema, max length, split ratio, seed, or repeat count.
|
| 144 |
+
|
| 145 |
Export for Android:
|
| 146 |
|
| 147 |
```bash
|
tools/encoded_dataset_cache/README.md
CHANGED
|
@@ -21,6 +21,26 @@ cargo run --release --manifest-path tools\encoded_dataset_cache\Cargo.toml -- `
|
|
| 21 |
--threads 16
|
| 22 |
```
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
Use the cache in training:
|
| 25 |
|
| 26 |
```powershell
|
|
|
|
| 21 |
--threads 16
|
| 22 |
```
|
| 23 |
|
| 24 |
+
Multiple JSONL inputs can be encoded into one deterministic train/eval split.
|
| 25 |
+
Pass `--input-repeat` once per `--input` when an augmentation source should be
|
| 26 |
+
upweighted:
|
| 27 |
+
|
| 28 |
+
```powershell
|
| 29 |
+
cargo run --release --manifest-path tools\encoded_dataset_cache\Cargo.toml -- `
|
| 30 |
+
--input data\schema_v2_hard_focus_char_seed63.jsonl `
|
| 31 |
+
--input data\schema_v2_synthetic_aug.jsonl `
|
| 32 |
+
--input-repeat 1 `
|
| 33 |
+
--input-repeat 3 `
|
| 34 |
+
--vocab-file datasets\AnimeName\vocab.char.json `
|
| 35 |
+
--label-schema-file label_schema.json `
|
| 36 |
+
--output-dir data\encoded_cache\schema_v2_hard_focus_seed63_synth_pathleaf_repeat3 `
|
| 37 |
+
--max-length 128 `
|
| 38 |
+
--train-split 0.995 `
|
| 39 |
+
--seed 63 `
|
| 40 |
+
--shard-size 25000 `
|
| 41 |
+
--threads 16
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
Use the cache in training:
|
| 45 |
|
| 46 |
```powershell
|
tools/encoded_dataset_cache/src/main.rs
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
use anyhow::{bail, Context, Result};
|
| 2 |
use clap::Parser;
|
| 3 |
use fancy_regex::Regex as FancyRegex;
|
|
@@ -78,13 +80,16 @@ const SEPARATOR_CHARS: &[char] = &[' ', '\t', '-', '_', '.', '|', '~', '~'];
|
|
| 78 |
)]
|
| 79 |
struct Args {
|
| 80 |
#[arg(long)]
|
| 81 |
-
input: PathBuf,
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
#[arg(long)]
|
| 84 |
-
vocab_file: PathBuf,
|
| 85 |
|
| 86 |
#[arg(long)]
|
| 87 |
-
output_dir: PathBuf,
|
| 88 |
|
| 89 |
#[arg(long, default_value = "label_schema.json")]
|
| 90 |
label_schema_file: PathBuf,
|
|
@@ -109,6 +114,15 @@ struct Args {
|
|
| 109 |
|
| 110 |
#[arg(long, default_value_t = 0)]
|
| 111 |
threads: usize,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
}
|
| 113 |
|
| 114 |
#[derive(Debug, Deserialize)]
|
|
@@ -160,6 +174,24 @@ struct SplitSummary {
|
|
| 160 |
|
| 161 |
fn main() -> Result<()> {
|
| 162 |
let args = Args::parse();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
if args.max_length < 4 {
|
| 164 |
bail!("--max-length must be at least 4");
|
| 165 |
}
|
|
@@ -177,9 +209,10 @@ fn main() -> Result<()> {
|
|
| 177 |
}
|
| 178 |
|
| 179 |
let started = Instant::now();
|
| 180 |
-
let vocab = load_vocab(
|
| 181 |
let label_ids = load_label_ids(&args.label_schema_file)?;
|
| 182 |
-
let
|
|
|
|
| 183 |
if rows.len() < 2 {
|
| 184 |
bail!("need at least two rows to build train/eval cache");
|
| 185 |
}
|
|
@@ -192,10 +225,10 @@ fn main() -> Result<()> {
|
|
| 192 |
let split_idx = split_idx.max(1).min(rows.len() - 1);
|
| 193 |
let (train_rows, eval_rows) = rows.split_at(split_idx);
|
| 194 |
|
| 195 |
-
fs::create_dir_all(
|
| 196 |
format!(
|
| 197 |
"failed to create output directory {}",
|
| 198 |
-
|
| 199 |
)
|
| 200 |
})?;
|
| 201 |
|
|
@@ -207,25 +240,26 @@ fn main() -> Result<()> {
|
|
| 207 |
let train_summary = write_split(
|
| 208 |
"train",
|
| 209 |
train_rows,
|
| 210 |
-
|
| 211 |
&context,
|
| 212 |
args.shard_size,
|
| 213 |
)?;
|
| 214 |
let eval_summary = write_split(
|
| 215 |
"eval",
|
| 216 |
eval_rows,
|
| 217 |
-
|
| 218 |
&context,
|
| 219 |
args.shard_size,
|
| 220 |
)?;
|
| 221 |
-
write_eval_records(eval_rows, &
|
| 222 |
|
| 223 |
let manifest = json!({
|
| 224 |
"format": "anifilebert.encoded_dataset_cache.v1",
|
| 225 |
-
"input": args.input,
|
| 226 |
-
"
|
|
|
|
| 227 |
"label_schema_file": args.label_schema_file,
|
| 228 |
-
"output_dir":
|
| 229 |
"max_length": args.max_length,
|
| 230 |
"shard_size": args.shard_size,
|
| 231 |
"limit_rows": args.limit_rows,
|
|
@@ -238,7 +272,7 @@ fn main() -> Result<()> {
|
|
| 238 |
"eval_records": "eval_records.jsonl",
|
| 239 |
"elapsed_seconds": started.elapsed().as_secs_f64(),
|
| 240 |
});
|
| 241 |
-
let manifest_path =
|
| 242 |
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
|
| 243 |
.with_context(|| format!("failed to write {}", manifest_path.display()))?;
|
| 244 |
println!("{}", serde_json::to_string_pretty(&manifest)?);
|
|
@@ -293,14 +327,67 @@ fn load_label_ids(path: &Path) -> Result<HashMap<String, i16>> {
|
|
| 293 |
.collect())
|
| 294 |
}
|
| 295 |
|
| 296 |
-
fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?;
|
| 298 |
let reader = BufReader::new(file);
|
| 299 |
let mut rows = Vec::new();
|
| 300 |
for (idx, line) in reader.lines().enumerate() {
|
| 301 |
-
if limit_rows > 0 && rows.len() >= limit_rows {
|
| 302 |
-
break;
|
| 303 |
-
}
|
| 304 |
let raw_line = line.with_context(|| format!("failed reading line {}", idx + 1))?;
|
| 305 |
if raw_line.trim().is_empty() {
|
| 306 |
continue;
|
|
|
|
| 1 |
+
mod regex_benchmark;
|
| 2 |
+
|
| 3 |
use anyhow::{bail, Context, Result};
|
| 4 |
use clap::Parser;
|
| 5 |
use fancy_regex::Regex as FancyRegex;
|
|
|
|
| 80 |
)]
|
| 81 |
struct Args {
|
| 82 |
#[arg(long)]
|
| 83 |
+
input: Vec<PathBuf>,
|
| 84 |
+
|
| 85 |
+
#[arg(long, value_name = "N")]
|
| 86 |
+
input_repeat: Vec<usize>,
|
| 87 |
|
| 88 |
#[arg(long)]
|
| 89 |
+
vocab_file: Option<PathBuf>,
|
| 90 |
|
| 91 |
#[arg(long)]
|
| 92 |
+
output_dir: Option<PathBuf>,
|
| 93 |
|
| 94 |
#[arg(long, default_value = "label_schema.json")]
|
| 95 |
label_schema_file: PathBuf,
|
|
|
|
| 114 |
|
| 115 |
#[arg(long, default_value_t = 0)]
|
| 116 |
threads: usize,
|
| 117 |
+
|
| 118 |
+
#[arg(long)]
|
| 119 |
+
regex_benchmark_input: Option<PathBuf>,
|
| 120 |
+
|
| 121 |
+
#[arg(long, default_value_t = 0)]
|
| 122 |
+
regex_benchmark_limit_rows: usize,
|
| 123 |
+
|
| 124 |
+
#[arg(long, default_value_t = 3)]
|
| 125 |
+
regex_benchmark_repeat: usize,
|
| 126 |
}
|
| 127 |
|
| 128 |
#[derive(Debug, Deserialize)]
|
|
|
|
| 174 |
|
| 175 |
fn main() -> Result<()> {
|
| 176 |
let args = Args::parse();
|
| 177 |
+
if let Some(input) = &args.regex_benchmark_input {
|
| 178 |
+
return regex_benchmark::run(
|
| 179 |
+
input,
|
| 180 |
+
args.regex_benchmark_limit_rows,
|
| 181 |
+
args.regex_benchmark_repeat,
|
| 182 |
+
);
|
| 183 |
+
}
|
| 184 |
+
if args.input.is_empty() {
|
| 185 |
+
bail!("at least one --input is required");
|
| 186 |
+
}
|
| 187 |
+
let vocab_file = args
|
| 188 |
+
.vocab_file
|
| 189 |
+
.as_ref()
|
| 190 |
+
.context("--vocab-file is required when building an encoded cache")?;
|
| 191 |
+
let output_dir = args
|
| 192 |
+
.output_dir
|
| 193 |
+
.as_ref()
|
| 194 |
+
.context("--output-dir is required when building an encoded cache")?;
|
| 195 |
if args.max_length < 4 {
|
| 196 |
bail!("--max-length must be at least 4");
|
| 197 |
}
|
|
|
|
| 209 |
}
|
| 210 |
|
| 211 |
let started = Instant::now();
|
| 212 |
+
let vocab = load_vocab(vocab_file)?;
|
| 213 |
let label_ids = load_label_ids(&args.label_schema_file)?;
|
| 214 |
+
let input_repeats = resolve_input_repeats(&args.input, &args.input_repeat)?;
|
| 215 |
+
let (mut rows, input_summaries) = load_input_rows(&args.input, &input_repeats, args.limit_rows)?;
|
| 216 |
if rows.len() < 2 {
|
| 217 |
bail!("need at least two rows to build train/eval cache");
|
| 218 |
}
|
|
|
|
| 225 |
let split_idx = split_idx.max(1).min(rows.len() - 1);
|
| 226 |
let (train_rows, eval_rows) = rows.split_at(split_idx);
|
| 227 |
|
| 228 |
+
fs::create_dir_all(output_dir).with_context(|| {
|
| 229 |
format!(
|
| 230 |
"failed to create output directory {}",
|
| 231 |
+
output_dir.display()
|
| 232 |
)
|
| 233 |
})?;
|
| 234 |
|
|
|
|
| 240 |
let train_summary = write_split(
|
| 241 |
"train",
|
| 242 |
train_rows,
|
| 243 |
+
output_dir,
|
| 244 |
&context,
|
| 245 |
args.shard_size,
|
| 246 |
)?;
|
| 247 |
let eval_summary = write_split(
|
| 248 |
"eval",
|
| 249 |
eval_rows,
|
| 250 |
+
output_dir,
|
| 251 |
&context,
|
| 252 |
args.shard_size,
|
| 253 |
)?;
|
| 254 |
+
write_eval_records(eval_rows, &output_dir.join("eval_records.jsonl"))?;
|
| 255 |
|
| 256 |
let manifest = json!({
|
| 257 |
"format": "anifilebert.encoded_dataset_cache.v1",
|
| 258 |
+
"input": args.input.first(),
|
| 259 |
+
"inputs": input_summaries,
|
| 260 |
+
"vocab_file": vocab_file,
|
| 261 |
"label_schema_file": args.label_schema_file,
|
| 262 |
+
"output_dir": output_dir,
|
| 263 |
"max_length": args.max_length,
|
| 264 |
"shard_size": args.shard_size,
|
| 265 |
"limit_rows": args.limit_rows,
|
|
|
|
| 272 |
"eval_records": "eval_records.jsonl",
|
| 273 |
"elapsed_seconds": started.elapsed().as_secs_f64(),
|
| 274 |
});
|
| 275 |
+
let manifest_path = output_dir.join("manifest.json");
|
| 276 |
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
|
| 277 |
.with_context(|| format!("failed to write {}", manifest_path.display()))?;
|
| 278 |
println!("{}", serde_json::to_string_pretty(&manifest)?);
|
|
|
|
| 327 |
.collect())
|
| 328 |
}
|
| 329 |
|
| 330 |
+
fn resolve_input_repeats(inputs: &[PathBuf], repeats: &[usize]) -> Result<Vec<usize>> {
|
| 331 |
+
if repeats.is_empty() {
|
| 332 |
+
return Ok(vec![1; inputs.len()]);
|
| 333 |
+
}
|
| 334 |
+
if repeats.len() == 1 {
|
| 335 |
+
return Ok(vec![repeats[0].max(1); inputs.len()]);
|
| 336 |
+
}
|
| 337 |
+
if repeats.len() != inputs.len() {
|
| 338 |
+
bail!(
|
| 339 |
+
"--input-repeat must be omitted, passed once for all inputs, or passed once per --input ({} inputs, {} repeats)",
|
| 340 |
+
inputs.len(),
|
| 341 |
+
repeats.len()
|
| 342 |
+
);
|
| 343 |
+
}
|
| 344 |
+
Ok(repeats.iter().map(|repeat| (*repeat).max(1)).collect())
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
fn load_input_rows(
|
| 348 |
+
inputs: &[PathBuf],
|
| 349 |
+
repeats: &[usize],
|
| 350 |
+
limit_rows: usize,
|
| 351 |
+
) -> Result<(Vec<SourceRow>, Vec<Value>)> {
|
| 352 |
+
let mut combined = Vec::new();
|
| 353 |
+
let mut summaries = Vec::new();
|
| 354 |
+
for (path, repeat) in inputs.iter().zip(repeats.iter()) {
|
| 355 |
+
let rows = load_rows(path)?;
|
| 356 |
+
let samples = rows.len();
|
| 357 |
+
let mut written = 0usize;
|
| 358 |
+
for _ in 0..*repeat {
|
| 359 |
+
for row in &rows {
|
| 360 |
+
if limit_rows > 0 && combined.len() >= limit_rows {
|
| 361 |
+
break;
|
| 362 |
+
}
|
| 363 |
+
let mut row = row.clone();
|
| 364 |
+
row.row_index = combined.len();
|
| 365 |
+
combined.push(row);
|
| 366 |
+
written += 1;
|
| 367 |
+
}
|
| 368 |
+
if limit_rows > 0 && combined.len() >= limit_rows {
|
| 369 |
+
break;
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
summaries.push(json!({
|
| 373 |
+
"path": path,
|
| 374 |
+
"samples": samples,
|
| 375 |
+
"repeat": repeat,
|
| 376 |
+
"effective_samples": samples * repeat,
|
| 377 |
+
"written_rows": written,
|
| 378 |
+
}));
|
| 379 |
+
if limit_rows > 0 && combined.len() >= limit_rows {
|
| 380 |
+
break;
|
| 381 |
+
}
|
| 382 |
+
}
|
| 383 |
+
Ok((combined, summaries))
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
fn load_rows(path: &Path) -> Result<Vec<SourceRow>> {
|
| 387 |
let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?;
|
| 388 |
let reader = BufReader::new(file);
|
| 389 |
let mut rows = Vec::new();
|
| 390 |
for (idx, line) in reader.lines().enumerate() {
|
|
|
|
|
|
|
|
|
|
| 391 |
let raw_line = line.with_context(|| format!("failed reading line {}", idx + 1))?;
|
| 392 |
if raw_line.trim().is_empty() {
|
| 393 |
continue;
|
tools/encoded_dataset_cache/src/{bin/regex_benchmark.rs → regex_benchmark.rs}
RENAMED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
use anyhow::{ensure, Context, Result};
|
| 2 |
-
use clap::Parser;
|
| 3 |
use fancy_regex::Regex as FancyRegex;
|
| 4 |
use regex::Regex;
|
| 5 |
use serde_json::Value;
|
|
@@ -24,37 +23,21 @@ const CJK_MARKER_PATTERN: &str = r"(?:[一二三四五六七八九十兩两貳
|
|
| 24 |
const SPECIAL_CONTEXT_PREFIX_PATTERN: &str =
|
| 25 |
r"(?i)^(?:[\[\(【《]\s*(?:menu|menus|bdmenu|ncop|nced|op|ed|ova|oad|sp)\s*[\]\)】》]\s*){0,2}";
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
about = "Compare regex vs fancy-regex workload costs for AniFileBERT cache preprocessing"
|
| 30 |
-
)]
|
| 31 |
-
struct Args {
|
| 32 |
-
#[arg(long)]
|
| 33 |
-
input: PathBuf,
|
| 34 |
|
| 35 |
-
|
| 36 |
-
limit_rows: usize,
|
| 37 |
-
|
| 38 |
-
#[arg(long, default_value_t = 3)]
|
| 39 |
-
repeat: usize,
|
| 40 |
-
}
|
| 41 |
-
|
| 42 |
-
fn main() -> Result<()> {
|
| 43 |
-
let args = Args::parse();
|
| 44 |
-
ensure!(args.repeat > 0, "--repeat must be greater than 0");
|
| 45 |
-
|
| 46 |
-
let filenames = load_filenames(&args.input, args.limit_rows)?;
|
| 47 |
if filenames.is_empty() {
|
| 48 |
-
anyhow::bail!("no filenames loaded from {}",
|
| 49 |
}
|
| 50 |
|
| 51 |
let selective = SelectivePatterns::new()?;
|
| 52 |
let fancy_all = FancyAllPatterns::new()?;
|
| 53 |
|
| 54 |
let (selective_seconds, selective_count) =
|
| 55 |
-
time_repeated(
|
| 56 |
let (fancy_seconds, fancy_count) =
|
| 57 |
-
time_repeated(
|
| 58 |
ensure!(
|
| 59 |
selective_count == fancy_count,
|
| 60 |
"selective and fancy-all match counts differ: selective={}, fancy_all={}",
|
|
@@ -71,7 +54,7 @@ fn main() -> Result<()> {
|
|
| 71 |
"{}",
|
| 72 |
serde_json::json!({
|
| 73 |
"rows": filenames.len(),
|
| 74 |
-
"repeat":
|
| 75 |
"selective_seconds": selective_seconds,
|
| 76 |
"fancy_all_seconds": fancy_seconds,
|
| 77 |
"ratio": ratio,
|
|
|
|
| 1 |
use anyhow::{ensure, Context, Result};
|
|
|
|
| 2 |
use fancy_regex::Regex as FancyRegex;
|
| 3 |
use regex::Regex;
|
| 4 |
use serde_json::Value;
|
|
|
|
| 23 |
const SPECIAL_CONTEXT_PREFIX_PATTERN: &str =
|
| 24 |
r"(?i)^(?:[\[\(【《]\s*(?:menu|menus|bdmenu|ncop|nced|op|ed|ova|oad|sp)\s*[\]\)】》]\s*){0,2}";
|
| 25 |
|
| 26 |
+
pub fn run(input: &PathBuf, limit_rows: usize, repeat: usize) -> Result<()> {
|
| 27 |
+
ensure!(repeat > 0, "--regex-benchmark-repeat must be greater than 0");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
let filenames = load_filenames(input, limit_rows)?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if filenames.is_empty() {
|
| 31 |
+
anyhow::bail!("no filenames loaded from {}", input.display());
|
| 32 |
}
|
| 33 |
|
| 34 |
let selective = SelectivePatterns::new()?;
|
| 35 |
let fancy_all = FancyAllPatterns::new()?;
|
| 36 |
|
| 37 |
let (selective_seconds, selective_count) =
|
| 38 |
+
time_repeated(repeat, || run_selective(&filenames, &selective))?;
|
| 39 |
let (fancy_seconds, fancy_count) =
|
| 40 |
+
time_repeated(repeat, || run_fancy_all(&filenames, &fancy_all))?;
|
| 41 |
ensure!(
|
| 42 |
selective_count == fancy_count,
|
| 43 |
"selective and fancy-all match counts differ: selective={}, fancy_all={}",
|
|
|
|
| 54 |
"{}",
|
| 55 |
serde_json::json!({
|
| 56 |
"rows": filenames.len(),
|
| 57 |
+
"repeat": repeat,
|
| 58 |
"selective_seconds": selective_seconds,
|
| 59 |
"fancy_all_seconds": fancy_seconds,
|
| 60 |
"ratio": ratio,
|
tools/train_schema_v2_synthetic.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
r"""Local schema v2 synthetic-augmentation training runner.
|
| 3 |
+
|
| 4 |
+
This wrapper keeps the training structure reproducible:
|
| 5 |
+
|
| 6 |
+
1. Build a combined Rust encoded cache from the hard-focus JSONL plus synthetic
|
| 7 |
+
augmentation JSONL.
|
| 8 |
+
2. Train with ``anifilebert.train --encoded-cache-dir`` so Python training never
|
| 9 |
+
has to re-split raw mixed JSONL in a non-comparable way.
|
| 10 |
+
|
| 11 |
+
Typical usage from the repo root on the local Windows GPU machine:
|
| 12 |
+
|
| 13 |
+
.\.venv\Scripts\python.exe -m tools.train_schema_v2_synthetic
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import datetime as dt
|
| 20 |
+
import json
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
import shlex
|
| 23 |
+
import shutil
|
| 24 |
+
import subprocess
|
| 25 |
+
import sys
|
| 26 |
+
from typing import Any, Sequence
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def utc_now() -> str:
|
| 30 |
+
return dt.datetime.now(dt.timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def command_text(args: Sequence[Any]) -> str:
|
| 34 |
+
return " ".join(shlex.quote(str(arg)) for arg in args)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def run(args: Sequence[Any], *, dry_run: bool, command_log: list[dict[str, Any]]) -> None:
|
| 38 |
+
entry: dict[str, Any] = {
|
| 39 |
+
"cmd": command_text(args),
|
| 40 |
+
"started_at": utc_now(),
|
| 41 |
+
"dry_run": dry_run,
|
| 42 |
+
}
|
| 43 |
+
command_log.append(entry)
|
| 44 |
+
print(f"\n$ {entry['cmd']}")
|
| 45 |
+
if dry_run:
|
| 46 |
+
entry["returncode"] = 0
|
| 47 |
+
entry["finished_at"] = utc_now()
|
| 48 |
+
return
|
| 49 |
+
proc = subprocess.Popen(
|
| 50 |
+
list(map(str, args)),
|
| 51 |
+
stdout=subprocess.PIPE,
|
| 52 |
+
stderr=subprocess.STDOUT,
|
| 53 |
+
text=True,
|
| 54 |
+
encoding="utf-8",
|
| 55 |
+
errors="replace",
|
| 56 |
+
bufsize=1,
|
| 57 |
+
)
|
| 58 |
+
assert proc.stdout is not None
|
| 59 |
+
for line in proc.stdout:
|
| 60 |
+
print(line, end="")
|
| 61 |
+
proc.wait()
|
| 62 |
+
entry["returncode"] = proc.returncode
|
| 63 |
+
entry["finished_at"] = utc_now()
|
| 64 |
+
if proc.returncode != 0:
|
| 65 |
+
raise RuntimeError(f"Command failed with exit code {proc.returncode}: {entry['cmd']}")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def parse_args() -> argparse.Namespace:
|
| 69 |
+
parser = argparse.ArgumentParser(description="Train schema v2 hard-focus + synthetic augmentation with Rust cache")
|
| 70 |
+
parser.add_argument("--primary-data-file", default="data/schema_v2_hard_focus_char_seed63.jsonl")
|
| 71 |
+
parser.add_argument("--synthetic-data-file", default="data/schema_v2_synthetic_aug.jsonl")
|
| 72 |
+
parser.add_argument("--synthetic-repeat", type=int, default=3)
|
| 73 |
+
parser.add_argument("--vocab-file", default="datasets/AnimeName/vocab.char.json")
|
| 74 |
+
parser.add_argument("--label-schema-file", default="label_schema.json")
|
| 75 |
+
parser.add_argument("--encoded-cache-dir", default="data/encoded_cache/schema_v2_hard_focus_seed63_synth_pathleaf_repeat3")
|
| 76 |
+
parser.add_argument("--save-dir", default="checkpoints/schema-v2-best-hardfocus-synth-pathleaf-cache")
|
| 77 |
+
parser.add_argument("--init-model-dir", default="checkpoints/ablation-schema-v2-hardfocus-cache-repaired-from-baseline-seed62-10epoch-rerun/final")
|
| 78 |
+
parser.add_argument("--case-eval-output", default="reports/schema_v2_best_hardfocus_synth_pathleaf_cache_case_metrics.json")
|
| 79 |
+
parser.add_argument("--experiment-name", default="schema-v2-best-hardfocus-synth-pathleaf-cache")
|
| 80 |
+
parser.add_argument("--max-length", type=int, default=128)
|
| 81 |
+
parser.add_argument("--train-split", type=float, default=0.995)
|
| 82 |
+
parser.add_argument("--seed", type=int, default=63)
|
| 83 |
+
parser.add_argument("--shard-size", type=int, default=25000)
|
| 84 |
+
parser.add_argument("--threads", type=int, default=16)
|
| 85 |
+
parser.add_argument("--epochs", type=float, default=2)
|
| 86 |
+
parser.add_argument("--batch-size", type=int, default=512)
|
| 87 |
+
parser.add_argument("--learning-rate", type=float, default=0.00004)
|
| 88 |
+
parser.add_argument("--warmup-steps", type=int, default=120)
|
| 89 |
+
parser.add_argument("--checkpoint-steps", type=int, default=1000)
|
| 90 |
+
parser.add_argument("--save-total-limit", type=int, default=3)
|
| 91 |
+
parser.add_argument("--parse-eval-limit", type=int, default=2048)
|
| 92 |
+
parser.add_argument("--case-eval-file", default="data/parser_regression_cases.json")
|
| 93 |
+
parser.add_argument("--force-cache", action="store_true", help="Delete and rebuild the encoded cache even if manifest exists")
|
| 94 |
+
parser.add_argument("--skip-cache", action="store_true", help="Reuse the existing encoded cache")
|
| 95 |
+
parser.add_argument("--dry-run", action="store_true")
|
| 96 |
+
return parser.parse_args()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main() -> None:
|
| 100 |
+
args = parse_args()
|
| 101 |
+
command_log: list[dict[str, Any]] = []
|
| 102 |
+
cache_dir = Path(args.encoded_cache_dir)
|
| 103 |
+
manifest_path = cache_dir / "manifest.json"
|
| 104 |
+
|
| 105 |
+
if args.force_cache and cache_dir.exists():
|
| 106 |
+
print(f"Removing existing cache: {cache_dir}")
|
| 107 |
+
if not args.dry_run:
|
| 108 |
+
shutil.rmtree(cache_dir)
|
| 109 |
+
|
| 110 |
+
if not args.skip_cache and not manifest_path.exists():
|
| 111 |
+
cache_cmd = [
|
| 112 |
+
"cargo", "run", "--release",
|
| 113 |
+
"--manifest-path", "tools/encoded_dataset_cache/Cargo.toml",
|
| 114 |
+
"--",
|
| 115 |
+
"--input", args.primary_data_file,
|
| 116 |
+
"--input", args.synthetic_data_file,
|
| 117 |
+
"--input-repeat", "1",
|
| 118 |
+
"--input-repeat", str(max(1, args.synthetic_repeat)),
|
| 119 |
+
"--vocab-file", args.vocab_file,
|
| 120 |
+
"--label-schema-file", args.label_schema_file,
|
| 121 |
+
"--output-dir", args.encoded_cache_dir,
|
| 122 |
+
"--max-length", str(args.max_length),
|
| 123 |
+
"--train-split", str(args.train_split),
|
| 124 |
+
"--seed", str(args.seed),
|
| 125 |
+
"--shard-size", str(args.shard_size),
|
| 126 |
+
"--threads", str(args.threads),
|
| 127 |
+
]
|
| 128 |
+
run(cache_cmd, dry_run=args.dry_run, command_log=command_log)
|
| 129 |
+
else:
|
| 130 |
+
print(f"Using existing encoded cache: {cache_dir}")
|
| 131 |
+
|
| 132 |
+
train_cmd = [
|
| 133 |
+
sys.executable, "-m", "anifilebert.train",
|
| 134 |
+
"--tokenizer", "char",
|
| 135 |
+
"--data-file", args.primary_data_file,
|
| 136 |
+
"--vocab-file", args.vocab_file,
|
| 137 |
+
"--encoded-cache-dir", args.encoded_cache_dir,
|
| 138 |
+
"--save-dir", args.save_dir,
|
| 139 |
+
"--init-model-dir", args.init_model_dir,
|
| 140 |
+
"--epochs", str(args.epochs),
|
| 141 |
+
"--batch-size", str(args.batch_size),
|
| 142 |
+
"--learning-rate", str(args.learning_rate),
|
| 143 |
+
"--warmup-steps", str(args.warmup_steps),
|
| 144 |
+
"--max-seq-length", str(args.max_length),
|
| 145 |
+
"--train-split", str(args.train_split),
|
| 146 |
+
"--num-workers", "0",
|
| 147 |
+
"--checkpoint-steps", str(args.checkpoint_steps),
|
| 148 |
+
"--save-total-limit", str(args.save_total_limit),
|
| 149 |
+
"--no-periodic-eval",
|
| 150 |
+
"--bf16",
|
| 151 |
+
"--auto-find-batch-size",
|
| 152 |
+
"--parse-eval-limit", str(args.parse_eval_limit),
|
| 153 |
+
"--case-eval-file", args.case_eval_file,
|
| 154 |
+
"--case-eval-output", args.case_eval_output,
|
| 155 |
+
"--seed", str(args.seed),
|
| 156 |
+
"--experiment-name", args.experiment_name,
|
| 157 |
+
]
|
| 158 |
+
run(train_cmd, dry_run=args.dry_run, command_log=command_log)
|
| 159 |
+
|
| 160 |
+
run_manifest = {
|
| 161 |
+
"name": args.experiment_name,
|
| 162 |
+
"started_at": command_log[0]["started_at"] if command_log else utc_now(),
|
| 163 |
+
"finished_at": utc_now(),
|
| 164 |
+
"primary_data_file": args.primary_data_file,
|
| 165 |
+
"synthetic_data_file": args.synthetic_data_file,
|
| 166 |
+
"synthetic_repeat": args.synthetic_repeat,
|
| 167 |
+
"encoded_cache_dir": args.encoded_cache_dir,
|
| 168 |
+
"save_dir": args.save_dir,
|
| 169 |
+
"init_model_dir": args.init_model_dir,
|
| 170 |
+
"commands": command_log,
|
| 171 |
+
}
|
| 172 |
+
manifest_output = Path(args.save_dir) / "schema_v2_synthetic_train_manifest.json"
|
| 173 |
+
print(f"Writing run manifest: {manifest_output}")
|
| 174 |
+
if not args.dry_run:
|
| 175 |
+
manifest_output.parent.mkdir(parents=True, exist_ok=True)
|
| 176 |
+
manifest_output.write_text(json.dumps(run_manifest, ensure_ascii=False, indent=2), encoding="utf-8")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
main()
|