from collections import Counter import numpy as np import torch import torch.nn.functional as F from sklearn.metrics import f1_score from torch import nn from torch.utils.data import DataLoader from transformers import Trainer from src.models.dataset import EntitySentimentDataset def compute_class_weights(examples: list[dict], n_classes: int) -> torch.Tensor: counts = Counter(e["label"] for e in examples) total = sum(counts.values()) weights = [total / (n_classes * counts.get(i, 1)) for i in range(n_classes)] return torch.tensor(weights, dtype=torch.float) def focal_loss( logits: torch.Tensor, labels: torch.Tensor, weight: torch.Tensor, gamma: float = 2.0, ) -> torch.Tensor: ce = F.cross_entropy(logits, labels, weight=weight, reduction="none") probs = F.softmax(logits, dim=-1) pt = probs.gather(1, labels.unsqueeze(1)).squeeze(1) return ((1 - pt) ** gamma * ce).mean() class WeightedLossTrainer(Trainer): def __init__(self, *args, class_weights: torch.Tensor, loss_fn: str = "cross_entropy", focal_gamma: float = 2.0, **kwargs): super().__init__(*args, **kwargs) self.class_weights = class_weights self.loss_fn = loss_fn self.focal_gamma = focal_gamma def compute_loss(self, model, inputs, return_outputs: bool = False, **kwargs): labels = inputs.pop("labels") outputs = model(**inputs) w = self.class_weights.to(outputs.logits.device) if self.loss_fn == "focal": loss = focal_loss(outputs.logits, labels, weight=w, gamma=self.focal_gamma) else: loss = nn.CrossEntropyLoss(weight=w)(outputs.logits, labels) return (loss, outputs) if return_outputs else loss def reconstruct_triplets( yes_probs: np.ndarray, bin_labels: np.ndarray ) -> tuple[list[int], list[int]]: """Group consecutive (neg, neu, pos) triplets and take argmax.""" preds3, labels3 = [], [] for i in range(0, len(yes_probs) - 2, 3): preds3.append(int(np.argmax(yes_probs[i: i + 3]))) labels3.append(int(np.argmax(bin_labels[i: i + 3]))) return preds3, labels3 def make_compute_metrics(mode: str): if mode in ("marker", "qa_m"): def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) macro_f1 = f1_score(labels, preds, average="macro") per_class = f1_score(labels, preds, average=None, labels=[0, 1, 2]) return { "macro_f1": macro_f1, "f1_negative": per_class[0], "f1_neutral": per_class[1], "f1_positive": per_class[2], } else: def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) bin_acc = float((preds == labels).mean()) bin_f1 = float(f1_score(labels, preds, average="binary", pos_label=1)) yes_probs = F.softmax( torch.tensor(logits, dtype=torch.float), dim=-1 )[:, 1].numpy() preds3, labels3 = reconstruct_triplets(yes_probs, labels) macro_f1 = float(f1_score(preds3, labels3, average="macro")) \ if preds3 else 0.0 return { "macro_f1": macro_f1, "bin_accuracy": bin_acc, "bin_f1_yes": bin_f1, } return compute_metrics def evaluate_qa_b_test( model, tokenizer, test_exs: list[dict], max_len: int, batch_size: int, device: torch.device, ) -> tuple[float, list[int], list[int]]: ds = EntitySentimentDataset(test_exs, tokenizer, max_len) loader = DataLoader(ds, batch_size=batch_size, shuffle=False) all_yes_probs, all_bin_labels = [], [] model.eval() with torch.no_grad(): for batch in loader: logits = model( input_ids=batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device), ).logits all_yes_probs.extend( F.softmax(logits, dim=-1)[:, 1].cpu().tolist() ) all_bin_labels.extend(batch["labels"].tolist()) preds3, labels3 = reconstruct_triplets( np.array(all_yes_probs), np.array(all_bin_labels) ) macro_f1 = f1_score(labels3, preds3, average="macro") return macro_f1, preds3, labels3