import torch import torch.nn.functional as F def fixed_cross_entropy( source, target, num_items_in_batch: int | None = None, ignore_index: int = -100, weight=None, **kwargs, ): reduction = "sum" if num_items_in_batch is not None else "mean" loss = F.cross_entropy( source, target, ignore_index=ignore_index, reduction=reduction, weight=weight, ) if reduction == "sum": loss = loss / num_items_in_batch return loss def WeightedCausalLMLoss( logits, labels, image_vocab_size: int, image_loss_weight: float = 1.0, image_token_ratio: float = 2.4, num_items_in_batch: int | None = None, ignore_index: int = -100, **kwargs, ): # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() labels = labels.to(logits.device) # Shift so that tokens < n predict n labels = F.pad(labels, (0, 1), value=ignore_index) shift_labels = labels[..., 1:].contiguous() # Compute loss weight if image_loss_weight != 1.0: weight = torch.ones(logits.size(-1), device=logits.device) weight[-image_vocab_size:] = image_loss_weight else: weight = None # Flatten the tokens logits = logits.view(-1, logits.size(-1)) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(logits.device) loss = fixed_cross_entropy( logits, shift_labels, num_items_in_batch, ignore_index, weight=weight, **kwargs, ) # Scale the loss if image_loss_weight != 1.0: denom = 1.0 + (image_token_ratio * image_loss_weight) scale = (1.0 + image_token_ratio) / denom loss = scale * loss return loss