File size: 8,162 Bytes
43539ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# Copied veratim from vortex
import torch
import logging

log = logging.getLogger(__name__)


def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
    """Get the dim for the local rank derived from splitting dim on world_size processes.

    The split may not be even across the world_size processes.
    """
    multiple = dim // multiple_of
    div = multiple // world_size
    mod = multiple % world_size
    local_multiple = div + int(local_rank < mod)
    return local_multiple * multiple_of


def grab_first_if_tuple(x):
    if x.__class__.__name__ == "tuple":
        return x[0]
    else:
        return x


def interleave(z_pre):
    if len(z_pre.shape) == 3:  # non-cached
        x1 = z_pre[:, 0::3, :]
        x2 = z_pre[:, 1::3, :]
        v = z_pre[:, 2::3, :]
        z_pre = torch.cat([x1, x2, v], dim=1)
        return z_pre
    else:
        x1 = z_pre[..., 0::3]
        x2 = z_pre[..., 1::3]
        v = z_pre[..., 2::3]
        z_pre = torch.concat([x1, x2, v], dim=-1)
        return z_pre


def column_split(x, num_heads, head_size):
    """Split a tensor with `num_heads` alongside the head dimension, instead of
    across heads. Fixed to three projections
    """
    # FIXME: merge cases
    if len(x.shape) == 2:
        x_reshaped = x.reshape(
            x.shape[0],
            num_heads,
            3 * head_size,
        )

        x2, x1, v = (
            x_reshaped[..., :head_size],
            x_reshaped[..., head_size : 2 * head_size],
            x_reshaped[..., 2 * head_size :],
        )
        x2, x1, v = (
            x2.reshape(x2.shape[0], -1),
            x1.reshape(x1.shape[0], -1),
            v.reshape(v.shape[0], -1),
        )
        return x2, x1, v
    else:
        x = x.reshape(
            x.shape[0],
            num_heads,
            3 * head_size,
            x.shape[2],
        )
        x2, x1, v = (
            x[:, :, :head_size],
            x[
                :,
                :,
                head_size : 2 * head_size,
            ],
            x[:, :, 2 * head_size :],
        )
        x2, x1, v = (
            x2.reshape(x2.shape[0], -1, x2.shape[-1]),
            x1.reshape(x1.shape[0], -1, x1.shape[-1]),
            v.reshape(v.shape[0], -1, v.shape[-1]),
        )
        return x2, x1, v


def load_checkpoint(model, checkpoint_path):
    if checkpoint_path is None:
        log.warning("Using random weights (dry-run)")
        return
    log.info(f"Loading {checkpoint_path}")

    # We must allowlist BytesIO, as fp8-enabled checkpoints store this type
    # in Transformer Engine layers' _extra keys. If not, weights_only=True
    # will not be happy.
    import io

    torch.serialization.add_safe_globals([io.BytesIO])

    with torch.inference_mode():
        state = torch.load(
            checkpoint_path,
            # Make sure we override device location that is specified in the
            # checkpoint dictionary (e.g. checkpoints may have "cuda:0"
            # as a location for all layers, which then wouldn't work for
            # multi-GPU case.)
            map_location="cpu",
            # This is an optimization: with that, we don't actually read
            # whole checkpoints dictionary from disk to CPU memory in one
            # go; instead, pytorch would only load relevant layers to CPU
            # memory when we are about to copy them to GPU.
            mmap=True,
            # Make sure PyTorch is not issuing a warning regarding potential
            # security issues.
            weights_only=True,
        )
        model.to_bfloat16_except_pr_lc(to_float32=True)

        model.custom_load_state_dict(state)

        model.to_bfloat16_except_pr_lc()


def move_to_device(module, device):
    """Recursively moves all parameters and buffers to the specified device."""
    for child in module.children():
        move_to_device(child, device)

    for param in module.parameters(recurse=False):
        if param.device != device:
            param.data = param.data.to(device)

    for buf in module.buffers(recurse=False):
        if buf.device != device:
            buf.data = buf.data.to(device)

    module.to(device)


def fixup_fp8_extra_states(module):
    """Recursively fixes device location of TE's Linear fp8 extra states."""
    for child in module.children():
        fixup_fp8_extra_states(child)

    # TE Linear uses default "cuda" device to load extra state, which causes
    # trouble when the layer is moved to another GPU. Instead, this is how
    # TE Linear should load extra_state: using parameters' device.
    torch_load = torch.load

    def overriden_load(state, map_location):
        device = next(module.parameters()).device
        return torch_load(state, map_location=device)

    if hasattr(module, "fp8_meta"):
        log.debug(f"Reloading fp8 extra state to a proper device for {module}")
        from unittest.mock import patch

        with patch("torch.load", new=overriden_load):
            module.set_extra_state(module.get_extra_state())


def fixup_te_workspace():
    """TE uses single workspace tensor for all calls, disregarding that inputs
    may be on separate GPUs. This patches TE's Linear module to use per-device
    workspaces."""
    from functools import lru_cache

    @lru_cache
    def te_cublas_get_workspace_per_device(device):
        log.info(f"Fixup applied: Allocating cublas workspace for {device=}")
        import transformer_engine.pytorch.module.base as tebase

        with torch.cuda.device(device):
            tebase._cublas_workspace = None  # Force get_workspace() to reallocate tensor
            return tebase.get_workspace()

    def get_workspace():
        return te_cublas_get_workspace_per_device(torch.cuda.current_device())

    import transformer_engine.pytorch.module.linear as telinear

    telinear.get_workspace = get_workspace


def get_init_from_string(init_str):
    if type(init_str) == str:
        if init_str == "torch.nn.init.zeros_":
            return torch.nn.init.zeros_
        elif init_str == "torch.nn.init.xavier_uniform_":
            return torch.nn.init.xavier_uniform_
        elif init_str == "torch.nn.init.xavier_normal_":
            return torch.nn.init.xavier_normal_
        else:
            raise ValueError(f"Unrecognized init {init_str}")


def print_rank_0(message, debug=False, end="\n"):
    """Print from rank 0 only."""
    if torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            print(message, flush=True, end=end)
    else:
        print(message, flush=True, end=end)


class dotdict(dict):
    """dot.notation access to dictionary attributes"""

    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


def ensure_divisibility(numerator, denominator):
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)


def divide(numerator, denominator):
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator


class Lambda(torch.nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


class VocabUtility:
    """Split the vocabulary into `world_size` chunks amd return the
    first and last index of the vocabulary belonging to the `rank`
    partition: Note that indices in [first, last]"""

    @staticmethod
    def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
        index_f = rank * per_partition_vocab_size
        index_l = index_f + per_partition_vocab_size
        return index_f, index_l

    @staticmethod
    def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
        per_partition_vocab_size = divide(global_vocab_size, world_size)
        return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)