Spaces:
Sleeping
Sleeping
| """Fine-tune xlm-roberta-base for token-level NER (BIO tagging). | |
| Tag set (9): | |
| O, B-PER, I-PER, B-LOC, I-LOC, B-ORG, I-ORG, B-DATE, I-DATE | |
| Inputs: | |
| data/processed/ner/ (HuggingFace DatasetDict; word-tokenized) | |
| data/processed/ner/labels.json | |
| Outputs: | |
| models/ner_model/ (best model + tokenizer) | |
| models/ner_model/eval_results.json (per-entity F1 from seqeval) | |
| models/ner_model/runs/ (training checkpoints + logs) | |
| Implementation notes: | |
| - The raw tokens come pre-tokenized at WORD level (whitespace-split). | |
| XLM-R uses SentencePiece subwords, so we re-tokenize with | |
| `is_split_into_words=True` and align labels to subwords: | |
| * first subword of each word -> word's tag | |
| * inner subwords -> -100 (ignored by the loss) | |
| * special tokens -> -100 | |
| This matches the standard HuggingFace NER recipe. | |
| - Metrics use seqeval (entity-level): a span counts as correct only if | |
| BOTH boundary AND type match — much stricter than token-level accuracy. | |
| GPU notes for GTX 1650 (3.6 GB VRAM): same recipe — fp16 + | |
| gradient_checkpointing + batch=8 with grad_accum=2. | |
| Usage: | |
| python src/train_ner.py | |
| python src/train_ner.py --epochs 3 | |
| python src/train_ner.py --quick | |
| """ | |
| from __future__ import annotations | |
| import os | |
| # Reduce CUDA memory fragmentation on tight-VRAM GPUs (must precede torch import). | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| import argparse | |
| import inspect | |
| import json | |
| import shutil | |
| import sys | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| from datasets import load_from_disk | |
| from seqeval.metrics import ( | |
| classification_report as seq_classification_report, | |
| f1_score as seq_f1, precision_score as seq_p, recall_score as seq_r, | |
| ) | |
| from transformers import ( | |
| AutoModelForTokenClassification, | |
| AutoTokenizer, | |
| DataCollatorForTokenClassification, | |
| Trainer, | |
| TrainingArguments, | |
| set_seed, | |
| ) | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DATA_DIR = PROJECT_ROOT / "data" / "processed" / "ner" | |
| LABELS_FILE = DATA_DIR / "labels.json" | |
| OUT_DIR = PROJECT_ROOT / "models" / "ner_model" | |
| RUNS_DIR = OUT_DIR / "runs" | |
| # Spec called for xlm-roberta-base, but at 270M params it does not fit in | |
| # the GTX 1650's 3.6 GB VRAM during optimizer step — Adam needs ~2.2 GB of | |
| # state and even Adafactor allocates a ~770 MB temp tensor for grad**2 over | |
| # the embedding matrix. distilbert-base-multilingual-cased is the standard | |
| # substitute: 134M params, trained on 104 languages (AR/EN/FR included), | |
| # typically within 1-3% F1 of XLM-R on classification. | |
| MODEL_NAME = "distilbert-base-multilingual-cased" | |
| MAX_LENGTH = 128 | |
| SEED = 42 | |
| def _trainer_with_tokenizer(tokenizer, **kwargs: Any) -> Trainer: | |
| """Construct Trainer with whichever tokenizer kwarg is supported.""" | |
| params = inspect.signature(Trainer.__init__).parameters | |
| if "processing_class" in params: | |
| kwargs["processing_class"] = tokenizer | |
| elif "tokenizer" in params: | |
| kwargs["tokenizer"] = tokenizer | |
| return Trainer(**kwargs) | |
| def main() -> int: | |
| """Train XLM-R for token-level NER. Returns exit code.""" | |
| parser = argparse.ArgumentParser(description=__doc__.split("\n")[0]) | |
| parser.add_argument("--epochs", type=int, default=5, | |
| help="Number of training epochs (default 5).") | |
| parser.add_argument("--batch-size", type=int, default=8, | |
| help="Per-device train batch size (default 8 — fits comfortably " | |
| "with distilbert-multilingual on a 3.6 GB GPU).") | |
| parser.add_argument("--quick", action="store_true", | |
| help="Sanity smoke test: 1 epoch, 500 train rows.") | |
| parser.add_argument("--lr", type=float, default=2e-5) | |
| parser.add_argument("--optim", type=str, default="adamw_torch", | |
| help="Optimizer name. AdamW fits with the smaller distilbert model.") | |
| args = parser.parse_args() | |
| set_seed(SEED) | |
| print("=" * 72) | |
| print("Train NER model (xlm-roberta-base, 9 BIO tags)") | |
| print("=" * 72) | |
| print(f" Data dir : {DATA_DIR}") | |
| print(f" Out dir : {OUT_DIR}") | |
| print(f" Epochs : {args.epochs}{' (QUICK)' if args.quick else ''}") | |
| print(f" Batch : {args.batch_size} (effective ≈ {args.batch_size * 2} via accum)") | |
| print(f" Optimizer: {args.optim}") | |
| # --- Labels -------------------------------------------------------------- | |
| labels_payload = json.loads(LABELS_FILE.read_text()) | |
| label_to_id: dict[str, int] = labels_payload["label_to_id"] | |
| id_to_label: dict[int, str] = {int(k): v for k, v in labels_payload["id_to_label"].items()} | |
| label_names = [id_to_label[i] for i in range(len(id_to_label))] | |
| num_labels = len(label_names) | |
| print(f" Labels : {label_to_id}") | |
| # --- Datasets ------------------------------------------------------------ | |
| ds = load_from_disk(str(DATA_DIR)) | |
| print(f" Splits : train={len(ds['train'])} val={len(ds['validation'])} " | |
| f"test={len(ds['test'])}") | |
| if args.quick: | |
| ds["train"] = ds["train"].shuffle(seed=SEED).select(range(min(500, len(ds["train"])))) | |
| ds["validation"] = ds["validation"].select(range(min(120, len(ds["validation"])))) | |
| ds["test"] = ds["test"].select(range(min(120, len(ds["test"])))) | |
| print(f" QUICK : sliced to {len(ds['train'])}/{len(ds['validation'])}/{len(ds['test'])}") | |
| # --- Tokenize + align labels to subwords -------------------------------- | |
| print("\nLoading tokenizer & model ...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| def tokenize_and_align(batch: dict[str, list]) -> dict[str, Any]: | |
| tokenized = tokenizer( | |
| batch["tokens"], | |
| is_split_into_words=True, | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| ) | |
| all_labels = [] | |
| for i, word_tag_ids in enumerate(batch["ner_tag_ids"]): | |
| word_ids = tokenized.word_ids(batch_index=i) | |
| previous_word: int | None = None | |
| label_ids: list[int] = [] | |
| for wid in word_ids: | |
| if wid is None: | |
| # Special tokens (CLS / SEP / PAD) | |
| label_ids.append(-100) | |
| elif wid != previous_word: | |
| # First subword of a word -> use the word's tag | |
| label_ids.append(int(word_tag_ids[wid])) | |
| else: | |
| # Inner subword -> ignore in loss | |
| label_ids.append(-100) | |
| previous_word = wid | |
| all_labels.append(label_ids) | |
| tokenized["labels"] = all_labels | |
| return tokenized | |
| drop_cols = [c for c in ds["train"].column_names if c not in ("language",)] | |
| ds_tok = ds.map( | |
| tokenize_and_align, batched=True, | |
| remove_columns=drop_cols, desc="Tokenizing + aligning", | |
| ) | |
| # --- Model --------------------------------------------------------------- | |
| model = AutoModelForTokenClassification.from_pretrained( | |
| MODEL_NAME, | |
| num_labels=num_labels, | |
| id2label=id_to_label, | |
| label2id=label_to_id, | |
| ) | |
| # Free any lingering CUDA blocks before optimizer states are allocated. | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # --- Training arguments -------------------------------------------------- | |
| n_epochs = 1 if args.quick else args.epochs | |
| RUNS_DIR.mkdir(parents=True, exist_ok=True) | |
| training_args_kwargs = dict( | |
| output_dir=str(RUNS_DIR), | |
| num_train_epochs=n_epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| per_device_eval_batch_size=args.batch_size * 2, # eval has no grads -> larger ok | |
| gradient_accumulation_steps=2, # effective batch = batch * 2 = 16 (matches spec) | |
| optim=args.optim, | |
| learning_rate=args.lr, | |
| warmup_steps=100 if not args.quick else 10, | |
| weight_decay=0.01, | |
| fp16=True, | |
| gradient_checkpointing=True, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=1, # keep only the current-best checkpoint (was 2) | |
| save_only_model=True, # skip optimizer/scheduler state — saves ~340 MB/ckpt | |
| load_best_model_at_end=True, | |
| metric_for_best_model="f1", | |
| greater_is_better=True, | |
| logging_steps=50, | |
| report_to="none", | |
| dataloader_num_workers=0, | |
| seed=SEED, | |
| ) | |
| try: | |
| training_args = TrainingArguments(**training_args_kwargs) | |
| except TypeError: | |
| training_args_kwargs["evaluation_strategy"] = training_args_kwargs.pop("eval_strategy") | |
| training_args = TrainingArguments(**training_args_kwargs) | |
| # --- Metrics (seqeval, entity-level) ------------------------------------- | |
| def _decode(predictions: np.ndarray, labels: np.ndarray) -> tuple[list[list[str]], list[list[str]]]: | |
| """Drop -100 positions; convert remaining IDs to label strings.""" | |
| true_preds: list[list[str]] = [] | |
| true_labels: list[list[str]] = [] | |
| for pred_seq, lab_seq in zip(predictions, labels): | |
| tp, tl = [], [] | |
| for p, l in zip(pred_seq, lab_seq): | |
| if l == -100: | |
| continue | |
| tp.append(id_to_label[int(p)]) | |
| tl.append(id_to_label[int(l)]) | |
| true_preds.append(tp) | |
| true_labels.append(tl) | |
| return true_preds, true_labels | |
| def compute_metrics(eval_pred) -> dict[str, float]: | |
| logits, labels = eval_pred | |
| if isinstance(logits, tuple): | |
| logits = logits[0] | |
| preds = np.argmax(logits, axis=-1) | |
| true_preds, true_labels = _decode(preds, labels) | |
| return { | |
| "f1": seq_f1(true_labels, true_preds), | |
| "precision": seq_p(true_labels, true_preds), | |
| "recall": seq_r(true_labels, true_preds), | |
| } | |
| # --- Trainer ------------------------------------------------------------- | |
| trainer = _trainer_with_tokenizer( | |
| tokenizer, | |
| model=model, | |
| args=training_args, | |
| train_dataset=ds_tok["train"], | |
| eval_dataset=ds_tok["validation"], | |
| data_collator=DataCollatorForTokenClassification(tokenizer), | |
| compute_metrics=compute_metrics, | |
| ) | |
| # --- Train --------------------------------------------------------------- | |
| print("\nStarting training ...") | |
| train_result = trainer.train() | |
| print(f" ✓ training done. final loss = {train_result.training_loss:.4f}") | |
| # --- Save best model ----------------------------------------------------- | |
| OUT_DIR.mkdir(parents=True, exist_ok=True) | |
| trainer.save_model(str(OUT_DIR)) | |
| tokenizer.save_pretrained(str(OUT_DIR)) | |
| shutil.copy(LABELS_FILE, OUT_DIR / "labels.json") | |
| # --- Final evaluation on test set --------------------------------------- | |
| print("\nEvaluating on TEST split ...") | |
| test_metrics = trainer.evaluate(ds_tok["test"], metric_key_prefix="test") | |
| test_pred = trainer.predict(ds_tok["test"]) | |
| if isinstance(test_pred.predictions, tuple): | |
| test_logits = test_pred.predictions[0] | |
| else: | |
| test_logits = test_pred.predictions | |
| pred_ids = np.argmax(test_logits, axis=-1) | |
| true_preds, true_labels = _decode(pred_ids, test_pred.label_ids) | |
| report_dict = seq_classification_report( | |
| true_labels, true_preds, output_dict=True, zero_division=0, | |
| ) | |
| report_text = seq_classification_report(true_labels, true_preds, zero_division=0) | |
| print("\nEntity-level classification report on TEST:") | |
| print(report_text) | |
| # --- Per-language breakdown --------------------------------------------- | |
| test_with_lang = load_from_disk(str(DATA_DIR))["test"] | |
| if args.quick: | |
| test_with_lang = test_with_lang.select(range(min(120, len(test_with_lang)))) | |
| per_lang: dict[str, dict[str, float]] = {} | |
| if "language" in test_with_lang.column_names: | |
| languages = test_with_lang["language"] | |
| for lang in sorted(set(languages)): | |
| mask = [la == lang for la in languages] | |
| sub_preds = [tp for tp, m in zip(true_preds, mask) if m] | |
| sub_labels = [tl for tl, m in zip(true_labels, mask) if m] | |
| if not sub_preds: | |
| continue | |
| per_lang[lang] = { | |
| "n": int(sum(mask)), | |
| "f1": float(seq_f1(sub_labels, sub_preds)), | |
| "precision": float(seq_p(sub_labels, sub_preds)), | |
| "recall": float(seq_r(sub_labels, sub_preds)), | |
| } | |
| print("\nPer-language entity-level metrics on TEST:") | |
| for lang, m in per_lang.items(): | |
| print(f" {lang}: n={m['n']} P={m['precision']:.4f} " | |
| f"R={m['recall']:.4f} F1={m['f1']:.4f}") | |
| # --- Save eval_results.json --------------------------------------------- | |
| # seqeval's classification_report returns numpy scalars (e.g. int64 'support'), | |
| # which json.dumps can't serialize. Convert recursively. | |
| def _to_jsonable(obj: Any) -> Any: | |
| if isinstance(obj, dict): | |
| return {k: _to_jsonable(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple)): | |
| return [_to_jsonable(v) for v in obj] | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| if isinstance(obj, np.floating): | |
| return float(obj) | |
| return obj | |
| payload = { | |
| "model_name": MODEL_NAME, | |
| "task": "ner", | |
| "num_labels": num_labels, | |
| "labels": label_to_id, | |
| "test_metrics": {k: float(v) for k, v in test_metrics.items() | |
| if isinstance(v, (int, float, np.integer, np.floating))}, | |
| "classification_report": _to_jsonable(report_dict), | |
| "per_language": per_lang, | |
| "training": { | |
| "epochs": n_epochs, | |
| "per_device_batch": args.batch_size, | |
| "grad_accum": 2, | |
| "effective_batch": args.batch_size * 2, | |
| "learning_rate": args.lr, | |
| "warmup_steps": training_args_kwargs.get("warmup_steps"), | |
| "fp16": True, | |
| "final_train_loss": float(train_result.training_loss), | |
| }, | |
| } | |
| (OUT_DIR / "eval_results.json").write_text( | |
| json.dumps(payload, indent=2, ensure_ascii=False) | |
| ) | |
| print(f"\n✓ Saved model to {OUT_DIR}") | |
| print(f"✓ Saved eval_results.json to {OUT_DIR / 'eval_results.json'}") | |
| return 0 | |
| if __name__ == "__main__": | |
| try: | |
| sys.exit(main()) | |
| except KeyboardInterrupt: | |
| print("\nAborted by user.") | |
| sys.exit(130) | |