ModerRAS commited on
Commit
c705a32
·
1 Parent(s): ed49faa

Add Rust encoded dataset cache

Browse files
.gitignore CHANGED
@@ -16,6 +16,7 @@ data/**/*.jsonl
16
  !data/test_smoke.jsonl
17
  data/**/*.db
18
  data/**/*.sqlite
 
19
  data/generated/
20
  reports/
21
  reports/generated/
 
16
  !data/test_smoke.jsonl
17
  data/**/*.db
18
  data/**/*.sqlite
19
+ data/encoded_cache/
20
  data/generated/
21
  reports/
22
  reports/generated/
AGENTS.md CHANGED
@@ -60,6 +60,38 @@ Train the current default character tokenizer:
60
  uv run python -m anifilebert.train --tokenizer char --data-file datasets/AnimeName/dmhy_weak_char.jsonl --vocab-file datasets/AnimeName/vocab.char.json --save-dir checkpoints/dmhy-char-full --init-model-dir . --epochs 2 --batch-size 256 --learning-rate 0.00008 --warmup-steps 300 --max-seq-length 128 --train-split 0.98 --num-workers 4 --checkpoint-steps 1000 --save-total-limit 3 --parse-eval-limit 2048 --case-eval-file data/parser_regression_cases.json --seed 52 --experiment-name dmhy-char-full
61
  ```
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  Export for Android:
64
 
65
  ```bash
@@ -134,6 +166,9 @@ land under `MyDrive/AniFileBERT/worker/jobs/<job-id>/`.
134
  before publishing parser changes.
135
  - For dataset alignment, tokenizer, model, or training-loop changes, run
136
  `python -m tools.test_train_small --limit-samples 5000 --epochs 2` when practical.
 
 
 
137
  - For export changes, run `python -m tools.export_onnx ...` and confirm the exporter
138
  reports a small PyTorch/ONNX logits difference.
139
  - For performance-sensitive inference changes, run `uv run python -m tools.benchmark_inference ...`
@@ -147,6 +182,8 @@ land under `MyDrive/AniFileBERT/worker/jobs/<job-id>/`.
147
  `test_checkpoints*/`, and `ab_checkpoints*/`.
148
  - Most `data/**/*.jsonl` files are generated and ignored. The small checked-in
149
  fixtures are `data/synthetic_small.jsonl` and `data/test_smoke.jsonl`.
 
 
150
  - For real training, choose exactly one current dataset:
151
  `datasets/AnimeName/dmhy_weak.jsonl` for regex tokenization or
152
  `datasets/AnimeName/dmhy_weak_char.jsonl` for character tokenization.
 
60
  uv run python -m anifilebert.train --tokenizer char --data-file datasets/AnimeName/dmhy_weak_char.jsonl --vocab-file datasets/AnimeName/vocab.char.json --save-dir checkpoints/dmhy-char-full --init-model-dir . --epochs 2 --batch-size 256 --learning-rate 0.00008 --warmup-steps 300 --max-seq-length 128 --train-split 0.98 --num-workers 4 --checkpoint-steps 1000 --save-total-limit 3 --parse-eval-limit 2048 --case-eval-file data/parser_regression_cases.json --seed 52 --experiment-name dmhy-char-full
61
  ```
62
 
63
+ For large generated or hard-focus JSONL files, pre-encode train/eval shards
64
+ with Rust before training to avoid the slow Python startup encode path:
65
+
66
+ ```powershell
67
+ cargo run --release --manifest-path tools\encoded_dataset_cache\Cargo.toml -- `
68
+ --input data\schema_v2_hard_focus_char_seed63.jsonl `
69
+ --vocab-file datasets\AnimeName\vocab.char.json `
70
+ --label-schema-file label_schema.json `
71
+ --output-dir data\encoded_cache\schema_v2_hard_focus_char_seed63 `
72
+ --max-length 128 `
73
+ --train-split 0.95 `
74
+ --seed 63 `
75
+ --shard-size 25000 `
76
+ --threads 16
77
+ ```
78
+
79
+ Then pass the generated cache to training with the same data/vocab/max-length,
80
+ split, and seed:
81
+
82
+ ```powershell
83
+ .\.venv\Scripts\python.exe -m anifilebert.train --tokenizer char `
84
+ --data-file data\schema_v2_hard_focus_char_seed63.jsonl `
85
+ --vocab-file datasets\AnimeName\vocab.char.json `
86
+ --encoded-cache-dir data\encoded_cache\schema_v2_hard_focus_char_seed63 `
87
+ --max-seq-length 128 --train-split 0.95 --seed 63
88
+ ```
89
+
90
+ Do not combine `--encoded-cache-dir` with `--extra-data-file`,
91
+ `--limit-samples`, `--rebuild-vocab`, training-time augmentation, or
92
+ `--apply-label-repairs`. Regenerate the cache after changing the JSONL, vocab,
93
+ label schema, max length, split ratio, or seed.
94
+
95
  Export for Android:
96
 
97
  ```bash
 
166
  before publishing parser changes.
167
  - For dataset alignment, tokenizer, model, or training-loop changes, run
168
  `python -m tools.test_train_small --limit-samples 5000 --epochs 2` when practical.
169
+ - For Rust encoded-cache changes, run `cargo check --manifest-path tools\encoded_dataset_cache\Cargo.toml`,
170
+ generate a small cache with `--limit-rows`, and verify `python -m anifilebert.train`
171
+ can start with `--encoded-cache-dir`.
172
  - For export changes, run `python -m tools.export_onnx ...` and confirm the exporter
173
  reports a small PyTorch/ONNX logits difference.
174
  - For performance-sensitive inference changes, run `uv run python -m tools.benchmark_inference ...`
 
182
  `test_checkpoints*/`, and `ab_checkpoints*/`.
183
  - Most `data/**/*.jsonl` files are generated and ignored. The small checked-in
184
  fixtures are `data/synthetic_small.jsonl` and `data/test_smoke.jsonl`.
185
+ - Rust encoded dataset caches under `data/encoded_cache/` are generated
186
+ artifacts and should not be committed.
187
  - For real training, choose exactly one current dataset:
188
  `datasets/AnimeName/dmhy_weak.jsonl` for regex tokenization or
189
  `datasets/AnimeName/dmhy_weak_char.jsonl` for character tokenization.
anifilebert/train.py CHANGED
@@ -93,6 +93,8 @@ def parse_args() -> argparse.Namespace:
93
  help="Repeat each extra dataset this many times after loading")
94
  parser.add_argument("--virtual-dataset-dir", default=None,
95
  help="Pre-encoded shard directory generated by tools/virtual_dataset_generator")
 
 
96
  parser.add_argument("--vocab-file", default=None,
97
  help="Tokenizer vocab JSON. Defaults to data/vocab.json or data/vocab.char.json")
98
  parser.add_argument("--save-dir", default=None, help="Checkpoint output directory")
@@ -275,6 +277,31 @@ def latest_checkpoint(save_dir: str) -> Optional[str]:
275
  return max(checkpoints)[1]
276
 
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  def validate_dataset_tokenizer_metadata(data: List[Dict], tokenizer_variant: str) -> None:
279
  variants = {item.get("tokenizer_variant") for item in data if item.get("tokenizer_variant")}
280
  if variants and variants != {tokenizer_variant}:
@@ -1285,12 +1312,6 @@ def main():
1285
 
1286
  print("Loading dataset...")
1287
  load_started_at = time.perf_counter()
1288
- all_data, data_sources = load_training_sources(
1289
- primary_data_file=config.data_file,
1290
- extra_data_files=list(args.extra_data_file or []),
1291
- extra_repeat=args.extra_data_repeat,
1292
- limit=args.limit_samples,
1293
- )
1294
  augmentation_metadata = {
1295
  "partial_requested": 0,
1296
  "partial_written": 0,
@@ -1300,23 +1321,60 @@ def main():
1300
  "special_written": 0,
1301
  "max_chars": args.augment_max_chars,
1302
  }
1303
- if args.augment_partial_samples or args.augment_permutation_samples or args.augment_special_samples:
1304
- if tokenizer_variant != "char":
1305
- raise ValueError("Training-time BIO span augmentation currently requires --tokenizer char.")
1306
- all_data, augmentation_metadata = augment_training_data(
1307
- data=all_data,
1308
- partial_count=args.augment_partial_samples,
1309
- permutation_count=args.augment_permutation_samples,
1310
- special_count=args.augment_special_samples,
1311
- max_chars=args.augment_max_chars,
1312
- seed=args.seed + 1009,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1313
  )
 
 
 
 
 
 
 
 
 
 
 
1314
  load_finished_at = time.perf_counter()
1315
- if len(all_data) < 2:
1316
- raise ValueError("Need at least two samples so train/eval split is non-empty.")
1317
- if not args.no_shuffle:
1318
- random.shuffle(all_data)
1319
- validate_dataset_tokenizer_metadata(all_data, tokenizer_variant)
 
 
 
 
 
1320
 
1321
  # Load tokenizer
1322
  print("Loading tokenizer...")
@@ -1396,13 +1454,36 @@ def main():
1396
  print("WARNING: Model exceeds the historical 5M target; continuing because vocab size is configurable.")
1397
 
1398
  use_cpu = args.cpu or not torch.cuda.is_available()
1399
- split_idx = int(len(all_data) * config.train_split)
1400
- split_idx = max(1, min(len(all_data) - 1, split_idx))
1401
- train_data = all_data[:split_idx]
1402
- eval_data = all_data[split_idx:]
 
1403
 
1404
  encode_started_at = time.perf_counter()
1405
- if args.virtual_dataset_dir:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1406
  virtual_dataset = ShardedEncodedDataset(args.virtual_dataset_dir)
1407
  if virtual_dataset.max_length != config.max_seq_length:
1408
  raise ValueError(
@@ -1584,6 +1665,7 @@ def main():
1584
  "data_sources": data_sources,
1585
  "augmentation": augmentation_metadata,
1586
  "dataset_mode": dataset_mode,
 
1587
  "virtual_dataset_dir": args.virtual_dataset_dir,
1588
  "apply_label_repairs": args.apply_label_repairs,
1589
  "keep_raw_dataset": args.keep_raw_dataset,
 
93
  help="Repeat each extra dataset this many times after loading")
94
  parser.add_argument("--virtual-dataset-dir", default=None,
95
  help="Pre-encoded shard directory generated by tools/virtual_dataset_generator")
96
+ parser.add_argument("--encoded-cache-dir", default=None,
97
+ help="Split train/eval encoded shard cache generated by tools/encoded_dataset_cache")
98
  parser.add_argument("--vocab-file", default=None,
99
  help="Tokenizer vocab JSON. Defaults to data/vocab.json or data/vocab.char.json")
100
  parser.add_argument("--save-dir", default=None, help="Checkpoint output directory")
 
277
  return max(checkpoints)[1]
278
 
279
 
280
+ def load_encoded_cache_manifest(cache_dir: str) -> Dict:
281
+ manifest_path = os.path.join(cache_dir, "manifest.json")
282
+ if not os.path.isfile(manifest_path):
283
+ raise FileNotFoundError(f"Encoded cache manifest not found: {manifest_path}")
284
+ with open(manifest_path, "r", encoding="utf-8") as f:
285
+ manifest = json.load(f)
286
+ if manifest.get("format") != "anifilebert.encoded_dataset_cache.v1":
287
+ raise ValueError(f"Unsupported encoded cache manifest: {manifest_path}")
288
+ return manifest
289
+
290
+
291
+ def encoded_cache_split_dir(cache_dir: str, manifest: Dict, split: str) -> str:
292
+ split_meta = manifest.get(split) or {}
293
+ relative_dir = split_meta.get("directory") or split
294
+ return os.path.join(cache_dir, relative_dir)
295
+
296
+
297
+ def load_encoded_cache_eval_data(cache_dir: str, manifest: Dict) -> List[Dict]:
298
+ relative_path = manifest.get("eval_records") or "eval_records.jsonl"
299
+ eval_path = os.path.join(cache_dir, relative_path)
300
+ if not os.path.isfile(eval_path):
301
+ raise FileNotFoundError(f"Encoded cache eval records not found: {eval_path}")
302
+ return load_jsonl(eval_path)
303
+
304
+
305
  def validate_dataset_tokenizer_metadata(data: List[Dict], tokenizer_variant: str) -> None:
306
  variants = {item.get("tokenizer_variant") for item in data if item.get("tokenizer_variant")}
307
  if variants and variants != {tokenizer_variant}:
 
1312
 
1313
  print("Loading dataset...")
1314
  load_started_at = time.perf_counter()
 
 
 
 
 
 
1315
  augmentation_metadata = {
1316
  "partial_requested": 0,
1317
  "partial_written": 0,
 
1321
  "special_written": 0,
1322
  "max_chars": args.augment_max_chars,
1323
  }
1324
+ encoded_cache_manifest = None
1325
+ if args.encoded_cache_dir:
1326
+ if args.extra_data_file:
1327
+ raise ValueError("--encoded-cache-dir cannot be combined with --extra-data-file.")
1328
+ if args.limit_samples is not None:
1329
+ raise ValueError("--encoded-cache-dir cannot be combined with --limit-samples.")
1330
+ if args.rebuild_vocab:
1331
+ raise ValueError("--encoded-cache-dir requires an existing vocab; do not pass --rebuild-vocab.")
1332
+ if args.augment_partial_samples or args.augment_permutation_samples or args.augment_special_samples:
1333
+ raise ValueError("--encoded-cache-dir cannot be combined with training-time augmentation.")
1334
+ if args.apply_label_repairs:
1335
+ raise ValueError("--encoded-cache-dir expects labels already repaired; do not pass --apply-label-repairs.")
1336
+ encoded_cache_manifest = load_encoded_cache_manifest(args.encoded_cache_dir)
1337
+ eval_data = load_encoded_cache_eval_data(args.encoded_cache_dir, encoded_cache_manifest)
1338
+ train_data: List[Dict] = []
1339
+ all_data: List[Dict] = []
1340
+ data_sources = [
1341
+ {
1342
+ "role": "encoded_cache",
1343
+ "path": args.encoded_cache_dir,
1344
+ "samples": int(encoded_cache_manifest.get("source_rows", 0)),
1345
+ "repeat": 1,
1346
+ "effective_samples": int(encoded_cache_manifest.get("source_rows", 0)),
1347
+ }
1348
+ ]
1349
+ else:
1350
+ all_data, data_sources = load_training_sources(
1351
+ primary_data_file=config.data_file,
1352
+ extra_data_files=list(args.extra_data_file or []),
1353
+ extra_repeat=args.extra_data_repeat,
1354
+ limit=args.limit_samples,
1355
  )
1356
+ if args.augment_partial_samples or args.augment_permutation_samples or args.augment_special_samples:
1357
+ if tokenizer_variant != "char":
1358
+ raise ValueError("Training-time BIO span augmentation currently requires --tokenizer char.")
1359
+ all_data, augmentation_metadata = augment_training_data(
1360
+ data=all_data,
1361
+ partial_count=args.augment_partial_samples,
1362
+ permutation_count=args.augment_permutation_samples,
1363
+ special_count=args.augment_special_samples,
1364
+ max_chars=args.augment_max_chars,
1365
+ seed=args.seed + 1009,
1366
+ )
1367
  load_finished_at = time.perf_counter()
1368
+ if args.encoded_cache_dir:
1369
+ if not eval_data:
1370
+ raise ValueError("Encoded cache eval_records.jsonl is empty.")
1371
+ validate_dataset_tokenizer_metadata(eval_data, tokenizer_variant)
1372
+ else:
1373
+ if len(all_data) < 2:
1374
+ raise ValueError("Need at least two samples so train/eval split is non-empty.")
1375
+ if not args.no_shuffle:
1376
+ random.shuffle(all_data)
1377
+ validate_dataset_tokenizer_metadata(all_data, tokenizer_variant)
1378
 
1379
  # Load tokenizer
1380
  print("Loading tokenizer...")
 
1454
  print("WARNING: Model exceeds the historical 5M target; continuing because vocab size is configurable.")
1455
 
1456
  use_cpu = args.cpu or not torch.cuda.is_available()
1457
+ if not args.encoded_cache_dir:
1458
+ split_idx = int(len(all_data) * config.train_split)
1459
+ split_idx = max(1, min(len(all_data) - 1, split_idx))
1460
+ train_data = all_data[:split_idx]
1461
+ eval_data = all_data[split_idx:]
1462
 
1463
  encode_started_at = time.perf_counter()
1464
+ if args.encoded_cache_dir:
1465
+ assert encoded_cache_manifest is not None
1466
+ train_cache_dir = encoded_cache_split_dir(args.encoded_cache_dir, encoded_cache_manifest, "train")
1467
+ eval_cache_dir = encoded_cache_split_dir(args.encoded_cache_dir, encoded_cache_manifest, "eval")
1468
+ train_dataset = ShardedEncodedDataset(train_cache_dir)
1469
+ eval_dataset = ShardedEncodedDataset(eval_cache_dir)
1470
+ for split_name, dataset in (("train", train_dataset), ("eval", eval_dataset)):
1471
+ if dataset.max_length != config.max_seq_length:
1472
+ raise ValueError(
1473
+ f"Encoded cache {split_name} max_length {dataset.max_length} does not match "
1474
+ f"configured max_seq_length {config.max_seq_length}"
1475
+ )
1476
+ if len(eval_dataset) != len(eval_data):
1477
+ raise ValueError(
1478
+ f"Encoded cache eval rows ({len(eval_dataset)}) do not match eval_records.jsonl "
1479
+ f"({len(eval_data)}). Regenerate the cache."
1480
+ )
1481
+ dataset_mode = "encoded-cache-sharded"
1482
+ if not args.keep_raw_dataset:
1483
+ all_data = []
1484
+ train_data = []
1485
+ gc.collect()
1486
+ elif args.virtual_dataset_dir:
1487
  virtual_dataset = ShardedEncodedDataset(args.virtual_dataset_dir)
1488
  if virtual_dataset.max_length != config.max_seq_length:
1489
  raise ValueError(
 
1665
  "data_sources": data_sources,
1666
  "augmentation": augmentation_metadata,
1667
  "dataset_mode": dataset_mode,
1668
+ "encoded_cache_dir": args.encoded_cache_dir,
1669
  "virtual_dataset_dir": args.virtual_dataset_dir,
1670
  "apply_label_repairs": args.apply_label_repairs,
1671
  "keep_raw_dataset": args.keep_raw_dataset,
tools/encoded_dataset_cache/Cargo.lock ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is automatically @generated by Cargo.
2
+ # It is not intended for manual editing.
3
+ version = 4
4
+
5
+ [[package]]
6
+ name = "aho-corasick"
7
+ version = "1.1.4"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
10
+ dependencies = [
11
+ "memchr",
12
+ ]
13
+
14
+ [[package]]
15
+ name = "anifilebert-encoded-dataset-cache"
16
+ version = "0.1.0"
17
+ dependencies = [
18
+ "anyhow",
19
+ "clap",
20
+ "rand",
21
+ "rayon",
22
+ "regex",
23
+ "serde",
24
+ "serde_json",
25
+ ]
26
+
27
+ [[package]]
28
+ name = "anstream"
29
+ version = "1.0.0"
30
+ source = "registry+https://github.com/rust-lang/crates.io-index"
31
+ checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d"
32
+ dependencies = [
33
+ "anstyle",
34
+ "anstyle-parse",
35
+ "anstyle-query",
36
+ "anstyle-wincon",
37
+ "colorchoice",
38
+ "is_terminal_polyfill",
39
+ "utf8parse",
40
+ ]
41
+
42
+ [[package]]
43
+ name = "anstyle"
44
+ version = "1.0.14"
45
+ source = "registry+https://github.com/rust-lang/crates.io-index"
46
+ checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000"
47
+
48
+ [[package]]
49
+ name = "anstyle-parse"
50
+ version = "1.0.0"
51
+ source = "registry+https://github.com/rust-lang/crates.io-index"
52
+ checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e"
53
+ dependencies = [
54
+ "utf8parse",
55
+ ]
56
+
57
+ [[package]]
58
+ name = "anstyle-query"
59
+ version = "1.1.5"
60
+ source = "registry+https://github.com/rust-lang/crates.io-index"
61
+ checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
62
+ dependencies = [
63
+ "windows-sys",
64
+ ]
65
+
66
+ [[package]]
67
+ name = "anstyle-wincon"
68
+ version = "3.0.11"
69
+ source = "registry+https://github.com/rust-lang/crates.io-index"
70
+ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
71
+ dependencies = [
72
+ "anstyle",
73
+ "once_cell_polyfill",
74
+ "windows-sys",
75
+ ]
76
+
77
+ [[package]]
78
+ name = "anyhow"
79
+ version = "1.0.102"
80
+ source = "registry+https://github.com/rust-lang/crates.io-index"
81
+ checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
82
+
83
+ [[package]]
84
+ name = "cfg-if"
85
+ version = "1.0.4"
86
+ source = "registry+https://github.com/rust-lang/crates.io-index"
87
+ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
88
+
89
+ [[package]]
90
+ name = "clap"
91
+ version = "4.6.1"
92
+ source = "registry+https://github.com/rust-lang/crates.io-index"
93
+ checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51"
94
+ dependencies = [
95
+ "clap_builder",
96
+ "clap_derive",
97
+ ]
98
+
99
+ [[package]]
100
+ name = "clap_builder"
101
+ version = "4.6.0"
102
+ source = "registry+https://github.com/rust-lang/crates.io-index"
103
+ checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f"
104
+ dependencies = [
105
+ "anstream",
106
+ "anstyle",
107
+ "clap_lex",
108
+ "strsim",
109
+ ]
110
+
111
+ [[package]]
112
+ name = "clap_derive"
113
+ version = "4.6.1"
114
+ source = "registry+https://github.com/rust-lang/crates.io-index"
115
+ checksum = "f2ce8604710f6733aa641a2b3731eaa1e8b3d9973d5e3565da11800813f997a9"
116
+ dependencies = [
117
+ "heck",
118
+ "proc-macro2",
119
+ "quote",
120
+ "syn",
121
+ ]
122
+
123
+ [[package]]
124
+ name = "clap_lex"
125
+ version = "1.1.0"
126
+ source = "registry+https://github.com/rust-lang/crates.io-index"
127
+ checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9"
128
+
129
+ [[package]]
130
+ name = "colorchoice"
131
+ version = "1.0.5"
132
+ source = "registry+https://github.com/rust-lang/crates.io-index"
133
+ checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
134
+
135
+ [[package]]
136
+ name = "crossbeam-deque"
137
+ version = "0.8.6"
138
+ source = "registry+https://github.com/rust-lang/crates.io-index"
139
+ checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
140
+ dependencies = [
141
+ "crossbeam-epoch",
142
+ "crossbeam-utils",
143
+ ]
144
+
145
+ [[package]]
146
+ name = "crossbeam-epoch"
147
+ version = "0.9.18"
148
+ source = "registry+https://github.com/rust-lang/crates.io-index"
149
+ checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
150
+ dependencies = [
151
+ "crossbeam-utils",
152
+ ]
153
+
154
+ [[package]]
155
+ name = "crossbeam-utils"
156
+ version = "0.8.21"
157
+ source = "registry+https://github.com/rust-lang/crates.io-index"
158
+ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
159
+
160
+ [[package]]
161
+ name = "either"
162
+ version = "1.16.0"
163
+ source = "registry+https://github.com/rust-lang/crates.io-index"
164
+ checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e"
165
+
166
+ [[package]]
167
+ name = "getrandom"
168
+ version = "0.2.17"
169
+ source = "registry+https://github.com/rust-lang/crates.io-index"
170
+ checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0"
171
+ dependencies = [
172
+ "cfg-if",
173
+ "libc",
174
+ "wasi",
175
+ ]
176
+
177
+ [[package]]
178
+ name = "heck"
179
+ version = "0.5.0"
180
+ source = "registry+https://github.com/rust-lang/crates.io-index"
181
+ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
182
+
183
+ [[package]]
184
+ name = "is_terminal_polyfill"
185
+ version = "1.70.2"
186
+ source = "registry+https://github.com/rust-lang/crates.io-index"
187
+ checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
188
+
189
+ [[package]]
190
+ name = "itoa"
191
+ version = "1.0.18"
192
+ source = "registry+https://github.com/rust-lang/crates.io-index"
193
+ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682"
194
+
195
+ [[package]]
196
+ name = "libc"
197
+ version = "0.2.186"
198
+ source = "registry+https://github.com/rust-lang/crates.io-index"
199
+ checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66"
200
+
201
+ [[package]]
202
+ name = "memchr"
203
+ version = "2.8.1"
204
+ source = "registry+https://github.com/rust-lang/crates.io-index"
205
+ checksum = "6b947ae49db0d222b1dbc6b113ce7248a3fc3a6ca21b696717bfc000ba4484d8"
206
+
207
+ [[package]]
208
+ name = "once_cell_polyfill"
209
+ version = "1.70.2"
210
+ source = "registry+https://github.com/rust-lang/crates.io-index"
211
+ checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
212
+
213
+ [[package]]
214
+ name = "ppv-lite86"
215
+ version = "0.2.21"
216
+ source = "registry+https://github.com/rust-lang/crates.io-index"
217
+ checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
218
+ dependencies = [
219
+ "zerocopy",
220
+ ]
221
+
222
+ [[package]]
223
+ name = "proc-macro2"
224
+ version = "1.0.106"
225
+ source = "registry+https://github.com/rust-lang/crates.io-index"
226
+ checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
227
+ dependencies = [
228
+ "unicode-ident",
229
+ ]
230
+
231
+ [[package]]
232
+ name = "quote"
233
+ version = "1.0.45"
234
+ source = "registry+https://github.com/rust-lang/crates.io-index"
235
+ checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
236
+ dependencies = [
237
+ "proc-macro2",
238
+ ]
239
+
240
+ [[package]]
241
+ name = "rand"
242
+ version = "0.8.6"
243
+ source = "registry+https://github.com/rust-lang/crates.io-index"
244
+ checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a"
245
+ dependencies = [
246
+ "libc",
247
+ "rand_chacha",
248
+ "rand_core",
249
+ ]
250
+
251
+ [[package]]
252
+ name = "rand_chacha"
253
+ version = "0.3.1"
254
+ source = "registry+https://github.com/rust-lang/crates.io-index"
255
+ checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
256
+ dependencies = [
257
+ "ppv-lite86",
258
+ "rand_core",
259
+ ]
260
+
261
+ [[package]]
262
+ name = "rand_core"
263
+ version = "0.6.4"
264
+ source = "registry+https://github.com/rust-lang/crates.io-index"
265
+ checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
266
+ dependencies = [
267
+ "getrandom",
268
+ ]
269
+
270
+ [[package]]
271
+ name = "rayon"
272
+ version = "1.12.0"
273
+ source = "registry+https://github.com/rust-lang/crates.io-index"
274
+ checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d"
275
+ dependencies = [
276
+ "either",
277
+ "rayon-core",
278
+ ]
279
+
280
+ [[package]]
281
+ name = "rayon-core"
282
+ version = "1.13.0"
283
+ source = "registry+https://github.com/rust-lang/crates.io-index"
284
+ checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
285
+ dependencies = [
286
+ "crossbeam-deque",
287
+ "crossbeam-utils",
288
+ ]
289
+
290
+ [[package]]
291
+ name = "regex"
292
+ version = "1.12.3"
293
+ source = "registry+https://github.com/rust-lang/crates.io-index"
294
+ checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276"
295
+ dependencies = [
296
+ "aho-corasick",
297
+ "memchr",
298
+ "regex-automata",
299
+ "regex-syntax",
300
+ ]
301
+
302
+ [[package]]
303
+ name = "regex-automata"
304
+ version = "0.4.14"
305
+ source = "registry+https://github.com/rust-lang/crates.io-index"
306
+ checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f"
307
+ dependencies = [
308
+ "aho-corasick",
309
+ "memchr",
310
+ "regex-syntax",
311
+ ]
312
+
313
+ [[package]]
314
+ name = "regex-syntax"
315
+ version = "0.8.10"
316
+ source = "registry+https://github.com/rust-lang/crates.io-index"
317
+ checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
318
+
319
+ [[package]]
320
+ name = "serde"
321
+ version = "1.0.228"
322
+ source = "registry+https://github.com/rust-lang/crates.io-index"
323
+ checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
324
+ dependencies = [
325
+ "serde_core",
326
+ "serde_derive",
327
+ ]
328
+
329
+ [[package]]
330
+ name = "serde_core"
331
+ version = "1.0.228"
332
+ source = "registry+https://github.com/rust-lang/crates.io-index"
333
+ checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
334
+ dependencies = [
335
+ "serde_derive",
336
+ ]
337
+
338
+ [[package]]
339
+ name = "serde_derive"
340
+ version = "1.0.228"
341
+ source = "registry+https://github.com/rust-lang/crates.io-index"
342
+ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
343
+ dependencies = [
344
+ "proc-macro2",
345
+ "quote",
346
+ "syn",
347
+ ]
348
+
349
+ [[package]]
350
+ name = "serde_json"
351
+ version = "1.0.150"
352
+ source = "registry+https://github.com/rust-lang/crates.io-index"
353
+ checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9"
354
+ dependencies = [
355
+ "itoa",
356
+ "memchr",
357
+ "serde",
358
+ "serde_core",
359
+ "zmij",
360
+ ]
361
+
362
+ [[package]]
363
+ name = "strsim"
364
+ version = "0.11.1"
365
+ source = "registry+https://github.com/rust-lang/crates.io-index"
366
+ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
367
+
368
+ [[package]]
369
+ name = "syn"
370
+ version = "2.0.117"
371
+ source = "registry+https://github.com/rust-lang/crates.io-index"
372
+ checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
373
+ dependencies = [
374
+ "proc-macro2",
375
+ "quote",
376
+ "unicode-ident",
377
+ ]
378
+
379
+ [[package]]
380
+ name = "unicode-ident"
381
+ version = "1.0.24"
382
+ source = "registry+https://github.com/rust-lang/crates.io-index"
383
+ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
384
+
385
+ [[package]]
386
+ name = "utf8parse"
387
+ version = "0.2.2"
388
+ source = "registry+https://github.com/rust-lang/crates.io-index"
389
+ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
390
+
391
+ [[package]]
392
+ name = "wasi"
393
+ version = "0.11.1+wasi-snapshot-preview1"
394
+ source = "registry+https://github.com/rust-lang/crates.io-index"
395
+ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
396
+
397
+ [[package]]
398
+ name = "windows-link"
399
+ version = "0.2.1"
400
+ source = "registry+https://github.com/rust-lang/crates.io-index"
401
+ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
402
+
403
+ [[package]]
404
+ name = "windows-sys"
405
+ version = "0.61.2"
406
+ source = "registry+https://github.com/rust-lang/crates.io-index"
407
+ checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc"
408
+ dependencies = [
409
+ "windows-link",
410
+ ]
411
+
412
+ [[package]]
413
+ name = "zerocopy"
414
+ version = "0.8.50"
415
+ source = "registry+https://github.com/rust-lang/crates.io-index"
416
+ checksum = "3b065d4f0e55f82fae73202e189638116a87c55ab6b8e6c2721e13dd9d854ad1"
417
+ dependencies = [
418
+ "zerocopy-derive",
419
+ ]
420
+
421
+ [[package]]
422
+ name = "zerocopy-derive"
423
+ version = "0.8.50"
424
+ source = "registry+https://github.com/rust-lang/crates.io-index"
425
+ checksum = "0b631b19d36a892ab55420c92dbc83ccd79274f25be714855d3074aa71cab639"
426
+ dependencies = [
427
+ "proc-macro2",
428
+ "quote",
429
+ "syn",
430
+ ]
431
+
432
+ [[package]]
433
+ name = "zmij"
434
+ version = "1.0.21"
435
+ source = "registry+https://github.com/rust-lang/crates.io-index"
436
+ checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa"
tools/encoded_dataset_cache/Cargo.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [package]
2
+ name = "anifilebert-encoded-dataset-cache"
3
+ version = "0.1.0"
4
+ edition = "2021"
5
+
6
+ [dependencies]
7
+ anyhow = "1.0"
8
+ clap = { version = "4.5", features = ["derive"] }
9
+ rand = "0.8"
10
+ rayon = "1.10"
11
+ regex = "1.11"
12
+ serde = { version = "1.0", features = ["derive"] }
13
+ serde_json = "1.0"
tools/encoded_dataset_cache/README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AniFileBERT encoded dataset cache
2
+
3
+ Builds split train/eval `.npy` shard caches for `anifilebert.train`.
4
+
5
+ The tool mirrors the Python char-tokenizer training encoder for JSONL rows with
6
+ `filename`, `tokens`, and `labels`, including projection from source tokens to
7
+ character labels and the structural media-label repairs used by training.
8
+
9
+ Example:
10
+
11
+ ```powershell
12
+ cargo run --release --manifest-path tools\encoded_dataset_cache\Cargo.toml -- `
13
+ --input data\schema_v2_hard_focus_char_seed63.jsonl `
14
+ --vocab-file datasets\AnimeName\vocab.char.json `
15
+ --label-schema-file label_schema.json `
16
+ --output-dir data\encoded_cache\schema_v2_hard_focus_char_seed63 `
17
+ --max-length 128 `
18
+ --train-split 0.95 `
19
+ --seed 63 `
20
+ --shard-size 25000 `
21
+ --threads 16
22
+ ```
23
+
24
+ Use the cache in training:
25
+
26
+ ```powershell
27
+ .\.venv\Scripts\python.exe -m anifilebert.train `
28
+ --tokenizer char `
29
+ --data-file data\schema_v2_hard_focus_char_seed63.jsonl `
30
+ --vocab-file datasets\AnimeName\vocab.char.json `
31
+ --encoded-cache-dir data\encoded_cache\schema_v2_hard_focus_char_seed63 `
32
+ --max-seq-length 128
33
+ ```
tools/encoded_dataset_cache/src/main.rs ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use anyhow::{bail, Context, Result};
2
+ use clap::Parser;
3
+ use rand::rngs::StdRng;
4
+ use rand::seq::SliceRandom;
5
+ use rand::SeedableRng;
6
+ use rayon::prelude::*;
7
+ use regex::Regex;
8
+ use serde::{Deserialize, Serialize};
9
+ use serde_json::{json, Value};
10
+ use std::collections::HashMap;
11
+ use std::fs::{self, File};
12
+ use std::io::{BufRead, BufReader, BufWriter, Write};
13
+ use std::path::{Path, PathBuf};
14
+ use std::sync::OnceLock;
15
+ use std::time::Instant;
16
+
17
+ const FALLBACK_LABELS: [&str; 37] = [
18
+ "O",
19
+ "B-TITLE_CHS",
20
+ "I-TITLE_CHS",
21
+ "B-TITLE_CHT",
22
+ "I-TITLE_CHT",
23
+ "B-TITLE_JPN",
24
+ "I-TITLE_JPN",
25
+ "B-TITLE_LATIN",
26
+ "I-TITLE_LATIN",
27
+ "B-TITLE_MIXED",
28
+ "I-TITLE_MIXED",
29
+ "B-PATH_TITLE_CHS",
30
+ "I-PATH_TITLE_CHS",
31
+ "B-PATH_TITLE_CHT",
32
+ "I-PATH_TITLE_CHT",
33
+ "B-PATH_TITLE_JPN",
34
+ "I-PATH_TITLE_JPN",
35
+ "B-PATH_TITLE_LATIN",
36
+ "I-PATH_TITLE_LATIN",
37
+ "B-PATH_TITLE_MIXED",
38
+ "I-PATH_TITLE_MIXED",
39
+ "B-PATH_SEASON",
40
+ "I-PATH_SEASON",
41
+ "B-SEASON",
42
+ "I-SEASON",
43
+ "B-EPISODE",
44
+ "I-EPISODE",
45
+ "B-SPECIAL",
46
+ "I-SPECIAL",
47
+ "B-GROUP",
48
+ "I-GROUP",
49
+ "B-RESOLUTION",
50
+ "I-RESOLUTION",
51
+ "B-SOURCE",
52
+ "I-SOURCE",
53
+ "B-TAG",
54
+ "I-TAG",
55
+ ];
56
+
57
+ const SOURCE_TOKEN_PATTERN: &str = r"WEB[-_ ]?DL|WEB[-_ ]?Rip|BDRip|BluRay|BDMV|BD|DVDRip|DVD|TVRip|HDTV|Netflix|NF|AMZN|Baha|CR|ABEMA|DSNP|U[-_ ]?NEXT|Hulu|AT[-_ ]?X|x26[45]|h\.?26[45]|HEVC|AVC|AV1|AAC\d*(?:\.\d+)?|AAC|FLAC|MP3|DTS|Opus|CHS|CHT|GB|BIG5|JPN?|JPSC|JPTC|繁中|简中";
58
+
59
+ static RESOLUTION_RE: OnceLock<Regex> = OnceLock::new();
60
+ static SOURCE_RE: OnceLock<Regex> = OnceLock::new();
61
+ static SOURCE_TAG_RE: OnceLock<Regex> = OnceLock::new();
62
+ static SPECIAL_TAG_RE: OnceLock<Regex> = OnceLock::new();
63
+ static SPECIAL_CODE_RE: OnceLock<Regex> = OnceLock::new();
64
+
65
+ #[derive(Parser, Debug)]
66
+ #[command(
67
+ about = "Build split train/eval encoded AniFileBERT shard caches",
68
+ version
69
+ )]
70
+ struct Args {
71
+ #[arg(long)]
72
+ input: PathBuf,
73
+
74
+ #[arg(long)]
75
+ vocab_file: PathBuf,
76
+
77
+ #[arg(long)]
78
+ output_dir: PathBuf,
79
+
80
+ #[arg(long, default_value = "label_schema.json")]
81
+ label_schema_file: PathBuf,
82
+
83
+ #[arg(long, default_value_t = 128)]
84
+ max_length: usize,
85
+
86
+ #[arg(long, default_value_t = 25_000)]
87
+ shard_size: usize,
88
+
89
+ #[arg(long, default_value_t = 0)]
90
+ limit_rows: usize,
91
+
92
+ #[arg(long, default_value_t = 0.98)]
93
+ train_split: f64,
94
+
95
+ #[arg(long, default_value_t = 42)]
96
+ seed: u64,
97
+
98
+ #[arg(long)]
99
+ no_shuffle: bool,
100
+
101
+ #[arg(long, default_value_t = 0)]
102
+ threads: usize,
103
+ }
104
+
105
+ #[derive(Debug, Deserialize)]
106
+ struct LabelSchema {
107
+ labels: Vec<String>,
108
+ }
109
+
110
+ #[derive(Clone)]
111
+ struct SourceRow {
112
+ row_index: usize,
113
+ raw_line: String,
114
+ filename: Option<String>,
115
+ tokens: Vec<String>,
116
+ labels: Vec<String>,
117
+ tokenizer_variant: Option<String>,
118
+ }
119
+
120
+ #[derive(Clone)]
121
+ struct Vocab {
122
+ ids: HashMap<String, u16>,
123
+ pad_id: u16,
124
+ unk_id: u16,
125
+ cls_id: u16,
126
+ sep_id: u16,
127
+ }
128
+
129
+ #[derive(Clone)]
130
+ struct EncodeContext {
131
+ vocab: Vocab,
132
+ label_ids: HashMap<String, i16>,
133
+ max_length: usize,
134
+ }
135
+
136
+ #[derive(Serialize)]
137
+ struct ShardManifest {
138
+ rows: usize,
139
+ input_ids: String,
140
+ attention_mask: String,
141
+ labels: String,
142
+ }
143
+
144
+ #[derive(Serialize)]
145
+ struct SplitSummary {
146
+ split: String,
147
+ rows: usize,
148
+ shards: usize,
149
+ directory: String,
150
+ }
151
+
152
+ fn main() -> Result<()> {
153
+ let args = Args::parse();
154
+ if args.max_length < 4 {
155
+ bail!("--max-length must be at least 4");
156
+ }
157
+ if args.shard_size == 0 {
158
+ bail!("--shard-size must be positive");
159
+ }
160
+ if !(0.0..1.0).contains(&args.train_split) {
161
+ bail!("--train-split must be > 0 and < 1");
162
+ }
163
+ if args.threads > 0 {
164
+ rayon::ThreadPoolBuilder::new()
165
+ .num_threads(args.threads)
166
+ .build_global()
167
+ .context("failed to configure rayon thread pool")?;
168
+ }
169
+
170
+ let started = Instant::now();
171
+ let vocab = load_vocab(&args.vocab_file)?;
172
+ let label_ids = load_label_ids(&args.label_schema_file)?;
173
+ let mut rows = load_rows(&args.input, args.limit_rows)?;
174
+ if rows.len() < 2 {
175
+ bail!("need at least two rows to build train/eval cache");
176
+ }
177
+
178
+ if !args.no_shuffle {
179
+ let mut rng = StdRng::seed_from_u64(args.seed);
180
+ rows.shuffle(&mut rng);
181
+ }
182
+ let split_idx = ((rows.len() as f64) * args.train_split) as usize;
183
+ let split_idx = split_idx.max(1).min(rows.len() - 1);
184
+ let (train_rows, eval_rows) = rows.split_at(split_idx);
185
+
186
+ fs::create_dir_all(&args.output_dir).with_context(|| {
187
+ format!(
188
+ "failed to create output directory {}",
189
+ args.output_dir.display()
190
+ )
191
+ })?;
192
+
193
+ let context = EncodeContext {
194
+ vocab,
195
+ label_ids,
196
+ max_length: args.max_length,
197
+ };
198
+ let train_summary = write_split(
199
+ "train",
200
+ train_rows,
201
+ &args.output_dir,
202
+ &context,
203
+ args.shard_size,
204
+ )?;
205
+ let eval_summary = write_split(
206
+ "eval",
207
+ eval_rows,
208
+ &args.output_dir,
209
+ &context,
210
+ args.shard_size,
211
+ )?;
212
+ write_eval_records(eval_rows, &args.output_dir.join("eval_records.jsonl"))?;
213
+
214
+ let manifest = json!({
215
+ "format": "anifilebert.encoded_dataset_cache.v1",
216
+ "input": args.input,
217
+ "vocab_file": args.vocab_file,
218
+ "label_schema_file": args.label_schema_file,
219
+ "output_dir": args.output_dir,
220
+ "max_length": args.max_length,
221
+ "shard_size": args.shard_size,
222
+ "limit_rows": args.limit_rows,
223
+ "source_rows": train_rows.len() + eval_rows.len(),
224
+ "train_split": args.train_split,
225
+ "seed": args.seed,
226
+ "shuffle": !args.no_shuffle,
227
+ "train": train_summary,
228
+ "eval": eval_summary,
229
+ "eval_records": "eval_records.jsonl",
230
+ "elapsed_seconds": started.elapsed().as_secs_f64(),
231
+ });
232
+ let manifest_path = args.output_dir.join("manifest.json");
233
+ fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
234
+ .with_context(|| format!("failed to write {}", manifest_path.display()))?;
235
+ println!("{}", serde_json::to_string_pretty(&manifest)?);
236
+ Ok(())
237
+ }
238
+
239
+ fn load_vocab(path: &Path) -> Result<Vocab> {
240
+ let text = fs::read_to_string(path)
241
+ .with_context(|| format!("failed to read vocab {}", path.display()))?;
242
+ let raw: HashMap<String, u64> =
243
+ serde_json::from_str(&text).with_context(|| format!("invalid vocab {}", path.display()))?;
244
+ let mut ids = HashMap::with_capacity(raw.len());
245
+ for (token, id) in raw {
246
+ if id > u16::MAX as u64 {
247
+ bail!("vocab id for token '{token}' exceeds u16: {id}");
248
+ }
249
+ ids.insert(token, id as u16);
250
+ }
251
+ let special = |token: &str| -> Result<u16> {
252
+ ids.get(token)
253
+ .copied()
254
+ .with_context(|| format!("vocab is missing special token {token}"))
255
+ };
256
+ Ok(Vocab {
257
+ pad_id: special("[PAD]")?,
258
+ unk_id: special("[UNK]")?,
259
+ cls_id: special("[CLS]")?,
260
+ sep_id: special("[SEP]")?,
261
+ ids,
262
+ })
263
+ }
264
+
265
+ fn load_label_ids(path: &Path) -> Result<HashMap<String, i16>> {
266
+ let labels = match fs::read_to_string(path) {
267
+ Ok(text) => {
268
+ serde_json::from_str::<LabelSchema>(&text)
269
+ .with_context(|| format!("invalid label schema {}", path.display()))?
270
+ .labels
271
+ }
272
+ Err(_) => FALLBACK_LABELS
273
+ .iter()
274
+ .map(|label| (*label).to_string())
275
+ .collect(),
276
+ };
277
+ if labels.is_empty() {
278
+ bail!("label schema has no labels");
279
+ }
280
+ Ok(labels
281
+ .into_iter()
282
+ .enumerate()
283
+ .map(|(idx, label)| (label, idx as i16))
284
+ .collect())
285
+ }
286
+
287
+ fn load_rows(path: &Path, limit_rows: usize) -> Result<Vec<SourceRow>> {
288
+ let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?;
289
+ let reader = BufReader::new(file);
290
+ let mut rows = Vec::new();
291
+ for (idx, line) in reader.lines().enumerate() {
292
+ if limit_rows > 0 && rows.len() >= limit_rows {
293
+ break;
294
+ }
295
+ let raw_line = line.with_context(|| format!("failed reading line {}", idx + 1))?;
296
+ if raw_line.trim().is_empty() {
297
+ continue;
298
+ }
299
+ let value: Value = serde_json::from_str(&raw_line)
300
+ .with_context(|| format!("failed to parse JSONL line {}", idx + 1))?;
301
+ let tokens = string_array_field(&value, "tokens", idx + 1)?;
302
+ let labels = string_array_field(&value, "labels", idx + 1)?;
303
+ if tokens.len() != labels.len() {
304
+ bail!(
305
+ "line {} has mismatched token/label lengths: {} vs {}",
306
+ idx + 1,
307
+ tokens.len(),
308
+ labels.len()
309
+ );
310
+ }
311
+ rows.push(SourceRow {
312
+ row_index: idx,
313
+ raw_line,
314
+ filename: value
315
+ .get("filename")
316
+ .and_then(Value::as_str)
317
+ .map(ToOwned::to_owned),
318
+ tokens,
319
+ labels,
320
+ tokenizer_variant: value
321
+ .get("tokenizer_variant")
322
+ .and_then(Value::as_str)
323
+ .map(ToOwned::to_owned),
324
+ });
325
+ }
326
+ Ok(rows)
327
+ }
328
+
329
+ fn string_array_field(value: &Value, field: &str, line_no: usize) -> Result<Vec<String>> {
330
+ let array = value
331
+ .get(field)
332
+ .and_then(Value::as_array)
333
+ .with_context(|| format!("line {line_no} missing array field '{field}'"))?;
334
+ array
335
+ .iter()
336
+ .map(|item| match item {
337
+ Value::String(text) => Ok(text.clone()),
338
+ other => Ok(match other {
339
+ Value::Null => String::new(),
340
+ _ => other.to_string(),
341
+ }),
342
+ })
343
+ .collect()
344
+ }
345
+
346
+ fn write_split(
347
+ split: &str,
348
+ rows: &[SourceRow],
349
+ output_dir: &Path,
350
+ context: &EncodeContext,
351
+ shard_size: usize,
352
+ ) -> Result<SplitSummary> {
353
+ let split_dir = output_dir.join(split);
354
+ fs::create_dir_all(&split_dir)
355
+ .with_context(|| format!("failed to create {}", split_dir.display()))?;
356
+ let chunks = rows
357
+ .chunks(shard_size)
358
+ .enumerate()
359
+ .collect::<Vec<(usize, &[SourceRow])>>();
360
+ let shards = chunks
361
+ .par_iter()
362
+ .map(|(shard_idx, chunk)| write_shard(split, *shard_idx, chunk, &split_dir, context))
363
+ .collect::<Result<Vec<_>>>()?;
364
+
365
+ let manifest = json!({
366
+ "format": "anifilebert.virtual_dataset.shards.v1",
367
+ "generated_by": "tools/encoded_dataset_cache",
368
+ "split": split,
369
+ "max_length": context.max_length,
370
+ "total_rows": rows.len(),
371
+ "shards": shards,
372
+ });
373
+ let manifest_path = split_dir.join("manifest.json");
374
+ fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
375
+ .with_context(|| format!("failed to write {}", manifest_path.display()))?;
376
+ Ok(SplitSummary {
377
+ split: split.to_string(),
378
+ rows: rows.len(),
379
+ shards: chunks.len(),
380
+ directory: split.to_string(),
381
+ })
382
+ }
383
+
384
+ fn write_shard(
385
+ split: &str,
386
+ shard_idx: usize,
387
+ rows: &[SourceRow],
388
+ split_dir: &Path,
389
+ context: &EncodeContext,
390
+ ) -> Result<ShardManifest> {
391
+ let capacity = rows.len().saturating_mul(context.max_length);
392
+ let mut input_ids = Vec::with_capacity(capacity);
393
+ let mut attention_mask = Vec::with_capacity(capacity);
394
+ let mut labels = Vec::with_capacity(capacity);
395
+ for row in rows {
396
+ let encoded = encode_row(row, context)
397
+ .with_context(|| format!("failed to encode source line {}", row.row_index + 1))?;
398
+ input_ids.extend_from_slice(&encoded.0);
399
+ attention_mask.extend_from_slice(&encoded.1);
400
+ labels.extend_from_slice(&encoded.2);
401
+ }
402
+
403
+ let base = format!("part-{split}-s{shard_idx:06}");
404
+ let input_name = format!("{base}.input_ids.npy");
405
+ let mask_name = format!("{base}.attention_mask.npy");
406
+ let label_name = format!("{base}.labels.npy");
407
+ write_npy_u16(
408
+ &split_dir.join(&input_name),
409
+ &input_ids,
410
+ rows.len(),
411
+ context.max_length,
412
+ )?;
413
+ write_npy_u8(
414
+ &split_dir.join(&mask_name),
415
+ &attention_mask,
416
+ rows.len(),
417
+ context.max_length,
418
+ )?;
419
+ write_npy_i16(
420
+ &split_dir.join(&label_name),
421
+ &labels,
422
+ rows.len(),
423
+ context.max_length,
424
+ )?;
425
+ Ok(ShardManifest {
426
+ rows: rows.len(),
427
+ input_ids: input_name,
428
+ attention_mask: mask_name,
429
+ labels: label_name,
430
+ })
431
+ }
432
+
433
+ fn encode_row(row: &SourceRow, context: &EncodeContext) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> {
434
+ let (tokens, labels) = labels_for_char_tokenizer(row);
435
+ let mut input_ids = vec![context.vocab.pad_id; context.max_length];
436
+ let mut attention_mask = vec![0u8; context.max_length];
437
+ let mut label_ids = vec![-100i16; context.max_length];
438
+
439
+ input_ids[0] = context.vocab.cls_id;
440
+ attention_mask[0] = 1;
441
+ let available = context.max_length.saturating_sub(2);
442
+ let token_count = tokens.len().min(labels.len()).min(available);
443
+ for idx in 0..token_count {
444
+ input_ids[idx + 1] = token_id(&context.vocab, &tokens[idx]);
445
+ attention_mask[idx + 1] = 1;
446
+ let label = canonical_bio_label(&labels[idx]);
447
+ label_ids[idx + 1] = context
448
+ .label_ids
449
+ .get(&label)
450
+ .copied()
451
+ .with_context(|| format!("unknown label '{label}'"))?;
452
+ }
453
+ let sep_pos = token_count + 1;
454
+ input_ids[sep_pos] = context.vocab.sep_id;
455
+ attention_mask[sep_pos] = 1;
456
+ Ok((input_ids, attention_mask, label_ids))
457
+ }
458
+
459
+ fn labels_for_char_tokenizer(row: &SourceRow) -> (Vec<String>, Vec<String>) {
460
+ if row.tokenizer_variant.as_deref() == Some("char") {
461
+ if let Some(filename) = row.filename.as_deref() {
462
+ let filename_chars = chars_as_strings(filename);
463
+ if row.tokens == filename_chars {
464
+ return (row.tokens.clone(), row.labels.clone());
465
+ }
466
+ }
467
+ }
468
+
469
+ if let Some(filename) = row.filename.as_deref() {
470
+ if let Some(projected) = project_labels_from_filename(filename, &row.tokens, &row.labels) {
471
+ let (tokens, mut labels) = projected;
472
+ repair_structural_meta_labels(filename, &mut labels);
473
+ return (tokens, labels);
474
+ }
475
+ }
476
+
477
+ let (tokens, mut labels) = align_tokens_to_chars(&row.tokens, &row.labels);
478
+ if let Some(filename) = row.filename.as_deref() {
479
+ repair_structural_meta_labels(filename, &mut labels);
480
+ }
481
+ (tokens, labels)
482
+ }
483
+
484
+ fn project_labels_from_filename(
485
+ filename: &str,
486
+ source_tokens: &[String],
487
+ source_labels: &[String],
488
+ ) -> Option<(Vec<String>, Vec<String>)> {
489
+ let offsets = token_offsets_in_text(filename, source_tokens)?;
490
+ if offsets.len() != source_labels.len() {
491
+ return None;
492
+ }
493
+ let char_len = filename.chars().count();
494
+ let mut char_entities: Vec<Option<String>> = vec![None; char_len];
495
+ for ((token, label), (mut start, mut end)) in source_tokens
496
+ .iter()
497
+ .zip(source_labels.iter())
498
+ .zip(offsets.into_iter())
499
+ {
500
+ let Some(entity) = bio_entity(label) else {
501
+ continue;
502
+ };
503
+ if is_wrapped_token(token) && end > start + 1 {
504
+ start += 1;
505
+ end -= 1;
506
+ }
507
+ for pos in start..end.min(char_entities.len()) {
508
+ char_entities[pos] = Some(entity.clone());
509
+ }
510
+ }
511
+
512
+ let tokens = chars_as_strings(filename);
513
+ let mut labels = Vec::with_capacity(tokens.len());
514
+ let mut active_entity: Option<String> = None;
515
+ for entity in char_entities {
516
+ match entity {
517
+ Some(entity) => {
518
+ let prefix = if active_entity.as_deref() == Some(entity.as_str()) {
519
+ "I"
520
+ } else {
521
+ "B"
522
+ };
523
+ labels.push(format!("{prefix}-{entity}"));
524
+ active_entity = Some(entity);
525
+ }
526
+ None => {
527
+ labels.push("O".to_string());
528
+ active_entity = None;
529
+ }
530
+ }
531
+ }
532
+ Some((tokens, labels))
533
+ }
534
+
535
+ fn token_offsets_in_text(text: &str, tokens: &[String]) -> Option<Vec<(usize, usize)>> {
536
+ let mut offsets = Vec::with_capacity(tokens.len());
537
+ let mut cursor = 0usize;
538
+ for token in tokens {
539
+ if token.is_empty() {
540
+ let char_cursor = char_index_at_byte(text, cursor);
541
+ offsets.push((char_cursor, char_cursor));
542
+ continue;
543
+ }
544
+ let relative = text.get(cursor..)?.find(token)?;
545
+ let start_byte = cursor + relative;
546
+ let end_byte = start_byte + token.len();
547
+ offsets.push((
548
+ char_index_at_byte(text, start_byte),
549
+ char_index_at_byte(text, end_byte),
550
+ ));
551
+ cursor = end_byte;
552
+ }
553
+ Some(offsets)
554
+ }
555
+
556
+ fn align_tokens_to_chars(tokens: &[String], labels: &[String]) -> (Vec<String>, Vec<String>) {
557
+ let mut char_tokens = Vec::new();
558
+ let mut char_labels = Vec::new();
559
+ for (token, label) in tokens.iter().zip(labels.iter()) {
560
+ let chars = chars_as_strings(token);
561
+ if chars.is_empty() {
562
+ continue;
563
+ }
564
+ let label = label.as_str();
565
+ if label.starts_with("B-") {
566
+ let entity = label
567
+ .split_once('-')
568
+ .map(|(_, entity)| entity)
569
+ .unwrap_or("");
570
+ char_labels.push(label.to_string());
571
+ char_labels.extend((1..chars.len()).map(|_| format!("I-{entity}")));
572
+ } else if label.starts_with("I-") {
573
+ char_labels.extend((0..chars.len()).map(|_| label.to_string()));
574
+ } else {
575
+ char_labels.extend((0..chars.len()).map(|_| label.to_string()));
576
+ }
577
+ char_tokens.extend(chars);
578
+ }
579
+ (char_tokens, char_labels)
580
+ }
581
+
582
+ fn repair_structural_meta_labels(text: &str, labels: &mut [String]) {
583
+ if labels.len() != text.chars().count() {
584
+ return;
585
+ }
586
+ let episode_end = first_episode_span_end(labels);
587
+ for (inner_start, inner_end) in bracket_inner_spans(text) {
588
+ let bracket_start = inner_start.saturating_sub(1);
589
+ if bracket_start < episode_end {
590
+ continue;
591
+ }
592
+ let inner = chars_range_to_string(text, inner_start, inner_end);
593
+ let (trim_start, trim_end) = trimmed_bounds(&inner);
594
+ if trim_start >= trim_end {
595
+ continue;
596
+ }
597
+ let clean = chars_slice_to_string(&inner, trim_start, trim_end);
598
+ let clean_start = inner_start + trim_start;
599
+ let clean_end = inner_start + trim_end;
600
+
601
+ if special_tag_re().is_match(&clean) || special_code_re().is_match(&clean) {
602
+ label_span_if_safe(labels, clean_start, clean_end, "SPECIAL");
603
+ continue;
604
+ }
605
+ if source_tag_re().is_match(&clean) {
606
+ label_span_if_safe(labels, clean_start, clean_end, "SOURCE");
607
+ continue;
608
+ }
609
+
610
+ for mat in resolution_re().find_iter(&inner) {
611
+ if !has_ascii_token_boundaries(&inner, mat.start(), mat.end()) {
612
+ continue;
613
+ }
614
+ let start = inner_start + char_index_at_byte(&inner, mat.start());
615
+ let end = inner_start + char_index_at_byte(&inner, mat.end());
616
+ label_span_if_safe(labels, start, end, "RESOLUTION");
617
+ }
618
+ for mat in source_re().find_iter(&inner) {
619
+ if !has_ascii_token_boundaries(&inner, mat.start(), mat.end()) {
620
+ continue;
621
+ }
622
+ let start = inner_start + char_index_at_byte(&inner, mat.start());
623
+ let end = inner_start + char_index_at_byte(&inner, mat.end());
624
+ label_span_if_safe(labels, start, end, "SOURCE");
625
+ }
626
+ }
627
+
628
+ for mat in resolution_re().find_iter(text) {
629
+ if !has_ascii_token_boundaries(text, mat.start(), mat.end()) {
630
+ continue;
631
+ }
632
+ let start = char_index_at_byte(text, mat.start());
633
+ if start < episode_end {
634
+ continue;
635
+ }
636
+ let end = char_index_at_byte(text, mat.end());
637
+ label_span_if_safe(labels, start, end, "RESOLUTION");
638
+ }
639
+ for mat in source_re().find_iter(text) {
640
+ if !has_ascii_token_boundaries(text, mat.start(), mat.end()) {
641
+ continue;
642
+ }
643
+ let start = char_index_at_byte(text, mat.start());
644
+ if start < episode_end {
645
+ continue;
646
+ }
647
+ let end = char_index_at_byte(text, mat.end());
648
+ label_span_if_safe(labels, start, end, "SOURCE");
649
+ }
650
+ }
651
+
652
+ fn first_episode_span_end(labels: &[String]) -> usize {
653
+ let mut idx = 0usize;
654
+ while idx < labels.len() {
655
+ if label_entity(&labels[idx]) == Some("EPISODE") {
656
+ let mut end = idx + 1;
657
+ while end < labels.len() && label_entity(&labels[end]) == Some("EPISODE") {
658
+ end += 1;
659
+ }
660
+ return end;
661
+ }
662
+ idx += 1;
663
+ }
664
+ 0
665
+ }
666
+
667
+ fn bracket_inner_spans(text: &str) -> Vec<(usize, usize)> {
668
+ let chars = text.chars().collect::<Vec<_>>();
669
+ let mut spans = Vec::new();
670
+ let mut idx = 0usize;
671
+ while idx < chars.len() {
672
+ let close = match chars[idx] {
673
+ '[' => ']',
674
+ '(' => ')',
675
+ '【' => '】',
676
+ '《' => '》',
677
+ _ => {
678
+ idx += 1;
679
+ continue;
680
+ }
681
+ };
682
+ if let Some(relative_end) = chars[idx + 1..].iter().position(|ch| *ch == close) {
683
+ let end = idx + 1 + relative_end;
684
+ spans.push((idx + 1, end));
685
+ idx = end + 1;
686
+ } else {
687
+ idx += 1;
688
+ }
689
+ }
690
+ spans
691
+ }
692
+
693
+ fn trimmed_bounds(text: &str) -> (usize, usize) {
694
+ let chars = text.chars().collect::<Vec<_>>();
695
+ let mut start = 0usize;
696
+ let mut end = chars.len();
697
+ while start < end && chars[start].is_whitespace() {
698
+ start += 1;
699
+ }
700
+ while end > start && chars[end - 1].is_whitespace() {
701
+ end -= 1;
702
+ }
703
+ (start, end)
704
+ }
705
+
706
+ fn chars_range_to_string(text: &str, start: usize, end: usize) -> String {
707
+ text.chars()
708
+ .skip(start)
709
+ .take(end.saturating_sub(start))
710
+ .collect()
711
+ }
712
+
713
+ fn chars_slice_to_string(text: &str, start: usize, end: usize) -> String {
714
+ text.chars()
715
+ .skip(start)
716
+ .take(end.saturating_sub(start))
717
+ .collect()
718
+ }
719
+
720
+ fn label_span_if_safe(labels: &mut [String], start: usize, end: usize, entity: &str) {
721
+ if start >= end || end > labels.len() {
722
+ return;
723
+ }
724
+ if labels[start..end].iter().any(|label| {
725
+ matches!(
726
+ label_entity(label),
727
+ Some("GROUP" | "EPISODE" | "SEASON" | "PATH_SEASON")
728
+ )
729
+ }) {
730
+ return;
731
+ }
732
+ let previous_same = start > 0 && label_entity(&labels[start - 1]) == Some(entity);
733
+ let mut first = !previous_same;
734
+ for label in labels.iter_mut().take(end).skip(start) {
735
+ *label = if first {
736
+ format!("B-{entity}")
737
+ } else {
738
+ format!("I-{entity}")
739
+ };
740
+ first = false;
741
+ }
742
+ }
743
+
744
+ fn has_ascii_token_boundaries(text: &str, start: usize, end: usize) -> bool {
745
+ let previous_ok = text[..start]
746
+ .chars()
747
+ .next_back()
748
+ .map(|ch| !ch.is_ascii_alphanumeric())
749
+ .unwrap_or(true);
750
+ let next_ok = text[end..]
751
+ .chars()
752
+ .next()
753
+ .map(|ch| !ch.is_ascii_alphanumeric())
754
+ .unwrap_or(true);
755
+ previous_ok && next_ok
756
+ }
757
+
758
+ fn label_entity(label: &str) -> Option<&str> {
759
+ let (prefix, entity) = label.split_once('-')?;
760
+ if prefix == "B" || prefix == "I" {
761
+ Some(entity)
762
+ } else {
763
+ None
764
+ }
765
+ }
766
+
767
+ fn resolution_re() -> &'static Regex {
768
+ RESOLUTION_RE
769
+ .get_or_init(|| Regex::new(r"(?i)(?:\d{3,4}p|\d[kK]|\d{3,4}[xX×]\d{3,4})").unwrap())
770
+ }
771
+
772
+ fn source_re() -> &'static Regex {
773
+ SOURCE_RE.get_or_init(|| Regex::new(&format!(r"(?i)(?:{SOURCE_TOKEN_PATTERN})")).unwrap())
774
+ }
775
+
776
+ fn source_tag_re() -> &'static Regex {
777
+ SOURCE_TAG_RE.get_or_init(|| {
778
+ Regex::new(&format!(
779
+ r"(?i)^(?:{SOURCE_TOKEN_PATTERN})(?:\s*(?:[&+/,_-]|,\s*)\s*(?:{SOURCE_TOKEN_PATTERN}))*$"
780
+ ))
781
+ .unwrap()
782
+ })
783
+ }
784
+
785
+ fn special_tag_re() -> &'static Regex {
786
+ SPECIAL_TAG_RE.get_or_init(|| {
787
+ Regex::new(r"(?i)^(?:檢索|检索|搜索|搜寻|搜尋|别名|別名|alias|search|keyword)\s*[::].+")
788
+ .unwrap()
789
+ })
790
+ }
791
+
792
+ fn special_code_re() -> &'static Regex {
793
+ SPECIAL_CODE_RE.get_or_init(|| {
794
+ Regex::new(r"(?i)^(?:NCOP|NCED|OP|ED|PV|CM)\d*$|^IV\d+$|^(?:OVA|OAD|SP)\d*$").unwrap()
795
+ })
796
+ }
797
+
798
+ fn chars_as_strings(text: &str) -> Vec<String> {
799
+ text.chars().map(|ch| ch.to_string()).collect()
800
+ }
801
+
802
+ fn char_index_at_byte(text: &str, byte_index: usize) -> usize {
803
+ text[..byte_index].chars().count()
804
+ }
805
+
806
+ fn bio_entity(label: &str) -> Option<String> {
807
+ let (prefix, entity) = label.split_once('-')?;
808
+ if prefix == "B" || prefix == "I" {
809
+ Some(entity.to_string())
810
+ } else {
811
+ None
812
+ }
813
+ }
814
+
815
+ fn is_wrapped_token(token: &str) -> bool {
816
+ let mut chars = token.chars();
817
+ let Some(first) = chars.next() else {
818
+ return false;
819
+ };
820
+ let Some(last) = token.chars().last() else {
821
+ return false;
822
+ };
823
+ matches!(first, '[' | '【' | '(' | '《') && matches!(last, ']' | '】' | ')' | '》')
824
+ }
825
+
826
+ fn canonical_bio_label(label: &str) -> String {
827
+ let Some((prefix, entity)) = label.split_once('-') else {
828
+ return if label == "O" {
829
+ "O".to_string()
830
+ } else {
831
+ label.to_string()
832
+ };
833
+ };
834
+ if prefix != "B" && prefix != "I" {
835
+ return label.to_string();
836
+ }
837
+ let canonical_entity = match entity {
838
+ "TITLE" => "TITLE_MIXED",
839
+ "PATH_TITLE" => "PATH_TITLE_MIXED",
840
+ other => other,
841
+ };
842
+ format!("{prefix}-{canonical_entity}")
843
+ }
844
+
845
+ fn token_id(vocab: &Vocab, token: &str) -> u16 {
846
+ *vocab.ids.get(token).unwrap_or(&vocab.unk_id)
847
+ }
848
+
849
+ fn write_eval_records(rows: &[SourceRow], path: &Path) -> Result<()> {
850
+ let mut writer = BufWriter::new(
851
+ File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
852
+ );
853
+ for row in rows {
854
+ writer.write_all(row.raw_line.as_bytes())?;
855
+ writer.write_all(b"\n")?;
856
+ }
857
+ Ok(())
858
+ }
859
+
860
+ fn write_npy_u16(path: &Path, data: &[u16], rows: usize, cols: usize) -> Result<()> {
861
+ let mut writer = BufWriter::new(
862
+ File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
863
+ );
864
+ write_npy_header(&mut writer, "<u2", rows, cols)?;
865
+ for value in data {
866
+ writer.write_all(&value.to_le_bytes())?;
867
+ }
868
+ Ok(())
869
+ }
870
+
871
+ fn write_npy_u8(path: &Path, data: &[u8], rows: usize, cols: usize) -> Result<()> {
872
+ let mut writer = BufWriter::new(
873
+ File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
874
+ );
875
+ write_npy_header(&mut writer, "|u1", rows, cols)?;
876
+ writer.write_all(data)?;
877
+ Ok(())
878
+ }
879
+
880
+ fn write_npy_i16(path: &Path, data: &[i16], rows: usize, cols: usize) -> Result<()> {
881
+ let mut writer = BufWriter::new(
882
+ File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
883
+ );
884
+ write_npy_header(&mut writer, "<i2", rows, cols)?;
885
+ for value in data {
886
+ writer.write_all(&value.to_le_bytes())?;
887
+ }
888
+ Ok(())
889
+ }
890
+
891
+ fn write_npy_header<W: Write>(writer: &mut W, descr: &str, rows: usize, cols: usize) -> Result<()> {
892
+ let mut header = format!(
893
+ "{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}",
894
+ descr, rows, cols
895
+ )
896
+ .into_bytes();
897
+ let preamble_len = 10usize;
898
+ let pad_len = (16 - ((preamble_len + header.len() + 1) % 16)) % 16;
899
+ header.extend(std::iter::repeat(b' ').take(pad_len));
900
+ header.push(b'\n');
901
+ if header.len() > u16::MAX as usize {
902
+ bail!("npy header too large");
903
+ }
904
+ writer.write_all(b"\x93NUMPY")?;
905
+ writer.write_all(&[1, 0])?;
906
+ writer.write_all(&(header.len() as u16).to_le_bytes())?;
907
+ writer.write_all(&header)?;
908
+ Ok(())
909
+ }