File size: 8,162 Bytes
43539ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
# Copied veratim from vortex
import torch
import logging
log = logging.getLogger(__name__)
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
"""Get the dim for the local rank derived from splitting dim on world_size processes.
The split may not be even across the world_size processes.
"""
multiple = dim // multiple_of
div = multiple // world_size
mod = multiple % world_size
local_multiple = div + int(local_rank < mod)
return local_multiple * multiple_of
def grab_first_if_tuple(x):
if x.__class__.__name__ == "tuple":
return x[0]
else:
return x
def interleave(z_pre):
if len(z_pre.shape) == 3: # non-cached
x1 = z_pre[:, 0::3, :]
x2 = z_pre[:, 1::3, :]
v = z_pre[:, 2::3, :]
z_pre = torch.cat([x1, x2, v], dim=1)
return z_pre
else:
x1 = z_pre[..., 0::3]
x2 = z_pre[..., 1::3]
v = z_pre[..., 2::3]
z_pre = torch.concat([x1, x2, v], dim=-1)
return z_pre
def column_split(x, num_heads, head_size):
"""Split a tensor with `num_heads` alongside the head dimension, instead of
across heads. Fixed to three projections
"""
# FIXME: merge cases
if len(x.shape) == 2:
x_reshaped = x.reshape(
x.shape[0],
num_heads,
3 * head_size,
)
x2, x1, v = (
x_reshaped[..., :head_size],
x_reshaped[..., head_size : 2 * head_size],
x_reshaped[..., 2 * head_size :],
)
x2, x1, v = (
x2.reshape(x2.shape[0], -1),
x1.reshape(x1.shape[0], -1),
v.reshape(v.shape[0], -1),
)
return x2, x1, v
else:
x = x.reshape(
x.shape[0],
num_heads,
3 * head_size,
x.shape[2],
)
x2, x1, v = (
x[:, :, :head_size],
x[
:,
:,
head_size : 2 * head_size,
],
x[:, :, 2 * head_size :],
)
x2, x1, v = (
x2.reshape(x2.shape[0], -1, x2.shape[-1]),
x1.reshape(x1.shape[0], -1, x1.shape[-1]),
v.reshape(v.shape[0], -1, v.shape[-1]),
)
return x2, x1, v
def load_checkpoint(model, checkpoint_path):
if checkpoint_path is None:
log.warning("Using random weights (dry-run)")
return
log.info(f"Loading {checkpoint_path}")
# We must allowlist BytesIO, as fp8-enabled checkpoints store this type
# in Transformer Engine layers' _extra keys. If not, weights_only=True
# will not be happy.
import io
torch.serialization.add_safe_globals([io.BytesIO])
with torch.inference_mode():
state = torch.load(
checkpoint_path,
# Make sure we override device location that is specified in the
# checkpoint dictionary (e.g. checkpoints may have "cuda:0"
# as a location for all layers, which then wouldn't work for
# multi-GPU case.)
map_location="cpu",
# This is an optimization: with that, we don't actually read
# whole checkpoints dictionary from disk to CPU memory in one
# go; instead, pytorch would only load relevant layers to CPU
# memory when we are about to copy them to GPU.
mmap=True,
# Make sure PyTorch is not issuing a warning regarding potential
# security issues.
weights_only=True,
)
model.to_bfloat16_except_pr_lc(to_float32=True)
model.custom_load_state_dict(state)
model.to_bfloat16_except_pr_lc()
def move_to_device(module, device):
"""Recursively moves all parameters and buffers to the specified device."""
for child in module.children():
move_to_device(child, device)
for param in module.parameters(recurse=False):
if param.device != device:
param.data = param.data.to(device)
for buf in module.buffers(recurse=False):
if buf.device != device:
buf.data = buf.data.to(device)
module.to(device)
def fixup_fp8_extra_states(module):
"""Recursively fixes device location of TE's Linear fp8 extra states."""
for child in module.children():
fixup_fp8_extra_states(child)
# TE Linear uses default "cuda" device to load extra state, which causes
# trouble when the layer is moved to another GPU. Instead, this is how
# TE Linear should load extra_state: using parameters' device.
torch_load = torch.load
def overriden_load(state, map_location):
device = next(module.parameters()).device
return torch_load(state, map_location=device)
if hasattr(module, "fp8_meta"):
log.debug(f"Reloading fp8 extra state to a proper device for {module}")
from unittest.mock import patch
with patch("torch.load", new=overriden_load):
module.set_extra_state(module.get_extra_state())
def fixup_te_workspace():
"""TE uses single workspace tensor for all calls, disregarding that inputs
may be on separate GPUs. This patches TE's Linear module to use per-device
workspaces."""
from functools import lru_cache
@lru_cache
def te_cublas_get_workspace_per_device(device):
log.info(f"Fixup applied: Allocating cublas workspace for {device=}")
import transformer_engine.pytorch.module.base as tebase
with torch.cuda.device(device):
tebase._cublas_workspace = None # Force get_workspace() to reallocate tensor
return tebase.get_workspace()
def get_workspace():
return te_cublas_get_workspace_per_device(torch.cuda.current_device())
import transformer_engine.pytorch.module.linear as telinear
telinear.get_workspace = get_workspace
def get_init_from_string(init_str):
if type(init_str) == str:
if init_str == "torch.nn.init.zeros_":
return torch.nn.init.zeros_
elif init_str == "torch.nn.init.xavier_uniform_":
return torch.nn.init.xavier_uniform_
elif init_str == "torch.nn.init.xavier_normal_":
return torch.nn.init.xavier_normal_
else:
raise ValueError(f"Unrecognized init {init_str}")
def print_rank_0(message, debug=False, end="\n"):
"""Print from rank 0 only."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True, end=end)
else:
print(message, flush=True, end=end)
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
class Lambda(torch.nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [first, last]"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
|