Token Classification
Transformers
ONNX
Safetensors
English
Japanese
Chinese
bert
anime
filename-parsing
Eval Results (legacy)
Instructions to use ModerRAS/AniFileBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ModerRAS/AniFileBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="ModerRAS/AniFileBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("ModerRAS/AniFileBERT") model = AutoModelForTokenClassification.from_pretrained("ModerRAS/AniFileBERT") - Notebooks
- Google Colab
- Kaggle
Add character-token DMHY training path
Browse files- README.md +29 -0
- colab_train.py +11 -10
- convert_to_char_dataset.py +201 -0
- datasets/AnimeName +1 -1
- train.py +13 -6
README.md
CHANGED
|
@@ -60,6 +60,11 @@ Common fansub group names (`Snow`, `LoliHouse`, `DMG`, `KTXP`, `Sakurato`, etc.)
|
|
| 60 |
and individual bracket characters (`[`, `]`, `(`, `)`) are included in the new
|
| 61 |
vocabulary.
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
## Evaluation
|
| 64 |
|
| 65 |
Balanced mixed-data A/B run (`50K` synthetic + `50K` DMHY weak labels, 1 epoch, batch size 128, seed 42):
|
|
@@ -139,6 +144,29 @@ The model loads the old 3000-token checkpoint, `resize_token_embeddings()` adds
|
|
| 139 |
trains the full model. About 96% of token occurrences are now covered (vs 90%
|
| 140 |
with the old 3000-token vocabulary).
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
### Regenerate datasets from source
|
| 143 |
|
| 144 |
```bash
|
|
@@ -178,6 +206,7 @@ the full training pipeline. Checkpoints are saved to your Drive automatically.
|
|
| 178 |
- `model.safetensors`, `config.json`, `vocab.json`: default fine-tuned model
|
| 179 |
- `train.py`, `dataset.py`, `tokenizer.py`, `model.py`: training pipeline
|
| 180 |
- `dmhy_dataset.py`, `mix_datasets.py`: weak-label export and dataset mixing
|
|
|
|
| 181 |
- `inference.py`: end-to-end filename parser CLI
|
| 182 |
- `export_onnx.py`: ONNX export for Android integration
|
| 183 |
- `exports/`: exported ONNX model and metadata
|
|
|
|
| 60 |
and individual bracket characters (`[`, `]`, `(`, `)`) are included in the new
|
| 61 |
vocabulary.
|
| 62 |
|
| 63 |
+
For character-token training, `datasets/AnimeName/vocab.char.json` is built
|
| 64 |
+
from the full `dmhy_weak_char.jsonl` export. The full DMHY weak dataset has
|
| 65 |
+
**6195 unique characters**, so the complete character vocab is only **6199**
|
| 66 |
+
entries including special tokens and reaches 100% token coverage.
|
| 67 |
+
|
| 68 |
## Evaluation
|
| 69 |
|
| 70 |
Balanced mixed-data A/B run (`50K` synthetic + `50K` DMHY weak labels, 1 epoch, batch size 128, seed 42):
|
|
|
|
| 144 |
trains the full model. About 96% of token occurrences are now covered (vs 90%
|
| 145 |
with the old 3000-token vocabulary).
|
| 146 |
|
| 147 |
+
### Character-token DMHY training
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
python convert_to_char_dataset.py \
|
| 151 |
+
--input datasets/AnimeName/dmhy_weak.jsonl \
|
| 152 |
+
--output datasets/AnimeName/dmhy_weak_char.jsonl \
|
| 153 |
+
--vocab-output datasets/AnimeName/vocab.char.json \
|
| 154 |
+
--manifest-output datasets/AnimeName/dmhy_weak_char.manifest.json
|
| 155 |
+
|
| 156 |
+
python train.py --tokenizer char \
|
| 157 |
+
--data-file datasets/AnimeName/dmhy_weak_char.jsonl \
|
| 158 |
+
--vocab-file datasets/AnimeName/vocab.char.json \
|
| 159 |
+
--save-dir checkpoints_char/dmhy-weak-char \
|
| 160 |
+
--epochs 1 --batch-size 64 \
|
| 161 |
+
--learning-rate 0.0003 --warmup-steps 300 \
|
| 162 |
+
--max-seq-length 128 --seed 42
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
The converter keeps source metadata and adds `tokenizer_variant`, source token
|
| 166 |
+
count, and character token count fields to each record. The char dataset's
|
| 167 |
+
p99 length is 107 characters, so `--max-seq-length 128` covers almost all rows
|
| 168 |
+
while leaving room for `[CLS]` and `[SEP]`.
|
| 169 |
+
|
| 170 |
### Regenerate datasets from source
|
| 171 |
|
| 172 |
```bash
|
|
|
|
| 206 |
- `model.safetensors`, `config.json`, `vocab.json`: default fine-tuned model
|
| 207 |
- `train.py`, `dataset.py`, `tokenizer.py`, `model.py`: training pipeline
|
| 208 |
- `dmhy_dataset.py`, `mix_datasets.py`: weak-label export and dataset mixing
|
| 209 |
+
- `convert_to_char_dataset.py`: full character-token projection for weak labels
|
| 210 |
- `inference.py`: end-to-end filename parser CLI
|
| 211 |
- `export_onnx.py`: ONNX export for Android integration
|
| 212 |
- `exports/`: exported ONNX model and metadata
|
colab_train.py
CHANGED
|
@@ -13,12 +13,12 @@ What it does:
|
|
| 13 |
- Mounts Google Drive (for persistent checkpoints)
|
| 14 |
- Clones AniFileBERT repo + AnimeName dataset submodule
|
| 15 |
- Installs PyTorch + Transformers dependencies
|
| 16 |
-
- Runs training:
|
| 17 |
- Saves final model to Drive
|
| 18 |
|
| 19 |
Output:
|
| 20 |
- Checkpoints saved to: MyDrive/AniFileBERT/checkpoints/
|
| 21 |
-
- Final model at: MyDrive/AniFileBERT/checkpoints/dmhy-
|
| 22 |
"""
|
| 23 |
|
| 24 |
import os
|
|
@@ -90,25 +90,26 @@ run("python -c 'import torch; print(f\"PyTorch {torch.__version__}, CUDA availab
|
|
| 90 |
print("\n" + "=" * 60)
|
| 91 |
print("STEP 5: Verify vocabulary")
|
| 92 |
print("=" * 60)
|
| 93 |
-
run("python -c 'import json; v=json.load(open(\"vocab.json\")); print(f\"
|
| 94 |
|
| 95 |
# ── 6. Run training ────────────────────────────────────────────
|
| 96 |
print("\n" + "=" * 60)
|
| 97 |
print("STEP 6: Train model")
|
| 98 |
print("=" * 60)
|
| 99 |
|
| 100 |
-
# The
|
| 101 |
-
#
|
| 102 |
-
SAVE_DIR = os.path.join(DRIVE_ROOT, "checkpoints", "dmhy-
|
| 103 |
|
| 104 |
run(
|
| 105 |
f"python train.py "
|
| 106 |
-
f"--
|
| 107 |
-
f"--
|
|
|
|
| 108 |
f"--save-dir {SAVE_DIR} "
|
| 109 |
-
f"--
|
| 110 |
-
f"--epochs 10 --batch-size 128 "
|
| 111 |
f"--learning-rate 0.0003 --warmup-steps 300 "
|
|
|
|
| 112 |
f"--seed 42 "
|
| 113 |
f"--no-shuffle"
|
| 114 |
)
|
|
|
|
| 13 |
- Mounts Google Drive (for persistent checkpoints)
|
| 14 |
- Clones AniFileBERT repo + AnimeName dataset submodule
|
| 15 |
- Installs PyTorch + Transformers dependencies
|
| 16 |
+
- Runs training: train a character-token model with the full DMHY vocab
|
| 17 |
- Saves final model to Drive
|
| 18 |
|
| 19 |
Output:
|
| 20 |
- Checkpoints saved to: MyDrive/AniFileBERT/checkpoints/
|
| 21 |
+
- Final model at: MyDrive/AniFileBERT/checkpoints/dmhy-weak-char/final/
|
| 22 |
"""
|
| 23 |
|
| 24 |
import os
|
|
|
|
| 90 |
print("\n" + "=" * 60)
|
| 91 |
print("STEP 5: Verify vocabulary")
|
| 92 |
print("=" * 60)
|
| 93 |
+
run("python -c 'import json; v=json.load(open(\"datasets/AnimeName/vocab.char.json\", encoding=\"utf-8\")); print(f\"Character vocab size: {len(v)} tokens\")'")
|
| 94 |
|
| 95 |
# ── 6. Run training ────────────────────────────────────────────
|
| 96 |
print("\n" + "=" * 60)
|
| 97 |
print("STEP 6: Train model")
|
| 98 |
print("=" * 60)
|
| 99 |
|
| 100 |
+
# The full DMHY character vocab is only 6199 tokens and covers every character
|
| 101 |
+
# occurrence in dmhy_weak_char.jsonl.
|
| 102 |
+
SAVE_DIR = os.path.join(DRIVE_ROOT, "checkpoints", "dmhy-weak-char")
|
| 103 |
|
| 104 |
run(
|
| 105 |
f"python train.py "
|
| 106 |
+
f"--tokenizer char "
|
| 107 |
+
f"--data-file datasets/AnimeName/dmhy_weak_char.jsonl "
|
| 108 |
+
f"--vocab-file datasets/AnimeName/vocab.char.json "
|
| 109 |
f"--save-dir {SAVE_DIR} "
|
| 110 |
+
f"--epochs 5 --batch-size 128 "
|
|
|
|
| 111 |
f"--learning-rate 0.0003 --warmup-steps 300 "
|
| 112 |
+
f"--max-seq-length 128 "
|
| 113 |
f"--seed 42 "
|
| 114 |
f"--no-shuffle"
|
| 115 |
)
|
convert_to_char_dataset.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Convert token-level anime filename JSONL datasets to character tokens.
|
| 2 |
+
|
| 3 |
+
Input records must contain parallel ``tokens`` and ``labels`` arrays. The
|
| 4 |
+
converter expands each original token into Unicode code points and projects BIO
|
| 5 |
+
labels onto the expanded sequence:
|
| 6 |
+
|
| 7 |
+
- ``B-X`` keeps ``B-X`` on the first character and uses ``I-X`` afterwards.
|
| 8 |
+
- ``I-X`` remains ``I-X`` on every character.
|
| 9 |
+
- ``O`` remains ``O`` on every character.
|
| 10 |
+
|
| 11 |
+
The script streams both input and output so it can process the full DMHY weak
|
| 12 |
+
dataset without loading hundreds of MB into memory.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
from collections import Counter
|
| 20 |
+
from datetime import datetime, timezone
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from statistics import mean
|
| 23 |
+
from typing import Iterable
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
SPECIAL_TOKENS = ("[PAD]", "[UNK]", "[CLS]", "[SEP]")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def projected_labels(token: str, label: str) -> tuple[list[str], list[str]]:
|
| 30 |
+
"""Return character tokens and projected BIO labels for one source token."""
|
| 31 |
+
chars = list(token)
|
| 32 |
+
if not chars:
|
| 33 |
+
return [], []
|
| 34 |
+
|
| 35 |
+
if label.startswith("B-"):
|
| 36 |
+
entity = label.split("-", 1)[1]
|
| 37 |
+
return chars, [label] + [f"I-{entity}"] * (len(chars) - 1)
|
| 38 |
+
if label.startswith("I-"):
|
| 39 |
+
return chars, [label] * len(chars)
|
| 40 |
+
return chars, [label] * len(chars)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def convert_record(record: dict) -> dict:
|
| 44 |
+
"""Convert one JSONL record while preserving non-token metadata."""
|
| 45 |
+
tokens = record["tokens"]
|
| 46 |
+
labels = record["labels"]
|
| 47 |
+
if len(tokens) != len(labels):
|
| 48 |
+
raise ValueError(
|
| 49 |
+
f"token/label length mismatch: {len(tokens)} tokens, {len(labels)} labels"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
char_tokens: list[str] = []
|
| 53 |
+
char_labels: list[str] = []
|
| 54 |
+
for token, label in zip(tokens, labels):
|
| 55 |
+
pieces, piece_labels = projected_labels(str(token), str(label))
|
| 56 |
+
char_tokens.extend(pieces)
|
| 57 |
+
char_labels.extend(piece_labels)
|
| 58 |
+
|
| 59 |
+
converted = dict(record)
|
| 60 |
+
converted["tokens"] = char_tokens
|
| 61 |
+
converted["labels"] = char_labels
|
| 62 |
+
converted["tokenizer_variant"] = "char"
|
| 63 |
+
converted["source_token_count"] = len(tokens)
|
| 64 |
+
converted["char_token_count"] = len(char_tokens)
|
| 65 |
+
return converted
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def iter_jsonl(path: Path) -> Iterable[dict]:
|
| 69 |
+
with path.open("r", encoding="utf-8") as handle:
|
| 70 |
+
for line_no, line in enumerate(handle, 1):
|
| 71 |
+
line = line.strip()
|
| 72 |
+
if not line:
|
| 73 |
+
continue
|
| 74 |
+
try:
|
| 75 |
+
yield json.loads(line)
|
| 76 |
+
except json.JSONDecodeError as exc:
|
| 77 |
+
raise ValueError(f"{path}:{line_no}: invalid JSON") from exc
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_vocab(counter: Counter[str], max_size: int | None = None) -> dict[str, int]:
|
| 81 |
+
"""Build a frequency-sorted vocab with fixed special-token IDs."""
|
| 82 |
+
vocab = {token: idx for idx, token in enumerate(SPECIAL_TOKENS)}
|
| 83 |
+
limit = None if max_size is None else max(max_size - len(vocab), 0)
|
| 84 |
+
for token, _count in counter.most_common(limit):
|
| 85 |
+
if token not in vocab:
|
| 86 |
+
vocab[token] = len(vocab)
|
| 87 |
+
return vocab
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def coverage(counter: Counter[str], vocab: dict[str, int]) -> float:
|
| 91 |
+
total = sum(counter.values())
|
| 92 |
+
if total == 0:
|
| 93 |
+
return 1.0
|
| 94 |
+
covered = sum(count for token, count in counter.items() if token in vocab)
|
| 95 |
+
return covered / total
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def percentile(values: list[int], pct: float) -> int:
|
| 99 |
+
if not values:
|
| 100 |
+
return 0
|
| 101 |
+
ordered = sorted(values)
|
| 102 |
+
index = min(len(ordered) - 1, round((pct / 100) * (len(ordered) - 1)))
|
| 103 |
+
return ordered[index]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def parse_args() -> argparse.Namespace:
|
| 107 |
+
parser = argparse.ArgumentParser(description="Convert JSONL token labels to character labels")
|
| 108 |
+
parser.add_argument("--input", required=True, help="Input token-level JSONL")
|
| 109 |
+
parser.add_argument("--output", required=True, help="Output character-level JSONL")
|
| 110 |
+
parser.add_argument("--vocab-output", required=True, help="Output vocab JSON")
|
| 111 |
+
parser.add_argument("--manifest-output", default=None, help="Output manifest JSON")
|
| 112 |
+
parser.add_argument("--max-vocab-size", type=int, default=None,
|
| 113 |
+
help="Optional vocab cap including special tokens")
|
| 114 |
+
parser.add_argument("--limit", type=int, default=None, help="Convert only the first N records")
|
| 115 |
+
parser.add_argument("--progress", type=int, default=50_000,
|
| 116 |
+
help="Print progress every N records")
|
| 117 |
+
return parser.parse_args()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def main() -> None:
|
| 121 |
+
args = parse_args()
|
| 122 |
+
input_path = Path(args.input)
|
| 123 |
+
output_path = Path(args.output)
|
| 124 |
+
vocab_path = Path(args.vocab_output)
|
| 125 |
+
manifest_path = (
|
| 126 |
+
Path(args.manifest_output)
|
| 127 |
+
if args.manifest_output
|
| 128 |
+
else output_path.with_suffix(".manifest.json")
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 132 |
+
vocab_path.parent.mkdir(parents=True, exist_ok=True)
|
| 133 |
+
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
| 134 |
+
|
| 135 |
+
char_counter: Counter[str] = Counter()
|
| 136 |
+
label_counter: Counter[str] = Counter()
|
| 137 |
+
row_count = 0
|
| 138 |
+
source_token_count = 0
|
| 139 |
+
char_token_count = 0
|
| 140 |
+
lengths: list[int] = []
|
| 141 |
+
examples: list[dict] = []
|
| 142 |
+
|
| 143 |
+
with output_path.open("w", encoding="utf-8", newline="\n") as out:
|
| 144 |
+
for record in iter_jsonl(input_path):
|
| 145 |
+
converted = convert_record(record)
|
| 146 |
+
out.write(json.dumps(converted, ensure_ascii=False, separators=(",", ":")) + "\n")
|
| 147 |
+
|
| 148 |
+
row_count += 1
|
| 149 |
+
source_token_count += converted["source_token_count"]
|
| 150 |
+
char_len = converted["char_token_count"]
|
| 151 |
+
char_token_count += char_len
|
| 152 |
+
lengths.append(char_len)
|
| 153 |
+
char_counter.update(converted["tokens"])
|
| 154 |
+
label_counter.update(converted["labels"])
|
| 155 |
+
if len(examples) < 5:
|
| 156 |
+
examples.append(converted)
|
| 157 |
+
|
| 158 |
+
if args.limit is not None and row_count >= args.limit:
|
| 159 |
+
break
|
| 160 |
+
if args.progress and row_count % args.progress == 0:
|
| 161 |
+
print(f"converted {row_count:,} rows; unique chars={len(char_counter):,}")
|
| 162 |
+
|
| 163 |
+
vocab = build_vocab(char_counter, args.max_vocab_size)
|
| 164 |
+
vocab_path.write_text(json.dumps(vocab, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
| 165 |
+
|
| 166 |
+
manifest = {
|
| 167 |
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
| 168 |
+
"input": str(input_path),
|
| 169 |
+
"output": str(output_path),
|
| 170 |
+
"vocab_output": str(vocab_path),
|
| 171 |
+
"tokenizer_variant": "char",
|
| 172 |
+
"projection": {
|
| 173 |
+
"B-X": "first char keeps B-X; remaining chars become I-X",
|
| 174 |
+
"I-X": "all chars keep I-X",
|
| 175 |
+
"O": "all chars keep O",
|
| 176 |
+
},
|
| 177 |
+
"row_count": row_count,
|
| 178 |
+
"source_token_count": source_token_count,
|
| 179 |
+
"char_token_count": char_token_count,
|
| 180 |
+
"unique_char_count": len(char_counter),
|
| 181 |
+
"vocab_size": len(vocab),
|
| 182 |
+
"max_vocab_size": args.max_vocab_size,
|
| 183 |
+
"vocab_coverage": coverage(char_counter, vocab),
|
| 184 |
+
"label_counts": dict(label_counter),
|
| 185 |
+
"char_length": {
|
| 186 |
+
"min": min(lengths) if lengths else 0,
|
| 187 |
+
"mean": mean(lengths) if lengths else 0,
|
| 188 |
+
"p50": percentile(lengths, 50),
|
| 189 |
+
"p90": percentile(lengths, 90),
|
| 190 |
+
"p95": percentile(lengths, 95),
|
| 191 |
+
"p99": percentile(lengths, 99),
|
| 192 |
+
"max": max(lengths) if lengths else 0,
|
| 193 |
+
},
|
| 194 |
+
"examples": examples,
|
| 195 |
+
}
|
| 196 |
+
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
| 197 |
+
print(json.dumps({k: v for k, v in manifest.items() if k != "examples"}, ensure_ascii=False, indent=2))
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if __name__ == "__main__":
|
| 201 |
+
main()
|
datasets/AnimeName
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit 867350a1712e50cc71f5a9e81dd331ca46a7b1dd
|
train.py
CHANGED
|
@@ -82,6 +82,9 @@ def parse_args() -> argparse.Namespace:
|
|
| 82 |
help="Use only the first N samples for quick A/B smoke runs")
|
| 83 |
parser.add_argument("--rebuild-vocab", action="store_true",
|
| 84 |
help="Rebuild vocab from the selected data file before training")
|
|
|
|
|
|
|
|
|
|
| 85 |
parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split")
|
| 86 |
return parser.parse_args()
|
| 87 |
|
|
@@ -146,8 +149,9 @@ def main():
|
|
| 146 |
vocab_path = resolve_vocab_path(config.data_file, args.tokenizer, args.vocab_file)
|
| 147 |
tokenizer = create_tokenizer(args.tokenizer)
|
| 148 |
if args.rebuild_vocab or not os.path.isfile(vocab_path):
|
| 149 |
-
|
| 150 |
-
|
|
|
|
| 151 |
tokenizer = create_tokenizer(args.tokenizer, vocab_file=vocab_path)
|
| 152 |
print(f" Variant: {args.tokenizer}")
|
| 153 |
print(f" Vocab size: {tokenizer.vocab_size}")
|
|
@@ -171,8 +175,7 @@ def main():
|
|
| 171 |
total_params = print_model_summary(model)
|
| 172 |
|
| 173 |
if total_params >= 5_000_000:
|
| 174 |
-
print("WARNING: Model exceeds 5M
|
| 175 |
-
sys.exit(1)
|
| 176 |
|
| 177 |
split_idx = int(len(all_data) * config.train_split)
|
| 178 |
train_data = all_data[:split_idx]
|
|
@@ -206,6 +209,10 @@ def main():
|
|
| 206 |
print(f" Train samples: {len(train_dataset)}")
|
| 207 |
print(f" Eval samples: {len(eval_dataset)}")
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# Training arguments
|
| 210 |
training_args = TrainingArguments(
|
| 211 |
output_dir=config.save_dir,
|
|
@@ -218,14 +225,14 @@ def main():
|
|
| 218 |
learning_rate=config.learning_rate,
|
| 219 |
weight_decay=config.weight_decay,
|
| 220 |
warmup_steps=config.warmup_steps,
|
| 221 |
-
use_cpu=
|
| 222 |
report_to="none",
|
| 223 |
save_total_limit=2,
|
| 224 |
load_best_model_at_end=True,
|
| 225 |
metric_for_best_model="f1",
|
| 226 |
greater_is_better=True,
|
| 227 |
dataloader_num_workers=config.num_workers,
|
| 228 |
-
fp16=
|
| 229 |
)
|
| 230 |
|
| 231 |
# Data collator
|
|
|
|
| 82 |
help="Use only the first N samples for quick A/B smoke runs")
|
| 83 |
parser.add_argument("--rebuild-vocab", action="store_true",
|
| 84 |
help="Rebuild vocab from the selected data file before training")
|
| 85 |
+
parser.add_argument("--max-vocab-size", type=int, default=None,
|
| 86 |
+
help="Optional vocab cap used with --rebuild-vocab")
|
| 87 |
+
parser.add_argument("--cpu", action="store_true", help="Force CPU training")
|
| 88 |
parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split")
|
| 89 |
return parser.parse_args()
|
| 90 |
|
|
|
|
| 149 |
vocab_path = resolve_vocab_path(config.data_file, args.tokenizer, args.vocab_file)
|
| 150 |
tokenizer = create_tokenizer(args.tokenizer)
|
| 151 |
if args.rebuild_vocab or not os.path.isfile(vocab_path):
|
| 152 |
+
max_vocab_size = args.max_vocab_size if args.max_vocab_size is not None else config.vocab_size
|
| 153 |
+
print(f" Building {args.tokenizer} vocab: {vocab_path} (max_size={max_vocab_size})")
|
| 154 |
+
build_vocab_from_data(all_data, tokenizer, vocab_path, max_size=max_vocab_size)
|
| 155 |
tokenizer = create_tokenizer(args.tokenizer, vocab_file=vocab_path)
|
| 156 |
print(f" Variant: {args.tokenizer}")
|
| 157 |
print(f" Vocab size: {tokenizer.vocab_size}")
|
|
|
|
| 175 |
total_params = print_model_summary(model)
|
| 176 |
|
| 177 |
if total_params >= 5_000_000:
|
| 178 |
+
print("WARNING: Model exceeds the historical 5M target; continuing because vocab size is configurable.")
|
|
|
|
| 179 |
|
| 180 |
split_idx = int(len(all_data) * config.train_split)
|
| 181 |
train_data = all_data[:split_idx]
|
|
|
|
| 209 |
print(f" Train samples: {len(train_dataset)}")
|
| 210 |
print(f" Eval samples: {len(eval_dataset)}")
|
| 211 |
|
| 212 |
+
use_cpu = args.cpu or not torch.cuda.is_available()
|
| 213 |
+
use_fp16 = not use_cpu
|
| 214 |
+
print(f" Device: {'CPU' if use_cpu else 'CUDA'}")
|
| 215 |
+
|
| 216 |
# Training arguments
|
| 217 |
training_args = TrainingArguments(
|
| 218 |
output_dir=config.save_dir,
|
|
|
|
| 225 |
learning_rate=config.learning_rate,
|
| 226 |
weight_decay=config.weight_decay,
|
| 227 |
warmup_steps=config.warmup_steps,
|
| 228 |
+
use_cpu=use_cpu,
|
| 229 |
report_to="none",
|
| 230 |
save_total_limit=2,
|
| 231 |
load_best_model_at_end=True,
|
| 232 |
metric_for_best_model="f1",
|
| 233 |
greater_is_better=True,
|
| 234 |
dataloader_num_workers=config.num_workers,
|
| 235 |
+
fp16=use_fp16,
|
| 236 |
)
|
| 237 |
|
| 238 |
# Data collator
|