| from datasets import load_dataset |
| from sklearn.metrics import precision_score, recall_score, f1_score |
| from transformers import ( |
| AutoConfig, |
| AutoModelForSequenceClassification, |
| AutoTokenizer, |
| Trainer, |
| TrainingArguments, |
| ) |
| import argparse |
| import logging.config |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| from utils import default_logging_config, get_torch_device, show_class_distribution |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class FocalTrainer(Trainer): |
| def __init__( |
| self, |
| alpha=0.25, |
| gamma=2.0, |
| label_smoothing=0.0, |
| *args, |
| **kwargs |
| ): |
| """ |
| alpha: weight for positive examples in focal loss (0 < alpha <= 1) |
| gamma: focusing parameter |
| label_smoothing: how much to smooth the 0/1 labels |
| """ |
| super().__init__(*args, **kwargs) |
| self.alpha = alpha |
| self.gamma = gamma |
| self.label_smoothing = label_smoothing |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| """ |
| Forward pass + Focal loss with optional label smoothing. |
| """ |
| labels = inputs.get("labels") |
| outputs = model(**{k: v for k, v in inputs.items() if k != "labels"}) |
| logits = outputs.logits |
|
|
| |
| if self.label_smoothing > 0: |
| epsilon = self.label_smoothing |
| smoothed_labels = (1.0 - epsilon) * labels + epsilon * (1.0 - labels) |
| else: |
| smoothed_labels = labels.float() |
|
|
| |
| loss = focal_loss( |
| logits=logits, |
| targets=smoothed_labels, |
| alpha=self.alpha, |
| gamma=self.gamma, |
| reduction="mean", |
| ) |
|
|
| return (loss, outputs) if return_outputs else loss |
|
|
|
|
| def apply_label_smoothing(labels: torch.Tensor, epsilon: float) -> torch.Tensor: |
| """ |
| For multi-label: |
| 1 -> (1 - epsilon), 0 -> epsilon |
| """ |
| return (1.0 - epsilon) * labels + epsilon * (1.0 - labels) |
|
|
|
|
| def find_best_threshold(logits, labels, step=0.01, metric="f1_macro"): |
| """ |
| logits: np.array of shape [num_samples, num_labels] |
| labels: np.array of shape [num_samples, num_labels] |
| step: how fine-grained to search the thresholds |
| metric: which metric key to maximize, e.g. "f1_micro" or "f1_macro". |
| Returns: |
| best_thresh: float |
| best_scores: dict with the metrics at that threshold |
| """ |
| best_thresh = 0.5 |
| best_metric_val = 0.0 |
| best_scores = None |
|
|
| |
| probs = 1 / (1 + np.exp(-logits)) |
|
|
| thresholds = np.arange(0.0, 1.0 + step, step) |
| for t in thresholds: |
| preds = (probs >= t).astype(int) |
| precision_micro = precision_score(labels, preds, average="micro", zero_division=0) |
| recall_micro = recall_score(labels, preds, average="micro", zero_division=0) |
| f1_micro = f1_score(labels, preds, average="micro", zero_division=0) |
|
|
| precision_macro = precision_score(labels, preds, average="macro", zero_division=0) |
| recall_macro = recall_score(labels, preds, average="macro", zero_division=0) |
| f1_macro = f1_score(labels, preds, average="macro", zero_division=0) |
|
|
| scores = { |
| "precision_micro": precision_micro, |
| "recall_micro": recall_micro, |
| "f1_micro": f1_micro, |
| "precision_macro": precision_macro, |
| "recall_macro": recall_macro, |
| "f1_macro": f1_macro, |
| } |
|
|
| if scores[metric] > best_metric_val: |
| best_metric_val = scores[metric] |
| best_thresh = t |
| best_scores = scores |
|
|
| return best_thresh, best_scores |
|
|
|
|
| def find_per_label_thresholds(logits, labels, step=0.01): |
| """ |
| For each label individually, find the threshold in [0,1] that maximizes its own F1 |
| (or precision/recall, etc.). Then we'll store them as a vector of thresholds. |
| |
| Args: |
| logits: np.array, shape [num_samples, num_labels] |
| labels: np.array, shape [num_samples, num_labels] |
| step: float, step size to search thresholds from 0..1 |
| average_metric: 'f1_micro' or 'f1_macro' for final reference, |
| but note that we do label-by-label search here. |
| |
| Returns: |
| per_label_thresholds: array/list of shape [num_labels] |
| """ |
|
|
| num_samples, num_labels = logits.shape |
| probs = 1 / (1 + np.exp(-logits)) |
|
|
| per_label_thresholds = np.zeros(num_labels) |
|
|
| for i in range(num_labels): |
| label_probs = probs[:, i] |
| label_true = labels[:, i] |
|
|
| best_thresh = 0.5 |
| best_f1 = 0.0 |
| |
| for t in np.arange(0.0, 1.0 + step, step): |
| preds_i = (label_probs >= t).astype(int) |
| f1_i = f1_score(label_true, preds_i, zero_division=0) |
| if f1_i > best_f1: |
| best_f1 = f1_i |
| best_thresh = t |
|
|
| per_label_thresholds[i] = best_thresh |
|
|
| return per_label_thresholds |
|
|
|
|
| def focal_loss( |
| logits: torch.Tensor, |
| targets: torch.Tensor, |
| alpha: float = 0.25, |
| gamma: float = 2.0, |
| reduction: str = "mean", |
| ) -> torch.Tensor: |
| """ |
| Compute focal loss for multi-label classification. |
| |
| Args: |
| logits: [batch_size, num_labels] raw output of model |
| targets: [batch_size, num_labels] 0/1 (or smoothed) ground-truth labels |
| alpha: weighting factor for positive examples (0 < alpha <= 1) |
| gamma: exponent for down-weighting easy examples |
| reduction: "mean", "sum", or "none" to aggregate across batch/labels |
| |
| Returns: |
| A scalar loss (if reduction != "none"). |
| """ |
|
|
| |
| probs = torch.sigmoid(logits) |
|
|
| |
| bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") |
| |
|
|
| |
| |
| pt = targets * probs + (1 - targets) * (1 - probs) |
| |
|
|
| |
| |
| |
| alpha_factor = alpha * targets + (1 - alpha) * (1 - targets) |
| |
|
|
| focal_factor = alpha_factor * (1.0 - pt).pow(gamma) |
| |
|
|
| |
| loss = focal_factor * bce |
|
|
| if reduction == "mean": |
| return loss.mean() |
| elif reduction == "sum": |
| return loss.sum() |
| else: |
| return loss |
|
|
|
|
| def multi_hot_encode_labels(example, num_labels): |
| """ |
| Convert a list of label indices -> multi-hot vector of length `num_labels`. |
| """ |
| new_labels = [0.0] * num_labels |
| for lbl in example["orig_labels"]: |
| new_labels[lbl] = 1.0 |
| example["labels"] = new_labels |
| return example |
|
|
|
|
| def multi_label_metrics(eval_pred, threshold=0.5): |
| """ |
| eval_pred: (predictions, labels) from Trainer |
| - predictions: [batch_size, num_labels] raw logits |
| - labels: [batch_size, num_labels] binary 0/1 |
| threshold: single float in [0,1] for deciding positive/negative. |
| |
| We'll do: |
| 1) Sigmoid on logits -> probabilities |
| 2) Threshold at 0.5 |
| 3) Compare to ground truth |
| 4) Return micro/macro P/R/F1 |
| """ |
| logits, labels = eval_pred |
| probs = 1 / (1 + np.exp(-logits)) |
| preds = (probs >= threshold).astype(int) |
|
|
| |
| precision_micro = precision_score(labels, preds, average="micro", zero_division=0) |
| recall_micro = recall_score(labels, preds, average="micro", zero_division=0) |
| f1_micro = f1_score(labels, preds, average="micro", zero_division=0) |
|
|
| precision_macro = precision_score(labels, preds, average="macro", zero_division=0) |
| recall_macro = recall_score(labels, preds, average="macro", zero_division=0) |
| f1_macro = f1_score(labels, preds, average="macro", zero_division=0) |
|
|
| return { |
| "precision_micro": precision_micro, |
| "recall_micro": recall_micro, |
| "f1_micro": f1_micro, |
| "precision_macro": precision_macro, |
| "recall_macro": recall_macro, |
| "f1_macro": f1_macro, |
| } |
|
|
|
|
| def multi_label_metrics_per_label_thresholds(logits, labels, thresholds): |
| """ |
| Apply a custom threshold for each label, then compute micro/macro P/R/F1. |
| |
| Args: |
| logits: np.array, shape [num_samples, num_labels] |
| labels: np.array, shape [num_samples, num_labels] |
| thresholds: list/array of shape [num_labels] with each label's threshold |
| """ |
| probs = 1 / (1 + np.exp(-logits)) |
| num_samples, num_labels = probs.shape |
|
|
| |
| preds = np.zeros_like(labels, dtype=int) |
| for i in range(num_labels): |
| preds[:, i] = (probs[:, i] >= thresholds[i]).astype(int) |
|
|
| precision_micro = precision_score(labels, preds, average="micro", zero_division=0) |
| recall_micro = recall_score(labels, preds, average="micro", zero_division=0) |
| f1_micro = f1_score(labels, preds, average="micro", zero_division=0) |
|
|
| precision_macro = precision_score(labels, preds, average="macro", zero_division=0) |
| recall_macro = recall_score(labels, preds, average="macro", zero_division=0) |
| f1_macro = f1_score(labels, preds, average="macro", zero_division=0) |
|
|
| return { |
| "precision_micro": precision_micro, |
| "recall_micro": recall_micro, |
| "f1_micro": f1_micro, |
| "precision_macro": precision_macro, |
| "recall_macro": recall_macro, |
| "f1_macro": f1_macro, |
| } |
|
|
|
|
| def main(): |
| logging.config.dictConfig(default_logging_config) |
|
|
| args = parse_args() |
| logger.info(f"Arguments: {args}") |
|
|
| |
| logger.info("Loading the GoEmotions dataset (train/validation/test).") |
| dataset = load_dataset("go_emotions") |
|
|
| |
| dataset = dataset.rename_column("labels", "orig_labels") |
|
|
| |
| |
| |
| label_names = dataset["train"].features["orig_labels"].feature.names |
| num_labels = len(label_names) |
| logger.info(f"Detected num_labels={num_labels} from dataset metadata.") |
| logger.info(f"Label names = {label_names}") |
|
|
| |
| show_class_distribution(dataset, "train", label_names) |
| show_class_distribution(dataset, "validation", label_names) |
| show_class_distribution(dataset, "test", label_names) |
|
|
| |
| dataset = dataset.map(lambda ex: multi_hot_encode_labels(ex, num_labels=num_labels)) |
| |
| |
| dataset = dataset.remove_columns("orig_labels") |
|
|
| |
| logger.info(f"Loading tokenizer/model from {args.model_name_or_path}") |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) |
|
|
| config = AutoConfig.from_pretrained(args.model_name_or_path) |
| config.num_labels = num_labels |
| |
| config.problem_type = "multi_label_classification" |
|
|
| model = AutoModelForSequenceClassification.from_pretrained( |
| args.model_name_or_path, |
| config=config, |
| ) |
| logging.info(f"model problem type: {model.config.problem_type}") |
| torch_device = get_torch_device() |
| model.to(torch_device) |
|
|
| |
| logger.info("Tokenizing dataset...") |
| |
| def tokenize_func(examples): |
| return tokenizer( |
| examples["text"], |
| truncation=True, |
| max_length=512, |
| padding="max_length" |
| ) |
|
|
| tokenized_ds = dataset.map(tokenize_func, batched=True) |
| tokenized_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) |
|
|
| |
| logger.info("Preparing TrainingArguments and Trainer.") |
| training_args = TrainingArguments( |
| eval_strategy=args.eval_strategy if args.eval else "no", |
| learning_rate=args.learning_rate, |
| logging_dir=f"{args.output_dir}/logs", |
| logging_steps=100, |
| output_dir=args.output_dir, |
| overwrite_output_dir=True, |
| save_strategy=args.eval_strategy, |
|
|
| load_best_model_at_end=True, |
| metric_for_best_model="f1_macro", |
| greater_is_better=True, |
| num_train_epochs=args.num_epochs, |
| per_device_eval_batch_size=args.batch_size, |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| warmup_ratio=0.1, |
| weight_decay=0.01, |
| ) |
|
|
| trainer_class = FocalTrainer |
| trainer = trainer_class( |
| alpha=args.focal_alpha, |
| gamma=args.focal_gamma, |
| label_smoothing=args.label_smoothing, |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_ds["train"], |
| eval_dataset=tokenized_ds["validation"], |
| compute_metrics=multi_label_metrics, |
| ) |
|
|
| |
| if args.train: |
| logger.info("Starting training...") |
| trainer.train() |
| logger.info("Training complete. Saving final model.") |
| trainer.save_model(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
|
|
| |
| if args.eval or args.test: |
| eval_metrics = trainer.evaluate(eval_dataset=tokenized_ds["validation"]) |
| logger.info(f"Validation metrics: {eval_metrics}") |
|
|
| |
| logger.info("Searching for best global threshold on validation set...") |
| preds_output = trainer.predict(tokenized_ds["validation"]) |
| logits_val, labels_val = preds_output.predictions, preds_output.label_ids |
|
|
| best_thresh, best_scores = find_best_threshold(logits_val, labels_val, step=0.01, metric="f1_macro") |
| logger.info(f"Best global threshold={best_thresh:.2f}, best_scores={best_scores}") |
|
|
| |
| per_label_thresh = find_per_label_thresholds(logits_val, labels_val, step=0.01) |
| logger.info(f"Found per-label thresholds: {per_label_thresh}") |
|
|
| |
| val_metrics_per_label = multi_label_metrics_per_label_thresholds(logits_val, labels_val, per_label_thresh) |
| logger.info(f"Validation metrics with per-label thresholds: {val_metrics_per_label}") |
|
|
| |
| if args.test: |
| logger.info("Predicting on test set using threshold={best_thresh:.2f}...") |
| preds_test = trainer.predict(tokenized_ds["test"]) |
| logits_test, labels_test = preds_test.predictions, preds_test.label_ids |
|
|
| test_metrics_per_label = multi_label_metrics_per_label_thresholds(logits_test, labels_test, per_label_thresh) |
| logger.info(f"Test metrics with per-label thresholds: {test_metrics_per_label}") |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train multi-label classifier on GoEmotions using DeBERTa-v3-base.") |
| parser.add_argument("--output_dir", type=str, default="./focal_loss_new", |
| help="Output directory for checkpoints.") |
| parser.add_argument("--model_name_or_path", type=str, default="microsoft/deberta-v3-base", |
| help="Any valid HF model name (or local path) for sequence classification.") |
|
|
| parser.add_argument("--num_epochs", type=int, default=8, help="Number of training epochs.") |
| parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.") |
| parser.add_argument("--eval_strategy", default="steps", choices=["epoch", "steps"], |
| help="How frequently to do evaluation steps.") |
|
|
| parser.add_argument("--batch_size", type=int, default=8, help="Per-device batch size.") |
| parser.add_argument("--focal_alpha", type=float, default=0.25) |
| parser.add_argument("--focal_gamma", type=float, default=2.0) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=8, |
| help="Effective batch size = batch size x gradient accumulation steps.") |
| parser.add_argument("--label_smoothing", type=float, default=0.05, |
| help="Label smoothing factor for BCE loss (0 = no smoothing).") |
|
|
| parser.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) |
| parser.add_argument("--train", action="store_true", help="Whether to run training.") |
| parser.add_argument("--eval", action="store_true", help="Whether to run eval on validation split.") |
| parser.add_argument("--test", action="store_true", help="Whether to run predictions on test split.") |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|