nameissakthi's picture
Remove pycache, add gitignore
c27df58
"""
Loss functions for SLM training.
Cross-entropy loss with optional label smoothing.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class LanguageModelingLoss(nn.Module):
"""Cross-entropy loss for language modeling.
Handles:
- Automatic shifting of labels
- Ignoring padding tokens (-100)
- Optional label smoothing
"""
def __init__(
self,
vocab_size: int,
label_smoothing: float = 0.0,
ignore_index: int = -100,
):
"""Initialize loss function.
Args:
vocab_size: Size of vocabulary
label_smoothing: Label smoothing factor (0.0 = no smoothing)
ignore_index: Index to ignore in loss calculation (padding)
"""
super().__init__()
self.vocab_size = vocab_size
self.label_smoothing = label_smoothing
self.ignore_index = ignore_index
self.ce_loss = nn.CrossEntropyLoss(
ignore_index=ignore_index,
label_smoothing=label_smoothing,
)
def forward(
self,
logits: torch.Tensor,
labels: torch.Tensor,
shift_labels: bool = True,
) -> torch.Tensor:
"""Compute loss.
Args:
logits: Model output logits [batch_size, seq_len, vocab_size]
labels: Target token IDs [batch_size, seq_len]
shift_labels: Whether to shift labels (for autoregressive LM)
Returns:
Scalar loss tensor
"""
if shift_labels:
# Shift so we predict next token
# logits: predict tokens 1..n
# labels: actual tokens 1..n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
else:
shift_logits = logits
shift_labels = labels
# Flatten for cross-entropy
# [batch * seq_len, vocab_size]
flat_logits = shift_logits.view(-1, self.vocab_size)
# [batch * seq_len]
flat_labels = shift_labels.view(-1)
loss = self.ce_loss(flat_logits, flat_labels)
return loss
def compute_perplexity(loss: torch.Tensor) -> torch.Tensor:
"""Compute perplexity from cross-entropy loss.
Args:
loss: Cross-entropy loss value
Returns:
Perplexity (exp of loss)
"""
return torch.exp(loss)
def compute_accuracy(
logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
) -> torch.Tensor:
"""Compute token prediction accuracy.
Args:
logits: Model output logits [batch_size, seq_len, vocab_size]
labels: Target token IDs [batch_size, seq_len]
ignore_index: Index to ignore in accuracy calculation
Returns:
Accuracy as a scalar tensor
"""
# Shift for autoregressive prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Get predictions
predictions = shift_logits.argmax(dim=-1)
# Mask for valid positions
mask = shift_labels != ignore_index
# Compute accuracy on valid positions
correct = (predictions == shift_labels) & mask
accuracy = correct.sum().float() / mask.sum().float()
return accuracy