multilingual-chatbot / src /train_intent.py
momenalhamza's picture
Deploy chatbot: code + RAG + Qwen (3 BERT classifiers loaded from HF Hub)
469ef7f verified
"""Fine-tune xlm-roberta-base for 6-class intent classification.
Intents: booking, complaint, farewell, greeting, inquiry, other.
Languages covered: AR, EN, FR (data was stratified across all of them).
Inputs:
data/processed/intent/ (HuggingFace DatasetDict)
data/processed/intent/labels.json
Outputs:
models/intent_classifier/ (best model + tokenizer)
models/intent_classifier/eval_results.json (test metrics + classification report)
models/intent_classifier/runs/ (training checkpoints + logs)
GPU notes for GTX 1650 (3.6 GB VRAM): same as train_lang_detector.py — we use
fp16 + gradient_checkpointing + batch=8 with grad_accum=2.
Usage:
python src/train_intent.py
python src/train_intent.py --epochs 3
python src/train_intent.py --quick
"""
from __future__ import annotations
import os
# Reduce CUDA memory fragmentation on tight-VRAM GPUs (must precede torch import).
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import argparse
import inspect
import json
import shutil
import sys
from pathlib import Path
from typing import Any
import numpy as np
import torch
from datasets import load_from_disk
from sklearn.metrics import (
accuracy_score, classification_report, f1_score,
precision_score, recall_score,
)
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
set_seed,
)
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = PROJECT_ROOT / "data" / "processed" / "intent"
LABELS_FILE = DATA_DIR / "labels.json"
OUT_DIR = PROJECT_ROOT / "models" / "intent_classifier"
RUNS_DIR = OUT_DIR / "runs"
# Spec called for xlm-roberta-base, but at 270M params it does not fit in
# the GTX 1650's 3.6 GB VRAM during optimizer step — Adam needs ~2.2 GB of
# state and even Adafactor allocates a ~770 MB temp tensor for grad**2 over
# the embedding matrix. distilbert-base-multilingual-cased is the standard
# substitute: 134M params, trained on 104 languages (AR/EN/FR included),
# typically within 1-3% F1 of XLM-R on classification.
MODEL_NAME = "distilbert-base-multilingual-cased"
MAX_LENGTH = 128
SEED = 42
def _trainer_with_tokenizer(tokenizer, **kwargs: Any) -> Trainer:
"""Construct Trainer with whichever tokenizer kwarg this transformers
version expects (`processing_class` in 5.x, `tokenizer` in 4.x)."""
params = inspect.signature(Trainer.__init__).parameters
if "processing_class" in params:
kwargs["processing_class"] = tokenizer
elif "tokenizer" in params:
kwargs["tokenizer"] = tokenizer
return Trainer(**kwargs)
def main() -> int:
"""Train XLM-R for intent classification. Returns exit code."""
parser = argparse.ArgumentParser(description=__doc__.split("\n")[0])
parser.add_argument("--epochs", type=int, default=5,
help="Number of training epochs (default 5).")
parser.add_argument("--batch-size", type=int, default=8,
help="Per-device train batch size (default 8 — fits comfortably "
"with distilbert-multilingual on a 3.6 GB GPU).")
parser.add_argument("--quick", action="store_true",
help="Sanity smoke test: 1 epoch, 500 train rows.")
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--optim", type=str, default="adamw_torch",
help="Optimizer name. AdamW fits with the smaller distilbert model.")
args = parser.parse_args()
set_seed(SEED)
print("=" * 72)
print("Train intent classifier (xlm-roberta-base, 6 classes)")
print("=" * 72)
print(f" Data dir : {DATA_DIR}")
print(f" Out dir : {OUT_DIR}")
print(f" Epochs : {args.epochs}{' (QUICK)' if args.quick else ''}")
print(f" Batch : {args.batch_size} (effective ≈ {args.batch_size * 2} via accum)")
print(f" Optimizer: {args.optim}")
# --- Labels --------------------------------------------------------------
labels_payload = json.loads(LABELS_FILE.read_text())
label_to_id: dict[str, int] = labels_payload["label_to_id"]
id_to_label: dict[int, str] = {int(k): v for k, v in labels_payload["id_to_label"].items()}
label_names = [id_to_label[i] for i in range(len(id_to_label))]
num_labels = len(label_names)
print(f" Labels : {label_to_id}")
# --- Datasets ------------------------------------------------------------
ds = load_from_disk(str(DATA_DIR))
print(f" Splits : train={len(ds['train'])} val={len(ds['validation'])} "
f"test={len(ds['test'])}")
if args.quick:
ds["train"] = ds["train"].shuffle(seed=SEED).select(range(min(500, len(ds["train"]))))
ds["validation"] = ds["validation"].select(range(min(120, len(ds["validation"]))))
ds["test"] = ds["test"].select(range(min(120, len(ds["test"]))))
print(f" QUICK : sliced to {len(ds['train'])}/{len(ds['validation'])}/{len(ds['test'])}")
# --- Tokenize ------------------------------------------------------------
print("\nLoading tokenizer & model ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def tokenize(batch: dict[str, list]) -> dict[str, Any]:
return tokenizer(batch["text"], truncation=True, max_length=MAX_LENGTH)
drop_cols = [c for c in ds["train"].column_names if c not in ("label",)]
ds_tok = ds.map(tokenize, batched=True, remove_columns=drop_cols,
desc="Tokenizing")
# --- Model ---------------------------------------------------------------
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label=id_to_label,
label2id=label_to_id,
)
# Free any lingering CUDA blocks before optimizer states are allocated.
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- Training arguments --------------------------------------------------
n_epochs = 1 if args.quick else args.epochs
RUNS_DIR.mkdir(parents=True, exist_ok=True)
training_args_kwargs = dict(
output_dir=str(RUNS_DIR),
num_train_epochs=n_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size * 2, # eval has no grads -> larger ok
gradient_accumulation_steps=2, # effective batch = batch * 2 = 16 (matches spec)
optim=args.optim,
learning_rate=args.lr,
warmup_steps=100 if not args.quick else 10,
weight_decay=0.01,
fp16=True,
gradient_checkpointing=True,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=1, # keep only the current-best checkpoint (was 2)
save_only_model=True, # skip optimizer/scheduler state — saves ~340 MB/ckpt
load_best_model_at_end=True,
metric_for_best_model="f1_macro", # macro is more sensitive to minority classes
greater_is_better=True,
logging_steps=50,
report_to="none",
dataloader_num_workers=0,
seed=SEED,
)
try:
training_args = TrainingArguments(**training_args_kwargs)
except TypeError:
training_args_kwargs["evaluation_strategy"] = training_args_kwargs.pop("eval_strategy")
training_args = TrainingArguments(**training_args_kwargs)
# --- Metrics -------------------------------------------------------------
def compute_metrics(eval_pred) -> dict[str, float]:
logits, labels = eval_pred
if isinstance(logits, tuple):
logits = logits[0]
preds = np.argmax(logits, axis=-1)
return {
"accuracy": accuracy_score(labels, preds),
"f1": f1_score(labels, preds, average="weighted", zero_division=0),
"f1_macro": f1_score(labels, preds, average="macro", zero_division=0),
"precision": precision_score(labels, preds, average="weighted", zero_division=0),
"recall": recall_score(labels, preds, average="weighted", zero_division=0),
}
# --- Trainer -------------------------------------------------------------
trainer = _trainer_with_tokenizer(
tokenizer,
model=model,
args=training_args,
train_dataset=ds_tok["train"],
eval_dataset=ds_tok["validation"],
data_collator=DataCollatorWithPadding(tokenizer),
compute_metrics=compute_metrics,
)
# --- Train ---------------------------------------------------------------
print("\nStarting training ...")
train_result = trainer.train()
print(f" ✓ training done. final loss = {train_result.training_loss:.4f}")
# --- Save best model -----------------------------------------------------
OUT_DIR.mkdir(parents=True, exist_ok=True)
trainer.save_model(str(OUT_DIR))
tokenizer.save_pretrained(str(OUT_DIR))
shutil.copy(LABELS_FILE, OUT_DIR / "labels.json")
# --- Final evaluation on test set ---------------------------------------
print("\nEvaluating on TEST split ...")
test_metrics = trainer.evaluate(ds_tok["test"], metric_key_prefix="test")
test_pred = trainer.predict(ds_tok["test"])
if isinstance(test_pred.predictions, tuple):
test_logits = test_pred.predictions[0]
else:
test_logits = test_pred.predictions
pred_ids = np.argmax(test_logits, axis=-1)
true_ids = test_pred.label_ids
report_dict = classification_report(
true_ids, pred_ids,
labels=list(range(num_labels)),
target_names=label_names,
output_dict=True, zero_division=0,
)
report_text = classification_report(
true_ids, pred_ids,
labels=list(range(num_labels)),
target_names=label_names,
zero_division=0,
)
print("\nClassification report on TEST:")
print(report_text)
# --- Per-language breakdown ---------------------------------------------
# Reload the test split with language column to break down errors by language
test_with_lang = load_from_disk(str(DATA_DIR))["test"]
if args.quick:
test_with_lang = test_with_lang.select(range(min(120, len(test_with_lang))))
per_lang: dict[str, dict[str, float]] = {}
if "language" in test_with_lang.column_names:
languages = test_with_lang["language"]
for lang in sorted(set(languages)):
mask = np.array([la == lang for la in languages])
if not mask.any():
continue
lp = pred_ids[mask]
lt = true_ids[mask]
per_lang[lang] = {
"n": int(mask.sum()),
"accuracy": float(accuracy_score(lt, lp)),
"f1_weighted": float(f1_score(lt, lp, average="weighted", zero_division=0)),
"f1_macro": float(f1_score(lt, lp, average="macro", zero_division=0)),
}
print("\nPer-language metrics on TEST:")
for lang, m in per_lang.items():
print(f" {lang}: n={m['n']} acc={m['accuracy']:.4f} "
f"f1_w={m['f1_weighted']:.4f} f1_m={m['f1_macro']:.4f}")
# --- Save eval_results.json ---------------------------------------------
payload = {
"model_name": MODEL_NAME,
"task": "intent",
"num_labels": num_labels,
"labels": label_to_id,
"test_metrics": {k: float(v) for k, v in test_metrics.items()
if isinstance(v, (int, float))},
"classification_report": report_dict,
"per_language": per_lang,
"training": {
"epochs": n_epochs,
"per_device_batch": args.batch_size,
"grad_accum": 2,
"effective_batch": args.batch_size * 2,
"learning_rate": args.lr,
"warmup_steps": training_args_kwargs.get("warmup_steps"),
"fp16": True,
"final_train_loss": float(train_result.training_loss),
},
}
(OUT_DIR / "eval_results.json").write_text(
json.dumps(payload, indent=2, ensure_ascii=False)
)
print(f"\n✓ Saved model to {OUT_DIR}")
print(f"✓ Saved eval_results.json to {OUT_DIR / 'eval_results.json'}")
return 0
if __name__ == "__main__":
try:
sys.exit(main())
except KeyboardInterrupt:
print("\nAborted by user.")
sys.exit(130)