|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class MaskedBCELoss(torch.nn.Module): |
|
|
""" |
|
|
Binary Cross-Entropy loss with explicit masking support. |
|
|
|
|
|
This loss function computes the binary cross-entropy over valid (non-padded) |
|
|
elements only, as indicated by a boolean mask. It supports both logits and |
|
|
probability inputs, and provides configurable reduction strategies. |
|
|
|
|
|
Masking semantics can be adapted to match PyTorch-style padding conventions |
|
|
or custom masking schemes. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
reduction: str = 'mean', |
|
|
valid_pad: bool = True, |
|
|
eps: float = 1e-7, |
|
|
logits: bool = True |
|
|
): |
|
|
""" |
|
|
Initialize the masked binary cross-entropy loss. |
|
|
|
|
|
Args: |
|
|
reduction (str, optional): Reduction method applied over valid |
|
|
elements. Must be either `'mean'` or `'sum'`. Defaults to `'mean'`. |
|
|
valid_pad (bool, optional): Mask interpretation mode. If True, |
|
|
`True` values in the mask indicate valid (non-padded) positions. |
|
|
If False, `True` values indicate padded positions, following |
|
|
PyTorch-style padding conventions. Defaults to True. |
|
|
eps (float, optional): Small numerical constant used to clamp |
|
|
probability inputs when `logits=False`. Defaults to 1e-7. |
|
|
logits (bool, optional): Whether the input predictions are logits. |
|
|
If True, `binary_cross_entropy_with_logits` is used; otherwise, |
|
|
standard binary cross-entropy is applied. Defaults to True. |
|
|
|
|
|
Raises: |
|
|
ValueError: If an unsupported reduction mode is provided. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
if reduction not in ['mean', 'sum']: |
|
|
raise ValueError("[MASKED-BCE] Reduction must be 'mean' or 'sum'") |
|
|
|
|
|
self.reduction = reduction |
|
|
self.valid_pad = valid_pad |
|
|
self.logits = logits |
|
|
self.eps = eps |
|
|
|
|
|
if logits: |
|
|
self.loss = torch.nn.functional.binary_cross_entropy_with_logits |
|
|
else: |
|
|
self.loss = torch.nn.functional.binary_cross_entropy |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
y: torch.Tensor, |
|
|
mask: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute the masked binary cross-entropy loss. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Model predictions with shape (B, S). If |
|
|
`logits=True`, values are interpreted as logits; otherwise, |
|
|
as probabilities in [0, 1]. |
|
|
y (torch.Tensor): Ground-truth binary labels with shape (B, S). |
|
|
mask (torch.Tensor): Boolean mask tensor with shape (B, S). |
|
|
The interpretation of the mask depends on `valid_pad`. |
|
|
If `valid_pad=True`, `True` indicates valid positions. |
|
|
If `valid_pad=False`, `True` indicates padded positions. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Scalar tensor containing the reduced loss value. |
|
|
""" |
|
|
|
|
|
if self.valid_pad: |
|
|
valid_mask = mask |
|
|
else: |
|
|
valid_mask = torch.logical_not(mask) |
|
|
|
|
|
|
|
|
if not self.logits: |
|
|
x = x.clamp(self.eps, 1.0 - self.eps) |
|
|
|
|
|
|
|
|
loss_per_token = self.loss( |
|
|
x.float(), |
|
|
y.float(), |
|
|
reduction='none' |
|
|
) |
|
|
|
|
|
|
|
|
masked_loss = loss_per_token * valid_mask.float() |
|
|
|
|
|
if self.reduction == 'mean': |
|
|
denom = valid_mask.sum().clamp(min=1) |
|
|
return masked_loss.sum() / denom |
|
|
elif self.reduction == 'sum': |
|
|
return masked_loss.sum() |
|
|
else: |
|
|
raise ValueError("[MASKED-BCE] Reduction must be 'mean' or 'sum'") |
|
|
|
|
|
|
|
|
class WindowDiffLoss(torch.nn.Module): |
|
|
""" |
|
|
WindowDiff loss function for sequence-to-sequence models. |
|
|
This loss function computes the difference between two sequences |
|
|
using a sliding window approach, allowing for partial matches. |
|
|
|
|
|
Why emphasize? |
|
|
|
|
|
We want to equalize the following formula: |
|
|
|
|
|
Being y a vector composed by 0 and 1 values where 1 is the positive class... |
|
|
|
|
|
mean(y) = 0.5 |
|
|
|
|
|
This means that positive and negative classes are equally represented in the loss. However, we have unbalanced |
|
|
data, so we want to emphasize the positive class in the loss calculation, so: |
|
|
|
|
|
mean(y) != 0.5 |
|
|
|
|
|
Let k be a constant that compensates the imbalance, then we want to equalize the following formula: |
|
|
|
|
|
mean(y * k) = 0.5 |
|
|
k * mean(y) = 0.5 |
|
|
k = 0.5 / mean(y) |
|
|
k = 0.5 * len(y) / sum(y) |
|
|
|
|
|
We call k the emphasis factor, and it is applied to the loss calculation to emphasize the positive class. |
|
|
""" |
|
|
|
|
|
def __init__(self, k: int = 1, normalize: bool = False, relaxed: bool = False): |
|
|
""" |
|
|
Initializes the WindowDiff loss function. |
|
|
:param k: Window size. |
|
|
:param normalize: If True, normalize the loss by the window size k. |
|
|
:param relaxed: If True, use a relaxed version of the WindowDiff loss. |
|
|
""" |
|
|
super(WindowDiffLoss, self).__init__() |
|
|
self.k = k |
|
|
self.normalize = normalize |
|
|
self.relaxed = relaxed |
|
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor, label_mask: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass of the WindowDiff loss function. |
|
|
:param x: Hypothesis logits or probabilities (B, S) |
|
|
:param y: Ground truth binary sequence (B, S) |
|
|
:param label_mask: Binary mask indicating valid labels (B, S) |
|
|
:return: Scalar loss (float tensor). |
|
|
""" |
|
|
if self.relaxed: |
|
|
return masked_window_diff_loss(x, y, label_mask, self.k, self.normalize) |
|
|
else: |
|
|
return original_window_diff(x, y, label_mask, self.k) |
|
|
|
|
|
|
|
|
def masked_window_diff_loss(x: torch.Tensor, y: torch.Tensor, label_mask: torch.Tensor, k: int, |
|
|
normalize: bool = False, emphasis: bool = False) -> torch.Tensor: |
|
|
""" |
|
|
Computes differentiable WindowDiff loss across a batch with per-sample variable-length candidate projections. |
|
|
|
|
|
:param x: (B, S) predicted logits or probabilities |
|
|
:param y: (B, S) ground truth binary labels |
|
|
:param label_mask: (B, S) boolean mask with valid candidate positions |
|
|
:param k: window size |
|
|
:param normalize: whether to divide each diff by k |
|
|
:param emphasis: whether to emphasize positive class |
|
|
:return: scalar tensor loss |
|
|
""" |
|
|
x = x.float() |
|
|
y = y.float() |
|
|
total_loss = 0.0 |
|
|
valid_count = 0 |
|
|
|
|
|
B = x.size(0) |
|
|
for b in range(B): |
|
|
mask_b = label_mask[b].bool() |
|
|
x_b = x[b][mask_b] |
|
|
y_b = y[b][mask_b] |
|
|
|
|
|
if x_b.numel() < k: |
|
|
continue |
|
|
|
|
|
if emphasis: |
|
|
emph = 0.5 * len(y_b) / ((y_b == 1).float().sum() + 1e-6) |
|
|
x_b = emph * x_b |
|
|
y_b = emph * y_b |
|
|
|
|
|
|
|
|
x_win = x_b.unfold(0, k, 1).sum(dim=1) |
|
|
y_win = y_b.unfold(0, k, 1).sum(dim=1) |
|
|
diff = (x_win - y_win).abs() |
|
|
|
|
|
if normalize: |
|
|
diff = diff / k |
|
|
|
|
|
total_loss += diff.mean() |
|
|
valid_count += 1 |
|
|
|
|
|
if valid_count == 0: |
|
|
return torch.tensor(0.0, device=x.device, requires_grad=True) |
|
|
else: |
|
|
return torch.tensor(total_loss / valid_count) |
|
|
|
|
|
|
|
|
def original_window_diff( |
|
|
hyp: torch.Tensor, |
|
|
ref: torch.Tensor, |
|
|
label_mask: torch.Tensor, |
|
|
k: int |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
WindowDiff original (no diferenciable), versión batch con máscara. |
|
|
|
|
|
:param hyp: (B, S) hipótesis binaria {0,1} |
|
|
:param ref: (B, S) referencia binaria {0,1} |
|
|
:param label_mask: (B, S) máscara de posiciones válidas |
|
|
:param k: tamaño de ventana |
|
|
:return: escalar torch.Tensor |
|
|
""" |
|
|
hyp = hyp.int() |
|
|
ref = ref.int() |
|
|
|
|
|
total_errors = 0 |
|
|
total_windows = 0 |
|
|
|
|
|
B = hyp.size(0) |
|
|
for b in range(B): |
|
|
mask_b = label_mask[b].bool() |
|
|
h = hyp[b][mask_b] |
|
|
r = ref[b][mask_b] |
|
|
|
|
|
n = h.numel() |
|
|
if n < k: |
|
|
continue |
|
|
|
|
|
|
|
|
h_win = h.unfold(0, k, 1).sum(dim=1) |
|
|
r_win = r.unfold(0, k, 1).sum(dim=1) |
|
|
|
|
|
|
|
|
errors = (h_win != r_win).int() |
|
|
|
|
|
total_errors += errors.sum().item() |
|
|
total_windows += errors.numel() |
|
|
|
|
|
if total_windows == 0: |
|
|
return torch.tensor(0.0, device=hyp.device) |
|
|
|
|
|
return torch.tensor( |
|
|
total_errors / total_windows, |
|
|
device=hyp.device, |
|
|
dtype=torch.float |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|