score-ae / src /utils /loss.py
hroth's picture
Upload 90 files
b57c46e verified
raw
history blame
1.88 kB
"""
Focal loss with optional label smoothing for LaM-SLidE autoencoder training.
"""
import torch
import torch.nn.functional as F
def compute_focal_loss(
logits: torch.Tensor,
targets: torch.Tensor,
mask: torch.Tensor,
gamma: float = 2.0,
label_smoothing: float = 0.1,
) -> torch.Tensor:
"""
Compute focal loss with label smoothing.
Focal Loss: FL(p_t) = (1 - p_t)^gamma * CE(smooth_targets)
Args:
logits: (B, N, C) unnormalized logits
targets: (B, N) class indices
mask: (B, N) validity mask
gamma: Focal focusing parameter
label_smoothing: Label smoothing factor
Returns:
Scalar loss
"""
B, N, C = logits.shape
logits_flat = logits.view(-1, C) # (B*N, C)
targets_flat = targets.view(-1) # (B*N,)
mask_flat = mask.view(-1).float() # (B*N,)
# Compute log probabilities
log_probs = F.log_softmax(logits_flat, dim=-1) # (B*N, C)
# Get log_prob and prob at target class using gather (avoids one-hot)
log_p_t = log_probs.gather(dim=-1, index=targets_flat.unsqueeze(-1)).squeeze(-1) # (B*N,)
p_t = log_p_t.exp() # (B*N,)
# Cross-entropy with label smoothing (without one-hot):
# smooth_target = (1-eps) at target, eps/C elsewhere
# CE = -sum(smooth * log_probs) = -(1-eps)*log_p_t - (eps/C)*sum(log_probs)
# = -(1-eps)*log_p_t - eps*mean(log_probs)
if label_smoothing > 0:
mean_log_probs = log_probs.mean(dim=-1) # (B*N,)
ce_loss = -(1 - label_smoothing) * log_p_t - label_smoothing * mean_log_probs
else:
ce_loss = -log_p_t
# Focal modulation: (1 - p_t)^gamma
focal_weight = (1 - p_t) ** gamma
# Apply focal weight and mask
loss = focal_weight * ce_loss
loss = (loss * mask_flat).sum() / mask_flat.sum()
return loss