|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def fixed_cross_entropy( |
|
|
source, |
|
|
target, |
|
|
num_items_in_batch: int | None = None, |
|
|
ignore_index: int = -100, |
|
|
weight=None, |
|
|
**kwargs, |
|
|
): |
|
|
reduction = "sum" if num_items_in_batch is not None else "mean" |
|
|
loss = F.cross_entropy( |
|
|
source, |
|
|
target, |
|
|
ignore_index=ignore_index, |
|
|
reduction=reduction, |
|
|
weight=weight, |
|
|
) |
|
|
if reduction == "sum": |
|
|
loss = loss / num_items_in_batch |
|
|
return loss |
|
|
|
|
|
|
|
|
def WeightedCausalLMLoss( |
|
|
logits, |
|
|
labels, |
|
|
image_vocab_size: int, |
|
|
image_loss_weight: float = 1.0, |
|
|
image_token_ratio: float = 2.4, |
|
|
num_items_in_batch: int | None = None, |
|
|
ignore_index: int = -100, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
logits = logits.float() |
|
|
labels = labels.to(logits.device) |
|
|
|
|
|
labels = F.pad(labels, (0, 1), value=ignore_index) |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
|
|
|
if image_loss_weight != 1.0: |
|
|
weight = torch.ones(logits.size(-1), device=logits.device) |
|
|
weight[-image_vocab_size:] = image_loss_weight |
|
|
else: |
|
|
weight = None |
|
|
|
|
|
|
|
|
logits = logits.view(-1, logits.size(-1)) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(logits.device) |
|
|
loss = fixed_cross_entropy( |
|
|
logits, |
|
|
shift_labels, |
|
|
num_items_in_batch, |
|
|
ignore_index, |
|
|
weight=weight, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
if image_loss_weight != 1.0: |
|
|
denom = 1.0 + (image_token_ratio * image_loss_weight) |
|
|
scale = (1.0 + image_token_ratio) / denom |
|
|
loss = scale * loss |
|
|
|
|
|
return loss |
|
|
|