GenD-Sentinel / src /loss.py
yermandy's picture
init
c29babb
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.losses.unifalign import alignment, uniformity
from src.utils import logger
from .config import Loss as LossConfig
@dataclass
class LossInputs:
logits_labels: None | torch.Tensor = None
labels: None | torch.Tensor = None
l2_embeddings: None | torch.Tensor = None
@dataclass
class LossOutputs:
ce_labels: None | float = None
uniformity: None | float = None
alignment_labels: None | float = None
compactness: None | float = None
total: int | torch.Tensor = 0
class Loss(nn.Module):
def __init__(self, config: LossConfig):
super().__init__()
self.config = config
def forward(
self,
inputs: LossInputs,
) -> LossOutputs:
loss_outputs = LossOutputs()
config = self.config
if inputs.logits_labels is not None:
if config.ce_labels:
L = config.ce_labels * F.cross_entropy(
inputs.logits_labels, inputs.labels, label_smoothing=config.label_smoothing
)
loss_outputs.ce_labels = L.item()
loss_outputs.total += L
if inputs.l2_embeddings is not None:
# L2 normalize embeddings
# See 3.1 https://arxiv.org/pdf/2004.11362
# embeddings = F.normalize(inputs.embeddings, p=2, dim=1)
l2_embeddings = inputs.l2_embeddings
# check that embeddings are normalized
if not torch.allclose(
l2_embeddings.norm(p=2, dim=1),
torch.ones(l2_embeddings.size(0), device=l2_embeddings.device, dtype=l2_embeddings.dtype),
):
logger.print_warning_once("[yellow]Embeddings are not normalized")
if inputs.labels is not None:
if config.alignment_labels:
L = config.alignment_labels * alignment(l2_embeddings, inputs.labels)
loss_outputs.alignment_labels = L.item()
loss_outputs.total += L
if config.uniformity:
L = config.uniformity * uniformity(l2_embeddings)
loss_outputs.uniformity = L.item()
loss_outputs.total += L
if isinstance(loss_outputs.total, int):
logger.print_warning_once("[yellow]Total loss is 0. Check if loss coefficients are set correctly.")
if isinstance(loss_outputs.total, torch.Tensor) and loss_outputs.total.isnan():
logger.print_warning("[yellow]Total loss is nan")
loss_outputs.total = inputs.logits_labels.sum() * 0
return loss_outputs
def __call__(self, inputs: LossInputs) -> LossOutputs:
return super().__call__(inputs)