File size: 5,845 Bytes
151bb5a a35a092 151bb5a a35a092 151bb5a a35a092 151bb5a a35a092 151bb5a a35a092 151bb5a a35a092 151bb5a a35a092 151bb5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import random
import sys
from collections.abc import Sequence
import pytest
import torch
import torch.distributed as dist
from packaging import version
from torch.distributed.tensor.placement_types import (Partial, Placement,
Replicate, Shard)
import activation
from .utils import assert_close, opcheck
DTYPES = [torch.float32]
NUM_TOKENS = [512] # Arbitrary values for testing
SEQUENCE_DIMS = [0, 1] # 0 is for [T, D] (packed), 1 is for [B, S, D]
D = [16] # Arbitrary values for testing
SEEDS = [0]
from activation.parallel_style import ResidualSequenceParallel
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
@pytest.fixture(scope="session", autouse=True)
def init_dist(request):
if version.parse(torch.__version__) < version.parse("2.8"):
pytest.skip("torch>=2.8.0 is required for sequence parallel")
return
try:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
except Exception as e:
print(f"Failed to initialize torch.distributed: {e}")
pytest.skip("Failed to initialize torch.distributed")
if dist.get_world_size() < 2:
pytest.skip("Need at least 2 processes in dist group. "
"You can run with `torchrun --nproc-per-node=2 "
"--local-ranks-filter 0 -m pytest "
"test_rms_norm_sequence_parallel.py`")
yield
dist.destroy_process_group()
class Model(torch.nn.Module):
def __init__(self, num_tokens, d) -> None:
super().__init__()
self.fused_add_rms_norm = activation.layers.FusedAddRMSNorm(d)
def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
return self.fused_add_rms_norm(x, residual)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("sequence_dim", SEQUENCE_DIMS)
@pytest.mark.parametrize("x_requires_grad", [True, False])
@pytest.mark.parametrize("residual_requires_grad", [True, False])
def test_fused_add_rms_norm_sequence_parallel(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
sequence_dim: int,
x_requires_grad: bool,
residual_requires_grad: bool,
) -> None:
if num_tokens % dist.get_world_size() != 0:
# It hangs at `y.full_tensor()` if not divisible
pytest.skip("num_tokens must be divisible by world_size for sharding")
if not x_requires_grad and not residual_requires_grad:
pytest.skip("For now, at least one of x or residual must require grad")
random.seed(seed)
torch.manual_seed(seed)
num_ranks = dist.get_world_size()
rank = dist.get_rank()
mesh = init_device_mesh("cuda", (num_ranks, ), mesh_dim_names=("shard", ))
match sequence_dim:
case 0:
x_shape = (num_tokens, d)
case 1:
BATCH_SIZE = 2
x_shape = (BATCH_SIZE, num_tokens, d)
case _:
raise ValueError(f"Invalid sequence_dim: {sequence_dim}")
x = torch.randn(x_shape, dtype=dtype, requires_grad=x_requires_grad).cuda()
residual = torch.randn(x_shape,
dtype=dtype,
requires_grad=residual_requires_grad).cuda()
weight = torch.ones(d, dtype=dtype, requires_grad=True).cuda()
eps = 1e-05
if x_requires_grad:
x.retain_grad()
if residual_requires_grad:
residual.retain_grad()
weight.retain_grad()
# Copy x, weight for reference
x_ref = x.detach().clone().requires_grad_(True)
residual_ref = residual.detach().clone().requires_grad_(True)
weight_ref = weight.detach().clone().requires_grad_(True)
model_sharded = Model(num_tokens, d).to(dtype=dtype).cuda()
model_sharded.fused_add_rms_norm.weight = torch.nn.Parameter(weight)
parallelize_module(model_sharded, mesh, {
"fused_add_rms_norm":
ResidualSequenceParallel(sequence_dim=sequence_dim)
})
x_replicate = DTensor.from_local(
x,
placements=(Replicate(), ),
device_mesh=mesh,
)
residual_replicate = DTensor.from_local(
residual,
placements=(Replicate(), ),
device_mesh=mesh,
)
y, add_output = model_sharded(x_replicate, residual_replicate)
y_from_sharded = y.full_tensor()
add_output_from_sharded = add_output.full_tensor()
model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda()
model_unsharded.fused_add_rms_norm.weight = torch.nn.Parameter(weight_ref)
y_from_unsharded, add_output_from_unsharded = model_unsharded(
x_ref, residual_ref)
assert_close(y_from_sharded, y_from_unsharded)
assert_close(add_output_from_sharded, add_output_from_unsharded)
# Backward
y_grad = torch.randn_like(y_from_unsharded)
add_output_grad = torch.randn_like(add_output_from_unsharded)
(y_grad * y_from_sharded +
add_output_grad * add_output_from_sharded).sum().backward()
(y_grad * y_from_unsharded +
add_output_grad * add_output_from_unsharded).sum().backward()
weight_grad_from_sharded = model_sharded.fused_add_rms_norm.weight.grad.full_tensor(
)
weight_grad_from_unsharded = model_unsharded.fused_add_rms_norm.weight.grad
assert (x.grad is None) ^ x_requires_grad
assert (residual.grad is None) ^ residual_requires_grad
if x_requires_grad:
assert_close(x.grad, x_ref.grad)
if residual_requires_grad:
assert_close(residual.grad, residual_ref.grad)
assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)
|