deberta-goemotions / pos_weight_trainer.py
veryfansome's picture
feat: adding training scripts
ef613cf
from datasets import load_dataset
from sklearn.metrics import precision_score, recall_score, f1_score
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
import argparse
import logging.config
import numpy as np
import torch
import torch.nn as nn
from utils import get_torch_device
logger = logging.getLogger(__name__)
class WeightedTrainer(Trainer):
def __init__(self, pos_weight=None, label_smoothing=0.0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.label_smoothing = label_smoothing
self.pos_weight = pos_weight
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
"""
Override the default to apply BCEWithLogitsLoss + pos_weight manually.
"""
labels = inputs.get("labels")
outputs = model(**{k: v for k, v in inputs.items() if k != "labels"})
logits = outputs.logits
# If we do label smoothing, push 1->(1 - epsilon), 0->epsilon
# Example: if smoothing=0.1, then 1->0.9, 0->0.1
# In practice, for multi-label tasks, you might prefer smaller smoothing, e.g. 0.05 or 0.1
if self.label_smoothing > 0.0:
epsilon = self.label_smoothing
smoothed_labels = (1.0 - epsilon) * labels + epsilon * (1.0 - labels)
else:
smoothed_labels = labels.float()
# Build our BCEWithLogitsLoss
if self.pos_weight is not None:
loss_fct = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight.to(logits.device))
else:
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, smoothed_labels)
return (loss, outputs) if return_outputs else loss
def compute_class_weights(train_dataset):
"""
Compute per-label pos_weight for BCEWithLogitsLoss.
pos_weight[i] = (#negatives[i]) / (#positives[i]) (in training set)
train_dataset: a HF dataset split with columns["labels"],
each 'labels' is size [num_labels].
"""
all_labels = np.array(train_dataset["labels"]) # shape: (num_samples, num_labels)
positives = all_labels.sum(axis=0)
negatives = len(all_labels) - positives
# Avoid dividing by zero:
full_ratio = negatives / (positives + 1e-5)
# When labels have extreme frequency differences (e.g. “grief” with 77 vs. “neutral” with 14,219 in train),
# using linear pos_weight = neg/pos can produce very large weights for the rare labels—often leading the model
# to “spam” positives and cause low precision. Square‐root weighting is a softer approach.
pos_weight = torch.sqrt(torch.tensor(full_ratio))
# Alternatively, you can clip or cap the maximum weight so that no label’s pos_weight exceeds, say, 3 or 5.
pos_weight = torch.clamp(pos_weight, max=5.0) # Capping
return torch.as_tensor(pos_weight, dtype=torch.float)
def find_best_threshold(logits, labels, step=0.01, metric="f1_macro"):
"""
logits: np.array of shape [num_samples, num_labels]
labels: np.array of shape [num_samples, num_labels]
step: how fine-grained to search the thresholds
metric: which metric key to maximize, e.g. "f1_micro" or "f1_macro".
Returns:
best_thresh: float
best_scores: dict with the metrics at that threshold
"""
best_thresh = 0.5
best_metric_val = 0.0
best_scores = None
# Convert logits -> probs once for efficiency
probs = 1 / (1 + np.exp(-logits))
thresholds = np.arange(0.0, 1.0 + step, step)
for t in thresholds:
preds = (probs >= t).astype(int)
precision_micro = precision_score(labels, preds, average="micro", zero_division=0)
recall_micro = recall_score(labels, preds, average="micro", zero_division=0)
f1_micro = f1_score(labels, preds, average="micro", zero_division=0)
precision_macro = precision_score(labels, preds, average="macro", zero_division=0)
recall_macro = recall_score(labels, preds, average="macro", zero_division=0)
f1_macro = f1_score(labels, preds, average="macro", zero_division=0)
scores = {
"precision_micro": precision_micro,
"recall_micro": recall_micro,
"f1_micro": f1_micro,
"precision_macro": precision_macro,
"recall_macro": recall_macro,
"f1_macro": f1_macro,
}
if scores[metric] > best_metric_val:
best_metric_val = scores[metric]
best_thresh = t
best_scores = scores
return best_thresh, best_scores
def find_per_label_thresholds(logits, labels, step=0.01):
"""
For each label individually, find the threshold in [0,1] that maximizes its own F1
(or precision/recall, etc.). Then we'll store them as a vector of thresholds.
Args:
logits: np.array, shape [num_samples, num_labels]
labels: np.array, shape [num_samples, num_labels]
step: float, step size to search thresholds from 0..1
average_metric: 'f1_micro' or 'f1_macro' for final reference,
but note that we do label-by-label search here.
Returns:
per_label_thresholds: array/list of shape [num_labels]
"""
num_samples, num_labels = logits.shape
probs = 1 / (1 + np.exp(-logits))
per_label_thresholds = np.zeros(num_labels)
for i in range(num_labels):
label_probs = probs[:, i]
label_true = labels[:, i]
best_thresh = 0.5
best_f1 = 0.0
# We only focus on label i's F1, ignoring other labels
for t in np.arange(0.0, 1.0 + step, step):
preds_i = (label_probs >= t).astype(int)
f1_i = f1_score(label_true, preds_i, zero_division=0)
if f1_i > best_f1:
best_f1 = f1_i
best_thresh = t
per_label_thresholds[i] = best_thresh
return per_label_thresholds
def multi_hot_encode_labels(example, num_labels):
"""
Convert a list of label indices -> multi-hot vector of length `num_labels`.
"""
new_labels = [0.0] * num_labels
for lbl in example["orig_labels"]:
new_labels[lbl] = 1.0
example["labels"] = new_labels
return example
def multi_label_metrics(eval_pred, threshold=0.5):
"""
eval_pred: (predictions, labels) from Trainer
- predictions: [batch_size, num_labels] raw logits
- labels: [batch_size, num_labels] binary 0/1
threshold: single float in [0,1] for deciding positive/negative.
We'll do:
1) Sigmoid on logits -> probabilities
2) Threshold at 0.5
3) Compare to ground truth
4) Return micro/macro P/R/F1
"""
logits, labels = eval_pred
probs = 1 / (1 + np.exp(-logits)) # sigmoid
preds = (probs >= threshold).astype(int)
# Micro-averaged P/R/F counts total true positives, false positives, false negatives across all classes
precision_micro = precision_score(labels, preds, average="micro", zero_division=0)
recall_micro = recall_score(labels, preds, average="micro", zero_division=0)
f1_micro = f1_score(labels, preds, average="micro", zero_division=0)
precision_macro = precision_score(labels, preds, average="macro", zero_division=0)
recall_macro = recall_score(labels, preds, average="macro", zero_division=0)
f1_macro = f1_score(labels, preds, average="macro", zero_division=0)
return {
"precision_micro": precision_micro,
"recall_micro": recall_micro,
"f1_micro": f1_micro,
"precision_macro": precision_macro,
"recall_macro": recall_macro,
"f1_macro": f1_macro,
}
def multi_label_metrics_per_label_thresholds(logits, labels, thresholds):
"""
Apply a custom threshold for each label, then compute micro/macro P/R/F1.
Args:
logits: np.array, shape [num_samples, num_labels]
labels: np.array, shape [num_samples, num_labels]
thresholds: list/array of shape [num_labels] with each label's threshold
"""
probs = 1 / (1 + np.exp(-logits))
num_samples, num_labels = probs.shape
# We'll build preds array of same shape
preds = np.zeros_like(labels, dtype=int)
for i in range(num_labels):
preds[:, i] = (probs[:, i] >= thresholds[i]).astype(int)
precision_micro = precision_score(labels, preds, average="micro", zero_division=0)
recall_micro = recall_score(labels, preds, average="micro", zero_division=0)
f1_micro = f1_score(labels, preds, average="micro", zero_division=0)
precision_macro = precision_score(labels, preds, average="macro", zero_division=0)
recall_macro = recall_score(labels, preds, average="macro", zero_division=0)
f1_macro = f1_score(labels, preds, average="macro", zero_division=0)
return {
"precision_micro": precision_micro,
"recall_micro": recall_micro,
"f1_micro": f1_micro,
"precision_macro": precision_macro,
"recall_macro": recall_macro,
"f1_macro": f1_macro,
}
def main():
args = parse_args()
setup_logging(args.log_level)
logger.info(f"Arguments: {args}")
# 1. Load the GoEmotions dataset
logger.info("Loading the GoEmotions dataset (train/validation/test).")
dataset = load_dataset("go_emotions")
# 2. Rename "labels" -> "orig_labels"
dataset = dataset.rename_column("labels", "orig_labels")
# The "go_emotions" dataset typically includes label metadata:
# dataset["train"].features["orig_labels"].feature.names
# This is a list of label names.
label_names = dataset["train"].features["orig_labels"].feature.names
num_labels = len(label_names)
logger.info(f"Detected num_labels={num_labels} from dataset metadata.")
logger.info(f"Label names = {label_names}")
# Show distribution for train/val/test
show_class_distribution(dataset, "train", label_names)
show_class_distribution(dataset, "validation", label_names)
show_class_distribution(dataset, "test", label_names)
# Now map your multi-hot function onto a brand-new 'labels' column (as floats):
dataset = dataset.map(lambda ex: multi_hot_encode_labels(ex, num_labels=num_labels))
# Splits: dataset["train"], dataset["validation"], dataset["test"]
# Each sample has 'text' (string) and 'labels' (list of emotion IDs).
dataset = dataset.remove_columns("orig_labels")
# 2. Create tokenizer & model
logger.info(f"Loading tokenizer/model from {args.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
config = AutoConfig.from_pretrained(args.model_name_or_path)
config.num_labels = num_labels
# problem_type="multi_label_classification" ensures we get a BCEWithLogits loss at training time
config.problem_type = "multi_label_classification"
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path,
config=config,
)
logging.info(f"model problem type: {model.config.problem_type}")
torch_device = get_torch_device()
model.to(torch_device)
# 3. Tokenize the dataset
logger.info("Tokenizing dataset...")
# We'll map a tokenize_function over the entire dataset
def tokenize_func(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=512,
padding="max_length"
)
tokenized_ds = dataset.map(tokenize_func, batched=True)
tokenized_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
# 5. Prepare trainer
logger.info("Preparing TrainingArguments and Trainer.")
training_args = TrainingArguments(
eval_strategy=args.eval_strategy if args.eval else "no",
learning_rate=args.learning_rate,
logging_dir=f"{args.output_dir}/logs",
logging_steps=100,
output_dir=args.output_dir,
overwrite_output_dir=True, # For demo. In production, you might not want to overwrite.
save_strategy=args.eval_strategy,
load_best_model_at_end=True,
metric_for_best_model="f1_macro",
greater_is_better=True,
num_train_epochs=args.num_epochs,
per_device_eval_batch_size=args.batch_size,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_ratio=0.1,
weight_decay=0.01,
)
logger.info("Computing class weights from train dataset...")
pos_weight = compute_class_weights(tokenized_ds["train"])
trainer_class = WeightedTrainer if pos_weight is not None else Trainer
trainer = trainer_class(
pos_weight=pos_weight,
label_smoothing=args.label_smoothing,
model=model,
args=training_args,
train_dataset=tokenized_ds["train"] if args.train else None,
eval_dataset=tokenized_ds["validation"] if args.eval else None,
compute_metrics=multi_label_metrics,
)
# 6. Training
if args.train:
logger.info("Starting training...")
trainer.train()
logger.info("Training complete. Saving final model.")
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# 7. Validation / Evaluate + threshold search
if args.eval or args.test:
eval_metrics = trainer.evaluate(eval_dataset=tokenized_ds["validation"])
logger.info(f"Validation metrics: {eval_metrics}")
# Now do threshold search explicitly
logger.info("Searching for best global threshold on validation set...")
preds_output = trainer.predict(tokenized_ds["validation"])
logits_val, labels_val = preds_output.predictions, preds_output.label_ids
best_thresh, best_scores = find_best_threshold(logits_val, labels_val, step=0.01, metric="f1_macro")
logger.info(f"Best global threshold={best_thresh:.2f}, best_scores={best_scores}")
# Per-label threshold search:
per_label_thresh = find_per_label_thresholds(logits_val, labels_val, step=0.01)
logger.info(f"Found per-label thresholds: {per_label_thresh}")
# Evaluate how well that does on validation:
val_metrics_per_label = multi_label_metrics_per_label_thresholds(logits_val, labels_val, per_label_thresh)
logger.info(f"Validation metrics with per-label thresholds: {val_metrics_per_label}")
model.config.best_global_threshold = float(best_thresh) # cast to float for JSON-safe
model.config.per_label_thresholds = [float(x) for x in per_label_thresh]
model.config.label_names = label_names
# Re-save the model/config (overwrites the existing config.json in output_dir).
model.save_pretrained(args.output_dir)
# 8. Test / Prediction
if args.test:
logger.info("Predicting on test set using threshold={best_thresh:.2f}...")
preds_test = trainer.predict(tokenized_ds["test"])
logits_test, labels_test = preds_test.predictions, preds_test.label_ids
test_metrics_per_label = multi_label_metrics_per_label_thresholds(logits_test, labels_test, per_label_thresh)
logger.info(f"Test metrics with per-label thresholds: {test_metrics_per_label}")
def parse_args():
parser = argparse.ArgumentParser(description="Train multi-label classifier on GoEmotions using DeBERTa-v3-base.")
parser.add_argument("--output_dir", type=str, default="./pos_weight_new",
help="Output directory for checkpoints.")
parser.add_argument("--model_name_or_path", type=str, default="microsoft/deberta-v3-base",
help="Any valid HF model name (or local path) for sequence classification.")
parser.add_argument("--num_epochs", type=int, default=8, help="Number of training epochs.")
parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
parser.add_argument("--eval_strategy", default="steps", choices=["epoch", "steps"],
help="How frequently to do evaluation steps.")
parser.add_argument("--batch_size", type=int, default=8, help="Per-device batch size.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=8,
help="Effective batch size = batch size x gradient accumulation steps.")
parser.add_argument("--label_smoothing", type=float, default=0.05,
help="Label smoothing factor for BCE loss (0 = no smoothing).")
parser.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
parser.add_argument("--train", action="store_true", help="Whether to run training.")
parser.add_argument("--eval", action="store_true", help="Whether to run eval on validation split.")
parser.add_argument("--test", action="store_true", help="Whether to run predictions on test split.")
return parser.parse_args()
def setup_logging(log_level):
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"loggers": {
"": {
"level": log_level,
"handlers": ["console"],
},
},
}
)
def show_class_distribution(dataset, split_name, label_names):
"""
Print how many samples contain each label in the chosen split.
This helps identify imbalance.
- dataset[split_name] is a huggingface Dataset
- label_names: list of label names in the dataset
"""
from collections import Counter
label_counter = Counter()
num_samples = len(dataset[split_name])
# Each sample's `orig_labels` is a list of label indices
for ex in dataset[split_name]["orig_labels"]:
label_counter.update(ex)
logger.info(f"\n--- Class distribution for split '{split_name}' ({num_samples} samples) ---")
for idx, label_name in enumerate(label_names):
logger.info(f"{idx:02d} ({label_name}): count = {label_counter[idx]}")
logger.info("---------------------------------------------------------------\n")
if __name__ == "__main__":
main()