ModerRAS commited on
Commit
0779202
·
1 Parent(s): e34dc04

Add character-token DMHY training path

Browse files
Files changed (5) hide show
  1. README.md +29 -0
  2. colab_train.py +11 -10
  3. convert_to_char_dataset.py +201 -0
  4. datasets/AnimeName +1 -1
  5. train.py +13 -6
README.md CHANGED
@@ -60,6 +60,11 @@ Common fansub group names (`Snow`, `LoliHouse`, `DMG`, `KTXP`, `Sakurato`, etc.)
60
  and individual bracket characters (`[`, `]`, `(`, `)`) are included in the new
61
  vocabulary.
62
 
 
 
 
 
 
63
  ## Evaluation
64
 
65
  Balanced mixed-data A/B run (`50K` synthetic + `50K` DMHY weak labels, 1 epoch, batch size 128, seed 42):
@@ -139,6 +144,29 @@ The model loads the old 3000-token checkpoint, `resize_token_embeddings()` adds
139
  trains the full model. About 96% of token occurrences are now covered (vs 90%
140
  with the old 3000-token vocabulary).
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  ### Regenerate datasets from source
143
 
144
  ```bash
@@ -178,6 +206,7 @@ the full training pipeline. Checkpoints are saved to your Drive automatically.
178
  - `model.safetensors`, `config.json`, `vocab.json`: default fine-tuned model
179
  - `train.py`, `dataset.py`, `tokenizer.py`, `model.py`: training pipeline
180
  - `dmhy_dataset.py`, `mix_datasets.py`: weak-label export and dataset mixing
 
181
  - `inference.py`: end-to-end filename parser CLI
182
  - `export_onnx.py`: ONNX export for Android integration
183
  - `exports/`: exported ONNX model and metadata
 
60
  and individual bracket characters (`[`, `]`, `(`, `)`) are included in the new
61
  vocabulary.
62
 
63
+ For character-token training, `datasets/AnimeName/vocab.char.json` is built
64
+ from the full `dmhy_weak_char.jsonl` export. The full DMHY weak dataset has
65
+ **6195 unique characters**, so the complete character vocab is only **6199**
66
+ entries including special tokens and reaches 100% token coverage.
67
+
68
  ## Evaluation
69
 
70
  Balanced mixed-data A/B run (`50K` synthetic + `50K` DMHY weak labels, 1 epoch, batch size 128, seed 42):
 
144
  trains the full model. About 96% of token occurrences are now covered (vs 90%
145
  with the old 3000-token vocabulary).
146
 
147
+ ### Character-token DMHY training
148
+
149
+ ```bash
150
+ python convert_to_char_dataset.py \
151
+ --input datasets/AnimeName/dmhy_weak.jsonl \
152
+ --output datasets/AnimeName/dmhy_weak_char.jsonl \
153
+ --vocab-output datasets/AnimeName/vocab.char.json \
154
+ --manifest-output datasets/AnimeName/dmhy_weak_char.manifest.json
155
+
156
+ python train.py --tokenizer char \
157
+ --data-file datasets/AnimeName/dmhy_weak_char.jsonl \
158
+ --vocab-file datasets/AnimeName/vocab.char.json \
159
+ --save-dir checkpoints_char/dmhy-weak-char \
160
+ --epochs 1 --batch-size 64 \
161
+ --learning-rate 0.0003 --warmup-steps 300 \
162
+ --max-seq-length 128 --seed 42
163
+ ```
164
+
165
+ The converter keeps source metadata and adds `tokenizer_variant`, source token
166
+ count, and character token count fields to each record. The char dataset's
167
+ p99 length is 107 characters, so `--max-seq-length 128` covers almost all rows
168
+ while leaving room for `[CLS]` and `[SEP]`.
169
+
170
  ### Regenerate datasets from source
171
 
172
  ```bash
 
206
  - `model.safetensors`, `config.json`, `vocab.json`: default fine-tuned model
207
  - `train.py`, `dataset.py`, `tokenizer.py`, `model.py`: training pipeline
208
  - `dmhy_dataset.py`, `mix_datasets.py`: weak-label export and dataset mixing
209
+ - `convert_to_char_dataset.py`: full character-token projection for weak labels
210
  - `inference.py`: end-to-end filename parser CLI
211
  - `export_onnx.py`: ONNX export for Android integration
212
  - `exports/`: exported ONNX model and metadata
colab_train.py CHANGED
@@ -13,12 +13,12 @@ What it does:
13
  - Mounts Google Drive (for persistent checkpoints)
14
  - Clones AniFileBERT repo + AnimeName dataset submodule
15
  - Installs PyTorch + Transformers dependencies
16
- - Runs training: fine-tune from current checkpoint with 8000-token vocab
17
  - Saves final model to Drive
18
 
19
  Output:
20
  - Checkpoints saved to: MyDrive/AniFileBERT/checkpoints/
21
- - Final model at: MyDrive/AniFileBERT/checkpoints/dmhy-finetune/final/
22
  """
23
 
24
  import os
@@ -90,25 +90,26 @@ run("python -c 'import torch; print(f\"PyTorch {torch.__version__}, CUDA availab
90
  print("\n" + "=" * 60)
91
  print("STEP 5: Verify vocabulary")
92
  print("=" * 60)
93
- run("python -c 'import json; v=json.load(open(\"vocab.json\")); print(f\"Vocab size: {len(v)} tokens\")'")
94
 
95
  # ── 6. Run training ────────────────────────────────────────────
96
  print("\n" + "=" * 60)
97
  print("STEP 6: Train model")
98
  print("=" * 60)
99
 
100
- # The 8000-token vocab is already in datasets/AnimeName/vocab.json.
101
- # The old checkpoint (3000-token embedding) gets resized automatically.
102
- SAVE_DIR = os.path.join(DRIVE_ROOT, "checkpoints", "dmhy-finetune")
103
 
104
  run(
105
  f"python train.py "
106
- f"--data-file datasets/AnimeName/dmhy_weak.jsonl "
107
- f"--vocab-file datasets/AnimeName/vocab.json "
 
108
  f"--save-dir {SAVE_DIR} "
109
- f"--init-model-dir . "
110
- f"--epochs 10 --batch-size 128 "
111
  f"--learning-rate 0.0003 --warmup-steps 300 "
 
112
  f"--seed 42 "
113
  f"--no-shuffle"
114
  )
 
13
  - Mounts Google Drive (for persistent checkpoints)
14
  - Clones AniFileBERT repo + AnimeName dataset submodule
15
  - Installs PyTorch + Transformers dependencies
16
+ - Runs training: train a character-token model with the full DMHY vocab
17
  - Saves final model to Drive
18
 
19
  Output:
20
  - Checkpoints saved to: MyDrive/AniFileBERT/checkpoints/
21
+ - Final model at: MyDrive/AniFileBERT/checkpoints/dmhy-weak-char/final/
22
  """
23
 
24
  import os
 
90
  print("\n" + "=" * 60)
91
  print("STEP 5: Verify vocabulary")
92
  print("=" * 60)
93
+ run("python -c 'import json; v=json.load(open(\"datasets/AnimeName/vocab.char.json\", encoding=\"utf-8\")); print(f\"Character vocab size: {len(v)} tokens\")'")
94
 
95
  # ── 6. Run training ────────────────────────────────────────────
96
  print("\n" + "=" * 60)
97
  print("STEP 6: Train model")
98
  print("=" * 60)
99
 
100
+ # The full DMHY character vocab is only 6199 tokens and covers every character
101
+ # occurrence in dmhy_weak_char.jsonl.
102
+ SAVE_DIR = os.path.join(DRIVE_ROOT, "checkpoints", "dmhy-weak-char")
103
 
104
  run(
105
  f"python train.py "
106
+ f"--tokenizer char "
107
+ f"--data-file datasets/AnimeName/dmhy_weak_char.jsonl "
108
+ f"--vocab-file datasets/AnimeName/vocab.char.json "
109
  f"--save-dir {SAVE_DIR} "
110
+ f"--epochs 5 --batch-size 128 "
 
111
  f"--learning-rate 0.0003 --warmup-steps 300 "
112
+ f"--max-seq-length 128 "
113
  f"--seed 42 "
114
  f"--no-shuffle"
115
  )
convert_to_char_dataset.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert token-level anime filename JSONL datasets to character tokens.
2
+
3
+ Input records must contain parallel ``tokens`` and ``labels`` arrays. The
4
+ converter expands each original token into Unicode code points and projects BIO
5
+ labels onto the expanded sequence:
6
+
7
+ - ``B-X`` keeps ``B-X`` on the first character and uses ``I-X`` afterwards.
8
+ - ``I-X`` remains ``I-X`` on every character.
9
+ - ``O`` remains ``O`` on every character.
10
+
11
+ The script streams both input and output so it can process the full DMHY weak
12
+ dataset without loading hundreds of MB into memory.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ from collections import Counter
20
+ from datetime import datetime, timezone
21
+ from pathlib import Path
22
+ from statistics import mean
23
+ from typing import Iterable
24
+
25
+
26
+ SPECIAL_TOKENS = ("[PAD]", "[UNK]", "[CLS]", "[SEP]")
27
+
28
+
29
+ def projected_labels(token: str, label: str) -> tuple[list[str], list[str]]:
30
+ """Return character tokens and projected BIO labels for one source token."""
31
+ chars = list(token)
32
+ if not chars:
33
+ return [], []
34
+
35
+ if label.startswith("B-"):
36
+ entity = label.split("-", 1)[1]
37
+ return chars, [label] + [f"I-{entity}"] * (len(chars) - 1)
38
+ if label.startswith("I-"):
39
+ return chars, [label] * len(chars)
40
+ return chars, [label] * len(chars)
41
+
42
+
43
+ def convert_record(record: dict) -> dict:
44
+ """Convert one JSONL record while preserving non-token metadata."""
45
+ tokens = record["tokens"]
46
+ labels = record["labels"]
47
+ if len(tokens) != len(labels):
48
+ raise ValueError(
49
+ f"token/label length mismatch: {len(tokens)} tokens, {len(labels)} labels"
50
+ )
51
+
52
+ char_tokens: list[str] = []
53
+ char_labels: list[str] = []
54
+ for token, label in zip(tokens, labels):
55
+ pieces, piece_labels = projected_labels(str(token), str(label))
56
+ char_tokens.extend(pieces)
57
+ char_labels.extend(piece_labels)
58
+
59
+ converted = dict(record)
60
+ converted["tokens"] = char_tokens
61
+ converted["labels"] = char_labels
62
+ converted["tokenizer_variant"] = "char"
63
+ converted["source_token_count"] = len(tokens)
64
+ converted["char_token_count"] = len(char_tokens)
65
+ return converted
66
+
67
+
68
+ def iter_jsonl(path: Path) -> Iterable[dict]:
69
+ with path.open("r", encoding="utf-8") as handle:
70
+ for line_no, line in enumerate(handle, 1):
71
+ line = line.strip()
72
+ if not line:
73
+ continue
74
+ try:
75
+ yield json.loads(line)
76
+ except json.JSONDecodeError as exc:
77
+ raise ValueError(f"{path}:{line_no}: invalid JSON") from exc
78
+
79
+
80
+ def build_vocab(counter: Counter[str], max_size: int | None = None) -> dict[str, int]:
81
+ """Build a frequency-sorted vocab with fixed special-token IDs."""
82
+ vocab = {token: idx for idx, token in enumerate(SPECIAL_TOKENS)}
83
+ limit = None if max_size is None else max(max_size - len(vocab), 0)
84
+ for token, _count in counter.most_common(limit):
85
+ if token not in vocab:
86
+ vocab[token] = len(vocab)
87
+ return vocab
88
+
89
+
90
+ def coverage(counter: Counter[str], vocab: dict[str, int]) -> float:
91
+ total = sum(counter.values())
92
+ if total == 0:
93
+ return 1.0
94
+ covered = sum(count for token, count in counter.items() if token in vocab)
95
+ return covered / total
96
+
97
+
98
+ def percentile(values: list[int], pct: float) -> int:
99
+ if not values:
100
+ return 0
101
+ ordered = sorted(values)
102
+ index = min(len(ordered) - 1, round((pct / 100) * (len(ordered) - 1)))
103
+ return ordered[index]
104
+
105
+
106
+ def parse_args() -> argparse.Namespace:
107
+ parser = argparse.ArgumentParser(description="Convert JSONL token labels to character labels")
108
+ parser.add_argument("--input", required=True, help="Input token-level JSONL")
109
+ parser.add_argument("--output", required=True, help="Output character-level JSONL")
110
+ parser.add_argument("--vocab-output", required=True, help="Output vocab JSON")
111
+ parser.add_argument("--manifest-output", default=None, help="Output manifest JSON")
112
+ parser.add_argument("--max-vocab-size", type=int, default=None,
113
+ help="Optional vocab cap including special tokens")
114
+ parser.add_argument("--limit", type=int, default=None, help="Convert only the first N records")
115
+ parser.add_argument("--progress", type=int, default=50_000,
116
+ help="Print progress every N records")
117
+ return parser.parse_args()
118
+
119
+
120
+ def main() -> None:
121
+ args = parse_args()
122
+ input_path = Path(args.input)
123
+ output_path = Path(args.output)
124
+ vocab_path = Path(args.vocab_output)
125
+ manifest_path = (
126
+ Path(args.manifest_output)
127
+ if args.manifest_output
128
+ else output_path.with_suffix(".manifest.json")
129
+ )
130
+
131
+ output_path.parent.mkdir(parents=True, exist_ok=True)
132
+ vocab_path.parent.mkdir(parents=True, exist_ok=True)
133
+ manifest_path.parent.mkdir(parents=True, exist_ok=True)
134
+
135
+ char_counter: Counter[str] = Counter()
136
+ label_counter: Counter[str] = Counter()
137
+ row_count = 0
138
+ source_token_count = 0
139
+ char_token_count = 0
140
+ lengths: list[int] = []
141
+ examples: list[dict] = []
142
+
143
+ with output_path.open("w", encoding="utf-8", newline="\n") as out:
144
+ for record in iter_jsonl(input_path):
145
+ converted = convert_record(record)
146
+ out.write(json.dumps(converted, ensure_ascii=False, separators=(",", ":")) + "\n")
147
+
148
+ row_count += 1
149
+ source_token_count += converted["source_token_count"]
150
+ char_len = converted["char_token_count"]
151
+ char_token_count += char_len
152
+ lengths.append(char_len)
153
+ char_counter.update(converted["tokens"])
154
+ label_counter.update(converted["labels"])
155
+ if len(examples) < 5:
156
+ examples.append(converted)
157
+
158
+ if args.limit is not None and row_count >= args.limit:
159
+ break
160
+ if args.progress and row_count % args.progress == 0:
161
+ print(f"converted {row_count:,} rows; unique chars={len(char_counter):,}")
162
+
163
+ vocab = build_vocab(char_counter, args.max_vocab_size)
164
+ vocab_path.write_text(json.dumps(vocab, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
165
+
166
+ manifest = {
167
+ "created_at": datetime.now(timezone.utc).isoformat(),
168
+ "input": str(input_path),
169
+ "output": str(output_path),
170
+ "vocab_output": str(vocab_path),
171
+ "tokenizer_variant": "char",
172
+ "projection": {
173
+ "B-X": "first char keeps B-X; remaining chars become I-X",
174
+ "I-X": "all chars keep I-X",
175
+ "O": "all chars keep O",
176
+ },
177
+ "row_count": row_count,
178
+ "source_token_count": source_token_count,
179
+ "char_token_count": char_token_count,
180
+ "unique_char_count": len(char_counter),
181
+ "vocab_size": len(vocab),
182
+ "max_vocab_size": args.max_vocab_size,
183
+ "vocab_coverage": coverage(char_counter, vocab),
184
+ "label_counts": dict(label_counter),
185
+ "char_length": {
186
+ "min": min(lengths) if lengths else 0,
187
+ "mean": mean(lengths) if lengths else 0,
188
+ "p50": percentile(lengths, 50),
189
+ "p90": percentile(lengths, 90),
190
+ "p95": percentile(lengths, 95),
191
+ "p99": percentile(lengths, 99),
192
+ "max": max(lengths) if lengths else 0,
193
+ },
194
+ "examples": examples,
195
+ }
196
+ manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
197
+ print(json.dumps({k: v for k, v in manifest.items() if k != "examples"}, ensure_ascii=False, indent=2))
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
datasets/AnimeName CHANGED
@@ -1 +1 @@
1
- Subproject commit 17c478b079deae90935a0c5392ee6138ea18b02f
 
1
+ Subproject commit 867350a1712e50cc71f5a9e81dd331ca46a7b1dd
train.py CHANGED
@@ -82,6 +82,9 @@ def parse_args() -> argparse.Namespace:
82
  help="Use only the first N samples for quick A/B smoke runs")
83
  parser.add_argument("--rebuild-vocab", action="store_true",
84
  help="Rebuild vocab from the selected data file before training")
 
 
 
85
  parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split")
86
  return parser.parse_args()
87
 
@@ -146,8 +149,9 @@ def main():
146
  vocab_path = resolve_vocab_path(config.data_file, args.tokenizer, args.vocab_file)
147
  tokenizer = create_tokenizer(args.tokenizer)
148
  if args.rebuild_vocab or not os.path.isfile(vocab_path):
149
- print(f" Building {args.tokenizer} vocab: {vocab_path} (max_size={config.vocab_size})")
150
- build_vocab_from_data(all_data, tokenizer, vocab_path, max_size=config.vocab_size)
 
151
  tokenizer = create_tokenizer(args.tokenizer, vocab_file=vocab_path)
152
  print(f" Variant: {args.tokenizer}")
153
  print(f" Vocab size: {tokenizer.vocab_size}")
@@ -171,8 +175,7 @@ def main():
171
  total_params = print_model_summary(model)
172
 
173
  if total_params >= 5_000_000:
174
- print("WARNING: Model exceeds 5M parameter limit. Consider reducing hidden_size or layers.")
175
- sys.exit(1)
176
 
177
  split_idx = int(len(all_data) * config.train_split)
178
  train_data = all_data[:split_idx]
@@ -206,6 +209,10 @@ def main():
206
  print(f" Train samples: {len(train_dataset)}")
207
  print(f" Eval samples: {len(eval_dataset)}")
208
 
 
 
 
 
209
  # Training arguments
210
  training_args = TrainingArguments(
211
  output_dir=config.save_dir,
@@ -218,14 +225,14 @@ def main():
218
  learning_rate=config.learning_rate,
219
  weight_decay=config.weight_decay,
220
  warmup_steps=config.warmup_steps,
221
- use_cpu=False,
222
  report_to="none",
223
  save_total_limit=2,
224
  load_best_model_at_end=True,
225
  metric_for_best_model="f1",
226
  greater_is_better=True,
227
  dataloader_num_workers=config.num_workers,
228
- fp16=True
229
  )
230
 
231
  # Data collator
 
82
  help="Use only the first N samples for quick A/B smoke runs")
83
  parser.add_argument("--rebuild-vocab", action="store_true",
84
  help="Rebuild vocab from the selected data file before training")
85
+ parser.add_argument("--max-vocab-size", type=int, default=None,
86
+ help="Optional vocab cap used with --rebuild-vocab")
87
+ parser.add_argument("--cpu", action="store_true", help="Force CPU training")
88
  parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split")
89
  return parser.parse_args()
90
 
 
149
  vocab_path = resolve_vocab_path(config.data_file, args.tokenizer, args.vocab_file)
150
  tokenizer = create_tokenizer(args.tokenizer)
151
  if args.rebuild_vocab or not os.path.isfile(vocab_path):
152
+ max_vocab_size = args.max_vocab_size if args.max_vocab_size is not None else config.vocab_size
153
+ print(f" Building {args.tokenizer} vocab: {vocab_path} (max_size={max_vocab_size})")
154
+ build_vocab_from_data(all_data, tokenizer, vocab_path, max_size=max_vocab_size)
155
  tokenizer = create_tokenizer(args.tokenizer, vocab_file=vocab_path)
156
  print(f" Variant: {args.tokenizer}")
157
  print(f" Vocab size: {tokenizer.vocab_size}")
 
175
  total_params = print_model_summary(model)
176
 
177
  if total_params >= 5_000_000:
178
+ print("WARNING: Model exceeds the historical 5M target; continuing because vocab size is configurable.")
 
179
 
180
  split_idx = int(len(all_data) * config.train_split)
181
  train_data = all_data[:split_idx]
 
209
  print(f" Train samples: {len(train_dataset)}")
210
  print(f" Eval samples: {len(eval_dataset)}")
211
 
212
+ use_cpu = args.cpu or not torch.cuda.is_available()
213
+ use_fp16 = not use_cpu
214
+ print(f" Device: {'CPU' if use_cpu else 'CUDA'}")
215
+
216
  # Training arguments
217
  training_args = TrainingArguments(
218
  output_dir=config.save_dir,
 
225
  learning_rate=config.learning_rate,
226
  weight_decay=config.weight_decay,
227
  warmup_steps=config.warmup_steps,
228
+ use_cpu=use_cpu,
229
  report_to="none",
230
  save_total_limit=2,
231
  load_best_model_at_end=True,
232
  metric_for_best_model="f1",
233
  greater_is_better=True,
234
  dataloader_num_workers=config.num_workers,
235
+ fp16=use_fp16,
236
  )
237
 
238
  # Data collator