| """ |
| AAM Diffusion LLM — Loss Functions |
| |
| Implements various loss functions for training the diffusion model, |
| including MSE, MAE, Huber, and weighted variants. |
| |
| Analogi: Seperti Jin Soun mengukur seberapa jauh prediksinya |
| dari kenyataan — semakin besar gap, semakin besar "rasa sakit" |
| yang mendorong perbaikan. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from diffusion_llm.config.model_config import DiffusionConfig |
|
|
|
|
| class DiffusionLoss(nn.Module): |
| """Loss function for diffusion model training. |
| |
| Computes the loss between predicted and target values, |
| with optional weighting strategies to balance training |
| across different noise levels. |
| |
| Args: |
| config: DiffusionConfig with loss hyperparameters. |
| """ |
|
|
| def __init__(self, config: DiffusionConfig): |
| super().__init__() |
| self.config = config |
|
|
| def forward( |
| self, |
| predicted: torch.Tensor, |
| target: torch.Tensor, |
| timestep: torch.Tensor, |
| alphas_cumprod: torch.Tensor, |
| ) -> torch.Tensor: |
| """Compute diffusion loss. |
| |
| Args: |
| predicted: Model output (predicted noise/x0/v). |
| target: Target values. |
| timestep: Timestep indices for weighting. |
| alphas_cumprod: Cumulative product of alphas from scheduler. |
| |
| Returns: |
| Scalar loss value. |
| """ |
| |
| if self.config.loss_type == "mse": |
| loss = F.mse_loss(predicted, target, reduction="none") |
| elif self.config.loss_type == "mae": |
| loss = F.l1_loss(predicted, target, reduction="none") |
| elif self.config.loss_type == "huber": |
| loss = F.smooth_l1_loss(predicted, target, reduction="none") |
| else: |
| raise ValueError(f"Unknown loss_type: {self.config.loss_type}") |
|
|
| |
| loss = loss.mean(dim=-1) |
|
|
| |
| if self.config.loss_weighting == "min_snr": |
| loss = self._min_snr_weight(loss, timestep, alphas_cumprod) |
| elif self.config.loss_weighting == "p2": |
| loss = self._p2_weight(loss, timestep, alphas_cumprod) |
|
|
| return loss.mean() |
|
|
| def _min_snr_weight( |
| self, |
| loss: torch.Tensor, |
| timestep: torch.Tensor, |
| alphas_cumprod: torch.Tensor, |
| gamma: float = 5.0, |
| ) -> torch.Tensor: |
| """Min-SNR-gamma weighting (Hang et al., 2023).""" |
| snr = alphas_cumprod[timestep] / (1 - alphas_cumprod[timestep] + 1e-8) |
| weight = torch.clamp(snr, max=gamma) / (snr + 1e-8) |
| weight = weight.unsqueeze(-1).expand_as(loss) |
| return loss * weight |
|
|
| def _p2_weight( |
| self, |
| loss: torch.Tensor, |
| timestep: torch.Tensor, |
| alphas_cumprod: torch.Tensor, |
| ) -> torch.Tensor: |
| """P2 weighting (Choi et al., 2022).""" |
| snr = alphas_cumprod[timestep] / (1 - alphas_cumprod[timestep] + 1e-8) |
| weight = 1.0 / (snr ** self.config.p2_gamma + self.config.p2_k) |
| weight = weight.unsqueeze(-1).expand_as(loss) |
| return loss * weight |
|
|
|
|
| def compute_loss( |
| predicted: torch.Tensor, |
| target: torch.Tensor, |
| timestep: torch.Tensor, |
| alphas_cumprod: torch.Tensor, |
| loss_type: str = "mse", |
| loss_weighting: str = "none", |
| ) -> torch.Tensor: |
| """Convenience function to compute diffusion loss without creating a module. |
| |
| Args: |
| predicted: Model output. |
| target: Target values. |
| timestep: Timestep indices. |
| alphas_cumprod: Alpha cumulative products. |
| loss_type: Loss function type. |
| loss_weighting: Weighting strategy. |
| |
| Returns: |
| Scalar loss value. |
| """ |
| config = DiffusionConfig( |
| loss_type=loss_type, |
| loss_weighting=loss_weighting, |
| ) |
| loss_fn = DiffusionLoss(config) |
| return loss_fn(predicted, target, timestep, alphas_cumprod) |
|
|