|
|
|
|
|
import torch |
|
|
import logging |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: |
|
|
"""Get the dim for the local rank derived from splitting dim on world_size processes. |
|
|
|
|
|
The split may not be even across the world_size processes. |
|
|
""" |
|
|
multiple = dim // multiple_of |
|
|
div = multiple // world_size |
|
|
mod = multiple % world_size |
|
|
local_multiple = div + int(local_rank < mod) |
|
|
return local_multiple * multiple_of |
|
|
|
|
|
|
|
|
def grab_first_if_tuple(x): |
|
|
if x.__class__.__name__ == "tuple": |
|
|
return x[0] |
|
|
else: |
|
|
return x |
|
|
|
|
|
|
|
|
def interleave(z_pre): |
|
|
if len(z_pre.shape) == 3: |
|
|
x1 = z_pre[:, 0::3, :] |
|
|
x2 = z_pre[:, 1::3, :] |
|
|
v = z_pre[:, 2::3, :] |
|
|
z_pre = torch.cat([x1, x2, v], dim=1) |
|
|
return z_pre |
|
|
else: |
|
|
x1 = z_pre[..., 0::3] |
|
|
x2 = z_pre[..., 1::3] |
|
|
v = z_pre[..., 2::3] |
|
|
z_pre = torch.concat([x1, x2, v], dim=-1) |
|
|
return z_pre |
|
|
|
|
|
|
|
|
def column_split(x, num_heads, head_size): |
|
|
"""Split a tensor with `num_heads` alongside the head dimension, instead of |
|
|
across heads. Fixed to three projections |
|
|
""" |
|
|
|
|
|
if len(x.shape) == 2: |
|
|
x_reshaped = x.reshape( |
|
|
x.shape[0], |
|
|
num_heads, |
|
|
3 * head_size, |
|
|
) |
|
|
|
|
|
x2, x1, v = ( |
|
|
x_reshaped[..., :head_size], |
|
|
x_reshaped[..., head_size : 2 * head_size], |
|
|
x_reshaped[..., 2 * head_size :], |
|
|
) |
|
|
x2, x1, v = ( |
|
|
x2.reshape(x2.shape[0], -1), |
|
|
x1.reshape(x1.shape[0], -1), |
|
|
v.reshape(v.shape[0], -1), |
|
|
) |
|
|
return x2, x1, v |
|
|
else: |
|
|
x = x.reshape( |
|
|
x.shape[0], |
|
|
num_heads, |
|
|
3 * head_size, |
|
|
x.shape[2], |
|
|
) |
|
|
x2, x1, v = ( |
|
|
x[:, :, :head_size], |
|
|
x[ |
|
|
:, |
|
|
:, |
|
|
head_size : 2 * head_size, |
|
|
], |
|
|
x[:, :, 2 * head_size :], |
|
|
) |
|
|
x2, x1, v = ( |
|
|
x2.reshape(x2.shape[0], -1, x2.shape[-1]), |
|
|
x1.reshape(x1.shape[0], -1, x1.shape[-1]), |
|
|
v.reshape(v.shape[0], -1, v.shape[-1]), |
|
|
) |
|
|
return x2, x1, v |
|
|
|
|
|
|
|
|
def load_checkpoint(model, checkpoint_path): |
|
|
if checkpoint_path is None: |
|
|
log.warning("Using random weights (dry-run)") |
|
|
return |
|
|
log.info(f"Loading {checkpoint_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
|
|
|
|
torch.serialization.add_safe_globals([io.BytesIO]) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
state = torch.load( |
|
|
checkpoint_path, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
map_location="cpu", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmap=True, |
|
|
|
|
|
|
|
|
weights_only=True, |
|
|
) |
|
|
model.to_bfloat16_except_pr_lc(to_float32=True) |
|
|
|
|
|
model.custom_load_state_dict(state) |
|
|
|
|
|
model.to_bfloat16_except_pr_lc() |
|
|
|
|
|
|
|
|
def move_to_device(module, device): |
|
|
"""Recursively moves all parameters and buffers to the specified device.""" |
|
|
for child in module.children(): |
|
|
move_to_device(child, device) |
|
|
|
|
|
for param in module.parameters(recurse=False): |
|
|
if param.device != device: |
|
|
param.data = param.data.to(device) |
|
|
|
|
|
for buf in module.buffers(recurse=False): |
|
|
if buf.device != device: |
|
|
buf.data = buf.data.to(device) |
|
|
|
|
|
module.to(device) |
|
|
|
|
|
|
|
|
def fixup_fp8_extra_states(module): |
|
|
"""Recursively fixes device location of TE's Linear fp8 extra states.""" |
|
|
for child in module.children(): |
|
|
fixup_fp8_extra_states(child) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_load = torch.load |
|
|
|
|
|
def overriden_load(state, map_location): |
|
|
device = next(module.parameters()).device |
|
|
return torch_load(state, map_location=device) |
|
|
|
|
|
if hasattr(module, "fp8_meta"): |
|
|
log.debug(f"Reloading fp8 extra state to a proper device for {module}") |
|
|
from unittest.mock import patch |
|
|
|
|
|
with patch("torch.load", new=overriden_load): |
|
|
module.set_extra_state(module.get_extra_state()) |
|
|
|
|
|
|
|
|
def fixup_te_workspace(): |
|
|
"""TE uses single workspace tensor for all calls, disregarding that inputs |
|
|
may be on separate GPUs. This patches TE's Linear module to use per-device |
|
|
workspaces.""" |
|
|
from functools import lru_cache |
|
|
|
|
|
@lru_cache |
|
|
def te_cublas_get_workspace_per_device(device): |
|
|
log.info(f"Fixup applied: Allocating cublas workspace for {device=}") |
|
|
import transformer_engine.pytorch.module.base as tebase |
|
|
|
|
|
with torch.cuda.device(device): |
|
|
tebase._cublas_workspace = None |
|
|
return tebase.get_workspace() |
|
|
|
|
|
def get_workspace(): |
|
|
return te_cublas_get_workspace_per_device(torch.cuda.current_device()) |
|
|
|
|
|
import transformer_engine.pytorch.module.linear as telinear |
|
|
|
|
|
telinear.get_workspace = get_workspace |
|
|
|
|
|
|
|
|
def get_init_from_string(init_str): |
|
|
if type(init_str) == str: |
|
|
if init_str == "torch.nn.init.zeros_": |
|
|
return torch.nn.init.zeros_ |
|
|
elif init_str == "torch.nn.init.xavier_uniform_": |
|
|
return torch.nn.init.xavier_uniform_ |
|
|
elif init_str == "torch.nn.init.xavier_normal_": |
|
|
return torch.nn.init.xavier_normal_ |
|
|
else: |
|
|
raise ValueError(f"Unrecognized init {init_str}") |
|
|
|
|
|
|
|
|
def print_rank_0(message, debug=False, end="\n"): |
|
|
"""Print from rank 0 only.""" |
|
|
if torch.distributed.is_initialized(): |
|
|
if torch.distributed.get_rank() == 0: |
|
|
print(message, flush=True, end=end) |
|
|
else: |
|
|
print(message, flush=True, end=end) |
|
|
|
|
|
|
|
|
class dotdict(dict): |
|
|
"""dot.notation access to dictionary attributes""" |
|
|
|
|
|
__getattr__ = dict.get |
|
|
__setattr__ = dict.__setitem__ |
|
|
__delattr__ = dict.__delitem__ |
|
|
|
|
|
|
|
|
def ensure_divisibility(numerator, denominator): |
|
|
"""Ensure that numerator is divisible by the denominator.""" |
|
|
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) |
|
|
|
|
|
|
|
|
def divide(numerator, denominator): |
|
|
"""Ensure that numerator is divisible by the denominator and return |
|
|
the division value.""" |
|
|
ensure_divisibility(numerator, denominator) |
|
|
return numerator // denominator |
|
|
|
|
|
|
|
|
class Lambda(torch.nn.Module): |
|
|
def __init__(self, func): |
|
|
super().__init__() |
|
|
self.func = func |
|
|
|
|
|
def forward(self, x): |
|
|
return self.func(x) |
|
|
|
|
|
|
|
|
class VocabUtility: |
|
|
"""Split the vocabulary into `world_size` chunks amd return the |
|
|
first and last index of the vocabulary belonging to the `rank` |
|
|
partition: Note that indices in [first, last]""" |
|
|
|
|
|
@staticmethod |
|
|
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): |
|
|
index_f = rank * per_partition_vocab_size |
|
|
index_l = index_f + per_partition_vocab_size |
|
|
return index_f, index_l |
|
|
|
|
|
@staticmethod |
|
|
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): |
|
|
per_partition_vocab_size = divide(global_vocab_size, world_size) |
|
|
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size) |
|
|
|