markupdm / loss_utils.py
ktrk115's picture
Upload MarkupDMForCausalLM
327fa8d verified
import torch
import torch.nn.functional as F
def fixed_cross_entropy(
source,
target,
num_items_in_batch: int | None = None,
ignore_index: int = -100,
weight=None,
**kwargs,
):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = F.cross_entropy(
source,
target,
ignore_index=ignore_index,
reduction=reduction,
weight=weight,
)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
def WeightedCausalLMLoss(
logits,
labels,
image_vocab_size: int,
image_loss_weight: float = 1.0,
image_token_ratio: float = 2.4,
num_items_in_batch: int | None = None,
ignore_index: int = -100,
**kwargs,
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Compute loss weight
if image_loss_weight != 1.0:
weight = torch.ones(logits.size(-1), device=logits.device)
weight[-image_vocab_size:] = image_loss_weight
else:
weight = None
# Flatten the tokens
logits = logits.view(-1, logits.size(-1))
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
loss = fixed_cross_entropy(
logits,
shift_labels,
num_items_in_batch,
ignore_index,
weight=weight,
**kwargs,
)
# Scale the loss
if image_loss_weight != 1.0:
denom = 1.0 + (image_token_ratio * image_loss_weight)
scale = (1.0 + image_token_ratio) / denom
loss = scale * loss
return loss