File size: 4,972 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication

from typing import List, Tuple, Union

import torch
import torch.distributed as dist

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_current_device

TensorShape = Union[torch.Size, List[int], Tuple[int]]


def send_meta_helper(obj, next_rank, tensor_kwargs):
    send_shape = torch.tensor(obj.size(), **tensor_kwargs)
    send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
    dist.send(send_ndims, next_rank)
    dist.send(send_shape, next_rank)


def send_obj_meta(obj, next_rank=None):
    """Sends obj meta information before sending a specific obj.
    Since the recipient must know the shape of the obj in p2p communications,
    meta information of the obj should be sent before communications. This function
    synchronizes with :func:`recv_obj_meta`.

    Args:
        obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
        need_meta (bool, optional): If False, meta information won't be sent.
        next_rank (int): The rank of the next member in pipeline parallel group.

    Returns:
        bool: False
    """
    if next_rank is None:
        next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)

    tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
    if isinstance(obj, torch.Tensor):
        send_obj_nums = torch.tensor(1, **tensor_kwargs)
        dist.send(send_obj_nums, next_rank)
        send_meta_helper(obj, next_rank, tensor_kwargs)
    else:
        send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
        dist.send(send_obj_nums, next_rank)
        for tensor_to_send in obj:
            send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)


def recv_meta_helper(prev_rank, tensor_kwargs):
    recv_ndims = torch.empty((), **tensor_kwargs)
    dist.recv(recv_ndims, prev_rank)
    recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
    dist.recv(recv_shape, prev_rank)
    return recv_shape


def recv_obj_meta(prev_rank=None) -> torch.Size:
    """Receives obj meta information before receiving a specific obj.
    Since the recipient must know the shape of the obj in p2p communications,
    meta information of the obj should be received before communications. This function
    synchronizes with :func:`send_obj_meta`.

    Args:
        obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
        prev_rank (int): The rank of the source of the obj.

    Returns:
        Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
    """
    if prev_rank is None:
        prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)

    tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
    recv_obj_nums = torch.empty((), **tensor_kwargs)
    dist.recv(recv_obj_nums, prev_rank)
    if recv_obj_nums.item() == 1:
        recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
        obj_shape = torch.Size(recv_shape)
    else:
        obj_shape = []
        for _ in range(recv_obj_nums.item()):
            recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
            obj_shape.append(torch.Size(recv_shape))

    return obj_shape


def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
    """Break a tensor into equal 1D chunks.

    Args:
        tensor (:class:`torch.Tensor`): Tensor to be split before communication.
        new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.

    Returns:
        :class:`torch.Tensor`: The split tensor
    """
    partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR)
    start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR)
    end_index = start_index + partition_size
    if new_buffer:
        data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
        data.copy_(tensor.view(-1)[start_index:end_index])
    else:
        data = tensor.view(-1)[start_index:end_index]
    return data


def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
    """Opposite of above function, gather values from model parallel ranks.

    Args:
        tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
    Returns:
        :class:`torch.Tensor`: The gathered tensor.
    """
    world_size = gpc.get_world_size(ParallelMode.TENSOR)
    numel = torch.numel(tensor)
    numel_gathered = world_size * numel
    gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
    chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
    dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR))
    return gathered