Token Classification
Transformers
ONNX
Safetensors
English
Japanese
Chinese
bert
anime
filename-parsing
Eval Results (legacy)
Instructions to use ModerRAS/AniFileBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ModerRAS/AniFileBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="ModerRAS/AniFileBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("ModerRAS/AniFileBERT") model = AutoModelForTokenClassification.from_pretrained("ModerRAS/AniFileBERT") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Training script for anime filename parser. | |
| Trains a Tiny BERT model for token classification on synthetic anime filename data. | |
| Uses HuggingFace Trainer for CPU training. | |
| Usage: | |
| python train.py | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import tempfile | |
| import argparse | |
| import random | |
| from collections import Counter | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| import torch | |
| from transformers import ( | |
| Trainer, | |
| TrainingArguments, | |
| DataCollatorForTokenClassification, | |
| BertForTokenClassification, | |
| ) | |
| from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score | |
| from config import Config | |
| from tokenizer import AnimeTokenizer, create_tokenizer, load_tokenizer | |
| from model import create_model, print_model_summary, count_parameters | |
| from dataset import AnimeDataset, labels_for_tokenizer | |
| from inference import parse_filename, postprocess | |
| def compute_metrics(p): | |
| """Compute token-level and entity-level metrics using seqeval.""" | |
| predictions, labels = p | |
| predictions = np.argmax(predictions, axis=2) | |
| # Remove ignored index (special tokens) | |
| true_predictions = [] | |
| true_labels = [] | |
| id2label = Config().id2label | |
| for pred_seq, label_seq in zip(predictions, labels): | |
| preds = [] | |
| lbls = [] | |
| for p, l in zip(pred_seq, label_seq): | |
| if l != -100: | |
| preds.append(id2label[p]) | |
| lbls.append(id2label[l]) | |
| true_predictions.append(preds) | |
| true_labels.append(lbls) | |
| # Entity-level metrics (via seqeval) | |
| return { | |
| "precision": precision_score(true_labels, true_predictions), | |
| "recall": recall_score(true_labels, true_predictions), | |
| "f1": f1_score(true_labels, true_predictions), | |
| "accuracy": accuracy_score(true_labels, true_predictions), | |
| } | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Train anime filename parser") | |
| parser.add_argument("--tokenizer", choices=["regex", "char"], default=None, | |
| help="Tokenizer variant for A/B testing. Defaults to dataset metadata") | |
| parser.add_argument("--data-file", default=None, help="Training JSONL file") | |
| parser.add_argument("--vocab-file", default=None, | |
| help="Tokenizer vocab JSON. Defaults to data/vocab.json or data/vocab.char.json") | |
| parser.add_argument("--save-dir", default=None, help="Checkpoint output directory") | |
| parser.add_argument("--init-model-dir", default=None, help="Optional checkpoint to fine-tune from") | |
| parser.add_argument("--epochs", type=float, default=None, help="Number of training epochs") | |
| parser.add_argument("--batch-size", type=int, default=None, help="Per-device train/eval batch size") | |
| parser.add_argument("--learning-rate", type=float, default=None, help="Learning rate") | |
| parser.add_argument("--warmup-steps", type=int, default=None, help="Warmup steps") | |
| parser.add_argument("--train-split", type=float, default=None, help="Train split ratio") | |
| parser.add_argument("--max-seq-length", type=int, default=None, help="Maximum sequence length") | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") | |
| parser.add_argument("--limit-samples", type=int, default=None, | |
| help="Use only the first N samples for quick A/B smoke runs") | |
| parser.add_argument("--rebuild-vocab", action="store_true", | |
| help="Rebuild vocab from the selected data file before training") | |
| parser.add_argument("--max-vocab-size", type=int, default=None, | |
| help="Optional vocab cap used with --rebuild-vocab") | |
| parser.add_argument("--checkpoint-steps", type=int, default=None, | |
| help="Save resumable checkpoints every N steps instead of only at epoch end") | |
| parser.add_argument("--save-total-limit", type=int, default=2, | |
| help="Maximum number of checkpoints to keep") | |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=1, | |
| help="Accumulate gradients across this many steps") | |
| parser.add_argument("--num-workers", type=int, default=None, | |
| help="DataLoader worker count. Defaults to config.num_workers") | |
| parser.add_argument("--cpu", action="store_true", help="Force CPU training") | |
| parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split") | |
| parser.add_argument("--resume-from-checkpoint", default=None, | |
| help="Resume Trainer state from a checkpoint directory, or 'auto' for the latest checkpoint") | |
| parser.add_argument("--tensorboard", dest="tensorboard", action="store_true", | |
| help="Log metrics to TensorBoard in addition to stdout/checkpoints") | |
| parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false", | |
| help="Disable TensorBoard logging") | |
| parser.add_argument("--experiment-name", default=None, | |
| help="Optional experiment name written to run_metadata.json") | |
| parser.add_argument("--parse-eval-limit", type=int, default=512, | |
| help="Run field exact-match evaluation on up to N eval samples after training; 0 disables it") | |
| parser.add_argument("--hidden-size", type=int, default=None, help="Override BERT hidden size") | |
| parser.add_argument("--num-hidden-layers", type=int, default=None, help="Override BERT layer count") | |
| parser.add_argument("--num-attention-heads", type=int, default=None, help="Override BERT attention heads") | |
| parser.add_argument("--intermediate-size", type=int, default=None, help="Override BERT FFN intermediate size") | |
| parser.set_defaults(tensorboard=True) | |
| return parser.parse_args() | |
| def detect_tokenizer_variant( | |
| data_file: str, | |
| explicit_variant: Optional[str], | |
| explicit_vocab_path: Optional[str], | |
| sample_size: int = 256, | |
| ) -> str: | |
| """Infer tokenizer variant from CLI, dataset metadata, or vocab filename.""" | |
| if explicit_variant: | |
| return explicit_variant | |
| variants = set() | |
| char_like = 0 | |
| inspected = 0 | |
| with open(data_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if inspected >= sample_size: | |
| break | |
| line = line.strip() | |
| if not line: | |
| continue | |
| item = json.loads(line) | |
| inspected += 1 | |
| variant = item.get("tokenizer_variant") | |
| if variant: | |
| variants.add(variant) | |
| tokens = item.get("tokens", []) | |
| filename = item.get("filename") | |
| if filename is not None and tokens == list(filename): | |
| char_like += 1 | |
| if len(variants) == 1: | |
| return next(iter(variants)) | |
| if len(variants) > 1: | |
| raise ValueError(f"Mixed tokenizer_variant values in {data_file}: {sorted(variants)}") | |
| if explicit_vocab_path and ".char" in os.path.basename(explicit_vocab_path).lower(): | |
| return "char" | |
| if inspected and char_like / inspected >= 0.95: | |
| return "char" | |
| return "regex" | |
| def resolve_vocab_path(data_file: str, tokenizer_variant: str, explicit_path: Optional[str]) -> str: | |
| if explicit_path: | |
| return explicit_path | |
| name = "vocab.json" if tokenizer_variant == "regex" else "vocab.char.json" | |
| return os.path.join(os.path.dirname(data_file), name) | |
| def latest_checkpoint(save_dir: str) -> Optional[str]: | |
| if not os.path.isdir(save_dir): | |
| return None | |
| checkpoints = [] | |
| for name in os.listdir(save_dir): | |
| if not name.startswith("checkpoint-"): | |
| continue | |
| path = os.path.join(save_dir, name) | |
| if not os.path.isdir(path): | |
| continue | |
| try: | |
| step = int(name.split("-")[-1]) | |
| except ValueError: | |
| continue | |
| checkpoints.append((step, path)) | |
| if not checkpoints: | |
| return None | |
| return max(checkpoints)[1] | |
| def validate_dataset_tokenizer_metadata(data: List[Dict], tokenizer_variant: str) -> None: | |
| variants = {item.get("tokenizer_variant") for item in data if item.get("tokenizer_variant")} | |
| if variants and variants != {tokenizer_variant}: | |
| raise ValueError( | |
| f"Dataset tokenizer_variant {sorted(variants)} does not match selected tokenizer " | |
| f"'{tokenizer_variant}'. Pass --tokenizer explicitly only when this is intentional." | |
| ) | |
| def load_jsonl(data_file: str, limit: Optional[int] = None) -> List[Dict]: | |
| """Load JSONL rows, stopping early for smoke runs.""" | |
| data: List[Dict] = [] | |
| with open(data_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| data.append(json.loads(line)) | |
| if limit is not None and len(data) >= limit: | |
| break | |
| return data | |
| def normalize_field_value(field: str, value) -> Optional[str]: | |
| if value is None: | |
| return None | |
| if field in {"episode", "season"}: | |
| try: | |
| return str(int(value)) | |
| except (TypeError, ValueError): | |
| return str(value).strip().lower() | |
| text = str(value).strip() | |
| if field in {"resolution", "source"}: | |
| return text.lower().replace("_", "-") | |
| return " ".join(text.lower().split()) | |
| def parse_exact_metrics( | |
| samples: List[Dict], | |
| model: BertForTokenClassification, | |
| tokenizer: AnimeTokenizer, | |
| id2label: Dict[int, str], | |
| max_length: int, | |
| limit: Optional[int], | |
| ) -> Dict: | |
| """Evaluate end-to-end field exact match on filenames, not just token loss.""" | |
| fields = ["group", "title", "season", "episode", "resolution", "source", "special"] | |
| selected = [sample for sample in samples if sample.get("filename")] | |
| if limit is not None and limit > 0: | |
| selected = selected[:limit] | |
| counter: Counter = Counter() | |
| failures: List[Dict] = [] | |
| model.eval() | |
| for sample in selected: | |
| filename = sample["filename"] | |
| tokens, gold_labels = labels_for_tokenizer(sample, tokenizer) | |
| available = max(0, max_length - 2) | |
| tokens = tokens[:available] | |
| gold_labels = gold_labels[:available] | |
| gold = postprocess(tokens, gold_labels, tokenizer=tokenizer, filename=filename, use_rules=True) | |
| gold_entities = {label.split("-", 1)[1] for label in gold_labels if label.startswith(("B-", "I-"))} | |
| for optional_field, entity in (("episode", "EPISODE"), ("season", "SEASON")): | |
| if entity not in gold_entities: | |
| gold[optional_field] = None | |
| pred = parse_filename( | |
| filename, | |
| model, | |
| tokenizer, | |
| id2label, | |
| max_length=max_length, | |
| debug=False, | |
| use_rules=True, | |
| constrain_bio=True, | |
| ) | |
| full_match = True | |
| field_errors: Dict[str, Dict[str, Optional[str]]] = {} | |
| for field in fields: | |
| gold_value = normalize_field_value(field, gold.get(field)) | |
| pred_value = normalize_field_value(field, pred.get(field)) | |
| counter[f"{field}_total"] += 1 | |
| if gold_value == pred_value: | |
| counter[f"{field}_correct"] += 1 | |
| else: | |
| full_match = False | |
| field_errors[field] = {"gold": gold_value, "pred": pred_value} | |
| counter["full_total"] += 1 | |
| if full_match: | |
| counter["full_correct"] += 1 | |
| elif len(failures) < 20: | |
| failures.append( | |
| { | |
| "filename": filename, | |
| "errors": field_errors, | |
| "gold": {field: gold.get(field) for field in fields}, | |
| "pred": {field: pred.get(field) for field in fields}, | |
| } | |
| ) | |
| field_accuracy = {} | |
| for field in fields: | |
| total = counter.get(f"{field}_total", 0) | |
| correct = counter.get(f"{field}_correct", 0) | |
| field_accuracy[field] = correct / total if total else 0.0 | |
| total = counter.get("full_total", 0) | |
| correct = counter.get("full_correct", 0) | |
| return { | |
| "sample_count": total, | |
| "field_accuracy": field_accuracy, | |
| "field_correct": {field: counter.get(f"{field}_correct", 0) for field in fields}, | |
| "field_total": {field: counter.get(f"{field}_total", 0) for field in fields}, | |
| "full_match_accuracy": correct / total if total else 0.0, | |
| "full_match_correct": correct, | |
| "full_match_total": total, | |
| "failures": failures, | |
| } | |
| def remap_token_embeddings( | |
| model: BertForTokenClassification, | |
| old_vocab: Dict[str, int], | |
| new_vocab: Dict[str, int], | |
| pad_token_id: int, | |
| ) -> int: | |
| """ | |
| Replace the input embedding table for a changed vocabulary. | |
| resize_token_embeddings() preserves rows by numeric ID, which is unsafe when | |
| two tokenizers assign different tokens to the same ID. This remaps by token | |
| string and randomly initializes tokens that do not exist in the old vocab. | |
| """ | |
| old_embeddings = model.get_input_embeddings() | |
| old_weight = old_embeddings.weight.data | |
| embedding_dim = old_weight.shape[1] | |
| new_embeddings = torch.nn.Embedding( | |
| len(new_vocab), | |
| embedding_dim, | |
| padding_idx=pad_token_id, | |
| device=old_weight.device, | |
| dtype=old_weight.dtype, | |
| ) | |
| torch.nn.init.normal_( | |
| new_embeddings.weight, | |
| mean=0.0, | |
| std=getattr(model.config, "initializer_range", 0.02), | |
| ) | |
| if pad_token_id is not None and 0 <= pad_token_id < len(new_vocab): | |
| new_embeddings.weight.data[pad_token_id].zero_() | |
| copied = 0 | |
| for token, new_id in new_vocab.items(): | |
| old_id = old_vocab.get(token) | |
| if old_id is None or old_id >= old_weight.shape[0]: | |
| continue | |
| new_embeddings.weight.data[new_id].copy_(old_weight[old_id]) | |
| copied += 1 | |
| model.set_input_embeddings(new_embeddings) | |
| model.config.vocab_size = len(new_vocab) | |
| return copied | |
| def build_vocab_from_data(data: List[Dict], tokenizer: AnimeTokenizer, vocab_path: str, | |
| max_size: Optional[int] = None) -> None: | |
| token_lists: List[List[str]] = [] | |
| for item in data: | |
| tokens, _labels = labels_for_tokenizer(item, tokenizer) | |
| token_lists.append(tokens) | |
| tokenizer.build_vocab(token_lists, max_size=max_size) | |
| save_dir = os.path.dirname(vocab_path) or "." | |
| os.makedirs(save_dir, exist_ok=True) | |
| with open(vocab_path, "w", encoding="utf-8") as f: | |
| json.dump(tokenizer.get_vocab(), f, ensure_ascii=False, indent=2) | |
| def main(): | |
| args = parse_args() | |
| config = Config() | |
| if args.data_file is not None: | |
| config.data_file = args.data_file | |
| tokenizer_variant = detect_tokenizer_variant(config.data_file, args.tokenizer, args.vocab_file) | |
| if args.save_dir is not None: | |
| config.save_dir = args.save_dir | |
| elif tokenizer_variant == "char": | |
| config.save_dir = "./checkpoints_char" | |
| if args.epochs is not None: | |
| config.num_epochs = args.epochs | |
| if args.batch_size is not None: | |
| config.batch_size = args.batch_size | |
| if args.learning_rate is not None: | |
| config.learning_rate = args.learning_rate | |
| if args.warmup_steps is not None: | |
| config.warmup_steps = args.warmup_steps | |
| if args.train_split is not None: | |
| config.train_split = args.train_split | |
| if args.num_workers is not None: | |
| config.num_workers = args.num_workers | |
| if args.max_seq_length is not None: | |
| config.max_seq_length = args.max_seq_length | |
| elif tokenizer_variant == "char": | |
| config.max_seq_length = max(config.max_seq_length, 128) | |
| if args.hidden_size is not None: | |
| config.hidden_size = args.hidden_size | |
| if args.num_hidden_layers is not None: | |
| config.num_hidden_layers = args.num_hidden_layers | |
| if args.num_attention_heads is not None: | |
| config.num_attention_heads = args.num_attention_heads | |
| if args.intermediate_size is not None: | |
| config.intermediate_size = args.intermediate_size | |
| if config.hidden_size % config.num_attention_heads != 0: | |
| raise ValueError( | |
| f"hidden_size ({config.hidden_size}) must be divisible by " | |
| f"num_attention_heads ({config.num_attention_heads})." | |
| ) | |
| config.max_position_embeddings = max(config.max_position_embeddings, config.max_seq_length) | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| print("Loading dataset...") | |
| all_data = load_jsonl(config.data_file, args.limit_samples) | |
| if len(all_data) < 2: | |
| raise ValueError("Need at least two samples so train/eval split is non-empty.") | |
| if not args.no_shuffle: | |
| random.shuffle(all_data) | |
| validate_dataset_tokenizer_metadata(all_data, tokenizer_variant) | |
| # Load tokenizer | |
| print("Loading tokenizer...") | |
| vocab_path = resolve_vocab_path(config.data_file, tokenizer_variant, args.vocab_file) | |
| tokenizer = create_tokenizer(tokenizer_variant) | |
| if args.rebuild_vocab or not os.path.isfile(vocab_path): | |
| max_vocab_size = args.max_vocab_size if args.max_vocab_size is not None else config.vocab_size | |
| print(f" Building {tokenizer_variant} vocab: {vocab_path} (max_size={max_vocab_size})") | |
| build_vocab_from_data(all_data, tokenizer, vocab_path, max_size=max_vocab_size) | |
| tokenizer = create_tokenizer(tokenizer_variant, vocab_file=vocab_path) | |
| print(f" Variant: {tokenizer_variant}") | |
| print(f" Vocab size: {tokenizer.vocab_size}") | |
| print(f" Max sequence length: {config.max_seq_length}") | |
| if torch.cuda.is_available() and not args.cpu: | |
| print(f" CUDA device: {torch.cuda.get_device_name(0)}") | |
| print(" Mixed precision: fp16") | |
| # Update config with actual vocab size | |
| config.vocab_size = tokenizer.vocab_size | |
| # Create model | |
| if args.init_model_dir: | |
| print(f"Loading model for fine-tuning: {args.init_model_dir}") | |
| model = BertForTokenClassification.from_pretrained(args.init_model_dir) | |
| init_tokenizer = load_tokenizer(args.init_model_dir, tokenizer_variant) | |
| init_vocab = init_tokenizer.get_vocab() | |
| embedding_size = model.get_input_embeddings().weight.shape[0] | |
| if len(init_vocab) != embedding_size: | |
| print( | |
| " WARNING: init checkpoint tokenizer vocab length does not match model embedding size " | |
| f"({len(init_vocab):,} vs {embedding_size:,}). Prefer a self-consistent checkpoint." | |
| ) | |
| init_variant = getattr(init_tokenizer, "tokenizer_variant", None) | |
| if init_variant != tokenizer_variant: | |
| print(f" WARNING: tokenizer variant changes during fine-tune: {init_variant} -> {tokenizer_variant}") | |
| print(" Token embeddings will be remapped by token string; unmatched tokens are newly initialized.") | |
| if model.config.vocab_size != config.vocab_size or init_vocab != tokenizer.get_vocab(): | |
| copied = remap_token_embeddings( | |
| model=model, | |
| old_vocab=init_vocab, | |
| new_vocab=tokenizer.get_vocab(), | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| print( | |
| f" Remapped token embeddings: copied {copied:,}/{config.vocab_size:,} " | |
| f"tokens from init checkpoint" | |
| ) | |
| model.config.num_labels = config.num_labels | |
| model.config.id2label = config.id2label | |
| model.config.label2id = config.label2id | |
| else: | |
| print("Creating model...") | |
| model: BertForTokenClassification = create_model(config) | |
| total_params = print_model_summary(model) | |
| if total_params >= 5_000_000: | |
| print("WARNING: Model exceeds the historical 5M target; continuing because vocab size is configurable.") | |
| split_idx = int(len(all_data) * config.train_split) | |
| split_idx = max(1, min(len(all_data) - 1, split_idx)) | |
| train_data = all_data[:split_idx] | |
| eval_data = all_data[split_idx:] | |
| # Write split files (temp) | |
| train_file = os.path.join(tempfile.gettempdir(), "anime_train.jsonl") | |
| eval_file = os.path.join(tempfile.gettempdir(), "anime_eval.jsonl") | |
| with open(train_file, 'w', encoding='utf-8') as f: | |
| for item in train_data: | |
| f.write(json.dumps(item, ensure_ascii=False) + '\n') | |
| with open(eval_file, 'w', encoding='utf-8') as f: | |
| for item in eval_data: | |
| f.write(json.dumps(item, ensure_ascii=False) + '\n') | |
| train_dataset = AnimeDataset( | |
| data_path=train_file, | |
| tokenizer=tokenizer, | |
| label2id=config.label2id, | |
| max_length=config.max_seq_length, | |
| ) | |
| eval_dataset = AnimeDataset( | |
| data_path=eval_file, | |
| tokenizer=tokenizer, | |
| label2id=config.label2id, | |
| max_length=config.max_seq_length, | |
| ) | |
| print(f" Train samples: {len(train_dataset)}") | |
| print(f" Eval samples: {len(eval_dataset)}") | |
| use_cpu = args.cpu or not torch.cuda.is_available() | |
| use_fp16 = not use_cpu | |
| print(f" Device: {'CPU' if use_cpu else 'CUDA'}") | |
| eval_save_strategy = "steps" if args.checkpoint_steps else "epoch" | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=config.save_dir, | |
| num_train_epochs=config.num_epochs, | |
| per_device_train_batch_size=config.batch_size, | |
| per_device_eval_batch_size=config.batch_size, | |
| eval_strategy=eval_save_strategy, | |
| save_strategy=eval_save_strategy, | |
| eval_steps=args.checkpoint_steps, | |
| save_steps=args.checkpoint_steps, | |
| logging_steps=config.log_interval, | |
| learning_rate=config.learning_rate, | |
| weight_decay=config.weight_decay, | |
| warmup_steps=config.warmup_steps, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| use_cpu=use_cpu, | |
| report_to=["tensorboard"] if args.tensorboard else "none", | |
| save_total_limit=args.save_total_limit, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="f1", | |
| greater_is_better=True, | |
| dataloader_num_workers=config.num_workers, | |
| dataloader_pin_memory=not use_cpu, | |
| fp16=use_fp16, | |
| ) | |
| # Data collator | |
| data_collator = DataCollatorForTokenClassification(tokenizer) | |
| # Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| # Train | |
| print("Starting training...") | |
| resume_from_checkpoint = args.resume_from_checkpoint | |
| if resume_from_checkpoint == "auto": | |
| resume_from_checkpoint = latest_checkpoint(config.save_dir) | |
| if resume_from_checkpoint: | |
| print(f"Resuming from latest checkpoint: {resume_from_checkpoint}") | |
| else: | |
| print("No checkpoint found; starting a fresh training run.") | |
| trainer.train(resume_from_checkpoint=resume_from_checkpoint) | |
| # Set proper label mappings in model config before saving | |
| model.config.id2label = config.id2label | |
| model.config.label2id = config.label2id | |
| model.config.tokenizer_variant = tokenizer_variant | |
| model.config.max_seq_length = config.max_seq_length | |
| # Save final model | |
| final_save_path = os.path.join(config.save_dir, "final") | |
| trainer.save_model(final_save_path) | |
| tokenizer.save_pretrained(final_save_path) | |
| metadata = { | |
| "experiment_name": args.experiment_name, | |
| "data_file": config.data_file, | |
| "tokenizer_variant": tokenizer_variant, | |
| "vocab_file": vocab_path, | |
| "vocab_size": tokenizer.vocab_size, | |
| "max_seq_length": config.max_seq_length, | |
| "hidden_size": config.hidden_size, | |
| "num_hidden_layers": config.num_hidden_layers, | |
| "num_attention_heads": config.num_attention_heads, | |
| "intermediate_size": config.intermediate_size, | |
| "train_samples": len(train_dataset), | |
| "eval_samples": len(eval_dataset), | |
| "epochs": config.num_epochs, | |
| "batch_size": config.batch_size, | |
| "learning_rate": config.learning_rate, | |
| "warmup_steps": config.warmup_steps, | |
| "seed": args.seed, | |
| "device": "cpu" if use_cpu else "cuda", | |
| "fp16": use_fp16, | |
| "gradient_accumulation_steps": training_args.gradient_accumulation_steps, | |
| "dataloader_num_workers": config.num_workers, | |
| } | |
| with open(os.path.join(final_save_path, "run_metadata.json"), "w", encoding="utf-8") as f: | |
| json.dump(metadata, f, ensure_ascii=False, indent=2) | |
| print(f"Model saved to: {final_save_path}") | |
| # Final evaluation | |
| print("\nFinal evaluation:") | |
| eval_results = trainer.evaluate() | |
| for key, value in eval_results.items(): | |
| print(f" {key}: {value:.4f}") | |
| with open(os.path.join(final_save_path, "trainer_eval_metrics.json"), "w", encoding="utf-8") as f: | |
| json.dump({key: float(value) for key, value in eval_results.items()}, f, ensure_ascii=False, indent=2) | |
| if args.parse_eval_limit != 0: | |
| parse_limit = args.parse_eval_limit if args.parse_eval_limit and args.parse_eval_limit > 0 else None | |
| parse_metrics = parse_exact_metrics( | |
| eval_data, | |
| trainer.model, | |
| tokenizer, | |
| config.id2label, | |
| config.max_seq_length, | |
| parse_limit, | |
| ) | |
| with open(os.path.join(final_save_path, "parse_eval_metrics.json"), "w", encoding="utf-8") as f: | |
| json.dump(parse_metrics, f, ensure_ascii=False, indent=2) | |
| print("\nParse exact-match evaluation:") | |
| print( | |
| f" full_match: {parse_metrics['full_match_correct']}/" | |
| f"{parse_metrics['full_match_total']} ({parse_metrics['full_match_accuracy']:.4f})" | |
| ) | |
| for field, accuracy in parse_metrics["field_accuracy"].items(): | |
| correct = parse_metrics["field_correct"][field] | |
| total = parse_metrics["field_total"][field] | |
| print(f" {field}: {correct}/{total} ({accuracy:.4f})") | |
| if __name__ == "__main__": | |
| main() | |