| """ |
| Loss functions for SLM training. |
| |
| Cross-entropy loss with optional label smoothing. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
|
|
|
|
| class LanguageModelingLoss(nn.Module): |
| """Cross-entropy loss for language modeling. |
| |
| Handles: |
| - Automatic shifting of labels |
| - Ignoring padding tokens (-100) |
| - Optional label smoothing |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| label_smoothing: float = 0.0, |
| ignore_index: int = -100, |
| ): |
| """Initialize loss function. |
| |
| Args: |
| vocab_size: Size of vocabulary |
| label_smoothing: Label smoothing factor (0.0 = no smoothing) |
| ignore_index: Index to ignore in loss calculation (padding) |
| """ |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.label_smoothing = label_smoothing |
| self.ignore_index = ignore_index |
|
|
| self.ce_loss = nn.CrossEntropyLoss( |
| ignore_index=ignore_index, |
| label_smoothing=label_smoothing, |
| ) |
|
|
| def forward( |
| self, |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| shift_labels: bool = True, |
| ) -> torch.Tensor: |
| """Compute loss. |
| |
| Args: |
| logits: Model output logits [batch_size, seq_len, vocab_size] |
| labels: Target token IDs [batch_size, seq_len] |
| shift_labels: Whether to shift labels (for autoregressive LM) |
| |
| Returns: |
| Scalar loss tensor |
| """ |
| if shift_labels: |
| |
| |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| else: |
| shift_logits = logits |
| shift_labels = labels |
|
|
| |
| |
| flat_logits = shift_logits.view(-1, self.vocab_size) |
| |
| flat_labels = shift_labels.view(-1) |
|
|
| loss = self.ce_loss(flat_logits, flat_labels) |
|
|
| return loss |
|
|
|
|
| def compute_perplexity(loss: torch.Tensor) -> torch.Tensor: |
| """Compute perplexity from cross-entropy loss. |
| |
| Args: |
| loss: Cross-entropy loss value |
| |
| Returns: |
| Perplexity (exp of loss) |
| """ |
| return torch.exp(loss) |
|
|
|
|
| def compute_accuracy( |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| ignore_index: int = -100, |
| ) -> torch.Tensor: |
| """Compute token prediction accuracy. |
| |
| Args: |
| logits: Model output logits [batch_size, seq_len, vocab_size] |
| labels: Target token IDs [batch_size, seq_len] |
| ignore_index: Index to ignore in accuracy calculation |
| |
| Returns: |
| Accuracy as a scalar tensor |
| """ |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| |
| predictions = shift_logits.argmax(dim=-1) |
|
|
| |
| mask = shift_labels != ignore_index |
|
|
| |
| correct = (predictions == shift_labels) & mask |
| accuracy = correct.sum().float() / mask.sum().float() |
|
|
| return accuracy |
|
|