deberta-goemotions / focal_loss_trainer.py
veryfansome's picture
feat: adding training scripts
ef613cf
from datasets import load_dataset
from sklearn.metrics import precision_score, recall_score, f1_score
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
import argparse
import logging.config
import numpy as np
import torch
import torch.nn.functional as F
from utils import default_logging_config, get_torch_device, show_class_distribution
logger = logging.getLogger(__name__)
class FocalTrainer(Trainer):
def __init__(
self,
alpha=0.25,
gamma=2.0,
label_smoothing=0.0,
*args,
**kwargs
):
"""
alpha: weight for positive examples in focal loss (0 < alpha <= 1)
gamma: focusing parameter
label_smoothing: how much to smooth the 0/1 labels
"""
super().__init__(*args, **kwargs)
self.alpha = alpha
self.gamma = gamma
self.label_smoothing = label_smoothing
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
"""
Forward pass + Focal loss with optional label smoothing.
"""
labels = inputs.get("labels")
outputs = model(**{k: v for k, v in inputs.items() if k != "labels"})
logits = outputs.logits
# 1) Apply label smoothing if needed
if self.label_smoothing > 0:
epsilon = self.label_smoothing
smoothed_labels = (1.0 - epsilon) * labels + epsilon * (1.0 - labels)
else:
smoothed_labels = labels.float()
# 2) Compute focal loss
loss = focal_loss(
logits=logits,
targets=smoothed_labels,
alpha=self.alpha,
gamma=self.gamma,
reduction="mean", # or "sum"
)
return (loss, outputs) if return_outputs else loss
def apply_label_smoothing(labels: torch.Tensor, epsilon: float) -> torch.Tensor:
"""
For multi-label:
1 -> (1 - epsilon), 0 -> epsilon
"""
return (1.0 - epsilon) * labels + epsilon * (1.0 - labels)
def find_best_threshold(logits, labels, step=0.01, metric="f1_macro"):
"""
logits: np.array of shape [num_samples, num_labels]
labels: np.array of shape [num_samples, num_labels]
step: how fine-grained to search the thresholds
metric: which metric key to maximize, e.g. "f1_micro" or "f1_macro".
Returns:
best_thresh: float
best_scores: dict with the metrics at that threshold
"""
best_thresh = 0.5
best_metric_val = 0.0
best_scores = None
# Convert logits -> probs once for efficiency
probs = 1 / (1 + np.exp(-logits))
thresholds = np.arange(0.0, 1.0 + step, step)
for t in thresholds:
preds = (probs >= t).astype(int)
precision_micro = precision_score(labels, preds, average="micro", zero_division=0)
recall_micro = recall_score(labels, preds, average="micro", zero_division=0)
f1_micro = f1_score(labels, preds, average="micro", zero_division=0)
precision_macro = precision_score(labels, preds, average="macro", zero_division=0)
recall_macro = recall_score(labels, preds, average="macro", zero_division=0)
f1_macro = f1_score(labels, preds, average="macro", zero_division=0)
scores = {
"precision_micro": precision_micro,
"recall_micro": recall_micro,
"f1_micro": f1_micro,
"precision_macro": precision_macro,
"recall_macro": recall_macro,
"f1_macro": f1_macro,
}
if scores[metric] > best_metric_val:
best_metric_val = scores[metric]
best_thresh = t
best_scores = scores
return best_thresh, best_scores
def find_per_label_thresholds(logits, labels, step=0.01):
"""
For each label individually, find the threshold in [0,1] that maximizes its own F1
(or precision/recall, etc.). Then we'll store them as a vector of thresholds.
Args:
logits: np.array, shape [num_samples, num_labels]
labels: np.array, shape [num_samples, num_labels]
step: float, step size to search thresholds from 0..1
average_metric: 'f1_micro' or 'f1_macro' for final reference,
but note that we do label-by-label search here.
Returns:
per_label_thresholds: array/list of shape [num_labels]
"""
num_samples, num_labels = logits.shape
probs = 1 / (1 + np.exp(-logits))
per_label_thresholds = np.zeros(num_labels)
for i in range(num_labels):
label_probs = probs[:, i]
label_true = labels[:, i]
best_thresh = 0.5
best_f1 = 0.0
# We only focus on label i's F1, ignoring other labels
for t in np.arange(0.0, 1.0 + step, step):
preds_i = (label_probs >= t).astype(int)
f1_i = f1_score(label_true, preds_i, zero_division=0)
if f1_i > best_f1:
best_f1 = f1_i
best_thresh = t
per_label_thresholds[i] = best_thresh
return per_label_thresholds
def focal_loss(
logits: torch.Tensor,
targets: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "mean",
) -> torch.Tensor:
"""
Compute focal loss for multi-label classification.
Args:
logits: [batch_size, num_labels] raw output of model
targets: [batch_size, num_labels] 0/1 (or smoothed) ground-truth labels
alpha: weighting factor for positive examples (0 < alpha <= 1)
gamma: exponent for down-weighting easy examples
reduction: "mean", "sum", or "none" to aggregate across batch/labels
Returns:
A scalar loss (if reduction != "none").
"""
# 1) Sigmoid to get probabilities
probs = torch.sigmoid(logits)
# 2) Compute binary cross-entropy element-wise (no reduction yet)
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
# "bce" is shape [batch, num_labels]
# 3) For focal loss, let p_t = p if y=1 else (1-p) if y=0
# We can compute p_t elementwise:
pt = targets * probs + (1 - targets) * (1 - probs)
# shape [batch, num_labels]
# 4) Apply the focal term: alpha * (1 - pt)^gamma
# Usually, alpha is used for the positives; for multi-label, we might
# treat alpha as the weight for positives only. Here's a simple approach:
alpha_factor = alpha * targets + (1 - alpha) * (1 - targets)
# shape [batch, num_labels]
focal_factor = alpha_factor * (1.0 - pt).pow(gamma)
# shape [batch, num_labels]
# 5) Final focal loss is focal_factor * bce
loss = focal_factor * bce # shape [batch, num_labels]
if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
else:
return loss # shape [batch, num_labels]
def multi_hot_encode_labels(example, num_labels):
"""
Convert a list of label indices -> multi-hot vector of length `num_labels`.
"""
new_labels = [0.0] * num_labels
for lbl in example["orig_labels"]:
new_labels[lbl] = 1.0
example["labels"] = new_labels
return example
def multi_label_metrics(eval_pred, threshold=0.5):
"""
eval_pred: (predictions, labels) from Trainer
- predictions: [batch_size, num_labels] raw logits
- labels: [batch_size, num_labels] binary 0/1
threshold: single float in [0,1] for deciding positive/negative.
We'll do:
1) Sigmoid on logits -> probabilities
2) Threshold at 0.5
3) Compare to ground truth
4) Return micro/macro P/R/F1
"""
logits, labels = eval_pred
probs = 1 / (1 + np.exp(-logits)) # sigmoid
preds = (probs >= threshold).astype(int)
# Micro-averaged P/R/F counts total true positives, false positives, false negatives across all classes
precision_micro = precision_score(labels, preds, average="micro", zero_division=0)
recall_micro = recall_score(labels, preds, average="micro", zero_division=0)
f1_micro = f1_score(labels, preds, average="micro", zero_division=0)
precision_macro = precision_score(labels, preds, average="macro", zero_division=0)
recall_macro = recall_score(labels, preds, average="macro", zero_division=0)
f1_macro = f1_score(labels, preds, average="macro", zero_division=0)
return {
"precision_micro": precision_micro,
"recall_micro": recall_micro,
"f1_micro": f1_micro,
"precision_macro": precision_macro,
"recall_macro": recall_macro,
"f1_macro": f1_macro,
}
def multi_label_metrics_per_label_thresholds(logits, labels, thresholds):
"""
Apply a custom threshold for each label, then compute micro/macro P/R/F1.
Args:
logits: np.array, shape [num_samples, num_labels]
labels: np.array, shape [num_samples, num_labels]
thresholds: list/array of shape [num_labels] with each label's threshold
"""
probs = 1 / (1 + np.exp(-logits))
num_samples, num_labels = probs.shape
# We'll build preds array of same shape
preds = np.zeros_like(labels, dtype=int)
for i in range(num_labels):
preds[:, i] = (probs[:, i] >= thresholds[i]).astype(int)
precision_micro = precision_score(labels, preds, average="micro", zero_division=0)
recall_micro = recall_score(labels, preds, average="micro", zero_division=0)
f1_micro = f1_score(labels, preds, average="micro", zero_division=0)
precision_macro = precision_score(labels, preds, average="macro", zero_division=0)
recall_macro = recall_score(labels, preds, average="macro", zero_division=0)
f1_macro = f1_score(labels, preds, average="macro", zero_division=0)
return {
"precision_micro": precision_micro,
"recall_micro": recall_micro,
"f1_micro": f1_micro,
"precision_macro": precision_macro,
"recall_macro": recall_macro,
"f1_macro": f1_macro,
}
def main():
logging.config.dictConfig(default_logging_config)
args = parse_args()
logger.info(f"Arguments: {args}")
# 1. Load the GoEmotions dataset
logger.info("Loading the GoEmotions dataset (train/validation/test).")
dataset = load_dataset("go_emotions")
# 2. Rename "labels" -> "orig_labels"
dataset = dataset.rename_column("labels", "orig_labels")
# The "go_emotions" dataset typically includes label metadata:
# dataset["train"].features["orig_labels"].feature.names
# This is a list of label names.
label_names = dataset["train"].features["orig_labels"].feature.names
num_labels = len(label_names)
logger.info(f"Detected num_labels={num_labels} from dataset metadata.")
logger.info(f"Label names = {label_names}")
# Show distribution for train/val/test
show_class_distribution(dataset, "train", label_names)
show_class_distribution(dataset, "validation", label_names)
show_class_distribution(dataset, "test", label_names)
# Now map your multi-hot function onto a brand-new 'labels' column (as floats):
dataset = dataset.map(lambda ex: multi_hot_encode_labels(ex, num_labels=num_labels))
# Splits: dataset["train"], dataset["validation"], dataset["test"]
# Each sample has 'text' (string) and 'labels' (list of emotion IDs).
dataset = dataset.remove_columns("orig_labels")
# 2. Create tokenizer & model
logger.info(f"Loading tokenizer/model from {args.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
config = AutoConfig.from_pretrained(args.model_name_or_path)
config.num_labels = num_labels
# problem_type="multi_label_classification" ensures we get a BCEWithLogits loss at training time
config.problem_type = "multi_label_classification"
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path,
config=config,
)
logging.info(f"model problem type: {model.config.problem_type}")
torch_device = get_torch_device()
model.to(torch_device)
# 3. Tokenize the dataset
logger.info("Tokenizing dataset...")
# We'll map a tokenize_function over the entire dataset
def tokenize_func(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=512,
padding="max_length"
)
tokenized_ds = dataset.map(tokenize_func, batched=True)
tokenized_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
# 5. Prepare trainer
logger.info("Preparing TrainingArguments and Trainer.")
training_args = TrainingArguments(
eval_strategy=args.eval_strategy if args.eval else "no",
learning_rate=args.learning_rate,
logging_dir=f"{args.output_dir}/logs",
logging_steps=100,
output_dir=args.output_dir,
overwrite_output_dir=True, # For demo. In production, you might not want to overwrite.
save_strategy=args.eval_strategy,
load_best_model_at_end=True,
metric_for_best_model="f1_macro",
greater_is_better=True,
num_train_epochs=args.num_epochs,
per_device_eval_batch_size=args.batch_size,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_ratio=0.1,
weight_decay=0.01,
)
trainer_class = FocalTrainer
trainer = trainer_class(
alpha=args.focal_alpha,
gamma=args.focal_gamma,
label_smoothing=args.label_smoothing,
model=model,
args=training_args,
train_dataset=tokenized_ds["train"],
eval_dataset=tokenized_ds["validation"],
compute_metrics=multi_label_metrics,
)
# 6. Training
if args.train:
logger.info("Starting training...")
trainer.train()
logger.info("Training complete. Saving final model.")
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# 7. Validation / Evaluate + threshold search
if args.eval or args.test:
eval_metrics = trainer.evaluate(eval_dataset=tokenized_ds["validation"])
logger.info(f"Validation metrics: {eval_metrics}")
# Now do threshold search explicitly
logger.info("Searching for best global threshold on validation set...")
preds_output = trainer.predict(tokenized_ds["validation"])
logits_val, labels_val = preds_output.predictions, preds_output.label_ids
best_thresh, best_scores = find_best_threshold(logits_val, labels_val, step=0.01, metric="f1_macro")
logger.info(f"Best global threshold={best_thresh:.2f}, best_scores={best_scores}")
# Per-label threshold search:
per_label_thresh = find_per_label_thresholds(logits_val, labels_val, step=0.01)
logger.info(f"Found per-label thresholds: {per_label_thresh}")
# Evaluate how well that does on validation:
val_metrics_per_label = multi_label_metrics_per_label_thresholds(logits_val, labels_val, per_label_thresh)
logger.info(f"Validation metrics with per-label thresholds: {val_metrics_per_label}")
# 8. Test / Prediction
if args.test:
logger.info("Predicting on test set using threshold={best_thresh:.2f}...")
preds_test = trainer.predict(tokenized_ds["test"])
logits_test, labels_test = preds_test.predictions, preds_test.label_ids
test_metrics_per_label = multi_label_metrics_per_label_thresholds(logits_test, labels_test, per_label_thresh)
logger.info(f"Test metrics with per-label thresholds: {test_metrics_per_label}")
def parse_args():
parser = argparse.ArgumentParser(description="Train multi-label classifier on GoEmotions using DeBERTa-v3-base.")
parser.add_argument("--output_dir", type=str, default="./focal_loss_new",
help="Output directory for checkpoints.")
parser.add_argument("--model_name_or_path", type=str, default="microsoft/deberta-v3-base",
help="Any valid HF model name (or local path) for sequence classification.")
parser.add_argument("--num_epochs", type=int, default=8, help="Number of training epochs.")
parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
parser.add_argument("--eval_strategy", default="steps", choices=["epoch", "steps"],
help="How frequently to do evaluation steps.")
parser.add_argument("--batch_size", type=int, default=8, help="Per-device batch size.")
parser.add_argument("--focal_alpha", type=float, default=0.25)
parser.add_argument("--focal_gamma", type=float, default=2.0)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8,
help="Effective batch size = batch size x gradient accumulation steps.")
parser.add_argument("--label_smoothing", type=float, default=0.05,
help="Label smoothing factor for BCE loss (0 = no smoothing).")
parser.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
parser.add_argument("--train", action="store_true", help="Whether to run training.")
parser.add_argument("--eval", action="store_true", help="Whether to run eval on validation split.")
parser.add_argument("--test", action="store_true", help="Whether to run predictions on test split.")
return parser.parse_args()
if __name__ == "__main__":
main()