File size: 9,846 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 | """
Test file to verify the correctness of parallel group calculations.
This test validates that the parallel group initialization creates the correct
groups for different parallelism configurations including:
- Tensor parallelism (TP)
- Pipeline parallelism (PP)
- Attention context parallelism (attn_cp)
- Attention data parallelism (attn_dp)
- MoE expert parallelism (EP)
- MoE data parallelism (moe_dp)
These tests call the ACTUAL initialize_model_parallel() function with mocked
distributed backend to verify the group construction logic.
## How These Tests Work
initialize_model_parallel() creates ALL groups for ALL ranks in a single call.
For example, when creating TP groups with tp_size=2 and world_size=8:
group_ranks = [[0,1], [2,3], [4,5], [6,7]] # ALL groups created
_TP = init_model_parallel_group(group_ranks, local_rank, ...)
ALL ranks call this function and get the same complete group structure. Each rank
then figures out which specific group(s) it belongs to.
Our tests:
1. Mock the distributed backend (no real GPUs needed)
2. Mock init_model_parallel_group to capture the group_ranks parameter
3. Call the real initialize_model_parallel()
4. Verify group_ranks contains the expected complete group structure
We only need to simulate rank 0 because we're testing the group creation logic,
not the per-rank group membership logic.
"""
from __future__ import annotations
import sys
from unittest.mock import Mock, patch
import pytest
from sglang.test.ci.ci_register import register_cuda_ci
register_cuda_ci(est_time=5, suite="stage-b-test-small-1-gpu")
# Import the actual parallel_state module
parallel_state = pytest.importorskip("sglang.srt.distributed.parallel_state")
def test_parallel_group_construction_tp8_attn_cp2():
"""
Test parallel group construction for 8 GPU configuration with:
- tensor_model_parallel_size = 8
- attention_context_model_parallel_size = 2
Expected groups based on docstring example:
1 tensor model-parallel group:
[g0, g1, g2, g3, g4, g5, g6, g7]
4 attention context-parallel groups:
[g0, g4], [g1, g5], [g2, g6], [g3, g7]
This test calls the ACTUAL initialize_model_parallel() and verifies the groups.
Note: We simulate only rank 0 here, but initialize_model_parallel() creates
ALL groups for ALL ranks in a single call. We capture these groups via mocking
and verify the complete group structure.
"""
world_size = 8
# Mock the distributed backend
# Note: get_rank() returns 0 because we're testing from a single process,
# but initialize_model_parallel() still creates all groups for all ranks
with patch.object(parallel_state, "_WORLD", None), patch.object(
parallel_state, "_TP", None
), patch.object(parallel_state, "_ATTN_CP", None), patch.object(
parallel_state, "_ATTN_TP", None
), patch.object(
parallel_state, "_PP", None
), patch(
"torch.distributed.is_initialized", return_value=True
), patch(
"torch.distributed.get_world_size", return_value=world_size
), patch(
"torch.distributed.get_rank", return_value=0
), patch(
"torch.distributed.get_backend", return_value="nccl"
):
# Mock init_model_parallel_group to capture the groups being created
created_groups = {}
def mock_init_model_parallel_group(group_ranks, local_rank, backend, **kwargs):
group_name = kwargs.get("group_name", "unknown")
created_groups[group_name] = group_ranks
# Create a mock group object
mock_group = Mock()
mock_group.device_group = Mock()
return mock_group
with patch.object(
parallel_state,
"init_model_parallel_group",
side_effect=mock_init_model_parallel_group,
), patch.object(parallel_state, "get_world_group") as mock_world_group:
# Mock world group
mock_world = Mock()
mock_world.device_group = Mock()
mock_world.local_rank = 0
mock_world_group.return_value = mock_world
# Call the actual function
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=8,
pipeline_model_parallel_size=1,
attention_context_model_parallel_size=2,
)
# Verify TP groups
tp_groups = created_groups.get("tp", [])
assert len(tp_groups) == 1, f"Expected 1 TP group, got {len(tp_groups)}"
assert tp_groups[0] == [
0,
1,
2,
3,
4,
5,
6,
7,
], f"Wrong TP group: {tp_groups[0]}"
# Verify ATTN_CP groups
attn_cp_groups = created_groups.get("attn_cp", [])
assert (
len(attn_cp_groups) == 4
), f"Expected 4 ATTN_CP groups, got {len(attn_cp_groups)}"
expected_attn_cp = [
[0, 4],
[1, 5],
[2, 6],
[3, 7],
]
assert (
attn_cp_groups == expected_attn_cp
), f"Wrong ATTN_CP groups: {attn_cp_groups}"
print("TP=8, Attn CP=2 group construction verified")
# Cleanup
parallel_state.destroy_model_parallel()
def test_parallel_group_construction_tp8_moe_ep4_cp2():
"""
Test parallel group construction for 8 GPU configuration with:
- tensor_model_parallel_size = 8
- expert_model_parallel_size = 4
- moe_data_model_parallel_size = 2
Expected groups:
1 tensor model-parallel group:
[g0, g1, g2, g3, g4, g5, g6, g7]
2 MoE expert-parallel groups:
[g0, g1, g2, g3], [g4, g5, g6, g7]
4 MoE data-parallel groups:
[g0, g4], [g1, g5], [g2, g6], [g3, g7]
"""
world_size = 8
# Mock the distributed backend
with patch.object(parallel_state, "_WORLD", None), patch.object(
parallel_state, "_TP", None
), patch.object(parallel_state, "_MOE_EP", None), patch.object(
parallel_state, "_MOE_DP", None
), patch.object(
parallel_state, "_MOE_TP", None
), patch.object(
parallel_state, "_PP", None
), patch(
"torch.distributed.is_initialized", return_value=True
), patch(
"torch.distributed.get_world_size", return_value=world_size
), patch(
"torch.distributed.get_rank", return_value=0
), patch(
"torch.distributed.get_backend", return_value="nccl"
):
# Mock init_model_parallel_group to capture the groups being created
created_groups = {}
def mock_init_model_parallel_group(group_ranks, local_rank, backend, **kwargs):
group_name = kwargs.get("group_name", "unknown")
created_groups[group_name] = group_ranks
# Create a mock group object
mock_group = Mock()
mock_group.device_group = Mock()
return mock_group
with patch.object(
parallel_state,
"init_model_parallel_group",
side_effect=mock_init_model_parallel_group,
), patch.object(parallel_state, "get_world_group") as mock_world_group:
# Mock world group
mock_world = Mock()
mock_world.device_group = Mock()
mock_world.local_rank = 0
mock_world_group.return_value = mock_world
# Call the actual function
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=8,
expert_model_parallel_size=4,
pipeline_model_parallel_size=1,
moe_data_model_parallel_size=2,
)
# Verify TP groups
tp_groups = created_groups.get("tp", [])
assert len(tp_groups) == 1, f"Expected 1 TP group, got {len(tp_groups)}"
assert tp_groups[0] == [
0,
1,
2,
3,
4,
5,
6,
7,
], f"Wrong TP group: {tp_groups[0]}"
# Verify MOE_EP groups
moe_ep_groups = created_groups.get("moe_ep", [])
assert (
len(moe_ep_groups) == 2
), f"Expected 2 MOE_EP groups, got {len(moe_ep_groups)}"
expected_moe_ep = [
[0, 1, 2, 3],
[4, 5, 6, 7],
]
assert (
moe_ep_groups == expected_moe_ep
), f"Wrong MOE_EP groups: {moe_ep_groups}"
# Verify MOE_DP groups
moe_dp_groups = created_groups.get("moe_dp", [])
assert (
len(moe_dp_groups) == 4
), f"Expected 4 MOE_DP groups, got {len(moe_dp_groups)}"
expected_moe_dp = [
[0, 4],
[1, 5],
[2, 6],
[3, 7],
]
assert (
moe_dp_groups == expected_moe_dp
), f"Wrong MOE_DP groups: {moe_dp_groups}"
print("TP=8, MoE EP=4, MoE CP=2 group construction verified")
# Cleanup
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
# Run tests without requiring GPUs
import sys
try:
test_parallel_group_construction_tp8_attn_cp2()
test_parallel_group_construction_tp8_moe_ep4_cp2()
sys.exit(0)
except AssertionError as e:
print(f"\n Test failed: {e}")
sys.exit(1)
except Exception as e:
print(f"\n Unexpected error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
|