| from datasets import load_dataset |
| from sklearn.metrics import precision_score, recall_score, f1_score |
| from transformers import ( |
| AutoConfig, |
| AutoModelForSequenceClassification, |
| AutoTokenizer, |
| Trainer, |
| TrainingArguments, |
| ) |
| import argparse |
| import logging.config |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from utils import get_torch_device |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class WeightedTrainer(Trainer): |
| def __init__(self, pos_weight=None, label_smoothing=0.0, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.label_smoothing = label_smoothing |
| self.pos_weight = pos_weight |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| """ |
| Override the default to apply BCEWithLogitsLoss + pos_weight manually. |
| """ |
| labels = inputs.get("labels") |
| outputs = model(**{k: v for k, v in inputs.items() if k != "labels"}) |
| logits = outputs.logits |
|
|
| |
| |
| |
| if self.label_smoothing > 0.0: |
| epsilon = self.label_smoothing |
| smoothed_labels = (1.0 - epsilon) * labels + epsilon * (1.0 - labels) |
| else: |
| smoothed_labels = labels.float() |
|
|
| |
| if self.pos_weight is not None: |
| loss_fct = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight.to(logits.device)) |
| else: |
| loss_fct = nn.BCEWithLogitsLoss() |
|
|
| loss = loss_fct(logits, smoothed_labels) |
|
|
| return (loss, outputs) if return_outputs else loss |
|
|
|
|
| def compute_class_weights(train_dataset): |
| """ |
| Compute per-label pos_weight for BCEWithLogitsLoss. |
| pos_weight[i] = (#negatives[i]) / (#positives[i]) (in training set) |
| |
| train_dataset: a HF dataset split with columns["labels"], |
| each 'labels' is size [num_labels]. |
| """ |
| all_labels = np.array(train_dataset["labels"]) |
| positives = all_labels.sum(axis=0) |
| negatives = len(all_labels) - positives |
|
|
| |
| full_ratio = negatives / (positives + 1e-5) |
|
|
| |
| |
| |
| pos_weight = torch.sqrt(torch.tensor(full_ratio)) |
|
|
| |
| pos_weight = torch.clamp(pos_weight, max=5.0) |
|
|
| return torch.as_tensor(pos_weight, dtype=torch.float) |
|
|
|
|
| def find_best_threshold(logits, labels, step=0.01, metric="f1_macro"): |
| """ |
| logits: np.array of shape [num_samples, num_labels] |
| labels: np.array of shape [num_samples, num_labels] |
| step: how fine-grained to search the thresholds |
| metric: which metric key to maximize, e.g. "f1_micro" or "f1_macro". |
| Returns: |
| best_thresh: float |
| best_scores: dict with the metrics at that threshold |
| """ |
| best_thresh = 0.5 |
| best_metric_val = 0.0 |
| best_scores = None |
|
|
| |
| probs = 1 / (1 + np.exp(-logits)) |
|
|
| thresholds = np.arange(0.0, 1.0 + step, step) |
| for t in thresholds: |
| preds = (probs >= t).astype(int) |
| precision_micro = precision_score(labels, preds, average="micro", zero_division=0) |
| recall_micro = recall_score(labels, preds, average="micro", zero_division=0) |
| f1_micro = f1_score(labels, preds, average="micro", zero_division=0) |
|
|
| precision_macro = precision_score(labels, preds, average="macro", zero_division=0) |
| recall_macro = recall_score(labels, preds, average="macro", zero_division=0) |
| f1_macro = f1_score(labels, preds, average="macro", zero_division=0) |
|
|
| scores = { |
| "precision_micro": precision_micro, |
| "recall_micro": recall_micro, |
| "f1_micro": f1_micro, |
| "precision_macro": precision_macro, |
| "recall_macro": recall_macro, |
| "f1_macro": f1_macro, |
| } |
|
|
| if scores[metric] > best_metric_val: |
| best_metric_val = scores[metric] |
| best_thresh = t |
| best_scores = scores |
|
|
| return best_thresh, best_scores |
|
|
|
|
| def find_per_label_thresholds(logits, labels, step=0.01): |
| """ |
| For each label individually, find the threshold in [0,1] that maximizes its own F1 |
| (or precision/recall, etc.). Then we'll store them as a vector of thresholds. |
| |
| Args: |
| logits: np.array, shape [num_samples, num_labels] |
| labels: np.array, shape [num_samples, num_labels] |
| step: float, step size to search thresholds from 0..1 |
| average_metric: 'f1_micro' or 'f1_macro' for final reference, |
| but note that we do label-by-label search here. |
| |
| Returns: |
| per_label_thresholds: array/list of shape [num_labels] |
| """ |
|
|
| num_samples, num_labels = logits.shape |
| probs = 1 / (1 + np.exp(-logits)) |
|
|
| per_label_thresholds = np.zeros(num_labels) |
|
|
| for i in range(num_labels): |
| label_probs = probs[:, i] |
| label_true = labels[:, i] |
|
|
| best_thresh = 0.5 |
| best_f1 = 0.0 |
| |
| for t in np.arange(0.0, 1.0 + step, step): |
| preds_i = (label_probs >= t).astype(int) |
| f1_i = f1_score(label_true, preds_i, zero_division=0) |
| if f1_i > best_f1: |
| best_f1 = f1_i |
| best_thresh = t |
|
|
| per_label_thresholds[i] = best_thresh |
|
|
| return per_label_thresholds |
|
|
|
|
| def multi_hot_encode_labels(example, num_labels): |
| """ |
| Convert a list of label indices -> multi-hot vector of length `num_labels`. |
| """ |
| new_labels = [0.0] * num_labels |
| for lbl in example["orig_labels"]: |
| new_labels[lbl] = 1.0 |
| example["labels"] = new_labels |
| return example |
|
|
|
|
| def multi_label_metrics(eval_pred, threshold=0.5): |
| """ |
| eval_pred: (predictions, labels) from Trainer |
| - predictions: [batch_size, num_labels] raw logits |
| - labels: [batch_size, num_labels] binary 0/1 |
| threshold: single float in [0,1] for deciding positive/negative. |
| |
| We'll do: |
| 1) Sigmoid on logits -> probabilities |
| 2) Threshold at 0.5 |
| 3) Compare to ground truth |
| 4) Return micro/macro P/R/F1 |
| """ |
| logits, labels = eval_pred |
| probs = 1 / (1 + np.exp(-logits)) |
| preds = (probs >= threshold).astype(int) |
|
|
| |
| precision_micro = precision_score(labels, preds, average="micro", zero_division=0) |
| recall_micro = recall_score(labels, preds, average="micro", zero_division=0) |
| f1_micro = f1_score(labels, preds, average="micro", zero_division=0) |
|
|
| precision_macro = precision_score(labels, preds, average="macro", zero_division=0) |
| recall_macro = recall_score(labels, preds, average="macro", zero_division=0) |
| f1_macro = f1_score(labels, preds, average="macro", zero_division=0) |
|
|
| return { |
| "precision_micro": precision_micro, |
| "recall_micro": recall_micro, |
| "f1_micro": f1_micro, |
| "precision_macro": precision_macro, |
| "recall_macro": recall_macro, |
| "f1_macro": f1_macro, |
| } |
|
|
|
|
| def multi_label_metrics_per_label_thresholds(logits, labels, thresholds): |
| """ |
| Apply a custom threshold for each label, then compute micro/macro P/R/F1. |
| |
| Args: |
| logits: np.array, shape [num_samples, num_labels] |
| labels: np.array, shape [num_samples, num_labels] |
| thresholds: list/array of shape [num_labels] with each label's threshold |
| """ |
| probs = 1 / (1 + np.exp(-logits)) |
| num_samples, num_labels = probs.shape |
|
|
| |
| preds = np.zeros_like(labels, dtype=int) |
| for i in range(num_labels): |
| preds[:, i] = (probs[:, i] >= thresholds[i]).astype(int) |
|
|
| precision_micro = precision_score(labels, preds, average="micro", zero_division=0) |
| recall_micro = recall_score(labels, preds, average="micro", zero_division=0) |
| f1_micro = f1_score(labels, preds, average="micro", zero_division=0) |
|
|
| precision_macro = precision_score(labels, preds, average="macro", zero_division=0) |
| recall_macro = recall_score(labels, preds, average="macro", zero_division=0) |
| f1_macro = f1_score(labels, preds, average="macro", zero_division=0) |
|
|
| return { |
| "precision_micro": precision_micro, |
| "recall_micro": recall_micro, |
| "f1_micro": f1_micro, |
| "precision_macro": precision_macro, |
| "recall_macro": recall_macro, |
| "f1_macro": f1_macro, |
| } |
|
|
|
|
| def main(): |
| args = parse_args() |
| setup_logging(args.log_level) |
| logger.info(f"Arguments: {args}") |
|
|
| |
| logger.info("Loading the GoEmotions dataset (train/validation/test).") |
| dataset = load_dataset("go_emotions") |
|
|
| |
| dataset = dataset.rename_column("labels", "orig_labels") |
|
|
| |
| |
| |
| label_names = dataset["train"].features["orig_labels"].feature.names |
| num_labels = len(label_names) |
| logger.info(f"Detected num_labels={num_labels} from dataset metadata.") |
| logger.info(f"Label names = {label_names}") |
|
|
| |
| show_class_distribution(dataset, "train", label_names) |
| show_class_distribution(dataset, "validation", label_names) |
| show_class_distribution(dataset, "test", label_names) |
|
|
| |
| dataset = dataset.map(lambda ex: multi_hot_encode_labels(ex, num_labels=num_labels)) |
| |
| |
| dataset = dataset.remove_columns("orig_labels") |
|
|
| |
| logger.info(f"Loading tokenizer/model from {args.model_name_or_path}") |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) |
|
|
| config = AutoConfig.from_pretrained(args.model_name_or_path) |
| config.num_labels = num_labels |
| |
| config.problem_type = "multi_label_classification" |
|
|
| model = AutoModelForSequenceClassification.from_pretrained( |
| args.model_name_or_path, |
| config=config, |
| ) |
| logging.info(f"model problem type: {model.config.problem_type}") |
| torch_device = get_torch_device() |
| model.to(torch_device) |
|
|
| |
| logger.info("Tokenizing dataset...") |
| |
| def tokenize_func(examples): |
| return tokenizer( |
| examples["text"], |
| truncation=True, |
| max_length=512, |
| padding="max_length" |
| ) |
|
|
| tokenized_ds = dataset.map(tokenize_func, batched=True) |
| tokenized_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) |
|
|
| |
| logger.info("Preparing TrainingArguments and Trainer.") |
| training_args = TrainingArguments( |
| eval_strategy=args.eval_strategy if args.eval else "no", |
| learning_rate=args.learning_rate, |
| logging_dir=f"{args.output_dir}/logs", |
| logging_steps=100, |
| output_dir=args.output_dir, |
| overwrite_output_dir=True, |
| save_strategy=args.eval_strategy, |
|
|
| load_best_model_at_end=True, |
| metric_for_best_model="f1_macro", |
| greater_is_better=True, |
| num_train_epochs=args.num_epochs, |
| per_device_eval_batch_size=args.batch_size, |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| warmup_ratio=0.1, |
| weight_decay=0.01, |
| ) |
|
|
| logger.info("Computing class weights from train dataset...") |
| pos_weight = compute_class_weights(tokenized_ds["train"]) |
|
|
| trainer_class = WeightedTrainer if pos_weight is not None else Trainer |
| trainer = trainer_class( |
| pos_weight=pos_weight, |
| label_smoothing=args.label_smoothing, |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_ds["train"] if args.train else None, |
| eval_dataset=tokenized_ds["validation"] if args.eval else None, |
| compute_metrics=multi_label_metrics, |
| ) |
|
|
| |
| if args.train: |
| logger.info("Starting training...") |
| trainer.train() |
| logger.info("Training complete. Saving final model.") |
| trainer.save_model(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
|
|
| |
| if args.eval or args.test: |
| eval_metrics = trainer.evaluate(eval_dataset=tokenized_ds["validation"]) |
| logger.info(f"Validation metrics: {eval_metrics}") |
|
|
| |
| logger.info("Searching for best global threshold on validation set...") |
| preds_output = trainer.predict(tokenized_ds["validation"]) |
| logits_val, labels_val = preds_output.predictions, preds_output.label_ids |
|
|
| best_thresh, best_scores = find_best_threshold(logits_val, labels_val, step=0.01, metric="f1_macro") |
| logger.info(f"Best global threshold={best_thresh:.2f}, best_scores={best_scores}") |
|
|
| |
| per_label_thresh = find_per_label_thresholds(logits_val, labels_val, step=0.01) |
| logger.info(f"Found per-label thresholds: {per_label_thresh}") |
|
|
| |
| val_metrics_per_label = multi_label_metrics_per_label_thresholds(logits_val, labels_val, per_label_thresh) |
| logger.info(f"Validation metrics with per-label thresholds: {val_metrics_per_label}") |
|
|
| model.config.best_global_threshold = float(best_thresh) |
| model.config.per_label_thresholds = [float(x) for x in per_label_thresh] |
| model.config.label_names = label_names |
|
|
| |
| model.save_pretrained(args.output_dir) |
|
|
| |
| if args.test: |
| logger.info("Predicting on test set using threshold={best_thresh:.2f}...") |
| preds_test = trainer.predict(tokenized_ds["test"]) |
| logits_test, labels_test = preds_test.predictions, preds_test.label_ids |
|
|
| test_metrics_per_label = multi_label_metrics_per_label_thresholds(logits_test, labels_test, per_label_thresh) |
| logger.info(f"Test metrics with per-label thresholds: {test_metrics_per_label}") |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train multi-label classifier on GoEmotions using DeBERTa-v3-base.") |
| parser.add_argument("--output_dir", type=str, default="./pos_weight_new", |
| help="Output directory for checkpoints.") |
| parser.add_argument("--model_name_or_path", type=str, default="microsoft/deberta-v3-base", |
| help="Any valid HF model name (or local path) for sequence classification.") |
|
|
| parser.add_argument("--num_epochs", type=int, default=8, help="Number of training epochs.") |
| parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.") |
| parser.add_argument("--eval_strategy", default="steps", choices=["epoch", "steps"], |
| help="How frequently to do evaluation steps.") |
|
|
| parser.add_argument("--batch_size", type=int, default=8, help="Per-device batch size.") |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=8, |
| help="Effective batch size = batch size x gradient accumulation steps.") |
| parser.add_argument("--label_smoothing", type=float, default=0.05, |
| help="Label smoothing factor for BCE loss (0 = no smoothing).") |
|
|
| parser.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) |
| parser.add_argument("--train", action="store_true", help="Whether to run training.") |
| parser.add_argument("--eval", action="store_true", help="Whether to run eval on validation split.") |
| parser.add_argument("--test", action="store_true", help="Whether to run predictions on test split.") |
| return parser.parse_args() |
|
|
|
|
| def setup_logging(log_level): |
| logging.config.dictConfig( |
| { |
| "version": 1, |
| "disable_existing_loggers": False, |
| "formatters": { |
| "default": { |
| "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| }, |
| }, |
| "handlers": { |
| "console": { |
| "class": "logging.StreamHandler", |
| "formatter": "default", |
| }, |
| }, |
| "loggers": { |
| "": { |
| "level": log_level, |
| "handlers": ["console"], |
| }, |
| }, |
| } |
| ) |
|
|
|
|
| def show_class_distribution(dataset, split_name, label_names): |
| """ |
| Print how many samples contain each label in the chosen split. |
| This helps identify imbalance. |
| - dataset[split_name] is a huggingface Dataset |
| - label_names: list of label names in the dataset |
| """ |
| from collections import Counter |
| label_counter = Counter() |
| num_samples = len(dataset[split_name]) |
|
|
| |
| for ex in dataset[split_name]["orig_labels"]: |
| label_counter.update(ex) |
|
|
| logger.info(f"\n--- Class distribution for split '{split_name}' ({num_samples} samples) ---") |
| for idx, label_name in enumerate(label_names): |
| logger.info(f"{idx:02d} ({label_name}): count = {label_counter[idx]}") |
| logger.info("---------------------------------------------------------------\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|