| | import random |
| | import sys |
| | from collections.abc import Sequence |
| |
|
| | import pytest |
| | import torch |
| | import torch.distributed as dist |
| | from packaging import version |
| | from torch.distributed.tensor.placement_types import (Partial, Placement, |
| | Replicate, Shard) |
| |
|
| | import activation |
| |
|
| | from .utils import assert_close, opcheck |
| |
|
| | DTYPES = [torch.float32] |
| | NUM_TOKENS = [512] |
| | SEQUENCE_DIMS = [0, 1] |
| | D = [16] |
| | SEEDS = [0] |
| |
|
| | from activation.parallel_style import ResidualSequenceParallel |
| | from torch.distributed._tensor import DTensor |
| | from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
| | from torch.distributed.tensor.parallel import parallelize_module |
| |
|
| |
|
| | @pytest.fixture(scope="session", autouse=True) |
| | def init_dist(request): |
| | if version.parse(torch.__version__) < version.parse("2.8"): |
| | pytest.skip("torch>=2.8.0 is required for sequence parallel") |
| | return |
| |
|
| | try: |
| | dist.init_process_group(backend="nccl") |
| | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
| | except Exception as e: |
| | print(f"Failed to initialize torch.distributed: {e}") |
| | pytest.skip("Failed to initialize torch.distributed") |
| |
|
| | if dist.get_world_size() < 2: |
| | pytest.skip("Need at least 2 processes in dist group. " |
| | "You can run with `torchrun --nproc-per-node=2 " |
| | "--local-ranks-filter 0 -m pytest " |
| | "test_rms_norm_sequence_parallel.py`") |
| |
|
| | yield |
| | dist.destroy_process_group() |
| |
|
| |
|
| | class Model(torch.nn.Module): |
| |
|
| | def __init__(self, num_tokens, d) -> None: |
| | super().__init__() |
| | self.fused_add_rms_norm = activation.layers.FusedAddRMSNorm(d) |
| |
|
| | def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: |
| | return self.fused_add_rms_norm(x, residual=residual) |
| |
|
| |
|
| | @pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| | @pytest.mark.parametrize("d", D) |
| | @pytest.mark.parametrize("dtype", DTYPES) |
| | @pytest.mark.parametrize("seed", SEEDS) |
| | @pytest.mark.parametrize("sequence_dim", SEQUENCE_DIMS) |
| | @pytest.mark.parametrize("x_requires_grad", [True, False]) |
| | @pytest.mark.parametrize("residual_requires_grad", [True, False]) |
| | def test_fused_add_rms_norm_sequence_parallel( |
| | num_tokens: int, |
| | d: int, |
| | dtype: torch.dtype, |
| | seed: int, |
| | sequence_dim: int, |
| | x_requires_grad: bool, |
| | residual_requires_grad: bool, |
| | ) -> None: |
| | if num_tokens % dist.get_world_size() != 0: |
| | |
| | pytest.skip("num_tokens must be divisible by world_size for sharding") |
| |
|
| | if not x_requires_grad and not residual_requires_grad: |
| | pytest.skip("For now, at least one of x or residual must require grad") |
| |
|
| | random.seed(seed) |
| | torch.manual_seed(seed) |
| |
|
| | num_ranks = dist.get_world_size() |
| | rank = dist.get_rank() |
| | mesh = init_device_mesh("cuda", (num_ranks, ), mesh_dim_names=("shard", )) |
| |
|
| | match sequence_dim: |
| | case 0: |
| | x_shape = (num_tokens, d) |
| | case 1: |
| | BATCH_SIZE = 2 |
| | x_shape = (BATCH_SIZE, num_tokens, d) |
| | case _: |
| | raise ValueError(f"Invalid sequence_dim: {sequence_dim}") |
| |
|
| | x = torch.randn(x_shape, dtype=dtype, requires_grad=x_requires_grad).cuda() |
| | residual = torch.randn(x_shape, |
| | dtype=dtype, |
| | requires_grad=residual_requires_grad).cuda() |
| | weight = torch.ones(d, dtype=dtype, requires_grad=True).cuda() |
| | eps = 1e-05 |
| |
|
| | if x_requires_grad: |
| | x.retain_grad() |
| | if residual_requires_grad: |
| | residual.retain_grad() |
| | weight.retain_grad() |
| |
|
| | |
| | x_ref = x.detach().clone().requires_grad_(True) |
| | residual_ref = residual.detach().clone().requires_grad_(True) |
| | weight_ref = weight.detach().clone().requires_grad_(True) |
| |
|
| | model_sharded = Model(num_tokens, d).to(dtype=dtype).cuda() |
| | model_sharded.fused_add_rms_norm.weight = torch.nn.Parameter(weight) |
| | parallelize_module(model_sharded, mesh, { |
| | "fused_add_rms_norm": |
| | ResidualSequenceParallel(sequence_dim=sequence_dim) |
| | }) |
| |
|
| | x_sharded = DTensor.from_local( |
| | x.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(), |
| | placements=(Shard(sequence_dim), ), |
| | device_mesh=mesh, |
| | ) |
| | residual_sharded = DTensor.from_local( |
| | residual.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(), |
| | placements=(Shard(sequence_dim), ), |
| | device_mesh=mesh, |
| | ) |
| |
|
| | y, add_output = model_sharded(x_sharded, residual_sharded) |
| |
|
| | y_from_sharded = y.full_tensor() |
| | add_output_from_sharded = add_output.full_tensor() |
| |
|
| | model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda() |
| | model_unsharded.fused_add_rms_norm.weight = torch.nn.Parameter(weight_ref) |
| |
|
| | y_from_unsharded, add_output_from_unsharded = model_unsharded( |
| | x_ref, residual_ref) |
| |
|
| | assert_close(y_from_sharded, y_from_unsharded) |
| | assert_close(add_output_from_sharded, add_output_from_unsharded) |
| |
|
| | |
| | y_grad = torch.randn_like(y_from_unsharded) |
| | add_output_grad = torch.randn_like(add_output_from_unsharded) |
| |
|
| | (y_grad * y_from_sharded + |
| | add_output_grad * add_output_from_sharded).sum().backward() |
| | (y_grad * y_from_unsharded + |
| | add_output_grad * add_output_from_unsharded).sum().backward() |
| |
|
| | weight_grad_from_sharded = model_sharded.fused_add_rms_norm.weight.grad._local_tensor |
| | weight_grad_from_unsharded = model_unsharded.fused_add_rms_norm.weight.grad |
| |
|
| | assert (x.grad is None) ^ x_requires_grad |
| | assert (residual.grad is None) ^ residual_requires_grad |
| |
|
| | torch.distributed.all_reduce(weight_grad_from_sharded, |
| | op=torch.distributed.ReduceOp.SUM) |
| |
|
| | if x.grad is not None: |
| | torch.distributed.all_reduce(x.grad, op=torch.distributed.ReduceOp.SUM) |
| | assert_close(x.grad, x_ref.grad) |
| | if residual.grad is not None: |
| | torch.distributed.all_reduce(residual.grad, |
| | op=torch.distributed.ReduceOp.SUM) |
| | assert_close(residual.grad, residual_ref.grad) |
| |
|
| | assert_close(weight_grad_from_sharded, weight_grad_from_unsharded) |
| |
|