# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import torch.distributed as dist from megatron.core.process_groups_config import ProcessGroupCollection class TestProcessGroupsConfig: """Simple tests for process group dataclasses.""" def test_transformer_process_groups(self, mocker): """Test basic functionality of TransformerProcessGroups.""" mock_pg1 = mocker.Mock(spec=dist.ProcessGroup) mock_pg2 = mocker.Mock(spec=dist.ProcessGroup) # Create instance model_pgs = ProcessGroupCollection() # Test setting attributes after creation model_pgs.tp = mock_pg1 model_pgs.pp = mock_pg2 # Test accessing attributes assert model_pgs.tp == mock_pg1 assert model_pgs.pp == mock_pg2 # Test attribute existence assert hasattr(model_pgs, 'tp') assert hasattr(model_pgs, 'pp') assert not hasattr(model_pgs, 'cp') # Not set yet def test_grad_comm_process_groups(self, mocker): """Test basic functionality of ProcessGroupCollection.""" # Create mock process groups mock_pg = mocker.Mock(spec=dist.ProcessGroup) # Create instance grad_pgs = ProcessGroupCollection() # Test setting attributes after creation grad_pgs.dp = mock_pg # Test accessing attributes assert grad_pgs.dp == mock_pg # Test attribute existence assert hasattr(grad_pgs, 'dp') assert not hasattr(grad_pgs, 'dp_cp') # Not set yet def test_hierarchical_context_parallel_groups(self, mocker): """Test setting and accessing the hierarchical context parallel list.""" # Create mock process groups mock_pg1 = mocker.Mock(spec=dist.ProcessGroup) mock_pg2 = mocker.Mock(spec=dist.ProcessGroup) # Create instance model_pgs = ProcessGroupCollection() # Set the hierarchical context parallel groups model_pgs.hcp = [mock_pg1, mock_pg2] # Test list access assert isinstance(model_pgs.hcp, list) assert len(model_pgs.hcp) == 2 assert model_pgs.hcp[0] == mock_pg1 assert model_pgs.hcp[1] == mock_pg2