Spaces:
Running
Running
| """DistilBERT fine-tuning and inference for intent classification.""" | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| import torch | |
| from datasets import Dataset # type: ignore | |
| from loguru import logger | |
| from sklearn.metrics import ( | |
| classification_report, | |
| confusion_matrix, | |
| f1_score, | |
| precision_score, | |
| recall_score, | |
| accuracy_score, | |
| ) | |
| from transformers import ( # type: ignore | |
| DistilBertForSequenceClassification, | |
| DistilBertTokenizerFast, | |
| EarlyStoppingCallback, | |
| Trainer, | |
| TrainingArguments, | |
| ) | |
| from src.data.dataset import INTENT_CATEGORIES | |
| LABEL2ID: Dict[str, int] = {label: idx for idx, label in enumerate(sorted(INTENT_CATEGORIES))} | |
| ID2LABEL: Dict[int, str] = {idx: label for label, idx in LABEL2ID.items()} | |
| def _tokenize(batch: dict, tokenizer, max_length: int) -> dict: | |
| return tokenizer( | |
| batch["text"], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=max_length, | |
| ) | |
| def _compute_metrics(eval_pred) -> Dict[str, float]: | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| return { | |
| "f1_weighted": f1_score(labels, preds, average="weighted"), | |
| "accuracy": accuracy_score(labels, preds), | |
| "precision_weighted": precision_score(labels, preds, average="weighted", zero_division=0), | |
| "recall_weighted": recall_score(labels, preds, average="weighted", zero_division=0), | |
| } | |
| def train( | |
| train_df: pd.DataFrame, | |
| val_df: pd.DataFrame, | |
| cfg: dict, | |
| save_dir: str, | |
| ) -> Trainer: | |
| """Fine-tune DistilBERT for sequence classification. | |
| Args: | |
| train_df: Training DataFrame with 'text' and 'label' columns. | |
| val_df: Validation DataFrame with 'text' and 'label' columns. | |
| cfg: Full config dict loaded from config.yaml. | |
| save_dir: Directory to save model checkpoints. | |
| Returns: | |
| Fitted HuggingFace Trainer. | |
| """ | |
| cc = cfg["classifier"] | |
| seed = cc["seed"] | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| use_gpu = torch.cuda.is_available() | |
| logger.info(f"GPU available: {use_gpu}") | |
| # On CPU, subsample training data to keep runtime feasible (~30 min) | |
| cpu_train_sample = cc.get("cpu_train_sample", 3000) | |
| if not use_gpu and len(train_df) > cpu_train_sample: | |
| logger.warning( | |
| f"No GPU detected. Subsampling training data to {cpu_train_sample} examples " | |
| f"(from {len(train_df)}) for feasible CPU training." | |
| ) | |
| from sklearn.model_selection import train_test_split as _tts | |
| _, train_df = _tts( | |
| train_df, | |
| test_size=min(cpu_train_sample / len(train_df), 0.9999), | |
| stratify=train_df["label"], | |
| random_state=seed, | |
| ) | |
| train_df = train_df.reset_index(drop=True) | |
| logger.info(f"Subsampled training set: {len(train_df)} examples") | |
| tokenizer = DistilBertTokenizerFast.from_pretrained(cc["model_name"]) | |
| model = DistilBertForSequenceClassification.from_pretrained( | |
| cc["model_name"], | |
| num_labels=cc["num_labels"], | |
| id2label=ID2LABEL, | |
| label2id=LABEL2ID, | |
| ) | |
| train_df = train_df.copy() | |
| val_df = val_df.copy() | |
| train_df["labels"] = train_df["label"].map(LABEL2ID) | |
| val_df["labels"] = val_df["label"].map(LABEL2ID) | |
| train_ds = Dataset.from_pandas(train_df[["text", "labels"]]) | |
| val_ds = Dataset.from_pandas(val_df[["text", "labels"]]) | |
| train_ds = train_ds.map( | |
| lambda b: _tokenize(b, tokenizer, cc["max_length"]), | |
| batched=True, | |
| remove_columns=["text"], | |
| ) | |
| val_ds = val_ds.map( | |
| lambda b: _tokenize(b, tokenizer, cc["max_length"]), | |
| batched=True, | |
| remove_columns=["text"], | |
| ) | |
| Path(save_dir).mkdir(parents=True, exist_ok=True) | |
| # Limit max_steps on CPU to avoid multi-hour runs | |
| cpu_max_steps = cc.get("cpu_max_steps", 300) | |
| steps_per_epoch = max(1, len(train_df) // cc["batch_size"]) | |
| total_steps = steps_per_epoch * cc["epochs"] | |
| effective_max_steps = total_steps if use_gpu else min(total_steps, cpu_max_steps) | |
| logger.info( | |
| f"Steps per epoch: {steps_per_epoch} | Total: {total_steps} | " | |
| f"Effective max_steps: {effective_max_steps}" | |
| ) | |
| training_args = TrainingArguments( | |
| output_dir=save_dir, | |
| max_steps=effective_max_steps, | |
| per_device_train_batch_size=cc["batch_size"], | |
| per_device_eval_batch_size=cc["batch_size"] * 2, | |
| learning_rate=cc["learning_rate"], | |
| weight_decay=cc["weight_decay"], | |
| warmup_steps=max(1, int(effective_max_steps * cc["warmup_ratio"])), | |
| eval_strategy="steps", | |
| eval_steps=max(10, effective_max_steps // 5), | |
| save_strategy="steps", | |
| save_steps=max(10, effective_max_steps // 5), | |
| load_best_model_at_end=cc["load_best_model_at_end"], | |
| metric_for_best_model=cc["metric_for_best_model"], | |
| greater_is_better=True, | |
| fp16=(cc["fp16"] and use_gpu), | |
| seed=seed, | |
| logging_steps=max(1, effective_max_steps // 20), | |
| report_to="none", | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_ds, | |
| eval_dataset=val_ds, | |
| compute_metrics=_compute_metrics, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=cc["early_stopping_patience"])], | |
| ) | |
| logger.info("Starting DistilBERT fine-tuning…") | |
| trainer.train() | |
| logger.info("Training complete.") | |
| # Save best model and tokenizer | |
| best_dir = Path(save_dir) / "best" | |
| trainer.save_model(str(best_dir)) | |
| tokenizer.save_pretrained(str(best_dir)) | |
| logger.info(f"Best model saved → {best_dir}") | |
| # Save training history | |
| history = trainer.state.log_history | |
| history_path = Path(save_dir) / "training_history.json" | |
| with open(history_path, "w") as f: | |
| json.dump(history, f, indent=2) | |
| _plot_training_curves(history, save_dir) | |
| return trainer | |
| def _plot_training_curves(history: list, save_dir: str) -> None: | |
| """Plot and save training loss and F1 curves. | |
| Args: | |
| history: List of log dicts from trainer.state.log_history. | |
| save_dir: Directory to save the PNG. | |
| """ | |
| train_steps, train_losses = [], [] | |
| eval_steps, eval_f1s, eval_losses = [], [], [] | |
| for entry in history: | |
| if "loss" in entry and "eval_loss" not in entry: | |
| train_steps.append(entry["step"]) | |
| train_losses.append(entry["loss"]) | |
| if "eval_loss" in entry: | |
| eval_steps.append(entry["step"]) | |
| eval_losses.append(entry["eval_loss"]) | |
| if "eval_f1_weighted" in entry: | |
| eval_f1s.append(entry["eval_f1_weighted"]) | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| axes[0].plot(train_steps, train_losses, label="Train Loss", color="steelblue") | |
| axes[0].plot(eval_steps, eval_losses, label="Val Loss", color="coral") | |
| axes[0].set_title("Training & Validation Loss") | |
| axes[0].set_xlabel("Step") | |
| axes[0].set_ylabel("Loss") | |
| axes[0].legend() | |
| if eval_f1s: | |
| axes[1].plot(eval_steps, eval_f1s, label="Val F1 (weighted)", color="green") | |
| axes[1].set_title("Validation F1 (Weighted)") | |
| axes[1].set_xlabel("Step") | |
| axes[1].set_ylabel("F1 Score") | |
| axes[1].legend() | |
| plt.tight_layout() | |
| path = Path(save_dir) / "training_curves.png" | |
| fig.savefig(path, dpi=150) | |
| plt.close(fig) | |
| logger.info(f"Saved training curves → {path}") | |
| def evaluate( | |
| model_dir: str, | |
| test_df: pd.DataFrame, | |
| results_dir: str, | |
| batch_size: int = 32, | |
| max_length: int = 128, | |
| ) -> Dict: | |
| """Run inference on the test set and save evaluation artifacts. | |
| Args: | |
| model_dir: Directory containing the saved best model/tokenizer. | |
| test_df: Test DataFrame with 'text' and 'label' columns. | |
| results_dir: Directory to save results. | |
| batch_size: Inference batch size. | |
| max_length: Max token length. | |
| Returns: | |
| Classification report dict. | |
| """ | |
| Path(results_dir).mkdir(parents=True, exist_ok=True) | |
| tokenizer, model = _load_model(model_dir) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| preds = _batch_predict(test_df["text"].tolist(), tokenizer, model, device, batch_size, max_length) | |
| labels_sorted = sorted(INTENT_CATEGORIES) | |
| report = classification_report( | |
| test_df["label"], preds, labels=labels_sorted, output_dict=True | |
| ) | |
| report_text = classification_report(test_df["label"], preds, labels=labels_sorted) | |
| logger.info(f"DistilBERT classification report:\n{report_text}") | |
| report_path = Path(results_dir) / "classification_report.json" | |
| with open(report_path, "w") as f: | |
| json.dump(report, f, indent=2) | |
| logger.info(f"Saved classification report → {report_path}") | |
| # Confusion matrix | |
| cm = confusion_matrix(test_df["label"], preds, labels=labels_sorted) | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| sns.heatmap( | |
| cm, | |
| annot=True, | |
| fmt="d", | |
| cmap="Blues", | |
| xticklabels=labels_sorted, | |
| yticklabels=labels_sorted, | |
| ax=ax, | |
| ) | |
| ax.set_title("DistilBERT Confusion Matrix") | |
| ax.set_xlabel("Predicted") | |
| ax.set_ylabel("True") | |
| plt.tight_layout() | |
| cm_path = Path(results_dir) / "confusion_matrix.png" | |
| fig.savefig(cm_path, dpi=150) | |
| plt.close(fig) | |
| logger.info(f"Saved confusion matrix → {cm_path}") | |
| return report | |
| def _load_model( | |
| model_dir: str, | |
| ) -> Tuple[DistilBertTokenizerFast, DistilBertForSequenceClassification]: | |
| """Load tokenizer and model from disk and return (tokenizer, model).""" | |
| tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir) | |
| model = DistilBertForSequenceClassification.from_pretrained(model_dir) | |
| return tokenizer, model | |
| def _batch_predict( | |
| texts: List[str], | |
| tokenizer, | |
| model, | |
| device, | |
| batch_size: int, | |
| max_length: int, | |
| ) -> List[str]: | |
| """Run batched inference and return predicted label strings.""" | |
| all_preds = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i : i + batch_size] | |
| enc = tokenizer( | |
| batch, | |
| padding=True, | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ).to(device) | |
| with torch.no_grad(): | |
| logits = model(**enc).logits | |
| pred_ids = logits.argmax(dim=-1).cpu().numpy() | |
| all_preds.extend([ID2LABEL[p] for p in pred_ids]) | |
| return all_preds | |
| class IntentClassifier: | |
| """Wrapper for DistilBERT intent classification inference. | |
| Args: | |
| model_dir: Path to saved model directory. | |
| max_length: Max token length for tokenizer. | |
| """ | |
| def __init__(self, model_dir: str, max_length: int = 128) -> None: | |
| self.max_length = max_length | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.tokenizer, self.model = _load_model(model_dir) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| logger.info(f"IntentClassifier loaded from {model_dir} on {self.device}") | |
| def predict(self, text: str) -> Tuple[str, float]: | |
| """Predict intent label and confidence for a single query. | |
| Args: | |
| text: Customer query string. | |
| Returns: | |
| (intent_label, confidence_score) tuple. | |
| """ | |
| enc = self.tokenizer( | |
| text, | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(**enc).logits | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] | |
| pred_id = int(np.argmax(probs)) | |
| return ID2LABEL[pred_id], float(probs[pred_id]) | |
| def predict_batch(self, texts: List[str]) -> List[Tuple[str, float]]: | |
| """Predict intents for a batch of queries. | |
| Args: | |
| texts: List of customer query strings. | |
| Returns: | |
| List of (intent_label, confidence_score) tuples. | |
| """ | |
| enc = self.tokenizer( | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(**enc).logits | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy() | |
| results = [] | |
| for row in probs: | |
| pred_id = int(np.argmax(row)) | |
| results.append((ID2LABEL[pred_id], float(row[pred_id]))) | |
| return results | |