multilingual-chatbot / src /train_ner.py
momenalhamza's picture
Deploy chatbot: code + RAG + Qwen (3 BERT classifiers loaded from HF Hub)
469ef7f verified
"""Fine-tune xlm-roberta-base for token-level NER (BIO tagging).
Tag set (9):
O, B-PER, I-PER, B-LOC, I-LOC, B-ORG, I-ORG, B-DATE, I-DATE
Inputs:
data/processed/ner/ (HuggingFace DatasetDict; word-tokenized)
data/processed/ner/labels.json
Outputs:
models/ner_model/ (best model + tokenizer)
models/ner_model/eval_results.json (per-entity F1 from seqeval)
models/ner_model/runs/ (training checkpoints + logs)
Implementation notes:
- The raw tokens come pre-tokenized at WORD level (whitespace-split).
XLM-R uses SentencePiece subwords, so we re-tokenize with
`is_split_into_words=True` and align labels to subwords:
* first subword of each word -> word's tag
* inner subwords -> -100 (ignored by the loss)
* special tokens -> -100
This matches the standard HuggingFace NER recipe.
- Metrics use seqeval (entity-level): a span counts as correct only if
BOTH boundary AND type match — much stricter than token-level accuracy.
GPU notes for GTX 1650 (3.6 GB VRAM): same recipe — fp16 +
gradient_checkpointing + batch=8 with grad_accum=2.
Usage:
python src/train_ner.py
python src/train_ner.py --epochs 3
python src/train_ner.py --quick
"""
from __future__ import annotations
import os
# Reduce CUDA memory fragmentation on tight-VRAM GPUs (must precede torch import).
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import argparse
import inspect
import json
import shutil
import sys
from pathlib import Path
from typing import Any
import numpy as np
import torch
from datasets import load_from_disk
from seqeval.metrics import (
classification_report as seq_classification_report,
f1_score as seq_f1, precision_score as seq_p, recall_score as seq_r,
)
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
Trainer,
TrainingArguments,
set_seed,
)
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = PROJECT_ROOT / "data" / "processed" / "ner"
LABELS_FILE = DATA_DIR / "labels.json"
OUT_DIR = PROJECT_ROOT / "models" / "ner_model"
RUNS_DIR = OUT_DIR / "runs"
# Spec called for xlm-roberta-base, but at 270M params it does not fit in
# the GTX 1650's 3.6 GB VRAM during optimizer step — Adam needs ~2.2 GB of
# state and even Adafactor allocates a ~770 MB temp tensor for grad**2 over
# the embedding matrix. distilbert-base-multilingual-cased is the standard
# substitute: 134M params, trained on 104 languages (AR/EN/FR included),
# typically within 1-3% F1 of XLM-R on classification.
MODEL_NAME = "distilbert-base-multilingual-cased"
MAX_LENGTH = 128
SEED = 42
def _trainer_with_tokenizer(tokenizer, **kwargs: Any) -> Trainer:
"""Construct Trainer with whichever tokenizer kwarg is supported."""
params = inspect.signature(Trainer.__init__).parameters
if "processing_class" in params:
kwargs["processing_class"] = tokenizer
elif "tokenizer" in params:
kwargs["tokenizer"] = tokenizer
return Trainer(**kwargs)
def main() -> int:
"""Train XLM-R for token-level NER. Returns exit code."""
parser = argparse.ArgumentParser(description=__doc__.split("\n")[0])
parser.add_argument("--epochs", type=int, default=5,
help="Number of training epochs (default 5).")
parser.add_argument("--batch-size", type=int, default=8,
help="Per-device train batch size (default 8 — fits comfortably "
"with distilbert-multilingual on a 3.6 GB GPU).")
parser.add_argument("--quick", action="store_true",
help="Sanity smoke test: 1 epoch, 500 train rows.")
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--optim", type=str, default="adamw_torch",
help="Optimizer name. AdamW fits with the smaller distilbert model.")
args = parser.parse_args()
set_seed(SEED)
print("=" * 72)
print("Train NER model (xlm-roberta-base, 9 BIO tags)")
print("=" * 72)
print(f" Data dir : {DATA_DIR}")
print(f" Out dir : {OUT_DIR}")
print(f" Epochs : {args.epochs}{' (QUICK)' if args.quick else ''}")
print(f" Batch : {args.batch_size} (effective ≈ {args.batch_size * 2} via accum)")
print(f" Optimizer: {args.optim}")
# --- Labels --------------------------------------------------------------
labels_payload = json.loads(LABELS_FILE.read_text())
label_to_id: dict[str, int] = labels_payload["label_to_id"]
id_to_label: dict[int, str] = {int(k): v for k, v in labels_payload["id_to_label"].items()}
label_names = [id_to_label[i] for i in range(len(id_to_label))]
num_labels = len(label_names)
print(f" Labels : {label_to_id}")
# --- Datasets ------------------------------------------------------------
ds = load_from_disk(str(DATA_DIR))
print(f" Splits : train={len(ds['train'])} val={len(ds['validation'])} "
f"test={len(ds['test'])}")
if args.quick:
ds["train"] = ds["train"].shuffle(seed=SEED).select(range(min(500, len(ds["train"]))))
ds["validation"] = ds["validation"].select(range(min(120, len(ds["validation"]))))
ds["test"] = ds["test"].select(range(min(120, len(ds["test"]))))
print(f" QUICK : sliced to {len(ds['train'])}/{len(ds['validation'])}/{len(ds['test'])}")
# --- Tokenize + align labels to subwords --------------------------------
print("\nLoading tokenizer & model ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def tokenize_and_align(batch: dict[str, list]) -> dict[str, Any]:
tokenized = tokenizer(
batch["tokens"],
is_split_into_words=True,
truncation=True,
max_length=MAX_LENGTH,
)
all_labels = []
for i, word_tag_ids in enumerate(batch["ner_tag_ids"]):
word_ids = tokenized.word_ids(batch_index=i)
previous_word: int | None = None
label_ids: list[int] = []
for wid in word_ids:
if wid is None:
# Special tokens (CLS / SEP / PAD)
label_ids.append(-100)
elif wid != previous_word:
# First subword of a word -> use the word's tag
label_ids.append(int(word_tag_ids[wid]))
else:
# Inner subword -> ignore in loss
label_ids.append(-100)
previous_word = wid
all_labels.append(label_ids)
tokenized["labels"] = all_labels
return tokenized
drop_cols = [c for c in ds["train"].column_names if c not in ("language",)]
ds_tok = ds.map(
tokenize_and_align, batched=True,
remove_columns=drop_cols, desc="Tokenizing + aligning",
)
# --- Model ---------------------------------------------------------------
model = AutoModelForTokenClassification.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label=id_to_label,
label2id=label_to_id,
)
# Free any lingering CUDA blocks before optimizer states are allocated.
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- Training arguments --------------------------------------------------
n_epochs = 1 if args.quick else args.epochs
RUNS_DIR.mkdir(parents=True, exist_ok=True)
training_args_kwargs = dict(
output_dir=str(RUNS_DIR),
num_train_epochs=n_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size * 2, # eval has no grads -> larger ok
gradient_accumulation_steps=2, # effective batch = batch * 2 = 16 (matches spec)
optim=args.optim,
learning_rate=args.lr,
warmup_steps=100 if not args.quick else 10,
weight_decay=0.01,
fp16=True,
gradient_checkpointing=True,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=1, # keep only the current-best checkpoint (was 2)
save_only_model=True, # skip optimizer/scheduler state — saves ~340 MB/ckpt
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
logging_steps=50,
report_to="none",
dataloader_num_workers=0,
seed=SEED,
)
try:
training_args = TrainingArguments(**training_args_kwargs)
except TypeError:
training_args_kwargs["evaluation_strategy"] = training_args_kwargs.pop("eval_strategy")
training_args = TrainingArguments(**training_args_kwargs)
# --- Metrics (seqeval, entity-level) -------------------------------------
def _decode(predictions: np.ndarray, labels: np.ndarray) -> tuple[list[list[str]], list[list[str]]]:
"""Drop -100 positions; convert remaining IDs to label strings."""
true_preds: list[list[str]] = []
true_labels: list[list[str]] = []
for pred_seq, lab_seq in zip(predictions, labels):
tp, tl = [], []
for p, l in zip(pred_seq, lab_seq):
if l == -100:
continue
tp.append(id_to_label[int(p)])
tl.append(id_to_label[int(l)])
true_preds.append(tp)
true_labels.append(tl)
return true_preds, true_labels
def compute_metrics(eval_pred) -> dict[str, float]:
logits, labels = eval_pred
if isinstance(logits, tuple):
logits = logits[0]
preds = np.argmax(logits, axis=-1)
true_preds, true_labels = _decode(preds, labels)
return {
"f1": seq_f1(true_labels, true_preds),
"precision": seq_p(true_labels, true_preds),
"recall": seq_r(true_labels, true_preds),
}
# --- Trainer -------------------------------------------------------------
trainer = _trainer_with_tokenizer(
tokenizer,
model=model,
args=training_args,
train_dataset=ds_tok["train"],
eval_dataset=ds_tok["validation"],
data_collator=DataCollatorForTokenClassification(tokenizer),
compute_metrics=compute_metrics,
)
# --- Train ---------------------------------------------------------------
print("\nStarting training ...")
train_result = trainer.train()
print(f" ✓ training done. final loss = {train_result.training_loss:.4f}")
# --- Save best model -----------------------------------------------------
OUT_DIR.mkdir(parents=True, exist_ok=True)
trainer.save_model(str(OUT_DIR))
tokenizer.save_pretrained(str(OUT_DIR))
shutil.copy(LABELS_FILE, OUT_DIR / "labels.json")
# --- Final evaluation on test set ---------------------------------------
print("\nEvaluating on TEST split ...")
test_metrics = trainer.evaluate(ds_tok["test"], metric_key_prefix="test")
test_pred = trainer.predict(ds_tok["test"])
if isinstance(test_pred.predictions, tuple):
test_logits = test_pred.predictions[0]
else:
test_logits = test_pred.predictions
pred_ids = np.argmax(test_logits, axis=-1)
true_preds, true_labels = _decode(pred_ids, test_pred.label_ids)
report_dict = seq_classification_report(
true_labels, true_preds, output_dict=True, zero_division=0,
)
report_text = seq_classification_report(true_labels, true_preds, zero_division=0)
print("\nEntity-level classification report on TEST:")
print(report_text)
# --- Per-language breakdown ---------------------------------------------
test_with_lang = load_from_disk(str(DATA_DIR))["test"]
if args.quick:
test_with_lang = test_with_lang.select(range(min(120, len(test_with_lang))))
per_lang: dict[str, dict[str, float]] = {}
if "language" in test_with_lang.column_names:
languages = test_with_lang["language"]
for lang in sorted(set(languages)):
mask = [la == lang for la in languages]
sub_preds = [tp for tp, m in zip(true_preds, mask) if m]
sub_labels = [tl for tl, m in zip(true_labels, mask) if m]
if not sub_preds:
continue
per_lang[lang] = {
"n": int(sum(mask)),
"f1": float(seq_f1(sub_labels, sub_preds)),
"precision": float(seq_p(sub_labels, sub_preds)),
"recall": float(seq_r(sub_labels, sub_preds)),
}
print("\nPer-language entity-level metrics on TEST:")
for lang, m in per_lang.items():
print(f" {lang}: n={m['n']} P={m['precision']:.4f} "
f"R={m['recall']:.4f} F1={m['f1']:.4f}")
# --- Save eval_results.json ---------------------------------------------
# seqeval's classification_report returns numpy scalars (e.g. int64 'support'),
# which json.dumps can't serialize. Convert recursively.
def _to_jsonable(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: _to_jsonable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_to_jsonable(v) for v in obj]
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
return obj
payload = {
"model_name": MODEL_NAME,
"task": "ner",
"num_labels": num_labels,
"labels": label_to_id,
"test_metrics": {k: float(v) for k, v in test_metrics.items()
if isinstance(v, (int, float, np.integer, np.floating))},
"classification_report": _to_jsonable(report_dict),
"per_language": per_lang,
"training": {
"epochs": n_epochs,
"per_device_batch": args.batch_size,
"grad_accum": 2,
"effective_batch": args.batch_size * 2,
"learning_rate": args.lr,
"warmup_steps": training_args_kwargs.get("warmup_steps"),
"fp16": True,
"final_train_loss": float(train_result.training_loss),
},
}
(OUT_DIR / "eval_results.json").write_text(
json.dumps(payload, indent=2, ensure_ascii=False)
)
print(f"\n✓ Saved model to {OUT_DIR}")
print(f"✓ Saved eval_results.json to {OUT_DIR / 'eval_results.json'}")
return 0
if __name__ == "__main__":
try:
sys.exit(main())
except KeyboardInterrupt:
print("\nAborted by user.")
sys.exit(130)