khala / models /Megatron /tests /unit_tests /test_process_groups_config.py
multimodalart's picture
multimodalart HF Staff
Initial best-effort ZeroGPU port of Khala song generation
d1f1097 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import torch.distributed as dist
from megatron.core.process_groups_config import ProcessGroupCollection
class TestProcessGroupsConfig:
"""Simple tests for process group dataclasses."""
def test_transformer_process_groups(self, mocker):
"""Test basic functionality of TransformerProcessGroups."""
mock_pg1 = mocker.Mock(spec=dist.ProcessGroup)
mock_pg2 = mocker.Mock(spec=dist.ProcessGroup)
# Create instance
model_pgs = ProcessGroupCollection()
# Test setting attributes after creation
model_pgs.tp = mock_pg1
model_pgs.pp = mock_pg2
# Test accessing attributes
assert model_pgs.tp == mock_pg1
assert model_pgs.pp == mock_pg2
# Test attribute existence
assert hasattr(model_pgs, 'tp')
assert hasattr(model_pgs, 'pp')
assert not hasattr(model_pgs, 'cp') # Not set yet
def test_grad_comm_process_groups(self, mocker):
"""Test basic functionality of ProcessGroupCollection."""
# Create mock process groups
mock_pg = mocker.Mock(spec=dist.ProcessGroup)
# Create instance
grad_pgs = ProcessGroupCollection()
# Test setting attributes after creation
grad_pgs.dp = mock_pg
# Test accessing attributes
assert grad_pgs.dp == mock_pg
# Test attribute existence
assert hasattr(grad_pgs, 'dp')
assert not hasattr(grad_pgs, 'dp_cp') # Not set yet
def test_hierarchical_context_parallel_groups(self, mocker):
"""Test setting and accessing the hierarchical context parallel list."""
# Create mock process groups
mock_pg1 = mocker.Mock(spec=dist.ProcessGroup)
mock_pg2 = mocker.Mock(spec=dist.ProcessGroup)
# Create instance
model_pgs = ProcessGroupCollection()
# Set the hierarchical context parallel groups
model_pgs.hcp = [mock_pg1, mock_pg2]
# Test list access
assert isinstance(model_pgs.hcp, list)
assert len(model_pgs.hcp) == 2
assert model_pgs.hcp[0] == mock_pg1
assert model_pgs.hcp[1] == mock_pg2