Spaces:
Sleeping
Sleeping
| """ | |
| Focal loss with optional label smoothing for LaM-SLidE autoencoder training. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| def compute_focal_loss( | |
| logits: torch.Tensor, | |
| targets: torch.Tensor, | |
| mask: torch.Tensor, | |
| gamma: float = 2.0, | |
| label_smoothing: float = 0.1, | |
| ) -> torch.Tensor: | |
| """ | |
| Compute focal loss with label smoothing. | |
| Focal Loss: FL(p_t) = (1 - p_t)^gamma * CE(smooth_targets) | |
| Args: | |
| logits: (B, N, C) unnormalized logits | |
| targets: (B, N) class indices | |
| mask: (B, N) validity mask | |
| gamma: Focal focusing parameter | |
| label_smoothing: Label smoothing factor | |
| Returns: | |
| Scalar loss | |
| """ | |
| B, N, C = logits.shape | |
| logits_flat = logits.view(-1, C) # (B*N, C) | |
| targets_flat = targets.view(-1) # (B*N,) | |
| mask_flat = mask.view(-1).float() # (B*N,) | |
| # Compute log probabilities | |
| log_probs = F.log_softmax(logits_flat, dim=-1) # (B*N, C) | |
| # Get log_prob and prob at target class using gather (avoids one-hot) | |
| log_p_t = log_probs.gather(dim=-1, index=targets_flat.unsqueeze(-1)).squeeze(-1) # (B*N,) | |
| p_t = log_p_t.exp() # (B*N,) | |
| # Cross-entropy with label smoothing (without one-hot): | |
| # smooth_target = (1-eps) at target, eps/C elsewhere | |
| # CE = -sum(smooth * log_probs) = -(1-eps)*log_p_t - (eps/C)*sum(log_probs) | |
| # = -(1-eps)*log_p_t - eps*mean(log_probs) | |
| if label_smoothing > 0: | |
| mean_log_probs = log_probs.mean(dim=-1) # (B*N,) | |
| ce_loss = -(1 - label_smoothing) * log_p_t - label_smoothing * mean_log_probs | |
| else: | |
| ce_loss = -log_p_t | |
| # Focal modulation: (1 - p_t)^gamma | |
| focal_weight = (1 - p_t) ** gamma | |
| # Apply focal weight and mask | |
| loss = focal_weight * ce_loss | |
| loss = (loss * mask_flat).sum() / mask_flat.sum() | |
| return loss | |