| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | """
|
| | Distributed basic functions.
|
| | """
|
| |
|
| | import os
|
| | import torch
|
| | from torch import nn
|
| | import torch.distributed as dist
|
| | from torch.nn.parallel import DistributedDataParallel
|
| | from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
|
| |
|
| |
|
| | def get_global_rank() -> int:
|
| | """
|
| | Get the global rank, the global index of the GPU.
|
| | """
|
| | return int(os.environ.get("RANK", "0"))
|
| |
|
| |
|
| | def get_local_rank() -> int:
|
| | """
|
| | Get the local rank, the local index of the GPU.
|
| | """
|
| | return int(os.environ.get("LOCAL_RANK", "0"))
|
| |
|
| |
|
| | def get_world_size() -> int:
|
| | """
|
| | Get the world size, the total amount of GPUs.
|
| | """
|
| | return int(os.environ.get("WORLD_SIZE", "1"))
|
| |
|
| |
|
| | def get_device() -> torch.device:
|
| | """
|
| | Get current rank device.
|
| | """
|
| | return torch.device("cuda", get_local_rank())
|
| |
|
| |
|
| | def barrier_if_distributed(*args, **kwargs):
|
| | """
|
| | Synchronizes all processes if under distributed context.
|
| | """
|
| | if dist.is_initialized():
|
| | return dist.barrier(*args, **kwargs)
|
| |
|
| |
|
| | def init_torch(cudnn_benchmark=True):
|
| | """
|
| | Common PyTorch initialization configuration.
|
| | """
|
| | torch.backends.cuda.matmul.allow_tf32 = True
|
| | torch.backends.cudnn.allow_tf32 = True
|
| | torch.backends.cudnn.benchmark = cudnn_benchmark
|
| | torch.cuda.set_device(get_local_rank())
|
| | dist.init_process_group(
|
| | backend="nccl",
|
| | rank=get_global_rank(),
|
| | world_size=get_world_size(),
|
| | )
|
| |
|
| |
|
| | def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
|
| | return DistributedDataParallel(
|
| | module=module,
|
| | device_ids=[get_local_rank()],
|
| | output_device=get_local_rank(),
|
| | **kwargs,
|
| | )
|
| |
|
| |
|
| | def meta_param_init_fn(module: nn.Module) -> None:
|
| | """
|
| | Used for model inited onto meta device.
|
| | Init meta param/buffer with empty tensor.
|
| | We don't care numerical correctness in this func.
|
| | FSDP will sync param/buffer state from rank0 to the other ranks.
|
| | """
|
| |
|
| | with torch.no_grad():
|
| | for submodule in module.modules():
|
| | for param_name, param in submodule.named_parameters(recurse=False):
|
| | if not _is_fsdp_flattened(param) and param.is_meta:
|
| | materialized_param = nn.Parameter(torch.empty_like(param, device="cpu"))
|
| | setattr(submodule, param_name, materialized_param)
|
| | for buffer_name, buffer in submodule.named_buffers(recurse=False):
|
| | if not _is_fsdp_flattened(buffer) and buffer.is_meta:
|
| | materialized_param = torch.empty_like(buffer, device="cpu")
|
| | setattr(submodule, buffer_name, materialized_param)
|
| | torch.cuda.empty_cache()
|
| |
|
| |
|
| | def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module:
|
| | """
|
| | Materialize meta device buffers that are not persistent in state_dict.
|
| | Handles special cases like RotaryEmbedding.freqs.
|
| | """
|
| | with torch.no_grad():
|
| | for submodule in module.modules():
|
| | if hasattr(submodule, "freqs"):
|
| | freqs = getattr(submodule, "freqs")
|
| | if isinstance(freqs, torch.Tensor) and freqs.is_meta:
|
| | dim = submodule.dim
|
| | def rope_params(max_seq_len, dim, theta=10000):
|
| | assert dim % 2 == 0
|
| | freqs = torch.outer(
|
| | torch.arange(max_seq_len),
|
| | 1.0 / torch.pow(theta,
|
| | torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
| | freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| | return freqs
|
| |
|
| | dim = 5120
|
| | num_heads = 40
|
| |
|
| |
|
| | d = dim // num_heads
|
| | freqs_tensor = torch.cat([
|
| | rope_params(1024, d - 4 * (d // 6)),
|
| | rope_params(1024, 2 * (d // 6)),
|
| | rope_params(1024, 2 * (d // 6))
|
| | ], dim=1).to(dtype=torch.cfloat, device="cpu")
|
| |
|
| | setattr(submodule, "freqs", freqs_tensor)
|
| | print(f"Successfully materialized freqs for {submodule.__class__.__name__}")
|
| |
|
| | assert not any(b.is_meta for n, b in module.named_buffers())
|
| | return module
|
| |
|