"""Fine-tune xlm-roberta-base for token-level NER (BIO tagging). Tag set (9): O, B-PER, I-PER, B-LOC, I-LOC, B-ORG, I-ORG, B-DATE, I-DATE Inputs: data/processed/ner/ (HuggingFace DatasetDict; word-tokenized) data/processed/ner/labels.json Outputs: models/ner_model/ (best model + tokenizer) models/ner_model/eval_results.json (per-entity F1 from seqeval) models/ner_model/runs/ (training checkpoints + logs) Implementation notes: - The raw tokens come pre-tokenized at WORD level (whitespace-split). XLM-R uses SentencePiece subwords, so we re-tokenize with `is_split_into_words=True` and align labels to subwords: * first subword of each word -> word's tag * inner subwords -> -100 (ignored by the loss) * special tokens -> -100 This matches the standard HuggingFace NER recipe. - Metrics use seqeval (entity-level): a span counts as correct only if BOTH boundary AND type match — much stricter than token-level accuracy. GPU notes for GTX 1650 (3.6 GB VRAM): same recipe — fp16 + gradient_checkpointing + batch=8 with grad_accum=2. Usage: python src/train_ner.py python src/train_ner.py --epochs 3 python src/train_ner.py --quick """ from __future__ import annotations import os # Reduce CUDA memory fragmentation on tight-VRAM GPUs (must precede torch import). os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import argparse import inspect import json import shutil import sys from pathlib import Path from typing import Any import numpy as np import torch from datasets import load_from_disk from seqeval.metrics import ( classification_report as seq_classification_report, f1_score as seq_f1, precision_score as seq_p, recall_score as seq_r, ) from transformers import ( AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer, TrainingArguments, set_seed, ) PROJECT_ROOT = Path(__file__).resolve().parent.parent DATA_DIR = PROJECT_ROOT / "data" / "processed" / "ner" LABELS_FILE = DATA_DIR / "labels.json" OUT_DIR = PROJECT_ROOT / "models" / "ner_model" RUNS_DIR = OUT_DIR / "runs" # Spec called for xlm-roberta-base, but at 270M params it does not fit in # the GTX 1650's 3.6 GB VRAM during optimizer step — Adam needs ~2.2 GB of # state and even Adafactor allocates a ~770 MB temp tensor for grad**2 over # the embedding matrix. distilbert-base-multilingual-cased is the standard # substitute: 134M params, trained on 104 languages (AR/EN/FR included), # typically within 1-3% F1 of XLM-R on classification. MODEL_NAME = "distilbert-base-multilingual-cased" MAX_LENGTH = 128 SEED = 42 def _trainer_with_tokenizer(tokenizer, **kwargs: Any) -> Trainer: """Construct Trainer with whichever tokenizer kwarg is supported.""" params = inspect.signature(Trainer.__init__).parameters if "processing_class" in params: kwargs["processing_class"] = tokenizer elif "tokenizer" in params: kwargs["tokenizer"] = tokenizer return Trainer(**kwargs) def main() -> int: """Train XLM-R for token-level NER. Returns exit code.""" parser = argparse.ArgumentParser(description=__doc__.split("\n")[0]) parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs (default 5).") parser.add_argument("--batch-size", type=int, default=8, help="Per-device train batch size (default 8 — fits comfortably " "with distilbert-multilingual on a 3.6 GB GPU).") parser.add_argument("--quick", action="store_true", help="Sanity smoke test: 1 epoch, 500 train rows.") parser.add_argument("--lr", type=float, default=2e-5) parser.add_argument("--optim", type=str, default="adamw_torch", help="Optimizer name. AdamW fits with the smaller distilbert model.") args = parser.parse_args() set_seed(SEED) print("=" * 72) print("Train NER model (xlm-roberta-base, 9 BIO tags)") print("=" * 72) print(f" Data dir : {DATA_DIR}") print(f" Out dir : {OUT_DIR}") print(f" Epochs : {args.epochs}{' (QUICK)' if args.quick else ''}") print(f" Batch : {args.batch_size} (effective ≈ {args.batch_size * 2} via accum)") print(f" Optimizer: {args.optim}") # --- Labels -------------------------------------------------------------- labels_payload = json.loads(LABELS_FILE.read_text()) label_to_id: dict[str, int] = labels_payload["label_to_id"] id_to_label: dict[int, str] = {int(k): v for k, v in labels_payload["id_to_label"].items()} label_names = [id_to_label[i] for i in range(len(id_to_label))] num_labels = len(label_names) print(f" Labels : {label_to_id}") # --- Datasets ------------------------------------------------------------ ds = load_from_disk(str(DATA_DIR)) print(f" Splits : train={len(ds['train'])} val={len(ds['validation'])} " f"test={len(ds['test'])}") if args.quick: ds["train"] = ds["train"].shuffle(seed=SEED).select(range(min(500, len(ds["train"])))) ds["validation"] = ds["validation"].select(range(min(120, len(ds["validation"])))) ds["test"] = ds["test"].select(range(min(120, len(ds["test"])))) print(f" QUICK : sliced to {len(ds['train'])}/{len(ds['validation'])}/{len(ds['test'])}") # --- Tokenize + align labels to subwords -------------------------------- print("\nLoading tokenizer & model ...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) def tokenize_and_align(batch: dict[str, list]) -> dict[str, Any]: tokenized = tokenizer( batch["tokens"], is_split_into_words=True, truncation=True, max_length=MAX_LENGTH, ) all_labels = [] for i, word_tag_ids in enumerate(batch["ner_tag_ids"]): word_ids = tokenized.word_ids(batch_index=i) previous_word: int | None = None label_ids: list[int] = [] for wid in word_ids: if wid is None: # Special tokens (CLS / SEP / PAD) label_ids.append(-100) elif wid != previous_word: # First subword of a word -> use the word's tag label_ids.append(int(word_tag_ids[wid])) else: # Inner subword -> ignore in loss label_ids.append(-100) previous_word = wid all_labels.append(label_ids) tokenized["labels"] = all_labels return tokenized drop_cols = [c for c in ds["train"].column_names if c not in ("language",)] ds_tok = ds.map( tokenize_and_align, batched=True, remove_columns=drop_cols, desc="Tokenizing + aligning", ) # --- Model --------------------------------------------------------------- model = AutoModelForTokenClassification.from_pretrained( MODEL_NAME, num_labels=num_labels, id2label=id_to_label, label2id=label_to_id, ) # Free any lingering CUDA blocks before optimizer states are allocated. if torch.cuda.is_available(): torch.cuda.empty_cache() # --- Training arguments -------------------------------------------------- n_epochs = 1 if args.quick else args.epochs RUNS_DIR.mkdir(parents=True, exist_ok=True) training_args_kwargs = dict( output_dir=str(RUNS_DIR), num_train_epochs=n_epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size * 2, # eval has no grads -> larger ok gradient_accumulation_steps=2, # effective batch = batch * 2 = 16 (matches spec) optim=args.optim, learning_rate=args.lr, warmup_steps=100 if not args.quick else 10, weight_decay=0.01, fp16=True, gradient_checkpointing=True, eval_strategy="epoch", save_strategy="epoch", save_total_limit=1, # keep only the current-best checkpoint (was 2) save_only_model=True, # skip optimizer/scheduler state — saves ~340 MB/ckpt load_best_model_at_end=True, metric_for_best_model="f1", greater_is_better=True, logging_steps=50, report_to="none", dataloader_num_workers=0, seed=SEED, ) try: training_args = TrainingArguments(**training_args_kwargs) except TypeError: training_args_kwargs["evaluation_strategy"] = training_args_kwargs.pop("eval_strategy") training_args = TrainingArguments(**training_args_kwargs) # --- Metrics (seqeval, entity-level) ------------------------------------- def _decode(predictions: np.ndarray, labels: np.ndarray) -> tuple[list[list[str]], list[list[str]]]: """Drop -100 positions; convert remaining IDs to label strings.""" true_preds: list[list[str]] = [] true_labels: list[list[str]] = [] for pred_seq, lab_seq in zip(predictions, labels): tp, tl = [], [] for p, l in zip(pred_seq, lab_seq): if l == -100: continue tp.append(id_to_label[int(p)]) tl.append(id_to_label[int(l)]) true_preds.append(tp) true_labels.append(tl) return true_preds, true_labels def compute_metrics(eval_pred) -> dict[str, float]: logits, labels = eval_pred if isinstance(logits, tuple): logits = logits[0] preds = np.argmax(logits, axis=-1) true_preds, true_labels = _decode(preds, labels) return { "f1": seq_f1(true_labels, true_preds), "precision": seq_p(true_labels, true_preds), "recall": seq_r(true_labels, true_preds), } # --- Trainer ------------------------------------------------------------- trainer = _trainer_with_tokenizer( tokenizer, model=model, args=training_args, train_dataset=ds_tok["train"], eval_dataset=ds_tok["validation"], data_collator=DataCollatorForTokenClassification(tokenizer), compute_metrics=compute_metrics, ) # --- Train --------------------------------------------------------------- print("\nStarting training ...") train_result = trainer.train() print(f" ✓ training done. final loss = {train_result.training_loss:.4f}") # --- Save best model ----------------------------------------------------- OUT_DIR.mkdir(parents=True, exist_ok=True) trainer.save_model(str(OUT_DIR)) tokenizer.save_pretrained(str(OUT_DIR)) shutil.copy(LABELS_FILE, OUT_DIR / "labels.json") # --- Final evaluation on test set --------------------------------------- print("\nEvaluating on TEST split ...") test_metrics = trainer.evaluate(ds_tok["test"], metric_key_prefix="test") test_pred = trainer.predict(ds_tok["test"]) if isinstance(test_pred.predictions, tuple): test_logits = test_pred.predictions[0] else: test_logits = test_pred.predictions pred_ids = np.argmax(test_logits, axis=-1) true_preds, true_labels = _decode(pred_ids, test_pred.label_ids) report_dict = seq_classification_report( true_labels, true_preds, output_dict=True, zero_division=0, ) report_text = seq_classification_report(true_labels, true_preds, zero_division=0) print("\nEntity-level classification report on TEST:") print(report_text) # --- Per-language breakdown --------------------------------------------- test_with_lang = load_from_disk(str(DATA_DIR))["test"] if args.quick: test_with_lang = test_with_lang.select(range(min(120, len(test_with_lang)))) per_lang: dict[str, dict[str, float]] = {} if "language" in test_with_lang.column_names: languages = test_with_lang["language"] for lang in sorted(set(languages)): mask = [la == lang for la in languages] sub_preds = [tp for tp, m in zip(true_preds, mask) if m] sub_labels = [tl for tl, m in zip(true_labels, mask) if m] if not sub_preds: continue per_lang[lang] = { "n": int(sum(mask)), "f1": float(seq_f1(sub_labels, sub_preds)), "precision": float(seq_p(sub_labels, sub_preds)), "recall": float(seq_r(sub_labels, sub_preds)), } print("\nPer-language entity-level metrics on TEST:") for lang, m in per_lang.items(): print(f" {lang}: n={m['n']} P={m['precision']:.4f} " f"R={m['recall']:.4f} F1={m['f1']:.4f}") # --- Save eval_results.json --------------------------------------------- # seqeval's classification_report returns numpy scalars (e.g. int64 'support'), # which json.dumps can't serialize. Convert recursively. def _to_jsonable(obj: Any) -> Any: if isinstance(obj, dict): return {k: _to_jsonable(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [_to_jsonable(v) for v in obj] if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) return obj payload = { "model_name": MODEL_NAME, "task": "ner", "num_labels": num_labels, "labels": label_to_id, "test_metrics": {k: float(v) for k, v in test_metrics.items() if isinstance(v, (int, float, np.integer, np.floating))}, "classification_report": _to_jsonable(report_dict), "per_language": per_lang, "training": { "epochs": n_epochs, "per_device_batch": args.batch_size, "grad_accum": 2, "effective_batch": args.batch_size * 2, "learning_rate": args.lr, "warmup_steps": training_args_kwargs.get("warmup_steps"), "fp16": True, "final_train_loss": float(train_result.training_loss), }, } (OUT_DIR / "eval_results.json").write_text( json.dumps(payload, indent=2, ensure_ascii=False) ) print(f"\n✓ Saved model to {OUT_DIR}") print(f"✓ Saved eval_results.json to {OUT_DIR / 'eval_results.json'}") return 0 if __name__ == "__main__": try: sys.exit(main()) except KeyboardInterrupt: print("\nAborted by user.") sys.exit(130)