Spaces:
Sleeping
Sleeping
| """Fine-tune xlm-roberta-base for 6-class intent classification. | |
| Intents: booking, complaint, farewell, greeting, inquiry, other. | |
| Languages covered: AR, EN, FR (data was stratified across all of them). | |
| Inputs: | |
| data/processed/intent/ (HuggingFace DatasetDict) | |
| data/processed/intent/labels.json | |
| Outputs: | |
| models/intent_classifier/ (best model + tokenizer) | |
| models/intent_classifier/eval_results.json (test metrics + classification report) | |
| models/intent_classifier/runs/ (training checkpoints + logs) | |
| GPU notes for GTX 1650 (3.6 GB VRAM): same as train_lang_detector.py — we use | |
| fp16 + gradient_checkpointing + batch=8 with grad_accum=2. | |
| Usage: | |
| python src/train_intent.py | |
| python src/train_intent.py --epochs 3 | |
| python src/train_intent.py --quick | |
| """ | |
| from __future__ import annotations | |
| import os | |
| # Reduce CUDA memory fragmentation on tight-VRAM GPUs (must precede torch import). | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| import argparse | |
| import inspect | |
| import json | |
| import shutil | |
| import sys | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| from datasets import load_from_disk | |
| from sklearn.metrics import ( | |
| accuracy_score, classification_report, f1_score, | |
| precision_score, recall_score, | |
| ) | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| DataCollatorWithPadding, | |
| Trainer, | |
| TrainingArguments, | |
| set_seed, | |
| ) | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DATA_DIR = PROJECT_ROOT / "data" / "processed" / "intent" | |
| LABELS_FILE = DATA_DIR / "labels.json" | |
| OUT_DIR = PROJECT_ROOT / "models" / "intent_classifier" | |
| RUNS_DIR = OUT_DIR / "runs" | |
| # Spec called for xlm-roberta-base, but at 270M params it does not fit in | |
| # the GTX 1650's 3.6 GB VRAM during optimizer step — Adam needs ~2.2 GB of | |
| # state and even Adafactor allocates a ~770 MB temp tensor for grad**2 over | |
| # the embedding matrix. distilbert-base-multilingual-cased is the standard | |
| # substitute: 134M params, trained on 104 languages (AR/EN/FR included), | |
| # typically within 1-3% F1 of XLM-R on classification. | |
| MODEL_NAME = "distilbert-base-multilingual-cased" | |
| MAX_LENGTH = 128 | |
| SEED = 42 | |
| def _trainer_with_tokenizer(tokenizer, **kwargs: Any) -> Trainer: | |
| """Construct Trainer with whichever tokenizer kwarg this transformers | |
| version expects (`processing_class` in 5.x, `tokenizer` in 4.x).""" | |
| params = inspect.signature(Trainer.__init__).parameters | |
| if "processing_class" in params: | |
| kwargs["processing_class"] = tokenizer | |
| elif "tokenizer" in params: | |
| kwargs["tokenizer"] = tokenizer | |
| return Trainer(**kwargs) | |
| def main() -> int: | |
| """Train XLM-R for intent classification. Returns exit code.""" | |
| parser = argparse.ArgumentParser(description=__doc__.split("\n")[0]) | |
| parser.add_argument("--epochs", type=int, default=5, | |
| help="Number of training epochs (default 5).") | |
| parser.add_argument("--batch-size", type=int, default=8, | |
| help="Per-device train batch size (default 8 — fits comfortably " | |
| "with distilbert-multilingual on a 3.6 GB GPU).") | |
| parser.add_argument("--quick", action="store_true", | |
| help="Sanity smoke test: 1 epoch, 500 train rows.") | |
| parser.add_argument("--lr", type=float, default=2e-5) | |
| parser.add_argument("--optim", type=str, default="adamw_torch", | |
| help="Optimizer name. AdamW fits with the smaller distilbert model.") | |
| args = parser.parse_args() | |
| set_seed(SEED) | |
| print("=" * 72) | |
| print("Train intent classifier (xlm-roberta-base, 6 classes)") | |
| print("=" * 72) | |
| print(f" Data dir : {DATA_DIR}") | |
| print(f" Out dir : {OUT_DIR}") | |
| print(f" Epochs : {args.epochs}{' (QUICK)' if args.quick else ''}") | |
| print(f" Batch : {args.batch_size} (effective ≈ {args.batch_size * 2} via accum)") | |
| print(f" Optimizer: {args.optim}") | |
| # --- Labels -------------------------------------------------------------- | |
| labels_payload = json.loads(LABELS_FILE.read_text()) | |
| label_to_id: dict[str, int] = labels_payload["label_to_id"] | |
| id_to_label: dict[int, str] = {int(k): v for k, v in labels_payload["id_to_label"].items()} | |
| label_names = [id_to_label[i] for i in range(len(id_to_label))] | |
| num_labels = len(label_names) | |
| print(f" Labels : {label_to_id}") | |
| # --- Datasets ------------------------------------------------------------ | |
| ds = load_from_disk(str(DATA_DIR)) | |
| print(f" Splits : train={len(ds['train'])} val={len(ds['validation'])} " | |
| f"test={len(ds['test'])}") | |
| if args.quick: | |
| ds["train"] = ds["train"].shuffle(seed=SEED).select(range(min(500, len(ds["train"])))) | |
| ds["validation"] = ds["validation"].select(range(min(120, len(ds["validation"])))) | |
| ds["test"] = ds["test"].select(range(min(120, len(ds["test"])))) | |
| print(f" QUICK : sliced to {len(ds['train'])}/{len(ds['validation'])}/{len(ds['test'])}") | |
| # --- Tokenize ------------------------------------------------------------ | |
| print("\nLoading tokenizer & model ...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| def tokenize(batch: dict[str, list]) -> dict[str, Any]: | |
| return tokenizer(batch["text"], truncation=True, max_length=MAX_LENGTH) | |
| drop_cols = [c for c in ds["train"].column_names if c not in ("label",)] | |
| ds_tok = ds.map(tokenize, batched=True, remove_columns=drop_cols, | |
| desc="Tokenizing") | |
| # --- Model --------------------------------------------------------------- | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| MODEL_NAME, | |
| num_labels=num_labels, | |
| id2label=id_to_label, | |
| label2id=label_to_id, | |
| ) | |
| # Free any lingering CUDA blocks before optimizer states are allocated. | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # --- Training arguments -------------------------------------------------- | |
| n_epochs = 1 if args.quick else args.epochs | |
| RUNS_DIR.mkdir(parents=True, exist_ok=True) | |
| training_args_kwargs = dict( | |
| output_dir=str(RUNS_DIR), | |
| num_train_epochs=n_epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| per_device_eval_batch_size=args.batch_size * 2, # eval has no grads -> larger ok | |
| gradient_accumulation_steps=2, # effective batch = batch * 2 = 16 (matches spec) | |
| optim=args.optim, | |
| learning_rate=args.lr, | |
| warmup_steps=100 if not args.quick else 10, | |
| weight_decay=0.01, | |
| fp16=True, | |
| gradient_checkpointing=True, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=1, # keep only the current-best checkpoint (was 2) | |
| save_only_model=True, # skip optimizer/scheduler state — saves ~340 MB/ckpt | |
| load_best_model_at_end=True, | |
| metric_for_best_model="f1_macro", # macro is more sensitive to minority classes | |
| greater_is_better=True, | |
| logging_steps=50, | |
| report_to="none", | |
| dataloader_num_workers=0, | |
| seed=SEED, | |
| ) | |
| try: | |
| training_args = TrainingArguments(**training_args_kwargs) | |
| except TypeError: | |
| training_args_kwargs["evaluation_strategy"] = training_args_kwargs.pop("eval_strategy") | |
| training_args = TrainingArguments(**training_args_kwargs) | |
| # --- Metrics ------------------------------------------------------------- | |
| def compute_metrics(eval_pred) -> dict[str, float]: | |
| logits, labels = eval_pred | |
| if isinstance(logits, tuple): | |
| logits = logits[0] | |
| preds = np.argmax(logits, axis=-1) | |
| return { | |
| "accuracy": accuracy_score(labels, preds), | |
| "f1": f1_score(labels, preds, average="weighted", zero_division=0), | |
| "f1_macro": f1_score(labels, preds, average="macro", zero_division=0), | |
| "precision": precision_score(labels, preds, average="weighted", zero_division=0), | |
| "recall": recall_score(labels, preds, average="weighted", zero_division=0), | |
| } | |
| # --- Trainer ------------------------------------------------------------- | |
| trainer = _trainer_with_tokenizer( | |
| tokenizer, | |
| model=model, | |
| args=training_args, | |
| train_dataset=ds_tok["train"], | |
| eval_dataset=ds_tok["validation"], | |
| data_collator=DataCollatorWithPadding(tokenizer), | |
| compute_metrics=compute_metrics, | |
| ) | |
| # --- Train --------------------------------------------------------------- | |
| print("\nStarting training ...") | |
| train_result = trainer.train() | |
| print(f" ✓ training done. final loss = {train_result.training_loss:.4f}") | |
| # --- Save best model ----------------------------------------------------- | |
| OUT_DIR.mkdir(parents=True, exist_ok=True) | |
| trainer.save_model(str(OUT_DIR)) | |
| tokenizer.save_pretrained(str(OUT_DIR)) | |
| shutil.copy(LABELS_FILE, OUT_DIR / "labels.json") | |
| # --- Final evaluation on test set --------------------------------------- | |
| print("\nEvaluating on TEST split ...") | |
| test_metrics = trainer.evaluate(ds_tok["test"], metric_key_prefix="test") | |
| test_pred = trainer.predict(ds_tok["test"]) | |
| if isinstance(test_pred.predictions, tuple): | |
| test_logits = test_pred.predictions[0] | |
| else: | |
| test_logits = test_pred.predictions | |
| pred_ids = np.argmax(test_logits, axis=-1) | |
| true_ids = test_pred.label_ids | |
| report_dict = classification_report( | |
| true_ids, pred_ids, | |
| labels=list(range(num_labels)), | |
| target_names=label_names, | |
| output_dict=True, zero_division=0, | |
| ) | |
| report_text = classification_report( | |
| true_ids, pred_ids, | |
| labels=list(range(num_labels)), | |
| target_names=label_names, | |
| zero_division=0, | |
| ) | |
| print("\nClassification report on TEST:") | |
| print(report_text) | |
| # --- Per-language breakdown --------------------------------------------- | |
| # Reload the test split with language column to break down errors by language | |
| test_with_lang = load_from_disk(str(DATA_DIR))["test"] | |
| if args.quick: | |
| test_with_lang = test_with_lang.select(range(min(120, len(test_with_lang)))) | |
| per_lang: dict[str, dict[str, float]] = {} | |
| if "language" in test_with_lang.column_names: | |
| languages = test_with_lang["language"] | |
| for lang in sorted(set(languages)): | |
| mask = np.array([la == lang for la in languages]) | |
| if not mask.any(): | |
| continue | |
| lp = pred_ids[mask] | |
| lt = true_ids[mask] | |
| per_lang[lang] = { | |
| "n": int(mask.sum()), | |
| "accuracy": float(accuracy_score(lt, lp)), | |
| "f1_weighted": float(f1_score(lt, lp, average="weighted", zero_division=0)), | |
| "f1_macro": float(f1_score(lt, lp, average="macro", zero_division=0)), | |
| } | |
| print("\nPer-language metrics on TEST:") | |
| for lang, m in per_lang.items(): | |
| print(f" {lang}: n={m['n']} acc={m['accuracy']:.4f} " | |
| f"f1_w={m['f1_weighted']:.4f} f1_m={m['f1_macro']:.4f}") | |
| # --- Save eval_results.json --------------------------------------------- | |
| payload = { | |
| "model_name": MODEL_NAME, | |
| "task": "intent", | |
| "num_labels": num_labels, | |
| "labels": label_to_id, | |
| "test_metrics": {k: float(v) for k, v in test_metrics.items() | |
| if isinstance(v, (int, float))}, | |
| "classification_report": report_dict, | |
| "per_language": per_lang, | |
| "training": { | |
| "epochs": n_epochs, | |
| "per_device_batch": args.batch_size, | |
| "grad_accum": 2, | |
| "effective_batch": args.batch_size * 2, | |
| "learning_rate": args.lr, | |
| "warmup_steps": training_args_kwargs.get("warmup_steps"), | |
| "fp16": True, | |
| "final_train_loss": float(train_result.training_loss), | |
| }, | |
| } | |
| (OUT_DIR / "eval_results.json").write_text( | |
| json.dumps(payload, indent=2, ensure_ascii=False) | |
| ) | |
| print(f"\n✓ Saved model to {OUT_DIR}") | |
| print(f"✓ Saved eval_results.json to {OUT_DIR / 'eval_results.json'}") | |
| return 0 | |
| if __name__ == "__main__": | |
| try: | |
| sys.exit(main()) | |
| except KeyboardInterrupt: | |
| print("\nAborted by user.") | |
| sys.exit(130) | |