File size: 5,124 Bytes
fb11af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup

from ...data.constants import IGNORE_INDEX
from .comm import get_ulysses_sequence_parallel_group, get_unified_sequence_parallel_group
from .ulysses import _Gather, _Slice
from .utils import pad_tensor, unpadding_tensor_for_seqeunce_parallel


def slice_input_tensor(
    x: Tensor,
    dim: int,
    padding: bool = True,
    padding_value: int = 0,
    group: ProcessGroup = None,
) -> Tensor:
    """
    A func to slice the input sequence in sequence parallel
    """
    group = get_unified_sequence_parallel_group() if group is None else group
    if not group:
        return x
    sp_rank = dist.get_rank(group)
    sp_world = dist.get_world_size(group)
    dim_size = x.shape[dim]
    unit = (dim_size + sp_world - 1) // sp_world
    if padding and dim_size % sp_world:
        padding_size = sp_world - (dim_size % sp_world)
        x = pad_tensor(x, dim, padding_size, padding_value)
    slc = [slice(None)] * len(x.shape)
    slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1))
    return x[slc].contiguous()


def slice_input_tensor_scale_grad(
    x: Tensor,
    dim: int,
    group: ProcessGroup = None,
    scale_grad=True,
):
    """
    A func to gather the outputs for the model result in sequence parallel
    """
    group = get_ulysses_sequence_parallel_group() if group is None else group
    if not group:
        return x
    x = _Slice.apply(group, x, dim, scale_grad)
    return x


def gather_outputs(
    x: Tensor,
    gather_dim: int,
    padding_dim: Optional[int] = None,
    unpad_dim_size: Optional[int] = None,
    scale_grad=True,
    group: ProcessGroup = None,
):
    """
    A func to gather the outputs for the model result in sequence parallel
    """
    group = get_unified_sequence_parallel_group() if group is None else group
    if not group:
        return x
    x = _Gather.apply(group, x, gather_dim, scale_grad)
    if padding_dim is not None:
        x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size, group)
    return x


def slice_position_embedding(position_embeddings: tuple, dim: int = 1, sp_group: dist.ProcessGroup = None):
    """
    Forward hook for LlamaRotaryEmbedding to apply Ulysses tensor slicing.

    Args:
        position_embeddings: Input tensors to the forward method
        dim: The dimension to slice
        sp_group: The sequence parallel group
    Returns:
        Modified (cos, sin) tuple with slicing applied if ulysses is enabled
    """
    if sp_group is not None:
        cos, sin = position_embeddings
        cos = slice_input_tensor(cos, dim=dim, padding=False, group=sp_group)
        sin = slice_input_tensor(sin, dim=dim, padding=False, group=sp_group)
        return (cos, sin)
    return position_embeddings


def sequence_parallel_preprocess(
    input_ids: torch.Tensor,
    labels: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    cu_seqlens: Optional[torch.Tensor] = None,
    sp_group: Optional[ProcessGroup] = None,
):
    """
    Preprocess input_ids and labels for sequence parallel training.

    Args:
        input_ids: Input token ids
        labels: Label token ids
        position_ids: Position ids
        attention_mask: Attention mask
        cu_seqlens: Cumulative sequence lengths

    Returns:
        Preprocessed input_ids, labels, position_ids, attention_mask, cu_seqlens
    """
    if sp_group is not None:
        sp_size = dist.get_world_size(sp_group)
        padding_size = (sp_size - (input_ids.shape[-1] % sp_size)) % sp_size

        # Slice input_ids among sequence parallel group
        input_ids = slice_input_tensor(input_ids, dim=-1, padding=True, padding_value=0, group=sp_group)

        # Slice labels among sequence parallel group
        if labels is not None:
            labels = labels[..., 1:].contiguous()  # shift labels
            labels = F.pad(labels, (0, 1), "constant", IGNORE_INDEX)  # pad to the same length as input_ids
            labels = slice_input_tensor(labels, dim=-1, padding=True, padding_value=IGNORE_INDEX, group=sp_group)

        # Padding position_ids
        if position_ids is not None:
            position_ids = pad_tensor(position_ids, dim=-1, padding_size=padding_size, padding_value=0)

        # Padding attention_mask
        if attention_mask is not None:
            attn_mask_padding_value = 1 if position_ids is not None else 0
            attention_mask = pad_tensor(
                attention_mask, dim=-1, padding_size=padding_size, padding_value=attn_mask_padding_value
            )

        # Padding cu_seqlens
        if cu_seqlens is not None:
            cu_seqlens_padding_value = cu_seqlens[-1].item() + padding_size
            cu_seqlens = pad_tensor(
                cu_seqlens, dim=-1, padding_size=padding_size, padding_value=cu_seqlens_padding_value
            )

    return input_ids, labels, position_ids, attention_mask, cu_seqlens