Token Classification
Transformers
ONNX
Safetensors
English
Japanese
Chinese
bert
anime
filename-parsing
Eval Results (legacy)
Instructions to use ModerRAS/AniFileBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ModerRAS/AniFileBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="ModerRAS/AniFileBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("ModerRAS/AniFileBERT") model = AutoModelForTokenClassification.from_pretrained("ModerRAS/AniFileBERT") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Train AniFileBERT for structured anime filename parsing. | |
| The training loop keeps the existing PyTorch/Transformers stack, writes | |
| Hugging Face checkpoints, records token/entity metrics, and also evaluates | |
| end-to-end parser exact-match on held-out filenames and fixed real-world cases. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import argparse | |
| import random | |
| import subprocess | |
| import threading | |
| import time | |
| import gc | |
| from collections import Counter | |
| from ctypes import POINTER, Structure, byref, c_int, c_uint, c_ulonglong, c_void_p, cdll | |
| from ctypes import util as ctypes_util | |
| from typing import Dict, List, Optional, Sequence | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import SequentialSampler | |
| from transformers import ( | |
| Trainer, | |
| TrainingArguments, | |
| BertForTokenClassification, | |
| TrainerCallback, | |
| ) | |
| from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score | |
| from .config import Config | |
| from .tokenizer import AnimeTokenizer, create_tokenizer, load_tokenizer | |
| from .model import create_model, print_model_summary, count_parameters | |
| from .dataset import AnimeItemsDataset, EncodedAnimeDataset, labels_for_tokenizer | |
| from .inference import parse_filename, postprocess | |
| from .virtual_dataset import DatasetRangeView, ShardedEncodedDataset | |
| def compute_metrics(p): | |
| """Compute token-level and entity-level metrics using seqeval.""" | |
| predictions, labels = p | |
| predictions = np.argmax(predictions, axis=2) | |
| # Remove ignored index (special tokens) | |
| true_predictions = [] | |
| true_labels = [] | |
| id2label = Config().id2label | |
| for pred_seq, label_seq in zip(predictions, labels): | |
| preds = [] | |
| lbls = [] | |
| for p, l in zip(pred_seq, label_seq): | |
| if l != -100: | |
| preds.append(id2label[p]) | |
| lbls.append(id2label[l]) | |
| true_predictions.append(preds) | |
| true_labels.append(lbls) | |
| # Entity-level metrics (via seqeval) | |
| return { | |
| "precision": precision_score(true_labels, true_predictions), | |
| "recall": recall_score(true_labels, true_predictions), | |
| "f1": f1_score(true_labels, true_predictions), | |
| "accuracy": accuracy_score(true_labels, true_predictions), | |
| } | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Train anime filename parser") | |
| parser.add_argument("--tokenizer", choices=["regex", "char"], default=None, | |
| help="Tokenizer variant for A/B testing. Defaults to dataset metadata") | |
| parser.add_argument("--data-file", default=None, help="Primary training JSONL file") | |
| parser.add_argument("--extra-data-file", action="append", default=[], | |
| help="Additional training JSONL file. Can be passed multiple times.") | |
| parser.add_argument("--extra-data-repeat", type=int, default=1, | |
| help="Repeat each extra dataset this many times after loading") | |
| parser.add_argument("--virtual-dataset-dir", default=None, | |
| help="Pre-encoded shard directory generated by tools/virtual_dataset_generator") | |
| parser.add_argument("--vocab-file", default=None, | |
| help="Tokenizer vocab JSON. Defaults to data/vocab.json or data/vocab.char.json") | |
| parser.add_argument("--save-dir", default=None, help="Checkpoint output directory") | |
| parser.add_argument("--init-model-dir", default=None, help="Optional checkpoint to fine-tune from") | |
| parser.add_argument("--epochs", type=float, default=None, help="Number of training epochs") | |
| parser.add_argument("--max-steps", type=int, default=-1, | |
| help="Override epoch-based training and stop after this many optimizer steps") | |
| parser.add_argument("--batch-size", type=int, default=None, help="Per-device train/eval batch size") | |
| parser.add_argument("--learning-rate", type=float, default=None, help="Learning rate") | |
| parser.add_argument("--warmup-steps", type=int, default=None, help="Warmup steps") | |
| parser.add_argument("--train-split", type=float, default=None, help="Train split ratio") | |
| parser.add_argument("--max-seq-length", type=int, default=None, help="Maximum sequence length") | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") | |
| parser.add_argument("--limit-samples", type=int, default=None, | |
| help="Use only the first N samples for quick A/B smoke runs") | |
| parser.add_argument("--augment-partial-samples", type=int, default=0, | |
| help="Generate this many partial BIO-span samples in memory before training") | |
| parser.add_argument("--augment-permutation-samples", type=int, default=0, | |
| help="Generate this many random BIO-span permutation samples in memory before training") | |
| parser.add_argument("--augment-special-samples", type=int, default=0, | |
| help="Generate this many special-only/title+special samples such as Menu01 in memory") | |
| parser.add_argument("--augment-max-chars", type=int, default=160, | |
| help="Maximum character length for generated augmentation samples") | |
| parser.add_argument("--rebuild-vocab", action="store_true", | |
| help="Rebuild vocab from the selected data file before training") | |
| parser.add_argument("--max-vocab-size", type=int, default=None, | |
| help="Optional vocab cap used with --rebuild-vocab") | |
| parser.add_argument("--checkpoint-steps", type=int, default=None, | |
| help="Save resumable checkpoints every N steps instead of only at epoch end") | |
| parser.add_argument("--save-total-limit", type=int, default=2, | |
| help="Maximum number of checkpoints to keep") | |
| parser.add_argument("--no-periodic-eval", action="store_true", | |
| help="Skip Trainer's scheduled train-time eval/load-best-model; final evaluation still runs") | |
| parser.add_argument("--keep-raw-dataset", action="store_true", | |
| help="Keep raw JSONL dictionaries in memory after encoded datasets are built") | |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=1, | |
| help="Accumulate gradients across this many steps") | |
| parser.add_argument("--num-workers", type=int, default=None, | |
| help="DataLoader worker count. Defaults to config.num_workers") | |
| parser.add_argument("--prefetch-factor", type=int, default=None, | |
| help="DataLoader prefetch factor when workers are enabled") | |
| parser.add_argument("--persistent-workers", action="store_true", | |
| help="Keep DataLoader workers alive between epochs") | |
| parser.add_argument("--lazy-dataset", action="store_true", | |
| help="Tokenize samples lazily in DataLoader workers instead of pre-encoding tensors") | |
| parser.add_argument("--apply-label-repairs", action="store_true", | |
| help="Apply runtime deterministic label repairs while building training tensors") | |
| parser.add_argument("--encoded-dataset-device", choices=["cpu", "cuda"], default="cpu", | |
| help="Store pre-encoded dataset tensors on this device; cuda requires --num-workers 0") | |
| parser.add_argument("--bf16", action="store_true", | |
| help="Use bfloat16 mixed precision on CUDA instead of fp16") | |
| parser.add_argument("--no-mixed-precision", action="store_true", | |
| help="Disable fp16/bf16 mixed precision even when CUDA is available") | |
| parser.add_argument("--tf32", dest="tf32", action="store_true", | |
| help="Enable TF32 matmul/cudnn kernels on CUDA") | |
| parser.add_argument("--no-tf32", dest="tf32", action="store_false", | |
| help="Disable TF32 matmul/cudnn kernels") | |
| parser.add_argument("--torch-compile", action="store_true", | |
| help="Enable torch.compile through Hugging Face Trainer") | |
| parser.add_argument("--auto-find-batch-size", action="store_true", | |
| help="Let Trainer reduce batch size automatically on CUDA OOM") | |
| parser.add_argument("--perf-log-steps", type=int, default=100, | |
| help="Sample training throughput, memory, and GPU stats every N steps; 0 disables") | |
| parser.add_argument("--perf-sample-interval", type=float, default=1.0, | |
| help="Background NVML sampling interval in seconds during training; 0 disables") | |
| parser.add_argument("--cpu", action="store_true", help="Force CPU training") | |
| parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split") | |
| parser.add_argument("--resume-from-checkpoint", default=None, | |
| help="Resume Trainer state from a checkpoint directory, or 'auto' for the latest checkpoint") | |
| parser.add_argument("--tensorboard", dest="tensorboard", action="store_true", | |
| help="Log metrics to TensorBoard in addition to stdout/checkpoints") | |
| parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false", | |
| help="Disable TensorBoard logging") | |
| parser.add_argument("--experiment-name", default=None, | |
| help="Optional experiment name written to run_metadata.json") | |
| parser.add_argument("--parse-eval-limit", type=int, default=512, | |
| help="Run field exact-match evaluation on up to N eval samples after training; 0 disables it") | |
| parser.add_argument("--case-eval-file", default=os.path.join("data", "parser_regression_cases.json"), | |
| help="Fixed real-world parser regression case file evaluated after training") | |
| parser.add_argument("--case-eval-output", default=None, | |
| help="Optional output path for fixed case metrics; defaults to final/case_metrics.json") | |
| parser.add_argument("--no-case-eval", action="store_true", | |
| help="Skip fixed real-world parser regression evaluation") | |
| parser.add_argument("--hidden-size", type=int, default=None, help="Override BERT hidden size") | |
| parser.add_argument("--num-hidden-layers", type=int, default=None, help="Override BERT layer count") | |
| parser.add_argument("--num-attention-heads", type=int, default=None, help="Override BERT attention heads") | |
| parser.add_argument("--intermediate-size", type=int, default=None, help="Override BERT FFN intermediate size") | |
| parser.set_defaults(tf32=True) | |
| parser.set_defaults(tensorboard=True) | |
| return parser.parse_args() | |
| def detect_tokenizer_variant( | |
| data_file: str, | |
| explicit_variant: Optional[str], | |
| explicit_vocab_path: Optional[str], | |
| sample_size: int = 256, | |
| ) -> str: | |
| """Infer tokenizer variant from CLI, dataset metadata, or vocab filename.""" | |
| if explicit_variant: | |
| return explicit_variant | |
| variants = set() | |
| char_like = 0 | |
| inspected = 0 | |
| with open(data_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if inspected >= sample_size: | |
| break | |
| line = line.strip() | |
| if not line: | |
| continue | |
| item = json.loads(line) | |
| inspected += 1 | |
| variant = item.get("tokenizer_variant") | |
| if variant: | |
| variants.add(variant) | |
| tokens = item.get("tokens", []) | |
| filename = item.get("filename") | |
| if filename is not None and tokens == list(filename): | |
| char_like += 1 | |
| if len(variants) == 1: | |
| return next(iter(variants)) | |
| if len(variants) > 1: | |
| raise ValueError(f"Mixed tokenizer_variant values in {data_file}: {sorted(variants)}") | |
| if explicit_vocab_path and ".char" in os.path.basename(explicit_vocab_path).lower(): | |
| return "char" | |
| if inspected and char_like / inspected >= 0.95: | |
| return "char" | |
| return "regex" | |
| def detect_tokenizer_variant_from_files( | |
| data_files: List[str], | |
| explicit_variant: Optional[str], | |
| explicit_vocab_path: Optional[str], | |
| ) -> str: | |
| if explicit_variant: | |
| return explicit_variant | |
| variants = { | |
| detect_tokenizer_variant(path, None, explicit_vocab_path) | |
| for path in data_files | |
| } | |
| if len(variants) > 1: | |
| raise ValueError(f"Mixed tokenizer variants across datasets: {sorted(variants)}") | |
| return next(iter(variants)) | |
| def resolve_vocab_path(data_file: str, tokenizer_variant: str, explicit_path: Optional[str]) -> str: | |
| if explicit_path: | |
| return explicit_path | |
| name = "vocab.json" if tokenizer_variant == "regex" else "vocab.char.json" | |
| return os.path.join(os.path.dirname(data_file), name) | |
| def latest_checkpoint(save_dir: str) -> Optional[str]: | |
| if not os.path.isdir(save_dir): | |
| return None | |
| checkpoints = [] | |
| for name in os.listdir(save_dir): | |
| if not name.startswith("checkpoint-"): | |
| continue | |
| path = os.path.join(save_dir, name) | |
| if not os.path.isdir(path): | |
| continue | |
| try: | |
| step = int(name.split("-")[-1]) | |
| except ValueError: | |
| continue | |
| checkpoints.append((step, path)) | |
| if not checkpoints: | |
| return None | |
| return max(checkpoints)[1] | |
| def validate_dataset_tokenizer_metadata(data: List[Dict], tokenizer_variant: str) -> None: | |
| variants = {item.get("tokenizer_variant") for item in data if item.get("tokenizer_variant")} | |
| if variants and variants != {tokenizer_variant}: | |
| raise ValueError( | |
| f"Dataset tokenizer_variant {sorted(variants)} does not match selected tokenizer " | |
| f"'{tokenizer_variant}'. Pass --tokenizer explicitly only when this is intentional." | |
| ) | |
| def load_jsonl(data_file: str, limit: Optional[int] = None) -> List[Dict]: | |
| """Load JSONL rows, stopping early for smoke runs.""" | |
| data: List[Dict] = [] | |
| with open(data_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| data.append(json.loads(line)) | |
| if limit is not None and len(data) >= limit: | |
| break | |
| return data | |
| def load_training_sources( | |
| primary_data_file: str, | |
| extra_data_files: List[str], | |
| extra_repeat: int, | |
| limit: Optional[int] = None, | |
| ) -> tuple[List[Dict], List[Dict]]: | |
| """Load primary plus extra datasets while preserving source metadata.""" | |
| sources: List[Dict] = [] | |
| primary = load_jsonl(primary_data_file, limit) | |
| all_data: List[Dict] = list(primary) | |
| sources.append( | |
| { | |
| "role": "primary", | |
| "path": primary_data_file, | |
| "samples": len(primary), | |
| "repeat": 1, | |
| "effective_samples": len(primary), | |
| } | |
| ) | |
| repeat = max(1, extra_repeat) | |
| for path in extra_data_files: | |
| rows = load_jsonl(path, None) | |
| for _ in range(repeat): | |
| all_data.extend(rows) | |
| sources.append( | |
| { | |
| "role": "extra", | |
| "path": path, | |
| "samples": len(rows), | |
| "repeat": repeat, | |
| "effective_samples": len(rows) * repeat, | |
| } | |
| ) | |
| return all_data, sources | |
| def extract_entities_from_labels(tokens: Sequence[str], labels: Sequence[str]) -> Dict[str, List[str]]: | |
| """Extract contiguous BIO entity text spans from token/label arrays.""" | |
| entities: Dict[str, List[str]] = {} | |
| active_entity: Optional[str] = None | |
| active_tokens: List[str] = [] | |
| for token, label in zip(tokens, labels): | |
| if label.startswith("B-"): | |
| if active_entity and active_tokens: | |
| entities.setdefault(active_entity, []).append("".join(active_tokens)) | |
| active_entity = label[2:] | |
| active_tokens = [str(token)] | |
| elif label.startswith("I-") and active_entity == label[2:]: | |
| active_tokens.append(str(token)) | |
| else: | |
| if active_entity and active_tokens: | |
| entities.setdefault(active_entity, []).append("".join(active_tokens)) | |
| active_entity = None | |
| active_tokens = [] | |
| if active_entity and active_tokens: | |
| entities.setdefault(active_entity, []).append("".join(active_tokens)) | |
| return entities | |
| def char_item_from_spans(filename: str, spans: Sequence[tuple[str, str]], source: str) -> Optional[Dict]: | |
| """Create a char-tokenized BIO item from ordered text/entity spans.""" | |
| filename = filename.strip() | |
| if not filename: | |
| return None | |
| tokens = list(filename) | |
| labels = ["O"] * len(tokens) | |
| cursor = 0 | |
| for text, entity in spans: | |
| if not text: | |
| continue | |
| start = filename.find(text, cursor) | |
| if start < 0: | |
| start = filename.find(text) | |
| if start < 0: | |
| return None | |
| end = start + len(text) | |
| labels[start] = f"B-{entity}" | |
| for idx in range(start + 1, end): | |
| labels[idx] = f"I-{entity}" | |
| cursor = end | |
| return { | |
| "filename": filename, | |
| "tokens": tokens, | |
| "labels": labels, | |
| "tokenizer_variant": "char", | |
| "source": source, | |
| } | |
| def entity_keep_probability(entity: str) -> float: | |
| return { | |
| "GROUP": 0.35, | |
| "TITLE": 0.65, | |
| "SEASON": 0.35, | |
| "EPISODE": 0.7, | |
| "SPECIAL": 0.3, | |
| "RESOLUTION": 0.65, | |
| "SOURCE": 0.65, | |
| }.get(entity, 0.5) | |
| def build_partial_augmented_item(item: Dict, max_chars: int) -> List[Dict]: | |
| entities = extract_entities_from_labels(item.get("tokens", []), item.get("labels", [])) | |
| title = next((value.strip() for value in entities.get("TITLE", []) if value.strip()), None) | |
| season = next((value.strip() for value in entities.get("SEASON", []) if value.strip()), None) | |
| episode = next((value.strip() for value in entities.get("EPISODE", []) if value.strip()), None) | |
| special = next((value.strip() for value in entities.get("SPECIAL", []) if value.strip()), None) | |
| resolution = next((value.strip() for value in entities.get("RESOLUTION", []) if value.strip()), None) | |
| source = next((value.strip() for value in entities.get("SOURCE", []) if value.strip()), None) | |
| specs: List[tuple[str, List[tuple[str, str]]]] = [] | |
| if title: | |
| specs.append((title, [(title, "TITLE")])) | |
| if title and season: | |
| specs.append((f"{title} {season}", [(title, "TITLE"), (season, "SEASON")])) | |
| if episode: | |
| specs.append((episode, [(episode, "EPISODE")])) | |
| if episode and resolution: | |
| specs.append((f"{episode} [{resolution}]", [(episode, "EPISODE"), (resolution, "RESOLUTION")])) | |
| if episode and resolution and source: | |
| specs.append( | |
| ( | |
| f"{episode} [{resolution}][{source}]", | |
| [(episode, "EPISODE"), (resolution, "RESOLUTION"), (source, "SOURCE")], | |
| ) | |
| ) | |
| if special: | |
| specs.append((special, [(special, "SPECIAL")])) | |
| if title and special: | |
| specs.append((f"{title} - {special}", [(title, "TITLE"), (special, "SPECIAL")])) | |
| augmented: List[Dict] = [] | |
| for text, spans in specs: | |
| if 2 <= len(text) <= max_chars: | |
| generated = char_item_from_spans(text, spans, "train_partial_augmentation") | |
| if generated is not None: | |
| augmented.append(generated) | |
| return augmented | |
| def build_permutation_augmented_item(item: Dict, rng: random.Random, max_chars: int) -> Optional[Dict]: | |
| entities = extract_entities_from_labels(item.get("tokens", []), item.get("labels", [])) | |
| available = [ | |
| entity | |
| for entity in ("GROUP", "TITLE", "SEASON", "EPISODE", "SPECIAL", "RESOLUTION", "SOURCE") | |
| if entities.get(entity) | |
| ] | |
| if not available: | |
| return None | |
| selected = [ | |
| entity | |
| for entity in available | |
| if rng.random() < entity_keep_probability(entity) | |
| ] | |
| if not selected: | |
| selected = [rng.choice(available)] | |
| if "TITLE" not in selected and "EPISODE" not in selected and "SPECIAL" not in selected: | |
| extras = [entity for entity in available if entity not in selected] | |
| selected.append(rng.choice(extras or available)) | |
| rng.shuffle(selected) | |
| separators = [" ", " - ", ".", "_", "]["] | |
| sep = rng.choice(separators) | |
| parts: List[str] = [] | |
| spans: List[tuple[str, str]] = [] | |
| for entity in selected: | |
| values = [value.strip() for value in entities.get(entity, []) if value.strip()] | |
| if not values: | |
| continue | |
| value = rng.choice(values) | |
| if entity in {"GROUP", "EPISODE", "SPECIAL", "RESOLUTION", "SOURCE"} and rng.random() < 0.35: | |
| parts.append(f"[{value}]") | |
| else: | |
| parts.append(value) | |
| spans.append((value, entity)) | |
| text = sep.join(parts).strip() | |
| if not (2 <= len(text) <= max_chars): | |
| return None | |
| return char_item_from_spans(text, spans, "train_permutation_augmentation") | |
| def build_special_augmented_item(data: List[Dict], rng: random.Random, max_chars: int) -> Optional[Dict]: | |
| base_titles: List[str] = [] | |
| for _ in range(min(16, len(data))): | |
| item = data[rng.randrange(len(data))] | |
| entities = extract_entities_from_labels(item.get("tokens", []), item.get("labels", [])) | |
| base_titles.extend(value.strip() for value in entities.get("TITLE", []) if 2 <= len(value.strip()) <= 80) | |
| title = rng.choice(base_titles) if base_titles else None | |
| special = rng.choice( | |
| [ | |
| f"Menu{rng.randint(1, 24):02d}", | |
| f"Menu {rng.randint(1, 24):02d}", | |
| f"BDMenu{rng.randint(1, 24):02d}", | |
| f"BD Menu{rng.randint(1, 24):02d}", | |
| f"Menu{rng.randint(1, 24):02d}-01", | |
| "Menu", | |
| f"OP{rng.randint(1, 6):02d}", | |
| f"ED E{rng.randint(1, 24):02d}", | |
| f"NCOP{rng.randint(1, 6):02d}", | |
| f"NCED{rng.randint(1, 6):02d}", | |
| f"CM{rng.randint(1, 12):02d}", | |
| f"PV{rng.randint(1, 12):02d}", | |
| ] | |
| ) | |
| if title and rng.random() < 0.55: | |
| text = f"{title} - {special}" | |
| spans = [(title, "TITLE"), (special, "SPECIAL")] | |
| else: | |
| text = special | |
| spans = [(special, "SPECIAL")] | |
| if len(text) > max_chars: | |
| return None | |
| return char_item_from_spans(text, spans, "train_special_augmentation") | |
| def process_memory_mb() -> Optional[float]: | |
| try: | |
| import psutil # type: ignore | |
| return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) | |
| except Exception: | |
| pass | |
| if os.name == "nt": | |
| try: | |
| import ctypes | |
| from ctypes import wintypes | |
| class PROCESS_MEMORY_COUNTERS(ctypes.Structure): | |
| _fields_ = [ | |
| ("cb", wintypes.DWORD), | |
| ("PageFaultCount", wintypes.DWORD), | |
| ("PeakWorkingSetSize", ctypes.c_size_t), | |
| ("WorkingSetSize", ctypes.c_size_t), | |
| ("QuotaPeakPagedPoolUsage", ctypes.c_size_t), | |
| ("QuotaPagedPoolUsage", ctypes.c_size_t), | |
| ("QuotaPeakNonPagedPoolUsage", ctypes.c_size_t), | |
| ("QuotaNonPagedPoolUsage", ctypes.c_size_t), | |
| ("PagefileUsage", ctypes.c_size_t), | |
| ("PeakPagefileUsage", ctypes.c_size_t), | |
| ] | |
| counters = PROCESS_MEMORY_COUNTERS() | |
| counters.cb = ctypes.sizeof(counters) | |
| handle = ctypes.windll.kernel32.GetCurrentProcess() | |
| if ctypes.windll.psapi.GetProcessMemoryInfo(handle, ctypes.byref(counters), counters.cb): | |
| return float(counters.WorkingSetSize) / (1024 * 1024) | |
| except Exception: | |
| pass | |
| try: | |
| import resource # type: ignore | |
| usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss | |
| if sys.platform == "darwin": | |
| usage = usage / 1024 / 1024 | |
| else: | |
| usage = usage / 1024 | |
| return float(usage) | |
| except Exception: | |
| return None | |
| class NvmlSampler: | |
| """Tiny NVML binding for runtime GPU telemetry without adding dependencies.""" | |
| NVML_TEMPERATURE_GPU = 0 | |
| def __init__(self): | |
| self._lib = None | |
| self._handle = None | |
| self._available = False | |
| self._init() | |
| def _candidate_names(self) -> List[str]: | |
| names = [] | |
| found = ctypes_util.find_library("nvidia-ml") | |
| if found: | |
| names.append(found) | |
| if os.name == "nt": | |
| names.extend( | |
| [ | |
| os.path.join(os.environ.get("SystemRoot", r"C:\Windows"), "System32", "nvml.dll"), | |
| "nvml.dll", | |
| "nvidia-ml.dll", | |
| ] | |
| ) | |
| else: | |
| names.extend(["libnvidia-ml.so.1", "libnvidia-ml.so"]) | |
| return names | |
| def _init(self) -> None: | |
| for name in self._candidate_names(): | |
| try: | |
| lib = cdll.LoadLibrary(name) | |
| break | |
| except Exception: | |
| continue | |
| else: | |
| return | |
| class NVMLUtilization(Structure): | |
| _fields_ = [("gpu", c_uint), ("memory", c_uint)] | |
| class NVMLMemory(Structure): | |
| _fields_ = [("total", c_ulonglong), ("free", c_ulonglong), ("used", c_ulonglong)] | |
| self.NVMLUtilization = NVMLUtilization | |
| self.NVMLMemory = NVMLMemory | |
| lib.nvmlInit_v2.restype = c_int | |
| lib.nvmlDeviceGetHandleByIndex_v2.argtypes = [c_uint, POINTER(c_void_p)] | |
| lib.nvmlDeviceGetHandleByIndex_v2.restype = c_int | |
| handle = c_void_p() | |
| try: | |
| if lib.nvmlInit_v2() != 0: | |
| return | |
| if lib.nvmlDeviceGetHandleByIndex_v2(0, byref(handle)) != 0: | |
| return | |
| except Exception: | |
| return | |
| self._lib = lib | |
| self._handle = handle | |
| self._available = True | |
| def available(self) -> bool: | |
| return self._available | |
| def sample(self) -> Dict[str, Optional[float]]: | |
| if not self._available or self._lib is None or self._handle is None: | |
| return {} | |
| stats: Dict[str, Optional[float]] = {} | |
| try: | |
| util_rates = self.NVMLUtilization() | |
| self._lib.nvmlDeviceGetUtilizationRates.argtypes = [c_void_p, POINTER(self.NVMLUtilization)] | |
| if self._lib.nvmlDeviceGetUtilizationRates(self._handle, byref(util_rates)) == 0: | |
| stats["gpu_util_percent"] = float(util_rates.gpu) | |
| stats["gpu_memory_util_percent"] = float(util_rates.memory) | |
| except Exception: | |
| pass | |
| try: | |
| memory = self.NVMLMemory() | |
| self._lib.nvmlDeviceGetMemoryInfo.argtypes = [c_void_p, POINTER(self.NVMLMemory)] | |
| if self._lib.nvmlDeviceGetMemoryInfo(self._handle, byref(memory)) == 0: | |
| stats["gpu_memory_used_mb"] = float(memory.used) / (1024 * 1024) | |
| stats["gpu_memory_total_mb"] = float(memory.total) / (1024 * 1024) | |
| except Exception: | |
| pass | |
| try: | |
| temperature = c_uint() | |
| self._lib.nvmlDeviceGetTemperature.argtypes = [c_void_p, c_uint, POINTER(c_uint)] | |
| if self._lib.nvmlDeviceGetTemperature(self._handle, self.NVML_TEMPERATURE_GPU, byref(temperature)) == 0: | |
| stats["gpu_temperature_c"] = float(temperature.value) | |
| except Exception: | |
| pass | |
| try: | |
| power_mw = c_uint() | |
| self._lib.nvmlDeviceGetPowerUsage.argtypes = [c_void_p, POINTER(c_uint)] | |
| if self._lib.nvmlDeviceGetPowerUsage(self._handle, byref(power_mw)) == 0: | |
| stats["gpu_power_w"] = float(power_mw.value) / 1000.0 | |
| except Exception: | |
| pass | |
| return stats | |
| _NVML_SAMPLER: Optional[NvmlSampler] = None | |
| def query_nvml() -> Dict[str, Optional[float]]: | |
| global _NVML_SAMPLER | |
| if _NVML_SAMPLER is None: | |
| _NVML_SAMPLER = NvmlSampler() | |
| return _NVML_SAMPLER.sample() | |
| def query_nvidia_smi() -> Dict[str, Optional[float]]: | |
| try: | |
| result = subprocess.run( | |
| [ | |
| "nvidia-smi", | |
| "--query-gpu=utilization.gpu,memory.used,memory.total,power.draw", | |
| "--format=csv,noheader,nounits", | |
| ], | |
| check=False, | |
| capture_output=True, | |
| text=True, | |
| timeout=2, | |
| ) | |
| except Exception: | |
| return {} | |
| if result.returncode != 0 or not result.stdout.strip(): | |
| return {} | |
| first_line = result.stdout.strip().splitlines()[0] | |
| values = [part.strip() for part in first_line.split(",")] | |
| keys = ["gpu_util_percent", "gpu_memory_used_mb", "gpu_memory_total_mb", "gpu_power_w"] | |
| stats: Dict[str, Optional[float]] = {} | |
| for key, value in zip(keys, values): | |
| try: | |
| stats[key] = float(value) | |
| except ValueError: | |
| stats[key] = None | |
| return stats | |
| def cuda_memory_stats_mb() -> Dict[str, float]: | |
| if not torch.cuda.is_available(): | |
| return {} | |
| return { | |
| "cuda_allocated_mb": torch.cuda.memory_allocated() / (1024 * 1024), | |
| "cuda_reserved_mb": torch.cuda.memory_reserved() / (1024 * 1024), | |
| "cuda_max_allocated_mb": torch.cuda.max_memory_allocated() / (1024 * 1024), | |
| "cuda_max_reserved_mb": torch.cuda.max_memory_reserved() / (1024 * 1024), | |
| } | |
| def snapshot_perf_stats() -> Dict[str, Optional[float]]: | |
| stats: Dict[str, Optional[float]] = {} | |
| stats["process_rss_mb"] = process_memory_mb() | |
| stats.update(cuda_memory_stats_mb()) | |
| gpu_stats = query_nvml() | |
| if not gpu_stats: | |
| gpu_stats = query_nvidia_smi() | |
| stats.update(gpu_stats) | |
| return stats | |
| class TrainingPerfCallback(TrainerCallback): | |
| """Lightweight runtime telemetry for spotting data-pipeline starvation.""" | |
| def __init__(self, batch_size: int, sequence_length: int, log_steps: int, sample_interval: float): | |
| self.batch_size = batch_size | |
| self.sequence_length = sequence_length | |
| self.log_steps = max(0, log_steps) | |
| self.sample_interval = max(0.0, sample_interval) | |
| self.samples: List[Dict[str, Optional[float]]] = [] | |
| self.background_samples: List[Dict[str, Optional[float]]] = [] | |
| self._last_step = 0 | |
| self._last_time: Optional[float] = None | |
| self._start_time: Optional[float] = None | |
| self._training = False | |
| self._stop_event = threading.Event() | |
| self._thread: Optional[threading.Thread] = None | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| now = time.perf_counter() | |
| self._start_time = now | |
| self._last_time = now | |
| self._last_step = int(state.global_step) | |
| self._training = True | |
| self._stop_event.clear() | |
| if self.sample_interval > 0: | |
| self._thread = threading.Thread(target=self._background_sample_loop, daemon=True) | |
| self._thread.start() | |
| def on_train_end(self, args, state, control, **kwargs): | |
| self._training = False | |
| self._stop_event.set() | |
| if self._thread is not None: | |
| self._thread.join(timeout=max(self.sample_interval * 2, 1.0)) | |
| self._thread = None | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if not self._training: | |
| return | |
| step = int(state.global_step) | |
| if self.log_steps <= 0 or step <= 0 or step % self.log_steps != 0: | |
| return | |
| self._record_sample(step) | |
| def on_step_end(self, args, state, control, **kwargs): | |
| if not self._training: | |
| return | |
| step = int(state.global_step) | |
| if self.log_steps <= 0 or step <= 0 or step % self.log_steps != 0: | |
| return | |
| if self.samples and self.samples[-1].get("step") == float(step): | |
| return | |
| self._record_sample(step) | |
| def _record_sample(self, step: int) -> None: | |
| if self.samples and self.samples[-1].get("step") == float(step): | |
| return | |
| now = time.perf_counter() | |
| last_time = self._last_time or now | |
| elapsed = max(now - last_time, 1e-9) | |
| step_delta = max(step - self._last_step, 0) | |
| samples_per_second = step_delta * self.batch_size / elapsed | |
| tokens_per_second = samples_per_second * self.sequence_length | |
| stats = snapshot_perf_stats() | |
| sample: Dict[str, Optional[float]] = { | |
| "step": float(step), | |
| "elapsed_seconds": now - (self._start_time or now), | |
| "window_seconds": elapsed, | |
| "steps_per_second": step_delta / elapsed, | |
| "samples_per_second": samples_per_second, | |
| "tokens_per_second": tokens_per_second, | |
| } | |
| sample.update(stats) | |
| self.samples.append(sample) | |
| print( | |
| " perf " | |
| f"step={step} " | |
| f"samples/s={samples_per_second:.1f} " | |
| f"tokens/s={tokens_per_second:.0f} " | |
| f"rss={stats.get('process_rss_mb') or 0:.0f}MB " | |
| f"cuda_alloc={stats.get('cuda_allocated_mb') or 0:.0f}MB " | |
| f"gpu_util={stats.get('gpu_util_percent') if stats.get('gpu_util_percent') is not None else 'n/a'}%" | |
| ) | |
| self._last_time = now | |
| self._last_step = step | |
| def _background_sample_loop(self) -> None: | |
| while not self._stop_event.wait(self.sample_interval): | |
| if not self._training: | |
| continue | |
| sample = snapshot_perf_stats() | |
| sample["elapsed_seconds"] = ( | |
| time.perf_counter() - self._start_time | |
| if self._start_time is not None | |
| else None | |
| ) | |
| self.background_samples.append(sample) | |
| def summary(self) -> Dict: | |
| numeric_keys = [ | |
| "samples_per_second", | |
| "tokens_per_second", | |
| "process_rss_mb", | |
| "cuda_max_allocated_mb", | |
| "gpu_util_percent", | |
| "gpu_memory_util_percent", | |
| "gpu_power_w", | |
| "gpu_temperature_c", | |
| ] | |
| summary: Dict[str, object] = { | |
| "sample_count": len(self.samples), | |
| "samples": self.samples, | |
| "background_sample_count": len(self.background_samples), | |
| "background_samples": self.background_samples, | |
| } | |
| sample_groups = { | |
| "step": self.samples, | |
| "background": self.background_samples, | |
| } | |
| for prefix, samples in sample_groups.items(): | |
| if not samples: | |
| continue | |
| for key in numeric_keys: | |
| values = [ | |
| float(sample[key]) | |
| for sample in samples | |
| if sample.get(key) is not None | |
| ] | |
| if values: | |
| summary[f"{prefix}_{key}_avg"] = sum(values) / len(values) | |
| summary[f"{prefix}_{key}_max"] = max(values) | |
| summary[f"{prefix}_{key}_min"] = min(values) | |
| if not self.samples and not self.background_samples: | |
| return summary | |
| for key in numeric_keys: | |
| values = [ | |
| float(sample[key]) | |
| for sample in self.samples | |
| if sample.get(key) is not None | |
| ] | |
| if values: | |
| summary[f"{key}_avg"] = sum(values) / len(values) | |
| summary[f"{key}_max"] = max(values) | |
| return summary | |
| class FastTokenClassificationCollator: | |
| """Stack already padded token-classification tensors without extra work.""" | |
| def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: | |
| batch = { | |
| key: torch.stack([feature[key] for feature in features]) | |
| for key in features[0].keys() | |
| } | |
| if "input_ids" in batch: | |
| batch["input_ids"] = batch["input_ids"].long() | |
| if "labels" in batch: | |
| batch["labels"] = batch["labels"].long() | |
| if "attention_mask" in batch: | |
| batch["attention_mask"] = batch["attention_mask"].to(dtype=torch.bool) | |
| return batch | |
| class OrderedTrainer(Trainer): | |
| """Trainer variant that preserves pre-shuffled order for virtual datasets.""" | |
| def _get_train_sampler(self, train_dataset=None): | |
| dataset = train_dataset if train_dataset is not None else self.train_dataset | |
| if getattr(dataset, "preserve_order", False): | |
| return SequentialSampler(dataset) | |
| return super()._get_train_sampler(train_dataset) | |
| def augment_training_data( | |
| data: List[Dict], | |
| partial_count: int, | |
| permutation_count: int, | |
| special_count: int, | |
| max_chars: int, | |
| seed: int, | |
| ) -> tuple[List[Dict], Dict]: | |
| """Append generated partial/permutation samples without modifying source JSONL.""" | |
| rng = random.Random(seed) | |
| augmented: List[Dict] = [] | |
| seen = { | |
| item.get("filename") or "".join(str(token) for token in item.get("tokens", [])) | |
| for item in data | |
| } | |
| partial_written = 0 | |
| if partial_count > 0: | |
| candidates: List[Dict] = [] | |
| attempts = 0 | |
| max_attempts = max(partial_count * 20, len(data)) | |
| while len(candidates) < partial_count * 4 and attempts < max_attempts: | |
| attempts += 1 | |
| candidates.extend(build_partial_augmented_item(rng.choice(data), max_chars)) | |
| rng.shuffle(candidates) | |
| for item in candidates: | |
| key = item["filename"] | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| augmented.append(item) | |
| partial_written += 1 | |
| if partial_written >= partial_count: | |
| break | |
| permutation_written = 0 | |
| attempts = 0 | |
| while permutation_written < permutation_count and attempts < max(permutation_count * 20, 100): | |
| attempts += 1 | |
| item = build_permutation_augmented_item(rng.choice(data), rng, max_chars) | |
| if item is None: | |
| continue | |
| key = item["filename"] | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| augmented.append(item) | |
| permutation_written += 1 | |
| special_written = 0 | |
| attempts = 0 | |
| while special_written < special_count and attempts < max(special_count * 20, 100): | |
| attempts += 1 | |
| item = build_special_augmented_item(data, rng, max_chars) | |
| if item is None: | |
| continue | |
| key = item["filename"] | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| augmented.append(item) | |
| special_written += 1 | |
| meta = { | |
| "partial_requested": partial_count, | |
| "partial_written": partial_written, | |
| "permutation_requested": permutation_count, | |
| "permutation_written": permutation_written, | |
| "special_requested": special_count, | |
| "special_written": special_written, | |
| "max_chars": max_chars, | |
| } | |
| return data + augmented, meta | |
| def normalize_field_value(field: str, value) -> Optional[str]: | |
| if value is None: | |
| return None | |
| if field in {"episode", "season"}: | |
| try: | |
| return str(int(value)) | |
| except (TypeError, ValueError): | |
| return str(value).strip().lower() | |
| text = str(value).strip() | |
| if field in {"resolution", "source"}: | |
| return text.lower().replace("_", "-") | |
| return " ".join(text.lower().split()) | |
| def parse_exact_metrics( | |
| samples: List[Dict], | |
| model: BertForTokenClassification, | |
| tokenizer: AnimeTokenizer, | |
| id2label: Dict[int, str], | |
| max_length: int, | |
| limit: Optional[int], | |
| constrain_bio: bool = True, | |
| ) -> Dict: | |
| """Evaluate end-to-end field exact match on filenames, not just token loss.""" | |
| fields = ["group", "title", "season", "episode", "resolution", "source", "special"] | |
| selected = [sample for sample in samples if sample.get("filename")] | |
| if limit is not None and limit > 0: | |
| selected = selected[:limit] | |
| counter: Counter = Counter() | |
| failures: List[Dict] = [] | |
| model.eval() | |
| for sample in selected: | |
| filename = sample["filename"] | |
| tokens, gold_labels = labels_for_tokenizer(sample, tokenizer) | |
| available = max(0, max_length - 2) | |
| tokens = tokens[:available] | |
| gold_labels = gold_labels[:available] | |
| gold = postprocess(tokens, gold_labels, tokenizer=tokenizer) | |
| gold_entities = {label.split("-", 1)[1] for label in gold_labels if label.startswith(("B-", "I-"))} | |
| for optional_field, entity in (("episode", "EPISODE"), ("season", "SEASON")): | |
| if entity not in gold_entities: | |
| gold[optional_field] = None | |
| pred = parse_filename( | |
| filename, | |
| model, | |
| tokenizer, | |
| id2label, | |
| max_length=max_length, | |
| debug=False, | |
| constrain_bio=constrain_bio, | |
| ) | |
| full_match = True | |
| field_errors: Dict[str, Dict[str, Optional[str]]] = {} | |
| for field in fields: | |
| gold_value = normalize_field_value(field, gold.get(field)) | |
| pred_value = normalize_field_value(field, pred.get(field)) | |
| counter[f"{field}_total"] += 1 | |
| if gold_value == pred_value: | |
| counter[f"{field}_correct"] += 1 | |
| else: | |
| full_match = False | |
| field_errors[field] = {"gold": gold_value, "pred": pred_value} | |
| counter["full_total"] += 1 | |
| if full_match: | |
| counter["full_correct"] += 1 | |
| elif len(failures) < 20: | |
| failures.append( | |
| { | |
| "filename": filename, | |
| "errors": field_errors, | |
| "gold": {field: gold.get(field) for field in fields}, | |
| "pred": {field: pred.get(field) for field in fields}, | |
| } | |
| ) | |
| field_accuracy = {} | |
| for field in fields: | |
| total = counter.get(f"{field}_total", 0) | |
| correct = counter.get(f"{field}_correct", 0) | |
| field_accuracy[field] = correct / total if total else 0.0 | |
| total = counter.get("full_total", 0) | |
| correct = counter.get("full_correct", 0) | |
| return { | |
| "constrain_bio": constrain_bio, | |
| "sample_count": total, | |
| "field_accuracy": field_accuracy, | |
| "field_correct": {field: counter.get(f"{field}_correct", 0) for field in fields}, | |
| "field_total": {field: counter.get(f"{field}_total", 0) for field in fields}, | |
| "full_match_accuracy": correct / total if total else 0.0, | |
| "full_match_correct": correct, | |
| "full_match_total": total, | |
| "failures": failures, | |
| } | |
| def parse_exact_metrics_all_modes( | |
| samples: List[Dict], | |
| model: BertForTokenClassification, | |
| tokenizer: AnimeTokenizer, | |
| id2label: Dict[int, str], | |
| max_length: int, | |
| limit: Optional[int], | |
| ) -> Dict: | |
| modes = { | |
| "model_only": {"constrain_bio": False}, | |
| "normalized_only": {"constrain_bio": True}, | |
| } | |
| return { | |
| "primary_metric": "normalized_only", | |
| "modes": { | |
| name: parse_exact_metrics( | |
| samples, | |
| model, | |
| tokenizer, | |
| id2label, | |
| max_length, | |
| limit, | |
| constrain_bio=settings["constrain_bio"], | |
| ) | |
| for name, settings in modes.items() | |
| }, | |
| } | |
| def remap_token_embeddings( | |
| model: BertForTokenClassification, | |
| old_vocab: Dict[str, int], | |
| new_vocab: Dict[str, int], | |
| pad_token_id: int, | |
| ) -> int: | |
| """ | |
| Replace the input embedding table for a changed vocabulary. | |
| resize_token_embeddings() preserves rows by numeric ID, which is unsafe when | |
| two tokenizers assign different tokens to the same ID. This remaps by token | |
| string and randomly initializes tokens that do not exist in the old vocab. | |
| """ | |
| old_embeddings = model.get_input_embeddings() | |
| old_weight = old_embeddings.weight.data | |
| embedding_dim = old_weight.shape[1] | |
| new_embeddings = torch.nn.Embedding( | |
| len(new_vocab), | |
| embedding_dim, | |
| padding_idx=pad_token_id, | |
| device=old_weight.device, | |
| dtype=old_weight.dtype, | |
| ) | |
| torch.nn.init.normal_( | |
| new_embeddings.weight, | |
| mean=0.0, | |
| std=getattr(model.config, "initializer_range", 0.02), | |
| ) | |
| if pad_token_id is not None and 0 <= pad_token_id < len(new_vocab): | |
| new_embeddings.weight.data[pad_token_id].zero_() | |
| copied = 0 | |
| for token, new_id in new_vocab.items(): | |
| old_id = old_vocab.get(token) | |
| if old_id is None or old_id >= old_weight.shape[0]: | |
| continue | |
| new_embeddings.weight.data[new_id].copy_(old_weight[old_id]) | |
| copied += 1 | |
| model.set_input_embeddings(new_embeddings) | |
| model.config.vocab_size = len(new_vocab) | |
| return copied | |
| def build_vocab_from_data(data: List[Dict], tokenizer: AnimeTokenizer, vocab_path: str, | |
| max_size: Optional[int] = None) -> None: | |
| token_lists: List[List[str]] = [] | |
| for item in data: | |
| tokens, _labels = labels_for_tokenizer(item, tokenizer) | |
| token_lists.append(tokens) | |
| tokenizer.build_vocab(token_lists, max_size=max_size) | |
| save_dir = os.path.dirname(vocab_path) or "." | |
| os.makedirs(save_dir, exist_ok=True) | |
| with open(vocab_path, "w", encoding="utf-8") as f: | |
| json.dump(tokenizer.get_vocab(), f, ensure_ascii=False, indent=2) | |
| def main(): | |
| args = parse_args() | |
| config = Config() | |
| if args.data_file is not None: | |
| config.data_file = args.data_file | |
| training_files = [config.data_file] + list(args.extra_data_file or []) | |
| tokenizer_variant = detect_tokenizer_variant_from_files(training_files, args.tokenizer, args.vocab_file) | |
| if args.save_dir is not None: | |
| config.save_dir = args.save_dir | |
| elif tokenizer_variant == "char": | |
| config.save_dir = "./checkpoints_char" | |
| if args.epochs is not None: | |
| config.num_epochs = args.epochs | |
| if args.batch_size is not None: | |
| config.batch_size = args.batch_size | |
| if args.learning_rate is not None: | |
| config.learning_rate = args.learning_rate | |
| if args.warmup_steps is not None: | |
| config.warmup_steps = args.warmup_steps | |
| if args.train_split is not None: | |
| config.train_split = args.train_split | |
| if args.num_workers is not None: | |
| config.num_workers = args.num_workers | |
| if args.max_seq_length is not None: | |
| config.max_seq_length = args.max_seq_length | |
| elif tokenizer_variant == "char": | |
| config.max_seq_length = max(config.max_seq_length, 128) | |
| if args.hidden_size is not None: | |
| config.hidden_size = args.hidden_size | |
| if args.num_hidden_layers is not None: | |
| config.num_hidden_layers = args.num_hidden_layers | |
| if args.num_attention_heads is not None: | |
| config.num_attention_heads = args.num_attention_heads | |
| if args.intermediate_size is not None: | |
| config.intermediate_size = args.intermediate_size | |
| if config.hidden_size % config.num_attention_heads != 0: | |
| raise ValueError( | |
| f"hidden_size ({config.hidden_size}) must be divisible by " | |
| f"num_attention_heads ({config.num_attention_heads})." | |
| ) | |
| config.max_position_embeddings = max(config.max_position_embeddings, config.max_seq_length) | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| print("Loading dataset...") | |
| load_started_at = time.perf_counter() | |
| all_data, data_sources = load_training_sources( | |
| primary_data_file=config.data_file, | |
| extra_data_files=list(args.extra_data_file or []), | |
| extra_repeat=args.extra_data_repeat, | |
| limit=args.limit_samples, | |
| ) | |
| augmentation_metadata = { | |
| "partial_requested": 0, | |
| "partial_written": 0, | |
| "permutation_requested": 0, | |
| "permutation_written": 0, | |
| "special_requested": 0, | |
| "special_written": 0, | |
| "max_chars": args.augment_max_chars, | |
| } | |
| if args.augment_partial_samples or args.augment_permutation_samples or args.augment_special_samples: | |
| if tokenizer_variant != "char": | |
| raise ValueError("Training-time BIO span augmentation currently requires --tokenizer char.") | |
| all_data, augmentation_metadata = augment_training_data( | |
| data=all_data, | |
| partial_count=args.augment_partial_samples, | |
| permutation_count=args.augment_permutation_samples, | |
| special_count=args.augment_special_samples, | |
| max_chars=args.augment_max_chars, | |
| seed=args.seed + 1009, | |
| ) | |
| load_finished_at = time.perf_counter() | |
| if len(all_data) < 2: | |
| raise ValueError("Need at least two samples so train/eval split is non-empty.") | |
| if not args.no_shuffle: | |
| random.shuffle(all_data) | |
| validate_dataset_tokenizer_metadata(all_data, tokenizer_variant) | |
| # Load tokenizer | |
| print("Loading tokenizer...") | |
| vocab_path = resolve_vocab_path(config.data_file, tokenizer_variant, args.vocab_file) | |
| tokenizer = create_tokenizer(tokenizer_variant) | |
| if args.rebuild_vocab or not os.path.isfile(vocab_path): | |
| max_vocab_size = args.max_vocab_size if args.max_vocab_size is not None else config.vocab_size | |
| print(f" Building {tokenizer_variant} vocab: {vocab_path} (max_size={max_vocab_size})") | |
| build_vocab_from_data(all_data, tokenizer, vocab_path, max_size=max_vocab_size) | |
| tokenizer = create_tokenizer(tokenizer_variant, vocab_file=vocab_path) | |
| print(f" Variant: {tokenizer_variant}") | |
| print(f" Vocab size: {tokenizer.vocab_size}") | |
| print(f" Max sequence length: {config.max_seq_length}") | |
| if torch.cuda.is_available() and not args.cpu: | |
| print(f" CUDA device: {torch.cuda.get_device_name(0)}") | |
| # Update config with actual vocab size | |
| config.vocab_size = tokenizer.vocab_size | |
| # Create model | |
| if args.init_model_dir: | |
| print(f"Loading model for fine-tuning: {args.init_model_dir}") | |
| model = BertForTokenClassification.from_pretrained(args.init_model_dir) | |
| init_tokenizer = load_tokenizer(args.init_model_dir, tokenizer_variant) | |
| init_vocab = init_tokenizer.get_vocab() | |
| embedding_size = model.get_input_embeddings().weight.shape[0] | |
| if len(init_vocab) != embedding_size: | |
| print( | |
| " WARNING: init checkpoint tokenizer vocab length does not match model embedding size " | |
| f"({len(init_vocab):,} vs {embedding_size:,}). Prefer a self-consistent checkpoint." | |
| ) | |
| init_variant = getattr(init_tokenizer, "tokenizer_variant", None) | |
| if init_variant != tokenizer_variant: | |
| print(f" WARNING: tokenizer variant changes during fine-tune: {init_variant} -> {tokenizer_variant}") | |
| print(" Token embeddings will be remapped by token string; unmatched tokens are newly initialized.") | |
| if model.config.vocab_size != config.vocab_size or init_vocab != tokenizer.get_vocab(): | |
| copied = remap_token_embeddings( | |
| model=model, | |
| old_vocab=init_vocab, | |
| new_vocab=tokenizer.get_vocab(), | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| print( | |
| f" Remapped token embeddings: copied {copied:,}/{config.vocab_size:,} " | |
| f"tokens from init checkpoint" | |
| ) | |
| model.config.num_labels = config.num_labels | |
| model.config.id2label = config.id2label | |
| model.config.label2id = config.label2id | |
| else: | |
| print("Creating model...") | |
| model: BertForTokenClassification = create_model(config) | |
| total_params = print_model_summary(model) | |
| if total_params >= 5_000_000: | |
| print("WARNING: Model exceeds the historical 5M target; continuing because vocab size is configurable.") | |
| use_cpu = args.cpu or not torch.cuda.is_available() | |
| split_idx = int(len(all_data) * config.train_split) | |
| split_idx = max(1, min(len(all_data) - 1, split_idx)) | |
| train_data = all_data[:split_idx] | |
| eval_data = all_data[split_idx:] | |
| encode_started_at = time.perf_counter() | |
| if args.virtual_dataset_dir: | |
| virtual_dataset = ShardedEncodedDataset(args.virtual_dataset_dir) | |
| if virtual_dataset.max_length != config.max_seq_length: | |
| raise ValueError( | |
| f"Virtual dataset max_length {virtual_dataset.max_length} does not match " | |
| f"configured max_seq_length {config.max_seq_length}" | |
| ) | |
| train_dataset = virtual_dataset | |
| eval_dataset = EncodedAnimeDataset( | |
| data=eval_data, | |
| tokenizer=tokenizer, | |
| label2id=config.label2id, | |
| max_length=config.max_seq_length, | |
| device=torch.device("cpu"), | |
| apply_label_repairs=args.apply_label_repairs, | |
| ) | |
| dataset_mode = "virtual-sharded" | |
| if not args.keep_raw_dataset: | |
| train_data = [] | |
| all_data = [] | |
| gc.collect() | |
| elif args.lazy_dataset: | |
| train_dataset = AnimeItemsDataset( | |
| data=train_data, | |
| tokenizer=tokenizer, | |
| label2id=config.label2id, | |
| max_length=config.max_seq_length, | |
| apply_label_repairs=args.apply_label_repairs, | |
| ) | |
| eval_dataset = AnimeItemsDataset( | |
| data=eval_data, | |
| tokenizer=tokenizer, | |
| label2id=config.label2id, | |
| max_length=config.max_seq_length, | |
| apply_label_repairs=args.apply_label_repairs, | |
| ) | |
| dataset_mode = "lazy" | |
| else: | |
| encoded_device = torch.device(args.encoded_dataset_device) | |
| if encoded_device.type == "cuda" and use_cpu: | |
| raise ValueError("--encoded-dataset-device cuda cannot be used with CPU training.") | |
| if encoded_device.type == "cuda" and config.num_workers > 0: | |
| raise ValueError("--encoded-dataset-device cuda requires --num-workers 0 to avoid worker duplication.") | |
| train_dataset = EncodedAnimeDataset( | |
| data=train_data, | |
| tokenizer=tokenizer, | |
| label2id=config.label2id, | |
| max_length=config.max_seq_length, | |
| device=encoded_device, | |
| apply_label_repairs=args.apply_label_repairs, | |
| ) | |
| eval_dataset = EncodedAnimeDataset( | |
| data=eval_data, | |
| tokenizer=tokenizer, | |
| label2id=config.label2id, | |
| max_length=config.max_seq_length, | |
| device=encoded_device, | |
| apply_label_repairs=args.apply_label_repairs, | |
| ) | |
| dataset_mode = "encoded" | |
| if not args.keep_raw_dataset: | |
| train_data = [] | |
| all_data = [] | |
| gc.collect() | |
| encode_finished_at = time.perf_counter() | |
| print(f" Train samples: {len(train_dataset)}") | |
| print(f" Eval samples: {len(eval_dataset)}") | |
| print(f" Dataset mode: {dataset_mode}") | |
| print(f" Load time: {load_finished_at - load_started_at:.2f}s") | |
| print(f" Encode time: {encode_finished_at - encode_started_at:.2f}s") | |
| use_bf16 = bool(args.bf16 and not use_cpu) | |
| use_fp16 = bool((not use_cpu) and not use_bf16 and not args.no_mixed_precision) | |
| if use_cpu and args.no_mixed_precision: | |
| use_fp16 = False | |
| if torch.cuda.is_available() and not use_cpu and args.tf32: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| print(f" Device: {'CPU' if use_cpu else 'CUDA'}") | |
| if not use_cpu: | |
| print(f" Mixed precision: {'bf16' if use_bf16 else ('fp16' if use_fp16 else 'disabled')}") | |
| print(f" TF32: {'enabled' if args.tf32 else 'disabled'}") | |
| eval_save_strategy = "no" if args.no_periodic_eval else ("steps" if args.checkpoint_steps else "epoch") | |
| save_strategy = "steps" if args.checkpoint_steps else "epoch" | |
| dataloader_prefetch_factor = args.prefetch_factor | |
| if dataloader_prefetch_factor is None: | |
| dataloader_prefetch_factor = 4 if config.num_workers > 0 else None | |
| persistent_workers = bool(args.persistent_workers and config.num_workers > 0) | |
| dataloader_pin_memory = bool((not use_cpu) and not (not args.lazy_dataset and args.encoded_dataset_device == "cuda")) | |
| if args.lazy_dataset and config.num_workers == 0: | |
| print(" WARNING: lazy dataset mode is slower with zero workers; consider --num-workers 4+.") | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=config.save_dir, | |
| num_train_epochs=config.num_epochs, | |
| max_steps=args.max_steps, | |
| per_device_train_batch_size=config.batch_size, | |
| per_device_eval_batch_size=config.batch_size, | |
| eval_strategy=eval_save_strategy, | |
| save_strategy=save_strategy, | |
| eval_steps=args.checkpoint_steps if eval_save_strategy == "steps" else None, | |
| save_steps=args.checkpoint_steps, | |
| logging_steps=config.log_interval, | |
| learning_rate=config.learning_rate, | |
| weight_decay=config.weight_decay, | |
| warmup_steps=config.warmup_steps, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| use_cpu=use_cpu, | |
| report_to=["tensorboard"] if args.tensorboard else "none", | |
| save_total_limit=args.save_total_limit, | |
| load_best_model_at_end=not args.no_periodic_eval, | |
| metric_for_best_model="f1", | |
| greater_is_better=True, | |
| dataloader_num_workers=config.num_workers, | |
| dataloader_pin_memory=dataloader_pin_memory, | |
| dataloader_prefetch_factor=dataloader_prefetch_factor, | |
| dataloader_persistent_workers=persistent_workers, | |
| fp16=use_fp16, | |
| bf16=use_bf16, | |
| tf32=args.tf32 and not use_cpu, | |
| torch_compile=bool(args.torch_compile and not use_cpu), | |
| auto_find_batch_size=bool(args.auto_find_batch_size and not use_cpu), | |
| include_num_input_tokens_seen=True, | |
| ) | |
| # Data collator | |
| data_collator = FastTokenClassificationCollator() | |
| # Trainer | |
| perf_callback = TrainingPerfCallback( | |
| batch_size=config.batch_size, | |
| sequence_length=config.max_seq_length, | |
| log_steps=args.perf_log_steps, | |
| sample_interval=args.perf_sample_interval, | |
| ) | |
| trainer = OrderedTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| callbacks=[perf_callback], | |
| ) | |
| # Train | |
| print("Starting training...") | |
| resume_from_checkpoint = args.resume_from_checkpoint | |
| if resume_from_checkpoint == "auto": | |
| resume_from_checkpoint = latest_checkpoint(config.save_dir) | |
| if resume_from_checkpoint: | |
| print(f"Resuming from latest checkpoint: {resume_from_checkpoint}") | |
| else: | |
| print("No checkpoint found; starting a fresh training run.") | |
| trainer.train(resume_from_checkpoint=resume_from_checkpoint) | |
| # Set proper label mappings in model config before saving | |
| model.config.id2label = config.id2label | |
| model.config.label2id = config.label2id | |
| model.config.tokenizer_variant = tokenizer_variant | |
| model.config.max_seq_length = config.max_seq_length | |
| # Save final model | |
| final_save_path = os.path.join(config.save_dir, "final") | |
| trainer.save_model(final_save_path) | |
| tokenizer.save_pretrained(final_save_path) | |
| metadata = { | |
| "experiment_name": args.experiment_name, | |
| "data_file": config.data_file, | |
| "data_sources": data_sources, | |
| "augmentation": augmentation_metadata, | |
| "dataset_mode": dataset_mode, | |
| "virtual_dataset_dir": args.virtual_dataset_dir, | |
| "apply_label_repairs": args.apply_label_repairs, | |
| "keep_raw_dataset": args.keep_raw_dataset, | |
| "tokenizer_variant": tokenizer_variant, | |
| "vocab_file": vocab_path, | |
| "vocab_size": tokenizer.vocab_size, | |
| "max_seq_length": config.max_seq_length, | |
| "hidden_size": config.hidden_size, | |
| "num_hidden_layers": config.num_hidden_layers, | |
| "num_attention_heads": config.num_attention_heads, | |
| "intermediate_size": config.intermediate_size, | |
| "train_samples": len(train_dataset), | |
| "eval_samples": len(eval_dataset), | |
| "load_seconds": load_finished_at - load_started_at, | |
| "encode_seconds": encode_finished_at - encode_started_at, | |
| "epochs": config.num_epochs, | |
| "max_steps": args.max_steps, | |
| "batch_size": config.batch_size, | |
| "learning_rate": config.learning_rate, | |
| "warmup_steps": config.warmup_steps, | |
| "seed": args.seed, | |
| "device": "cpu" if use_cpu else "cuda", | |
| "fp16": use_fp16, | |
| "gradient_accumulation_steps": training_args.gradient_accumulation_steps, | |
| "dataloader_num_workers": config.num_workers, | |
| "dataloader_prefetch_factor": dataloader_prefetch_factor, | |
| "dataloader_persistent_workers": persistent_workers, | |
| "dataloader_pin_memory": dataloader_pin_memory, | |
| "encoded_dataset_device": args.encoded_dataset_device if not args.lazy_dataset else None, | |
| "mixed_precision": "bf16" if use_bf16 else ("fp16" if use_fp16 else "none"), | |
| "tf32": bool(args.tf32 and not use_cpu), | |
| "torch_compile": bool(args.torch_compile and not use_cpu), | |
| "auto_find_batch_size": bool(args.auto_find_batch_size and not use_cpu), | |
| "perf_log_steps": args.perf_log_steps, | |
| "perf_sample_interval": args.perf_sample_interval, | |
| "periodic_eval": not args.no_periodic_eval, | |
| } | |
| with open(os.path.join(final_save_path, "run_metadata.json"), "w", encoding="utf-8") as f: | |
| json.dump(metadata, f, ensure_ascii=False, indent=2) | |
| print(f"Model saved to: {final_save_path}") | |
| with open(os.path.join(final_save_path, "perf_metrics.json"), "w", encoding="utf-8") as f: | |
| json.dump(perf_callback.summary(), f, ensure_ascii=False, indent=2) | |
| train_runtime = None | |
| if trainer.state.log_history: | |
| for entry in reversed(trainer.state.log_history): | |
| if "train_runtime" in entry: | |
| train_runtime = entry["train_runtime"] | |
| break | |
| if train_runtime is not None: | |
| print(f" Train runtime: {train_runtime:.2f}s") | |
| print(f" Total wall time (load+encode+train): {(load_finished_at - load_started_at) + (encode_finished_at - encode_started_at) + train_runtime:.2f}s") | |
| # Final evaluation | |
| print("\nFinal evaluation:") | |
| eval_results = trainer.evaluate() | |
| for key, value in eval_results.items(): | |
| print(f" {key}: {value:.4f}") | |
| with open(os.path.join(final_save_path, "trainer_eval_metrics.json"), "w", encoding="utf-8") as f: | |
| json.dump({key: float(value) for key, value in eval_results.items()}, f, ensure_ascii=False, indent=2) | |
| if args.parse_eval_limit != 0: | |
| parse_limit = args.parse_eval_limit if args.parse_eval_limit and args.parse_eval_limit > 0 else None | |
| parse_metrics = parse_exact_metrics_all_modes( | |
| eval_data, | |
| trainer.model, | |
| tokenizer, | |
| config.id2label, | |
| config.max_seq_length, | |
| parse_limit, | |
| ) | |
| with open(os.path.join(final_save_path, "parse_eval_metrics.json"), "w", encoding="utf-8") as f: | |
| json.dump(parse_metrics, f, ensure_ascii=False, indent=2) | |
| print("\nParse exact-match evaluation:") | |
| for mode_name, mode_metrics in parse_metrics["modes"].items(): | |
| print( | |
| f" {mode_name}: {mode_metrics['full_match_correct']}/" | |
| f"{mode_metrics['full_match_total']} ({mode_metrics['full_match_accuracy']:.4f})" | |
| ) | |
| if not args.no_case_eval: | |
| if args.case_eval_file and os.path.isfile(args.case_eval_file): | |
| from tools.evaluate_parser_cases import evaluate_case_modes | |
| case_metrics = evaluate_case_modes( | |
| model_dir=final_save_path, | |
| case_file=args.case_eval_file, | |
| tokenizer_variant=tokenizer_variant, | |
| max_length=config.max_seq_length, | |
| ) | |
| case_output = args.case_eval_output or os.path.join(final_save_path, "case_metrics.json") | |
| os.makedirs(os.path.dirname(case_output) or ".", exist_ok=True) | |
| with open(case_output, "w", encoding="utf-8") as f: | |
| json.dump(case_metrics, f, ensure_ascii=False, indent=2) | |
| print("\nFixed case regression evaluation:") | |
| for mode_name, mode_metrics in case_metrics["modes"].items(): | |
| print( | |
| f" {mode_name}: {mode_metrics['full_correct']}/" | |
| f"{mode_metrics['case_count']} ({mode_metrics['full_accuracy']:.4f})" | |
| ) | |
| primary = case_metrics["modes"][case_metrics["primary_metric"]] | |
| if primary["failures"]: | |
| print(f" primary failures: {len(primary['failures'])} (see {case_output})") | |
| elif args.case_eval_file: | |
| print(f"\nSkipping fixed case regression evaluation; file not found: {args.case_eval_file}") | |
| if __name__ == "__main__": | |
| main() | |