Kernels
File size: 5,845 Bytes
151bb5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a35a092
151bb5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a35a092
 
 
151bb5a
 
a35a092
 
 
151bb5a
 
 
a35a092
151bb5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a35a092
 
151bb5a
 
 
 
 
a35a092
151bb5a
a35a092
151bb5a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import random
import sys
from collections.abc import Sequence

import pytest
import torch
import torch.distributed as dist
from packaging import version
from torch.distributed.tensor.placement_types import (Partial, Placement,
                                                      Replicate, Shard)

import activation

from .utils import assert_close, opcheck

DTYPES = [torch.float32]
NUM_TOKENS = [512]  # Arbitrary values for testing
SEQUENCE_DIMS = [0, 1]  # 0 is for [T, D] (packed), 1 is for [B, S, D]
D = [16]  # Arbitrary values for testing
SEEDS = [0]

from activation.parallel_style import ResidualSequenceParallel
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module


@pytest.fixture(scope="session", autouse=True)
def init_dist(request):
    if version.parse(torch.__version__) < version.parse("2.8"):
        pytest.skip("torch>=2.8.0 is required for sequence parallel")
        return

    try:
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
    except Exception as e:
        print(f"Failed to initialize torch.distributed: {e}")
        pytest.skip("Failed to initialize torch.distributed")

    if dist.get_world_size() < 2:
        pytest.skip("Need at least 2 processes in dist group. "
                    "You can run with `torchrun --nproc-per-node=2 "
                    "--local-ranks-filter 0 -m pytest "
                    "test_rms_norm_sequence_parallel.py`")

    yield
    dist.destroy_process_group()


class Model(torch.nn.Module):

    def __init__(self, num_tokens, d) -> None:
        super().__init__()
        self.fused_add_rms_norm = activation.layers.FusedAddRMSNorm(d)

    def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
        return self.fused_add_rms_norm(x, residual)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("sequence_dim", SEQUENCE_DIMS)
@pytest.mark.parametrize("x_requires_grad", [True, False])
@pytest.mark.parametrize("residual_requires_grad", [True, False])
def test_fused_add_rms_norm_sequence_parallel(
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
    seed: int,
    sequence_dim: int,
    x_requires_grad: bool,
    residual_requires_grad: bool,
) -> None:
    if num_tokens % dist.get_world_size() != 0:
        # It hangs at `y.full_tensor()` if not divisible
        pytest.skip("num_tokens must be divisible by world_size for sharding")

    if not x_requires_grad and not residual_requires_grad:
        pytest.skip("For now, at least one of x or residual must require grad")

    random.seed(seed)
    torch.manual_seed(seed)

    num_ranks = dist.get_world_size()
    rank = dist.get_rank()
    mesh = init_device_mesh("cuda", (num_ranks, ), mesh_dim_names=("shard", ))

    match sequence_dim:
        case 0:
            x_shape = (num_tokens, d)
        case 1:
            BATCH_SIZE = 2
            x_shape = (BATCH_SIZE, num_tokens, d)
        case _:
            raise ValueError(f"Invalid sequence_dim: {sequence_dim}")

    x = torch.randn(x_shape, dtype=dtype, requires_grad=x_requires_grad).cuda()
    residual = torch.randn(x_shape,
                           dtype=dtype,
                           requires_grad=residual_requires_grad).cuda()
    weight = torch.ones(d, dtype=dtype, requires_grad=True).cuda()
    eps = 1e-05

    if x_requires_grad:
        x.retain_grad()
    if residual_requires_grad:
        residual.retain_grad()
    weight.retain_grad()

    # Copy x, weight for reference
    x_ref = x.detach().clone().requires_grad_(True)
    residual_ref = residual.detach().clone().requires_grad_(True)
    weight_ref = weight.detach().clone().requires_grad_(True)

    model_sharded = Model(num_tokens, d).to(dtype=dtype).cuda()
    model_sharded.fused_add_rms_norm.weight = torch.nn.Parameter(weight)
    parallelize_module(model_sharded, mesh, {
        "fused_add_rms_norm":
        ResidualSequenceParallel(sequence_dim=sequence_dim)
    })

    x_replicate = DTensor.from_local(
        x,
        placements=(Replicate(), ),
        device_mesh=mesh,
    )
    residual_replicate = DTensor.from_local(
        residual,
        placements=(Replicate(), ),
        device_mesh=mesh,
    )

    y, add_output = model_sharded(x_replicate, residual_replicate)

    y_from_sharded = y.full_tensor()
    add_output_from_sharded = add_output.full_tensor()

    model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda()
    model_unsharded.fused_add_rms_norm.weight = torch.nn.Parameter(weight_ref)

    y_from_unsharded, add_output_from_unsharded = model_unsharded(
        x_ref, residual_ref)

    assert_close(y_from_sharded, y_from_unsharded)
    assert_close(add_output_from_sharded, add_output_from_unsharded)

    # Backward
    y_grad = torch.randn_like(y_from_unsharded)
    add_output_grad = torch.randn_like(add_output_from_unsharded)

    (y_grad * y_from_sharded +
     add_output_grad * add_output_from_sharded).sum().backward()
    (y_grad * y_from_unsharded +
     add_output_grad * add_output_from_unsharded).sum().backward()

    weight_grad_from_sharded = model_sharded.fused_add_rms_norm.weight.grad.full_tensor(
    )
    weight_grad_from_unsharded = model_unsharded.fused_add_rms_norm.weight.grad

    assert (x.grad is None) ^ x_requires_grad
    assert (residual.grad is None) ^ residual_requires_grad

    if x_requires_grad:
        assert_close(x.grad, x_ref.grad)
    if residual_requires_grad:
        assert_close(residual.grad, residual_ref.grad)

    assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)