Spaces:
Running on Zero
Running on Zero
File size: 2,210 Bytes
d1f1097 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import torch.distributed as dist
from megatron.core.process_groups_config import ProcessGroupCollection
class TestProcessGroupsConfig:
"""Simple tests for process group dataclasses."""
def test_transformer_process_groups(self, mocker):
"""Test basic functionality of TransformerProcessGroups."""
mock_pg1 = mocker.Mock(spec=dist.ProcessGroup)
mock_pg2 = mocker.Mock(spec=dist.ProcessGroup)
# Create instance
model_pgs = ProcessGroupCollection()
# Test setting attributes after creation
model_pgs.tp = mock_pg1
model_pgs.pp = mock_pg2
# Test accessing attributes
assert model_pgs.tp == mock_pg1
assert model_pgs.pp == mock_pg2
# Test attribute existence
assert hasattr(model_pgs, 'tp')
assert hasattr(model_pgs, 'pp')
assert not hasattr(model_pgs, 'cp') # Not set yet
def test_grad_comm_process_groups(self, mocker):
"""Test basic functionality of ProcessGroupCollection."""
# Create mock process groups
mock_pg = mocker.Mock(spec=dist.ProcessGroup)
# Create instance
grad_pgs = ProcessGroupCollection()
# Test setting attributes after creation
grad_pgs.dp = mock_pg
# Test accessing attributes
assert grad_pgs.dp == mock_pg
# Test attribute existence
assert hasattr(grad_pgs, 'dp')
assert not hasattr(grad_pgs, 'dp_cp') # Not set yet
def test_hierarchical_context_parallel_groups(self, mocker):
"""Test setting and accessing the hierarchical context parallel list."""
# Create mock process groups
mock_pg1 = mocker.Mock(spec=dist.ProcessGroup)
mock_pg2 = mocker.Mock(spec=dist.ProcessGroup)
# Create instance
model_pgs = ProcessGroupCollection()
# Set the hierarchical context parallel groups
model_pgs.hcp = [mock_pg1, mock_pg2]
# Test list access
assert isinstance(model_pgs.hcp, list)
assert len(model_pgs.hcp) == 2
assert model_pgs.hcp[0] == mock_pg1
assert model_pgs.hcp[1] == mock_pg2
|