File size: 5,765 Bytes
fb11af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Tuple

import torch
from torch import Tensor
from torch.distributed import ProcessGroup

from .comm import (
    get_ulysses_sequence_parallel_group,
    get_ulysses_sequence_parallel_rank,
    get_ulysses_sequence_parallel_world_size,
)


def unpadding_tensor_for_seqeunce_parallel(x: Tensor, dim: int, unpadded_dim_size: int, group: ProcessGroup = None):
    """
    A func to remove the padding part of the tensor based on its original shape
    """
    group = get_ulysses_sequence_parallel_group() if group is None else group
    if not group:
        return x
    sp_world = get_ulysses_sequence_parallel_world_size(group)
    if unpadded_dim_size % sp_world == 0:
        return x
    padding_size = sp_world - (unpadded_dim_size % sp_world)
    assert (padding_size + unpadded_dim_size) % sp_world == 0
    return unpad_tensor(x, dim=dim, padding_size=padding_size)


def padding_tensor_for_seqeunce_parallel(x: Tensor, dim: int, group: ProcessGroup = None) -> Tensor:
    """
    A func to remove the padding part of the tensor based on its original shape
    """
    group = get_ulysses_sequence_parallel_group() if group is None else group
    if not group:
        return x
    sp_world = get_ulysses_sequence_parallel_world_size(group)
    dim_size = x.shape[dim]
    if dim_size % sp_world:
        padding_size = sp_world - (dim_size % sp_world)
        x = pad_tensor(x, dim, padding_size)
    return x


def pad_tensor(x: Tensor, dim: int, padding_size: int, padding_value: int = 0) -> Tensor:
    shape = list(x.shape)
    shape[dim] = padding_size
    pad = torch.full(shape, padding_value, dtype=x.dtype, device=x.device)
    return torch.cat([x, pad], dim=dim)


def unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
    slc = [slice(None)] * len(x.shape)
    slc[dim] = slice(0, -padding_size)
    return x[slc]


def remove_last_rank_padding(x: Tensor, dim: int, unpad_dim_size: int, group: ProcessGroup = None) -> Tensor:
    group = get_ulysses_sequence_parallel_group() if group is None else group
    if not group:
        return x
    sp_rank = get_ulysses_sequence_parallel_rank(group)
    sp_world = get_ulysses_sequence_parallel_world_size(group)
    if unpad_dim_size % sp_world == 0 and sp_rank + 1 != sp_world:
        return x
    pad = sp_world - (unpad_dim_size % sp_world)
    assert (pad + x.shape[dim]) % sp_world == 0
    slc = [slice(None)] * len(x.shape)
    slc[dim] = slice(0, -pad)
    return x[slc]


def has_overlap(x1, x2, y1, y2) -> Tuple[bool, int]:
    """
    A func to judge if two intervals have overlaps, and return the length of overlaps
    """
    max_value = max(x1, y1)
    min_value = min(x2, y2)
    return max_value < min_value, min_value - max_value


def all2all_splits(image_lens: List, image_lens_per_rank: List, sp_size: int, sp_rank: int) -> Tuple[List, List]:
    """
    A func to generate splits for all2all communication
    """
    assert sum(image_lens) == sum(image_lens_per_rank)
    num_images = len(image_lens)
    sp_step = (num_images + sp_size - 1) // sp_size
    in_splits, out_splits = [0 for _ in range(sp_size)], [0 for _ in range(sp_size)]
    cu_seqlens = [0] + [sum(image_lens_per_rank[: i + 1]) for i in range(sp_size)]
    rank = 0
    num_tokens = 0
    for image_idx, image_lens in enumerate(image_lens):
        src_rank = image_idx // sp_step
        tokens_split = []
        for rank in range(sp_size):
            overlap, overlap_len = has_overlap(
                num_tokens, num_tokens + image_lens, cu_seqlens[rank], cu_seqlens[rank + 1]
            )
            if overlap:
                tokens_split.append(overlap_len)
                if rank == sp_rank:
                    out_splits[src_rank] += overlap_len
                if src_rank == sp_rank:
                    in_splits[rank] += overlap_len
        assert sum(tokens_split) == image_lens

        num_tokens += image_lens

    return in_splits, out_splits


def vlm_images_a2a_meta(
    sp_rank: int, sp_size: int, image_lens: List, image_masks: torch.Tensor
) -> Tuple[List, List, torch.Tensor]:
    """
    A func to generate metadata for all2all communication after we balance the computaion in vision encoder
    Usually we will split the batches of images for vision encoder in sp group. However, before we feed images
    tokens into language model, we need to use all2all communication to gather necessary tokens into the current rank.
    """
    assert sum(image_lens) == image_masks.sum().item(), (
        f"The sum of image_lens must be equal to the number of tokens, {image_lens} vs {image_masks.sum().item()}"
    )
    seq_len = image_masks.shape[1]
    step = (seq_len + sp_size - 1) // sp_size
    sequence_per_rank = [min(step * (i + 1), seq_len) - min(step * i, seq_len) for i in range(sp_size)]
    mask_per_rank = image_masks.split(sequence_per_rank, dim=1)
    image_lens_per_rank = [mask_per_rank[i].sum().item() for i in range(sp_size)]
    in_splits, out_splits = all2all_splits(image_lens, image_lens_per_rank, sp_size, sp_rank)
    local_image_masks = mask_per_rank[sp_rank]
    return in_splits, out_splits, local_image_masks