Spaces:
Configuration error
Configuration error
File size: 2,795 Bytes
c29babb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.losses.unifalign import alignment, uniformity
from src.utils import logger
from .config import Loss as LossConfig
@dataclass
class LossInputs:
logits_labels: None | torch.Tensor = None
labels: None | torch.Tensor = None
l2_embeddings: None | torch.Tensor = None
@dataclass
class LossOutputs:
ce_labels: None | float = None
uniformity: None | float = None
alignment_labels: None | float = None
compactness: None | float = None
total: int | torch.Tensor = 0
class Loss(nn.Module):
def __init__(self, config: LossConfig):
super().__init__()
self.config = config
def forward(
self,
inputs: LossInputs,
) -> LossOutputs:
loss_outputs = LossOutputs()
config = self.config
if inputs.logits_labels is not None:
if config.ce_labels:
L = config.ce_labels * F.cross_entropy(
inputs.logits_labels, inputs.labels, label_smoothing=config.label_smoothing
)
loss_outputs.ce_labels = L.item()
loss_outputs.total += L
if inputs.l2_embeddings is not None:
# L2 normalize embeddings
# See 3.1 https://arxiv.org/pdf/2004.11362
# embeddings = F.normalize(inputs.embeddings, p=2, dim=1)
l2_embeddings = inputs.l2_embeddings
# check that embeddings are normalized
if not torch.allclose(
l2_embeddings.norm(p=2, dim=1),
torch.ones(l2_embeddings.size(0), device=l2_embeddings.device, dtype=l2_embeddings.dtype),
):
logger.print_warning_once("[yellow]Embeddings are not normalized")
if inputs.labels is not None:
if config.alignment_labels:
L = config.alignment_labels * alignment(l2_embeddings, inputs.labels)
loss_outputs.alignment_labels = L.item()
loss_outputs.total += L
if config.uniformity:
L = config.uniformity * uniformity(l2_embeddings)
loss_outputs.uniformity = L.item()
loss_outputs.total += L
if isinstance(loss_outputs.total, int):
logger.print_warning_once("[yellow]Total loss is 0. Check if loss coefficients are set correctly.")
if isinstance(loss_outputs.total, torch.Tensor) and loss_outputs.total.isnan():
logger.print_warning("[yellow]Total loss is nan")
loss_outputs.total = inputs.logits_labels.sum() * 0
return loss_outputs
def __call__(self, inputs: LossInputs) -> LossOutputs:
return super().__call__(inputs)
|