File size: 1,835 Bytes
327fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn.functional as F


def fixed_cross_entropy(
    source,
    target,
    num_items_in_batch: int | None = None,
    ignore_index: int = -100,
    weight=None,
    **kwargs,
):
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = F.cross_entropy(
        source,
        target,
        ignore_index=ignore_index,
        reduction=reduction,
        weight=weight,
    )
    if reduction == "sum":
        loss = loss / num_items_in_batch
    return loss


def WeightedCausalLMLoss(
    logits,
    labels,
    image_vocab_size: int,
    image_loss_weight: float = 1.0,
    image_token_ratio: float = 2.4,
    num_items_in_batch: int | None = None,
    ignore_index: int = -100,
    **kwargs,
):
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()
    labels = labels.to(logits.device)
    # Shift so that tokens < n predict n
    labels = F.pad(labels, (0, 1), value=ignore_index)
    shift_labels = labels[..., 1:].contiguous()

    # Compute loss weight
    if image_loss_weight != 1.0:
        weight = torch.ones(logits.size(-1), device=logits.device)
        weight[-image_vocab_size:] = image_loss_weight
    else:
        weight = None

    # Flatten the tokens
    logits = logits.view(-1, logits.size(-1))
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(logits.device)
    loss = fixed_cross_entropy(
        logits,
        shift_labels,
        num_items_in_batch,
        ignore_index,
        weight=weight,
        **kwargs,
    )

    # Scale the loss
    if image_loss_weight != 1.0:
        denom = 1.0 + (image_token_ratio * image_loss_weight)
        scale = (1.0 + image_token_ratio) / denom
        loss = scale * loss

    return loss