""" Train AniFileBERT for structured anime filename parsing. The training loop keeps the existing PyTorch/Transformers stack, writes Hugging Face checkpoints, records token/entity metrics, and also evaluates end-to-end parser exact-match on held-out filenames and fixed real-world cases. """ import os import sys import json import argparse import random import subprocess import threading import time import gc from collections import Counter from ctypes import POINTER, Structure, byref, c_int, c_uint, c_ulonglong, c_void_p, cdll from ctypes import util as ctypes_util from typing import Dict, List, Optional, Sequence import numpy as np import torch from torch.utils.data import SequentialSampler from transformers import ( Trainer, TrainingArguments, BertForTokenClassification, TrainerCallback, ) from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score from .config import Config from .tokenizer import AnimeTokenizer, create_tokenizer, load_tokenizer from .model import create_model, print_model_summary, count_parameters from .dataset import AnimeItemsDataset, EncodedAnimeDataset, labels_for_tokenizer from .inference import parse_filename, postprocess from .virtual_dataset import DatasetRangeView, ShardedEncodedDataset def compute_metrics(p): """Compute token-level and entity-level metrics using seqeval.""" predictions, labels = p predictions = np.argmax(predictions, axis=2) # Remove ignored index (special tokens) true_predictions = [] true_labels = [] id2label = Config().id2label for pred_seq, label_seq in zip(predictions, labels): preds = [] lbls = [] for p, l in zip(pred_seq, label_seq): if l != -100: preds.append(id2label[p]) lbls.append(id2label[l]) true_predictions.append(preds) true_labels.append(lbls) # Entity-level metrics (via seqeval) return { "precision": precision_score(true_labels, true_predictions), "recall": recall_score(true_labels, true_predictions), "f1": f1_score(true_labels, true_predictions), "accuracy": accuracy_score(true_labels, true_predictions), } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train anime filename parser") parser.add_argument("--tokenizer", choices=["regex", "char"], default=None, help="Tokenizer variant for A/B testing. Defaults to dataset metadata") parser.add_argument("--data-file", default=None, help="Primary training JSONL file") parser.add_argument("--extra-data-file", action="append", default=[], help="Additional training JSONL file. Can be passed multiple times.") parser.add_argument("--extra-data-repeat", type=int, default=1, help="Repeat each extra dataset this many times after loading") parser.add_argument("--virtual-dataset-dir", default=None, help="Pre-encoded shard directory generated by tools/virtual_dataset_generator") parser.add_argument("--vocab-file", default=None, help="Tokenizer vocab JSON. Defaults to data/vocab.json or data/vocab.char.json") parser.add_argument("--save-dir", default=None, help="Checkpoint output directory") parser.add_argument("--init-model-dir", default=None, help="Optional checkpoint to fine-tune from") parser.add_argument("--epochs", type=float, default=None, help="Number of training epochs") parser.add_argument("--max-steps", type=int, default=-1, help="Override epoch-based training and stop after this many optimizer steps") parser.add_argument("--batch-size", type=int, default=None, help="Per-device train/eval batch size") parser.add_argument("--learning-rate", type=float, default=None, help="Learning rate") parser.add_argument("--warmup-steps", type=int, default=None, help="Warmup steps") parser.add_argument("--train-split", type=float, default=None, help="Train split ratio") parser.add_argument("--max-seq-length", type=int, default=None, help="Maximum sequence length") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--limit-samples", type=int, default=None, help="Use only the first N samples for quick A/B smoke runs") parser.add_argument("--augment-partial-samples", type=int, default=0, help="Generate this many partial BIO-span samples in memory before training") parser.add_argument("--augment-permutation-samples", type=int, default=0, help="Generate this many random BIO-span permutation samples in memory before training") parser.add_argument("--augment-special-samples", type=int, default=0, help="Generate this many special-only/title+special samples such as Menu01 in memory") parser.add_argument("--augment-max-chars", type=int, default=160, help="Maximum character length for generated augmentation samples") parser.add_argument("--rebuild-vocab", action="store_true", help="Rebuild vocab from the selected data file before training") parser.add_argument("--max-vocab-size", type=int, default=None, help="Optional vocab cap used with --rebuild-vocab") parser.add_argument("--checkpoint-steps", type=int, default=None, help="Save resumable checkpoints every N steps instead of only at epoch end") parser.add_argument("--save-total-limit", type=int, default=2, help="Maximum number of checkpoints to keep") parser.add_argument("--no-periodic-eval", action="store_true", help="Skip Trainer's scheduled train-time eval/load-best-model; final evaluation still runs") parser.add_argument("--keep-raw-dataset", action="store_true", help="Keep raw JSONL dictionaries in memory after encoded datasets are built") parser.add_argument("--gradient-accumulation-steps", type=int, default=1, help="Accumulate gradients across this many steps") parser.add_argument("--num-workers", type=int, default=None, help="DataLoader worker count. Defaults to config.num_workers") parser.add_argument("--prefetch-factor", type=int, default=None, help="DataLoader prefetch factor when workers are enabled") parser.add_argument("--persistent-workers", action="store_true", help="Keep DataLoader workers alive between epochs") parser.add_argument("--lazy-dataset", action="store_true", help="Tokenize samples lazily in DataLoader workers instead of pre-encoding tensors") parser.add_argument("--apply-label-repairs", action="store_true", help="Apply runtime deterministic label repairs while building training tensors") parser.add_argument("--encoded-dataset-device", choices=["cpu", "cuda"], default="cpu", help="Store pre-encoded dataset tensors on this device; cuda requires --num-workers 0") parser.add_argument("--bf16", action="store_true", help="Use bfloat16 mixed precision on CUDA instead of fp16") parser.add_argument("--no-mixed-precision", action="store_true", help="Disable fp16/bf16 mixed precision even when CUDA is available") parser.add_argument("--tf32", dest="tf32", action="store_true", help="Enable TF32 matmul/cudnn kernels on CUDA") parser.add_argument("--no-tf32", dest="tf32", action="store_false", help="Disable TF32 matmul/cudnn kernels") parser.add_argument("--torch-compile", action="store_true", help="Enable torch.compile through Hugging Face Trainer") parser.add_argument("--auto-find-batch-size", action="store_true", help="Let Trainer reduce batch size automatically on CUDA OOM") parser.add_argument("--perf-log-steps", type=int, default=100, help="Sample training throughput, memory, and GPU stats every N steps; 0 disables") parser.add_argument("--perf-sample-interval", type=float, default=1.0, help="Background NVML sampling interval in seconds during training; 0 disables") parser.add_argument("--cpu", action="store_true", help="Force CPU training") parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split") parser.add_argument("--resume-from-checkpoint", default=None, help="Resume Trainer state from a checkpoint directory, or 'auto' for the latest checkpoint") parser.add_argument("--tensorboard", dest="tensorboard", action="store_true", help="Log metrics to TensorBoard in addition to stdout/checkpoints") parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false", help="Disable TensorBoard logging") parser.add_argument("--experiment-name", default=None, help="Optional experiment name written to run_metadata.json") parser.add_argument("--parse-eval-limit", type=int, default=512, help="Run field exact-match evaluation on up to N eval samples after training; 0 disables it") parser.add_argument("--case-eval-file", default=os.path.join("data", "parser_regression_cases.json"), help="Fixed real-world parser regression case file evaluated after training") parser.add_argument("--case-eval-output", default=None, help="Optional output path for fixed case metrics; defaults to final/case_metrics.json") parser.add_argument("--no-case-eval", action="store_true", help="Skip fixed real-world parser regression evaluation") parser.add_argument("--hidden-size", type=int, default=None, help="Override BERT hidden size") parser.add_argument("--num-hidden-layers", type=int, default=None, help="Override BERT layer count") parser.add_argument("--num-attention-heads", type=int, default=None, help="Override BERT attention heads") parser.add_argument("--intermediate-size", type=int, default=None, help="Override BERT FFN intermediate size") parser.set_defaults(tf32=True) parser.set_defaults(tensorboard=True) return parser.parse_args() def detect_tokenizer_variant( data_file: str, explicit_variant: Optional[str], explicit_vocab_path: Optional[str], sample_size: int = 256, ) -> str: """Infer tokenizer variant from CLI, dataset metadata, or vocab filename.""" if explicit_variant: return explicit_variant variants = set() char_like = 0 inspected = 0 with open(data_file, "r", encoding="utf-8") as f: for line in f: if inspected >= sample_size: break line = line.strip() if not line: continue item = json.loads(line) inspected += 1 variant = item.get("tokenizer_variant") if variant: variants.add(variant) tokens = item.get("tokens", []) filename = item.get("filename") if filename is not None and tokens == list(filename): char_like += 1 if len(variants) == 1: return next(iter(variants)) if len(variants) > 1: raise ValueError(f"Mixed tokenizer_variant values in {data_file}: {sorted(variants)}") if explicit_vocab_path and ".char" in os.path.basename(explicit_vocab_path).lower(): return "char" if inspected and char_like / inspected >= 0.95: return "char" return "regex" def detect_tokenizer_variant_from_files( data_files: List[str], explicit_variant: Optional[str], explicit_vocab_path: Optional[str], ) -> str: if explicit_variant: return explicit_variant variants = { detect_tokenizer_variant(path, None, explicit_vocab_path) for path in data_files } if len(variants) > 1: raise ValueError(f"Mixed tokenizer variants across datasets: {sorted(variants)}") return next(iter(variants)) def resolve_vocab_path(data_file: str, tokenizer_variant: str, explicit_path: Optional[str]) -> str: if explicit_path: return explicit_path name = "vocab.json" if tokenizer_variant == "regex" else "vocab.char.json" return os.path.join(os.path.dirname(data_file), name) def latest_checkpoint(save_dir: str) -> Optional[str]: if not os.path.isdir(save_dir): return None checkpoints = [] for name in os.listdir(save_dir): if not name.startswith("checkpoint-"): continue path = os.path.join(save_dir, name) if not os.path.isdir(path): continue try: step = int(name.split("-")[-1]) except ValueError: continue checkpoints.append((step, path)) if not checkpoints: return None return max(checkpoints)[1] def validate_dataset_tokenizer_metadata(data: List[Dict], tokenizer_variant: str) -> None: variants = {item.get("tokenizer_variant") for item in data if item.get("tokenizer_variant")} if variants and variants != {tokenizer_variant}: raise ValueError( f"Dataset tokenizer_variant {sorted(variants)} does not match selected tokenizer " f"'{tokenizer_variant}'. Pass --tokenizer explicitly only when this is intentional." ) def load_jsonl(data_file: str, limit: Optional[int] = None) -> List[Dict]: """Load JSONL rows, stopping early for smoke runs.""" data: List[Dict] = [] with open(data_file, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue data.append(json.loads(line)) if limit is not None and len(data) >= limit: break return data def load_training_sources( primary_data_file: str, extra_data_files: List[str], extra_repeat: int, limit: Optional[int] = None, ) -> tuple[List[Dict], List[Dict]]: """Load primary plus extra datasets while preserving source metadata.""" sources: List[Dict] = [] primary = load_jsonl(primary_data_file, limit) all_data: List[Dict] = list(primary) sources.append( { "role": "primary", "path": primary_data_file, "samples": len(primary), "repeat": 1, "effective_samples": len(primary), } ) repeat = max(1, extra_repeat) for path in extra_data_files: rows = load_jsonl(path, None) for _ in range(repeat): all_data.extend(rows) sources.append( { "role": "extra", "path": path, "samples": len(rows), "repeat": repeat, "effective_samples": len(rows) * repeat, } ) return all_data, sources def extract_entities_from_labels(tokens: Sequence[str], labels: Sequence[str]) -> Dict[str, List[str]]: """Extract contiguous BIO entity text spans from token/label arrays.""" entities: Dict[str, List[str]] = {} active_entity: Optional[str] = None active_tokens: List[str] = [] for token, label in zip(tokens, labels): if label.startswith("B-"): if active_entity and active_tokens: entities.setdefault(active_entity, []).append("".join(active_tokens)) active_entity = label[2:] active_tokens = [str(token)] elif label.startswith("I-") and active_entity == label[2:]: active_tokens.append(str(token)) else: if active_entity and active_tokens: entities.setdefault(active_entity, []).append("".join(active_tokens)) active_entity = None active_tokens = [] if active_entity and active_tokens: entities.setdefault(active_entity, []).append("".join(active_tokens)) return entities def char_item_from_spans(filename: str, spans: Sequence[tuple[str, str]], source: str) -> Optional[Dict]: """Create a char-tokenized BIO item from ordered text/entity spans.""" filename = filename.strip() if not filename: return None tokens = list(filename) labels = ["O"] * len(tokens) cursor = 0 for text, entity in spans: if not text: continue start = filename.find(text, cursor) if start < 0: start = filename.find(text) if start < 0: return None end = start + len(text) labels[start] = f"B-{entity}" for idx in range(start + 1, end): labels[idx] = f"I-{entity}" cursor = end return { "filename": filename, "tokens": tokens, "labels": labels, "tokenizer_variant": "char", "source": source, } def entity_keep_probability(entity: str) -> float: return { "GROUP": 0.35, "TITLE": 0.65, "SEASON": 0.35, "EPISODE": 0.7, "SPECIAL": 0.3, "RESOLUTION": 0.65, "SOURCE": 0.65, }.get(entity, 0.5) def build_partial_augmented_item(item: Dict, max_chars: int) -> List[Dict]: entities = extract_entities_from_labels(item.get("tokens", []), item.get("labels", [])) title = next((value.strip() for value in entities.get("TITLE", []) if value.strip()), None) season = next((value.strip() for value in entities.get("SEASON", []) if value.strip()), None) episode = next((value.strip() for value in entities.get("EPISODE", []) if value.strip()), None) special = next((value.strip() for value in entities.get("SPECIAL", []) if value.strip()), None) resolution = next((value.strip() for value in entities.get("RESOLUTION", []) if value.strip()), None) source = next((value.strip() for value in entities.get("SOURCE", []) if value.strip()), None) specs: List[tuple[str, List[tuple[str, str]]]] = [] if title: specs.append((title, [(title, "TITLE")])) if title and season: specs.append((f"{title} {season}", [(title, "TITLE"), (season, "SEASON")])) if episode: specs.append((episode, [(episode, "EPISODE")])) if episode and resolution: specs.append((f"{episode} [{resolution}]", [(episode, "EPISODE"), (resolution, "RESOLUTION")])) if episode and resolution and source: specs.append( ( f"{episode} [{resolution}][{source}]", [(episode, "EPISODE"), (resolution, "RESOLUTION"), (source, "SOURCE")], ) ) if special: specs.append((special, [(special, "SPECIAL")])) if title and special: specs.append((f"{title} - {special}", [(title, "TITLE"), (special, "SPECIAL")])) augmented: List[Dict] = [] for text, spans in specs: if 2 <= len(text) <= max_chars: generated = char_item_from_spans(text, spans, "train_partial_augmentation") if generated is not None: augmented.append(generated) return augmented def build_permutation_augmented_item(item: Dict, rng: random.Random, max_chars: int) -> Optional[Dict]: entities = extract_entities_from_labels(item.get("tokens", []), item.get("labels", [])) available = [ entity for entity in ("GROUP", "TITLE", "SEASON", "EPISODE", "SPECIAL", "RESOLUTION", "SOURCE") if entities.get(entity) ] if not available: return None selected = [ entity for entity in available if rng.random() < entity_keep_probability(entity) ] if not selected: selected = [rng.choice(available)] if "TITLE" not in selected and "EPISODE" not in selected and "SPECIAL" not in selected: extras = [entity for entity in available if entity not in selected] selected.append(rng.choice(extras or available)) rng.shuffle(selected) separators = [" ", " - ", ".", "_", "]["] sep = rng.choice(separators) parts: List[str] = [] spans: List[tuple[str, str]] = [] for entity in selected: values = [value.strip() for value in entities.get(entity, []) if value.strip()] if not values: continue value = rng.choice(values) if entity in {"GROUP", "EPISODE", "SPECIAL", "RESOLUTION", "SOURCE"} and rng.random() < 0.35: parts.append(f"[{value}]") else: parts.append(value) spans.append((value, entity)) text = sep.join(parts).strip() if not (2 <= len(text) <= max_chars): return None return char_item_from_spans(text, spans, "train_permutation_augmentation") def build_special_augmented_item(data: List[Dict], rng: random.Random, max_chars: int) -> Optional[Dict]: base_titles: List[str] = [] for _ in range(min(16, len(data))): item = data[rng.randrange(len(data))] entities = extract_entities_from_labels(item.get("tokens", []), item.get("labels", [])) base_titles.extend(value.strip() for value in entities.get("TITLE", []) if 2 <= len(value.strip()) <= 80) title = rng.choice(base_titles) if base_titles else None special = rng.choice( [ f"Menu{rng.randint(1, 24):02d}", f"Menu {rng.randint(1, 24):02d}", f"BDMenu{rng.randint(1, 24):02d}", f"BD Menu{rng.randint(1, 24):02d}", f"Menu{rng.randint(1, 24):02d}-01", "Menu", f"OP{rng.randint(1, 6):02d}", f"ED E{rng.randint(1, 24):02d}", f"NCOP{rng.randint(1, 6):02d}", f"NCED{rng.randint(1, 6):02d}", f"CM{rng.randint(1, 12):02d}", f"PV{rng.randint(1, 12):02d}", ] ) if title and rng.random() < 0.55: text = f"{title} - {special}" spans = [(title, "TITLE"), (special, "SPECIAL")] else: text = special spans = [(special, "SPECIAL")] if len(text) > max_chars: return None return char_item_from_spans(text, spans, "train_special_augmentation") def process_memory_mb() -> Optional[float]: try: import psutil # type: ignore return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) except Exception: pass if os.name == "nt": try: import ctypes from ctypes import wintypes class PROCESS_MEMORY_COUNTERS(ctypes.Structure): _fields_ = [ ("cb", wintypes.DWORD), ("PageFaultCount", wintypes.DWORD), ("PeakWorkingSetSize", ctypes.c_size_t), ("WorkingSetSize", ctypes.c_size_t), ("QuotaPeakPagedPoolUsage", ctypes.c_size_t), ("QuotaPagedPoolUsage", ctypes.c_size_t), ("QuotaPeakNonPagedPoolUsage", ctypes.c_size_t), ("QuotaNonPagedPoolUsage", ctypes.c_size_t), ("PagefileUsage", ctypes.c_size_t), ("PeakPagefileUsage", ctypes.c_size_t), ] counters = PROCESS_MEMORY_COUNTERS() counters.cb = ctypes.sizeof(counters) handle = ctypes.windll.kernel32.GetCurrentProcess() if ctypes.windll.psapi.GetProcessMemoryInfo(handle, ctypes.byref(counters), counters.cb): return float(counters.WorkingSetSize) / (1024 * 1024) except Exception: pass try: import resource # type: ignore usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss if sys.platform == "darwin": usage = usage / 1024 / 1024 else: usage = usage / 1024 return float(usage) except Exception: return None class NvmlSampler: """Tiny NVML binding for runtime GPU telemetry without adding dependencies.""" NVML_TEMPERATURE_GPU = 0 def __init__(self): self._lib = None self._handle = None self._available = False self._init() def _candidate_names(self) -> List[str]: names = [] found = ctypes_util.find_library("nvidia-ml") if found: names.append(found) if os.name == "nt": names.extend( [ os.path.join(os.environ.get("SystemRoot", r"C:\Windows"), "System32", "nvml.dll"), "nvml.dll", "nvidia-ml.dll", ] ) else: names.extend(["libnvidia-ml.so.1", "libnvidia-ml.so"]) return names def _init(self) -> None: for name in self._candidate_names(): try: lib = cdll.LoadLibrary(name) break except Exception: continue else: return class NVMLUtilization(Structure): _fields_ = [("gpu", c_uint), ("memory", c_uint)] class NVMLMemory(Structure): _fields_ = [("total", c_ulonglong), ("free", c_ulonglong), ("used", c_ulonglong)] self.NVMLUtilization = NVMLUtilization self.NVMLMemory = NVMLMemory lib.nvmlInit_v2.restype = c_int lib.nvmlDeviceGetHandleByIndex_v2.argtypes = [c_uint, POINTER(c_void_p)] lib.nvmlDeviceGetHandleByIndex_v2.restype = c_int handle = c_void_p() try: if lib.nvmlInit_v2() != 0: return if lib.nvmlDeviceGetHandleByIndex_v2(0, byref(handle)) != 0: return except Exception: return self._lib = lib self._handle = handle self._available = True @property def available(self) -> bool: return self._available def sample(self) -> Dict[str, Optional[float]]: if not self._available or self._lib is None or self._handle is None: return {} stats: Dict[str, Optional[float]] = {} try: util_rates = self.NVMLUtilization() self._lib.nvmlDeviceGetUtilizationRates.argtypes = [c_void_p, POINTER(self.NVMLUtilization)] if self._lib.nvmlDeviceGetUtilizationRates(self._handle, byref(util_rates)) == 0: stats["gpu_util_percent"] = float(util_rates.gpu) stats["gpu_memory_util_percent"] = float(util_rates.memory) except Exception: pass try: memory = self.NVMLMemory() self._lib.nvmlDeviceGetMemoryInfo.argtypes = [c_void_p, POINTER(self.NVMLMemory)] if self._lib.nvmlDeviceGetMemoryInfo(self._handle, byref(memory)) == 0: stats["gpu_memory_used_mb"] = float(memory.used) / (1024 * 1024) stats["gpu_memory_total_mb"] = float(memory.total) / (1024 * 1024) except Exception: pass try: temperature = c_uint() self._lib.nvmlDeviceGetTemperature.argtypes = [c_void_p, c_uint, POINTER(c_uint)] if self._lib.nvmlDeviceGetTemperature(self._handle, self.NVML_TEMPERATURE_GPU, byref(temperature)) == 0: stats["gpu_temperature_c"] = float(temperature.value) except Exception: pass try: power_mw = c_uint() self._lib.nvmlDeviceGetPowerUsage.argtypes = [c_void_p, POINTER(c_uint)] if self._lib.nvmlDeviceGetPowerUsage(self._handle, byref(power_mw)) == 0: stats["gpu_power_w"] = float(power_mw.value) / 1000.0 except Exception: pass return stats _NVML_SAMPLER: Optional[NvmlSampler] = None def query_nvml() -> Dict[str, Optional[float]]: global _NVML_SAMPLER if _NVML_SAMPLER is None: _NVML_SAMPLER = NvmlSampler() return _NVML_SAMPLER.sample() def query_nvidia_smi() -> Dict[str, Optional[float]]: try: result = subprocess.run( [ "nvidia-smi", "--query-gpu=utilization.gpu,memory.used,memory.total,power.draw", "--format=csv,noheader,nounits", ], check=False, capture_output=True, text=True, timeout=2, ) except Exception: return {} if result.returncode != 0 or not result.stdout.strip(): return {} first_line = result.stdout.strip().splitlines()[0] values = [part.strip() for part in first_line.split(",")] keys = ["gpu_util_percent", "gpu_memory_used_mb", "gpu_memory_total_mb", "gpu_power_w"] stats: Dict[str, Optional[float]] = {} for key, value in zip(keys, values): try: stats[key] = float(value) except ValueError: stats[key] = None return stats def cuda_memory_stats_mb() -> Dict[str, float]: if not torch.cuda.is_available(): return {} return { "cuda_allocated_mb": torch.cuda.memory_allocated() / (1024 * 1024), "cuda_reserved_mb": torch.cuda.memory_reserved() / (1024 * 1024), "cuda_max_allocated_mb": torch.cuda.max_memory_allocated() / (1024 * 1024), "cuda_max_reserved_mb": torch.cuda.max_memory_reserved() / (1024 * 1024), } def snapshot_perf_stats() -> Dict[str, Optional[float]]: stats: Dict[str, Optional[float]] = {} stats["process_rss_mb"] = process_memory_mb() stats.update(cuda_memory_stats_mb()) gpu_stats = query_nvml() if not gpu_stats: gpu_stats = query_nvidia_smi() stats.update(gpu_stats) return stats class TrainingPerfCallback(TrainerCallback): """Lightweight runtime telemetry for spotting data-pipeline starvation.""" def __init__(self, batch_size: int, sequence_length: int, log_steps: int, sample_interval: float): self.batch_size = batch_size self.sequence_length = sequence_length self.log_steps = max(0, log_steps) self.sample_interval = max(0.0, sample_interval) self.samples: List[Dict[str, Optional[float]]] = [] self.background_samples: List[Dict[str, Optional[float]]] = [] self._last_step = 0 self._last_time: Optional[float] = None self._start_time: Optional[float] = None self._training = False self._stop_event = threading.Event() self._thread: Optional[threading.Thread] = None def on_train_begin(self, args, state, control, **kwargs): now = time.perf_counter() self._start_time = now self._last_time = now self._last_step = int(state.global_step) self._training = True self._stop_event.clear() if self.sample_interval > 0: self._thread = threading.Thread(target=self._background_sample_loop, daemon=True) self._thread.start() def on_train_end(self, args, state, control, **kwargs): self._training = False self._stop_event.set() if self._thread is not None: self._thread.join(timeout=max(self.sample_interval * 2, 1.0)) self._thread = None def on_log(self, args, state, control, logs=None, **kwargs): if not self._training: return step = int(state.global_step) if self.log_steps <= 0 or step <= 0 or step % self.log_steps != 0: return self._record_sample(step) def on_step_end(self, args, state, control, **kwargs): if not self._training: return step = int(state.global_step) if self.log_steps <= 0 or step <= 0 or step % self.log_steps != 0: return if self.samples and self.samples[-1].get("step") == float(step): return self._record_sample(step) def _record_sample(self, step: int) -> None: if self.samples and self.samples[-1].get("step") == float(step): return now = time.perf_counter() last_time = self._last_time or now elapsed = max(now - last_time, 1e-9) step_delta = max(step - self._last_step, 0) samples_per_second = step_delta * self.batch_size / elapsed tokens_per_second = samples_per_second * self.sequence_length stats = snapshot_perf_stats() sample: Dict[str, Optional[float]] = { "step": float(step), "elapsed_seconds": now - (self._start_time or now), "window_seconds": elapsed, "steps_per_second": step_delta / elapsed, "samples_per_second": samples_per_second, "tokens_per_second": tokens_per_second, } sample.update(stats) self.samples.append(sample) print( " perf " f"step={step} " f"samples/s={samples_per_second:.1f} " f"tokens/s={tokens_per_second:.0f} " f"rss={stats.get('process_rss_mb') or 0:.0f}MB " f"cuda_alloc={stats.get('cuda_allocated_mb') or 0:.0f}MB " f"gpu_util={stats.get('gpu_util_percent') if stats.get('gpu_util_percent') is not None else 'n/a'}%" ) self._last_time = now self._last_step = step def _background_sample_loop(self) -> None: while not self._stop_event.wait(self.sample_interval): if not self._training: continue sample = snapshot_perf_stats() sample["elapsed_seconds"] = ( time.perf_counter() - self._start_time if self._start_time is not None else None ) self.background_samples.append(sample) def summary(self) -> Dict: numeric_keys = [ "samples_per_second", "tokens_per_second", "process_rss_mb", "cuda_max_allocated_mb", "gpu_util_percent", "gpu_memory_util_percent", "gpu_power_w", "gpu_temperature_c", ] summary: Dict[str, object] = { "sample_count": len(self.samples), "samples": self.samples, "background_sample_count": len(self.background_samples), "background_samples": self.background_samples, } sample_groups = { "step": self.samples, "background": self.background_samples, } for prefix, samples in sample_groups.items(): if not samples: continue for key in numeric_keys: values = [ float(sample[key]) for sample in samples if sample.get(key) is not None ] if values: summary[f"{prefix}_{key}_avg"] = sum(values) / len(values) summary[f"{prefix}_{key}_max"] = max(values) summary[f"{prefix}_{key}_min"] = min(values) if not self.samples and not self.background_samples: return summary for key in numeric_keys: values = [ float(sample[key]) for sample in self.samples if sample.get(key) is not None ] if values: summary[f"{key}_avg"] = sum(values) / len(values) summary[f"{key}_max"] = max(values) return summary class FastTokenClassificationCollator: """Stack already padded token-classification tensors without extra work.""" def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: batch = { key: torch.stack([feature[key] for feature in features]) for key in features[0].keys() } if "input_ids" in batch: batch["input_ids"] = batch["input_ids"].long() if "labels" in batch: batch["labels"] = batch["labels"].long() if "attention_mask" in batch: batch["attention_mask"] = batch["attention_mask"].to(dtype=torch.bool) return batch class OrderedTrainer(Trainer): """Trainer variant that preserves pre-shuffled order for virtual datasets.""" def _get_train_sampler(self, train_dataset=None): dataset = train_dataset if train_dataset is not None else self.train_dataset if getattr(dataset, "preserve_order", False): return SequentialSampler(dataset) return super()._get_train_sampler(train_dataset) def augment_training_data( data: List[Dict], partial_count: int, permutation_count: int, special_count: int, max_chars: int, seed: int, ) -> tuple[List[Dict], Dict]: """Append generated partial/permutation samples without modifying source JSONL.""" rng = random.Random(seed) augmented: List[Dict] = [] seen = { item.get("filename") or "".join(str(token) for token in item.get("tokens", [])) for item in data } partial_written = 0 if partial_count > 0: candidates: List[Dict] = [] attempts = 0 max_attempts = max(partial_count * 20, len(data)) while len(candidates) < partial_count * 4 and attempts < max_attempts: attempts += 1 candidates.extend(build_partial_augmented_item(rng.choice(data), max_chars)) rng.shuffle(candidates) for item in candidates: key = item["filename"] if key in seen: continue seen.add(key) augmented.append(item) partial_written += 1 if partial_written >= partial_count: break permutation_written = 0 attempts = 0 while permutation_written < permutation_count and attempts < max(permutation_count * 20, 100): attempts += 1 item = build_permutation_augmented_item(rng.choice(data), rng, max_chars) if item is None: continue key = item["filename"] if key in seen: continue seen.add(key) augmented.append(item) permutation_written += 1 special_written = 0 attempts = 0 while special_written < special_count and attempts < max(special_count * 20, 100): attempts += 1 item = build_special_augmented_item(data, rng, max_chars) if item is None: continue key = item["filename"] if key in seen: continue seen.add(key) augmented.append(item) special_written += 1 meta = { "partial_requested": partial_count, "partial_written": partial_written, "permutation_requested": permutation_count, "permutation_written": permutation_written, "special_requested": special_count, "special_written": special_written, "max_chars": max_chars, } return data + augmented, meta def normalize_field_value(field: str, value) -> Optional[str]: if value is None: return None if field in {"episode", "season"}: try: return str(int(value)) except (TypeError, ValueError): return str(value).strip().lower() text = str(value).strip() if field in {"resolution", "source"}: return text.lower().replace("_", "-") return " ".join(text.lower().split()) def parse_exact_metrics( samples: List[Dict], model: BertForTokenClassification, tokenizer: AnimeTokenizer, id2label: Dict[int, str], max_length: int, limit: Optional[int], constrain_bio: bool = True, ) -> Dict: """Evaluate end-to-end field exact match on filenames, not just token loss.""" fields = ["group", "title", "season", "episode", "resolution", "source", "special"] selected = [sample for sample in samples if sample.get("filename")] if limit is not None and limit > 0: selected = selected[:limit] counter: Counter = Counter() failures: List[Dict] = [] model.eval() for sample in selected: filename = sample["filename"] tokens, gold_labels = labels_for_tokenizer(sample, tokenizer) available = max(0, max_length - 2) tokens = tokens[:available] gold_labels = gold_labels[:available] gold = postprocess(tokens, gold_labels, tokenizer=tokenizer) gold_entities = {label.split("-", 1)[1] for label in gold_labels if label.startswith(("B-", "I-"))} for optional_field, entity in (("episode", "EPISODE"), ("season", "SEASON")): if entity not in gold_entities: gold[optional_field] = None pred = parse_filename( filename, model, tokenizer, id2label, max_length=max_length, debug=False, constrain_bio=constrain_bio, ) full_match = True field_errors: Dict[str, Dict[str, Optional[str]]] = {} for field in fields: gold_value = normalize_field_value(field, gold.get(field)) pred_value = normalize_field_value(field, pred.get(field)) counter[f"{field}_total"] += 1 if gold_value == pred_value: counter[f"{field}_correct"] += 1 else: full_match = False field_errors[field] = {"gold": gold_value, "pred": pred_value} counter["full_total"] += 1 if full_match: counter["full_correct"] += 1 elif len(failures) < 20: failures.append( { "filename": filename, "errors": field_errors, "gold": {field: gold.get(field) for field in fields}, "pred": {field: pred.get(field) for field in fields}, } ) field_accuracy = {} for field in fields: total = counter.get(f"{field}_total", 0) correct = counter.get(f"{field}_correct", 0) field_accuracy[field] = correct / total if total else 0.0 total = counter.get("full_total", 0) correct = counter.get("full_correct", 0) return { "constrain_bio": constrain_bio, "sample_count": total, "field_accuracy": field_accuracy, "field_correct": {field: counter.get(f"{field}_correct", 0) for field in fields}, "field_total": {field: counter.get(f"{field}_total", 0) for field in fields}, "full_match_accuracy": correct / total if total else 0.0, "full_match_correct": correct, "full_match_total": total, "failures": failures, } def parse_exact_metrics_all_modes( samples: List[Dict], model: BertForTokenClassification, tokenizer: AnimeTokenizer, id2label: Dict[int, str], max_length: int, limit: Optional[int], ) -> Dict: modes = { "model_only": {"constrain_bio": False}, "normalized_only": {"constrain_bio": True}, } return { "primary_metric": "normalized_only", "modes": { name: parse_exact_metrics( samples, model, tokenizer, id2label, max_length, limit, constrain_bio=settings["constrain_bio"], ) for name, settings in modes.items() }, } def remap_token_embeddings( model: BertForTokenClassification, old_vocab: Dict[str, int], new_vocab: Dict[str, int], pad_token_id: int, ) -> int: """ Replace the input embedding table for a changed vocabulary. resize_token_embeddings() preserves rows by numeric ID, which is unsafe when two tokenizers assign different tokens to the same ID. This remaps by token string and randomly initializes tokens that do not exist in the old vocab. """ old_embeddings = model.get_input_embeddings() old_weight = old_embeddings.weight.data embedding_dim = old_weight.shape[1] new_embeddings = torch.nn.Embedding( len(new_vocab), embedding_dim, padding_idx=pad_token_id, device=old_weight.device, dtype=old_weight.dtype, ) torch.nn.init.normal_( new_embeddings.weight, mean=0.0, std=getattr(model.config, "initializer_range", 0.02), ) if pad_token_id is not None and 0 <= pad_token_id < len(new_vocab): new_embeddings.weight.data[pad_token_id].zero_() copied = 0 for token, new_id in new_vocab.items(): old_id = old_vocab.get(token) if old_id is None or old_id >= old_weight.shape[0]: continue new_embeddings.weight.data[new_id].copy_(old_weight[old_id]) copied += 1 model.set_input_embeddings(new_embeddings) model.config.vocab_size = len(new_vocab) return copied def build_vocab_from_data(data: List[Dict], tokenizer: AnimeTokenizer, vocab_path: str, max_size: Optional[int] = None) -> None: token_lists: List[List[str]] = [] for item in data: tokens, _labels = labels_for_tokenizer(item, tokenizer) token_lists.append(tokens) tokenizer.build_vocab(token_lists, max_size=max_size) save_dir = os.path.dirname(vocab_path) or "." os.makedirs(save_dir, exist_ok=True) with open(vocab_path, "w", encoding="utf-8") as f: json.dump(tokenizer.get_vocab(), f, ensure_ascii=False, indent=2) def main(): args = parse_args() config = Config() if args.data_file is not None: config.data_file = args.data_file training_files = [config.data_file] + list(args.extra_data_file or []) tokenizer_variant = detect_tokenizer_variant_from_files(training_files, args.tokenizer, args.vocab_file) if args.save_dir is not None: config.save_dir = args.save_dir elif tokenizer_variant == "char": config.save_dir = "./checkpoints_char" if args.epochs is not None: config.num_epochs = args.epochs if args.batch_size is not None: config.batch_size = args.batch_size if args.learning_rate is not None: config.learning_rate = args.learning_rate if args.warmup_steps is not None: config.warmup_steps = args.warmup_steps if args.train_split is not None: config.train_split = args.train_split if args.num_workers is not None: config.num_workers = args.num_workers if args.max_seq_length is not None: config.max_seq_length = args.max_seq_length elif tokenizer_variant == "char": config.max_seq_length = max(config.max_seq_length, 128) if args.hidden_size is not None: config.hidden_size = args.hidden_size if args.num_hidden_layers is not None: config.num_hidden_layers = args.num_hidden_layers if args.num_attention_heads is not None: config.num_attention_heads = args.num_attention_heads if args.intermediate_size is not None: config.intermediate_size = args.intermediate_size if config.hidden_size % config.num_attention_heads != 0: raise ValueError( f"hidden_size ({config.hidden_size}) must be divisible by " f"num_attention_heads ({config.num_attention_heads})." ) config.max_position_embeddings = max(config.max_position_embeddings, config.max_seq_length) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) print("Loading dataset...") load_started_at = time.perf_counter() all_data, data_sources = load_training_sources( primary_data_file=config.data_file, extra_data_files=list(args.extra_data_file or []), extra_repeat=args.extra_data_repeat, limit=args.limit_samples, ) augmentation_metadata = { "partial_requested": 0, "partial_written": 0, "permutation_requested": 0, "permutation_written": 0, "special_requested": 0, "special_written": 0, "max_chars": args.augment_max_chars, } if args.augment_partial_samples or args.augment_permutation_samples or args.augment_special_samples: if tokenizer_variant != "char": raise ValueError("Training-time BIO span augmentation currently requires --tokenizer char.") all_data, augmentation_metadata = augment_training_data( data=all_data, partial_count=args.augment_partial_samples, permutation_count=args.augment_permutation_samples, special_count=args.augment_special_samples, max_chars=args.augment_max_chars, seed=args.seed + 1009, ) load_finished_at = time.perf_counter() if len(all_data) < 2: raise ValueError("Need at least two samples so train/eval split is non-empty.") if not args.no_shuffle: random.shuffle(all_data) validate_dataset_tokenizer_metadata(all_data, tokenizer_variant) # Load tokenizer print("Loading tokenizer...") vocab_path = resolve_vocab_path(config.data_file, tokenizer_variant, args.vocab_file) tokenizer = create_tokenizer(tokenizer_variant) if args.rebuild_vocab or not os.path.isfile(vocab_path): max_vocab_size = args.max_vocab_size if args.max_vocab_size is not None else config.vocab_size print(f" Building {tokenizer_variant} vocab: {vocab_path} (max_size={max_vocab_size})") build_vocab_from_data(all_data, tokenizer, vocab_path, max_size=max_vocab_size) tokenizer = create_tokenizer(tokenizer_variant, vocab_file=vocab_path) print(f" Variant: {tokenizer_variant}") print(f" Vocab size: {tokenizer.vocab_size}") print(f" Max sequence length: {config.max_seq_length}") if torch.cuda.is_available() and not args.cpu: print(f" CUDA device: {torch.cuda.get_device_name(0)}") # Update config with actual vocab size config.vocab_size = tokenizer.vocab_size # Create model if args.init_model_dir: print(f"Loading model for fine-tuning: {args.init_model_dir}") model = BertForTokenClassification.from_pretrained(args.init_model_dir) init_tokenizer = load_tokenizer(args.init_model_dir, tokenizer_variant) init_vocab = init_tokenizer.get_vocab() embedding_size = model.get_input_embeddings().weight.shape[0] if len(init_vocab) != embedding_size: print( " WARNING: init checkpoint tokenizer vocab length does not match model embedding size " f"({len(init_vocab):,} vs {embedding_size:,}). Prefer a self-consistent checkpoint." ) init_variant = getattr(init_tokenizer, "tokenizer_variant", None) if init_variant != tokenizer_variant: print(f" WARNING: tokenizer variant changes during fine-tune: {init_variant} -> {tokenizer_variant}") print(" Token embeddings will be remapped by token string; unmatched tokens are newly initialized.") if model.config.vocab_size != config.vocab_size or init_vocab != tokenizer.get_vocab(): copied = remap_token_embeddings( model=model, old_vocab=init_vocab, new_vocab=tokenizer.get_vocab(), pad_token_id=tokenizer.pad_token_id, ) print( f" Remapped token embeddings: copied {copied:,}/{config.vocab_size:,} " f"tokens from init checkpoint" ) model.config.num_labels = config.num_labels model.config.id2label = config.id2label model.config.label2id = config.label2id else: print("Creating model...") model: BertForTokenClassification = create_model(config) total_params = print_model_summary(model) if total_params >= 5_000_000: print("WARNING: Model exceeds the historical 5M target; continuing because vocab size is configurable.") use_cpu = args.cpu or not torch.cuda.is_available() split_idx = int(len(all_data) * config.train_split) split_idx = max(1, min(len(all_data) - 1, split_idx)) train_data = all_data[:split_idx] eval_data = all_data[split_idx:] encode_started_at = time.perf_counter() if args.virtual_dataset_dir: virtual_dataset = ShardedEncodedDataset(args.virtual_dataset_dir) if virtual_dataset.max_length != config.max_seq_length: raise ValueError( f"Virtual dataset max_length {virtual_dataset.max_length} does not match " f"configured max_seq_length {config.max_seq_length}" ) train_dataset = virtual_dataset eval_dataset = EncodedAnimeDataset( data=eval_data, tokenizer=tokenizer, label2id=config.label2id, max_length=config.max_seq_length, device=torch.device("cpu"), apply_label_repairs=args.apply_label_repairs, ) dataset_mode = "virtual-sharded" if not args.keep_raw_dataset: train_data = [] all_data = [] gc.collect() elif args.lazy_dataset: train_dataset = AnimeItemsDataset( data=train_data, tokenizer=tokenizer, label2id=config.label2id, max_length=config.max_seq_length, apply_label_repairs=args.apply_label_repairs, ) eval_dataset = AnimeItemsDataset( data=eval_data, tokenizer=tokenizer, label2id=config.label2id, max_length=config.max_seq_length, apply_label_repairs=args.apply_label_repairs, ) dataset_mode = "lazy" else: encoded_device = torch.device(args.encoded_dataset_device) if encoded_device.type == "cuda" and use_cpu: raise ValueError("--encoded-dataset-device cuda cannot be used with CPU training.") if encoded_device.type == "cuda" and config.num_workers > 0: raise ValueError("--encoded-dataset-device cuda requires --num-workers 0 to avoid worker duplication.") train_dataset = EncodedAnimeDataset( data=train_data, tokenizer=tokenizer, label2id=config.label2id, max_length=config.max_seq_length, device=encoded_device, apply_label_repairs=args.apply_label_repairs, ) eval_dataset = EncodedAnimeDataset( data=eval_data, tokenizer=tokenizer, label2id=config.label2id, max_length=config.max_seq_length, device=encoded_device, apply_label_repairs=args.apply_label_repairs, ) dataset_mode = "encoded" if not args.keep_raw_dataset: train_data = [] all_data = [] gc.collect() encode_finished_at = time.perf_counter() print(f" Train samples: {len(train_dataset)}") print(f" Eval samples: {len(eval_dataset)}") print(f" Dataset mode: {dataset_mode}") print(f" Load time: {load_finished_at - load_started_at:.2f}s") print(f" Encode time: {encode_finished_at - encode_started_at:.2f}s") use_bf16 = bool(args.bf16 and not use_cpu) use_fp16 = bool((not use_cpu) and not use_bf16 and not args.no_mixed_precision) if use_cpu and args.no_mixed_precision: use_fp16 = False if torch.cuda.is_available() and not use_cpu and args.tf32: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True print(f" Device: {'CPU' if use_cpu else 'CUDA'}") if not use_cpu: print(f" Mixed precision: {'bf16' if use_bf16 else ('fp16' if use_fp16 else 'disabled')}") print(f" TF32: {'enabled' if args.tf32 else 'disabled'}") eval_save_strategy = "no" if args.no_periodic_eval else ("steps" if args.checkpoint_steps else "epoch") save_strategy = "steps" if args.checkpoint_steps else "epoch" dataloader_prefetch_factor = args.prefetch_factor if dataloader_prefetch_factor is None: dataloader_prefetch_factor = 4 if config.num_workers > 0 else None persistent_workers = bool(args.persistent_workers and config.num_workers > 0) dataloader_pin_memory = bool((not use_cpu) and not (not args.lazy_dataset and args.encoded_dataset_device == "cuda")) if args.lazy_dataset and config.num_workers == 0: print(" WARNING: lazy dataset mode is slower with zero workers; consider --num-workers 4+.") # Training arguments training_args = TrainingArguments( output_dir=config.save_dir, num_train_epochs=config.num_epochs, max_steps=args.max_steps, per_device_train_batch_size=config.batch_size, per_device_eval_batch_size=config.batch_size, eval_strategy=eval_save_strategy, save_strategy=save_strategy, eval_steps=args.checkpoint_steps if eval_save_strategy == "steps" else None, save_steps=args.checkpoint_steps, logging_steps=config.log_interval, learning_rate=config.learning_rate, weight_decay=config.weight_decay, warmup_steps=config.warmup_steps, gradient_accumulation_steps=args.gradient_accumulation_steps, use_cpu=use_cpu, report_to=["tensorboard"] if args.tensorboard else "none", save_total_limit=args.save_total_limit, load_best_model_at_end=not args.no_periodic_eval, metric_for_best_model="f1", greater_is_better=True, dataloader_num_workers=config.num_workers, dataloader_pin_memory=dataloader_pin_memory, dataloader_prefetch_factor=dataloader_prefetch_factor, dataloader_persistent_workers=persistent_workers, fp16=use_fp16, bf16=use_bf16, tf32=args.tf32 and not use_cpu, torch_compile=bool(args.torch_compile and not use_cpu), auto_find_batch_size=bool(args.auto_find_batch_size and not use_cpu), include_num_input_tokens_seen=True, ) # Data collator data_collator = FastTokenClassificationCollator() # Trainer perf_callback = TrainingPerfCallback( batch_size=config.batch_size, sequence_length=config.max_seq_length, log_steps=args.perf_log_steps, sample_interval=args.perf_sample_interval, ) trainer = OrderedTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, compute_metrics=compute_metrics, callbacks=[perf_callback], ) # Train print("Starting training...") resume_from_checkpoint = args.resume_from_checkpoint if resume_from_checkpoint == "auto": resume_from_checkpoint = latest_checkpoint(config.save_dir) if resume_from_checkpoint: print(f"Resuming from latest checkpoint: {resume_from_checkpoint}") else: print("No checkpoint found; starting a fresh training run.") trainer.train(resume_from_checkpoint=resume_from_checkpoint) # Set proper label mappings in model config before saving model.config.id2label = config.id2label model.config.label2id = config.label2id model.config.tokenizer_variant = tokenizer_variant model.config.max_seq_length = config.max_seq_length # Save final model final_save_path = os.path.join(config.save_dir, "final") trainer.save_model(final_save_path) tokenizer.save_pretrained(final_save_path) metadata = { "experiment_name": args.experiment_name, "data_file": config.data_file, "data_sources": data_sources, "augmentation": augmentation_metadata, "dataset_mode": dataset_mode, "virtual_dataset_dir": args.virtual_dataset_dir, "apply_label_repairs": args.apply_label_repairs, "keep_raw_dataset": args.keep_raw_dataset, "tokenizer_variant": tokenizer_variant, "vocab_file": vocab_path, "vocab_size": tokenizer.vocab_size, "max_seq_length": config.max_seq_length, "hidden_size": config.hidden_size, "num_hidden_layers": config.num_hidden_layers, "num_attention_heads": config.num_attention_heads, "intermediate_size": config.intermediate_size, "train_samples": len(train_dataset), "eval_samples": len(eval_dataset), "load_seconds": load_finished_at - load_started_at, "encode_seconds": encode_finished_at - encode_started_at, "epochs": config.num_epochs, "max_steps": args.max_steps, "batch_size": config.batch_size, "learning_rate": config.learning_rate, "warmup_steps": config.warmup_steps, "seed": args.seed, "device": "cpu" if use_cpu else "cuda", "fp16": use_fp16, "gradient_accumulation_steps": training_args.gradient_accumulation_steps, "dataloader_num_workers": config.num_workers, "dataloader_prefetch_factor": dataloader_prefetch_factor, "dataloader_persistent_workers": persistent_workers, "dataloader_pin_memory": dataloader_pin_memory, "encoded_dataset_device": args.encoded_dataset_device if not args.lazy_dataset else None, "mixed_precision": "bf16" if use_bf16 else ("fp16" if use_fp16 else "none"), "tf32": bool(args.tf32 and not use_cpu), "torch_compile": bool(args.torch_compile and not use_cpu), "auto_find_batch_size": bool(args.auto_find_batch_size and not use_cpu), "perf_log_steps": args.perf_log_steps, "perf_sample_interval": args.perf_sample_interval, "periodic_eval": not args.no_periodic_eval, } with open(os.path.join(final_save_path, "run_metadata.json"), "w", encoding="utf-8") as f: json.dump(metadata, f, ensure_ascii=False, indent=2) print(f"Model saved to: {final_save_path}") with open(os.path.join(final_save_path, "perf_metrics.json"), "w", encoding="utf-8") as f: json.dump(perf_callback.summary(), f, ensure_ascii=False, indent=2) train_runtime = None if trainer.state.log_history: for entry in reversed(trainer.state.log_history): if "train_runtime" in entry: train_runtime = entry["train_runtime"] break if train_runtime is not None: print(f" Train runtime: {train_runtime:.2f}s") print(f" Total wall time (load+encode+train): {(load_finished_at - load_started_at) + (encode_finished_at - encode_started_at) + train_runtime:.2f}s") # Final evaluation print("\nFinal evaluation:") eval_results = trainer.evaluate() for key, value in eval_results.items(): print(f" {key}: {value:.4f}") with open(os.path.join(final_save_path, "trainer_eval_metrics.json"), "w", encoding="utf-8") as f: json.dump({key: float(value) for key, value in eval_results.items()}, f, ensure_ascii=False, indent=2) if args.parse_eval_limit != 0: parse_limit = args.parse_eval_limit if args.parse_eval_limit and args.parse_eval_limit > 0 else None parse_metrics = parse_exact_metrics_all_modes( eval_data, trainer.model, tokenizer, config.id2label, config.max_seq_length, parse_limit, ) with open(os.path.join(final_save_path, "parse_eval_metrics.json"), "w", encoding="utf-8") as f: json.dump(parse_metrics, f, ensure_ascii=False, indent=2) print("\nParse exact-match evaluation:") for mode_name, mode_metrics in parse_metrics["modes"].items(): print( f" {mode_name}: {mode_metrics['full_match_correct']}/" f"{mode_metrics['full_match_total']} ({mode_metrics['full_match_accuracy']:.4f})" ) if not args.no_case_eval: if args.case_eval_file and os.path.isfile(args.case_eval_file): from tools.evaluate_parser_cases import evaluate_case_modes case_metrics = evaluate_case_modes( model_dir=final_save_path, case_file=args.case_eval_file, tokenizer_variant=tokenizer_variant, max_length=config.max_seq_length, ) case_output = args.case_eval_output or os.path.join(final_save_path, "case_metrics.json") os.makedirs(os.path.dirname(case_output) or ".", exist_ok=True) with open(case_output, "w", encoding="utf-8") as f: json.dump(case_metrics, f, ensure_ascii=False, indent=2) print("\nFixed case regression evaluation:") for mode_name, mode_metrics in case_metrics["modes"].items(): print( f" {mode_name}: {mode_metrics['full_correct']}/" f"{mode_metrics['case_count']} ({mode_metrics['full_accuracy']:.4f})" ) primary = case_metrics["modes"][case_metrics["primary_metric"]] if primary["failures"]: print(f" primary failures: {len(primary['failures'])} (see {case_output})") elif args.case_eval_file: print(f"\nSkipping fixed case regression evaluation; file not found: {args.case_eval_file}") if __name__ == "__main__": main()