""" Focal loss with optional label smoothing for LaM-SLidE autoencoder training. """ import torch import torch.nn.functional as F def compute_focal_loss( logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor, gamma: float = 2.0, label_smoothing: float = 0.1, ) -> torch.Tensor: """ Compute focal loss with label smoothing. Focal Loss: FL(p_t) = (1 - p_t)^gamma * CE(smooth_targets) Args: logits: (B, N, C) unnormalized logits targets: (B, N) class indices mask: (B, N) validity mask gamma: Focal focusing parameter label_smoothing: Label smoothing factor Returns: Scalar loss """ B, N, C = logits.shape logits_flat = logits.view(-1, C) # (B*N, C) targets_flat = targets.view(-1) # (B*N,) mask_flat = mask.view(-1).float() # (B*N,) # Compute log probabilities log_probs = F.log_softmax(logits_flat, dim=-1) # (B*N, C) # Get log_prob and prob at target class using gather (avoids one-hot) log_p_t = log_probs.gather(dim=-1, index=targets_flat.unsqueeze(-1)).squeeze(-1) # (B*N,) p_t = log_p_t.exp() # (B*N,) # Cross-entropy with label smoothing (without one-hot): # smooth_target = (1-eps) at target, eps/C elsewhere # CE = -sum(smooth * log_probs) = -(1-eps)*log_p_t - (eps/C)*sum(log_probs) # = -(1-eps)*log_p_t - eps*mean(log_probs) if label_smoothing > 0: mean_log_probs = log_probs.mean(dim=-1) # (B*N,) ce_loss = -(1 - label_smoothing) * log_p_t - label_smoothing * mean_log_probs else: ce_loss = -log_p_t # Focal modulation: (1 - p_t)^gamma focal_weight = (1 - p_t) ** gamma # Apply focal weight and mask loss = focal_weight * ce_loss loss = (loss * mask_flat).sum() / mask_flat.sum() return loss