| import math |
| from torch import nn |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| class SentimentWeightedLoss(nn.Module): |
| """BCEWithLogits + dynamic weighting. |
| |
| We weight each sample by: |
| • length_weight: sqrt(num_tokens) / sqrt(max_tokens) |
| • confidence_weight: |sigmoid(logits) - 0.5| (higher confidence ⇒ larger weight) |
| |
| The two weights are combined multiplicatively then normalized. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| |
| self.bce = nn.BCEWithLogitsLoss(reduction="none") |
| self.min_len_weight_sqrt = 0.1 |
|
|
| def forward(self, logits, targets, lengths): |
| base_loss = self.bce(logits.view(-1), targets.float()) |
| |
| prob = torch.sigmoid(logits.view(-1)) |
| confidence_weight = (prob - 0.5).abs() * 2 |
|
|
| if lengths.numel() == 0: |
| |
| |
| return torch.tensor(0.0, device=logits.device, requires_grad=logits.requires_grad) |
| |
| length_weight = torch.sqrt(lengths.float()) / math.sqrt(lengths.max().item()) |
| length_weight = length_weight.clamp(self.min_len_weight_sqrt, 1.0) |
|
|
| weights = confidence_weight * length_weight |
| weights = weights / (weights.mean() + 1e-8) |
| return (base_loss * weights).mean() |
|
|
|
|
|
|
|
|
| class SentimentFocalLoss(nn.Module): |
| """ |
| This loss function incorporates: |
| 1. Base BCEWithLogitsLoss. |
| 2. Label Smoothing. |
| 3. Focal Loss modulation to focus more on hard examples (can be reversed to focus on easy examples). |
| 4. Sample weighting based on review length. |
| 5. Sample weighting based on prediction confidence. |
| |
| The final loss for each sample is calculated roughly as: |
| Loss_sample = FocalModulator(pt, gamma) * BCE(logits, smoothed_targets) * NormalizedExternalWeight |
| NormalizedExternalWeight = (ConfidenceWeight * LengthWeight) / Mean(ConfidenceWeight * LengthWeight) |
| """ |
|
|
| def __init__(self, gamma_focal: float = 0.1, label_smoothing_epsilon: float = 0.05): |
| """ |
| Args: |
| gamma_focal (float): Gamma parameter for Focal Loss. |
| - If gamma_focal > 0 (e.g., 2.0), applies standard Focal Loss, |
| down-weighting easy examples (focus on hard examples). |
| - If gamma_focal < 0 (e.g., -2.0), applies a reversed Focal Loss, |
| down-weighting hard examples (focus on easy examples by up-weighting pt). |
| - If gamma_focal = 0, no Focal Loss modulation is applied. |
| label_smoothing_epsilon (float): Epsilon for label smoothing. (0.0 <= epsilon < 1.0) |
| - If 0.0, no label smoothing is applied. Converts hard labels (0, 1) |
| to soft labels (epsilon, 1-epsilon). |
| """ |
| super().__init__() |
| if not (0.0 <= label_smoothing_epsilon < 1.0): |
| raise ValueError("label_smoothing_epsilon must be between 0.0 and <1.0.") |
| |
| self.gamma_focal = gamma_focal |
| self.label_smoothing_epsilon = label_smoothing_epsilon |
| |
| self.bce_loss_no_reduction = nn.BCEWithLogitsLoss(reduction="none") |
|
|
| def forward(self, logits: torch.Tensor, targets: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: |
| """ |
| Computes the custom loss. |
| |
| Args: |
| logits (torch.Tensor): Raw logits from the model. Expected shape [B] or [B, 1]. |
| targets (torch.Tensor): Ground truth labels (0 or 1). Expected shape [B] or [B, 1]. |
| lengths (torch.Tensor): Number of tokens in each review. Expected shape [B]. |
| |
| Returns: |
| torch.Tensor: The computed scalar loss. |
| """ |
| B = logits.size(0) |
| if B == 0: |
| return torch.tensor(0.0, device=logits.device, requires_grad=True) |
|
|
| logits_flat = logits.view(-1) |
| original_targets_flat = targets.view(-1).float() |
|
|
| |
| if self.label_smoothing_epsilon > 0: |
| |
| targets_for_bce = original_targets_flat * (1.0 - self.label_smoothing_epsilon) + \ |
| (1.0 - original_targets_flat) * self.label_smoothing_epsilon |
| else: |
| targets_for_bce = original_targets_flat |
|
|
| |
| base_bce_loss_terms = self.bce_loss_no_reduction(logits_flat, targets_for_bce) |
|
|
| |
| |
| probs = torch.sigmoid(logits_flat) |
| |
| pt = torch.where(original_targets_flat.bool(), probs, 1.0 - probs) |
|
|
| focal_modulator = torch.ones_like(pt) |
| if self.gamma_focal > 0: |
| focal_modulator = (1.0 - pt + 1e-8).pow(self.gamma_focal) |
| elif self.gamma_focal < 0: |
| focal_modulator = (pt + 1e-8).pow(abs(self.gamma_focal)) |
| |
| modulated_loss_terms = focal_modulator * base_bce_loss_terms |
|
|
| |
| |
| confidence_w = (probs - 0.5).abs() * 2.0 |
|
|
| |
| lengths_flat = lengths.view(-1).float() |
| max_len_in_batch = lengths_flat.max().item() |
| |
| if max_len_in_batch == 0: |
| length_w = torch.ones_like(lengths_flat) |
| else: |
| |
| length_w = torch.sqrt(lengths_flat) / (math.sqrt(max_len_in_batch) + 1e-8) |
| length_w = torch.clamp(length_w, 0.0, 1.0) |
|
|
| |
| |
| external_weights = confidence_w * length_w |
| |
| |
| |
| if external_weights.sum() > 1e-8: |
| normalized_external_weights = external_weights / (external_weights.mean() + 1e-8) |
| else: |
| normalized_external_weights = torch.ones_like(external_weights) |
|
|
| |
| final_loss_terms_per_sample = modulated_loss_terms * normalized_external_weights |
| |
| |
| loss = final_loss_terms_per_sample.mean() |
| |
| return loss |
|
|