# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Optional, Union from treelib import Tree class ProcessGroupNode: """ Class to store the attributes of a distributed process group Attributes ---------- name : str Name of the process group size : Optional[int] Optional, number of processes in the process group """ def __init__( self, name: str, size: Optional[int] = None, ): """ Constructor for the ProcessGroupNode class Parameters ---------- name : str Name of the process group size : Optional[int] Optional, size of the process group """ self.name = name self.size = size def __str__(self): """ String representation of the process group node Returns ------- str String representation of the process group node """ return f"ProcessGroupNode(name={self.name}, size={self.size}, " def __repr__(self): """ String representation of the process group node Returns ------- str String representation of the process group node """ return self.__str__() class ProcessGroupConfig: """ Class to define the configuration of a model's parallel process group structure as a tree. Each node of the tree is of type `ProcessGroupNode`. Once the process group config structure (i.e, the tree structure) is set, it is sufficient to set only the sizes for each leaf process group. Then, the size of every parent group can be automatically computed as the product reduction of the sub-tree of that parent group node. Examples -------- >>> from physicsnemo.distributed import ProcessGroupNode, ProcessGroupConfig >>> >>> # Create world group that contains all processes that are part of this job >>> world = ProcessGroupNode("world") >>> >>> # Create the process group config with the highest level process group >>> config = ProcessGroupConfig(world) >>> >>> # Create model and data parallel sub-groups >>> # Sub-groups of a single node are guaranteed to be orthogonal by construction >>> # Nodes can be added with either the name of the node or the node itself >>> config.add_node(ProcessGroupNode("model_parallel"), parent=world) >>> config.add_node(ProcessGroupNode("data_parallel"), parent="world") >>> >>> # Create spatial and channel parallel sub-groups >>> config.add_node(ProcessGroupNode("spatial_parallel"), parent="model_parallel") >>> config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel") >>> >>> config.leaf_groups() ['data_parallel', 'spatial_parallel', 'channel_parallel'] >>> >>> # Set leaf group sizes >>> # Note: product of all leaf-node sizes should be the world size >>> group_sizes = {"channel_parallel": 3, "spatial_parallel": 2, "data_parallel": 4} >>> config.set_leaf_group_sizes(group_sizes) # Update all parent group sizes too >>> config.get_node("model_parallel").size 6 """ def __init__(self, node: ProcessGroupNode): """ Constructor to the ProcessGroupConfig class Parameters ---------- node : ProcessGroupNode Root node of the tree, typically would be 'world' Note, it is generally recommended to set the child groups for 'world' to 'model_parallel' and 'data_parallel' to aid with distributed data parallel training unless there is a specific reason to choose a different structure """ self.root = node self.root_id = node.name self.tree = Tree() self.tree.create_node(node.name, node.name, data=node) def add_node(self, node: ProcessGroupNode, parent=Union[str, ProcessGroupNode]): """ Add a node to the process group config Parameters ---------- node : ProcessGroupNode The new node to be added to the config parent : Union[str, ProcessGroupNode] Parent node of the node to be added. Should already be in the config. If str, it is the name of the parent node. Otherwise, the parent ProcessGroupNode itself. """ if isinstance(parent, ProcessGroupNode): parent = parent.name self.tree.create_node(node.name, node.name, data=node, parent=parent) def get_node(self, name: str) -> ProcessGroupNode: """ Method to get the node given the name of the node Parameters ---------- name : str Name of the node to retrieve Returns ------- ProcessGroupNode Node with the given name from the config """ return self.tree.get_node(name).data def update_parent_sizes(self, verbose: bool = False) -> int: """ Method to update parent node sizes after setting the sizes for each leaf node Parameters ---------- verbose : bool If True, print a message each time a parent node size was updated Returns ------- int Size of the root node """ return _tree_product_reduction(self.tree, self.root_id, verbose=verbose) def leaf_groups(self) -> List[str]: """ Get a list of all leaf group names Returns ------- List[str] List of all leaf node names """ return [n.identifier for n in self.tree.leaves()] def set_leaf_group_sizes( self, group_sizes: Dict[str, int], update_parent_sizes: bool = True ): """ Set process group sizes for all leaf groups Parameters ---------- group_sizes : Dict[str, int] Dictionary with a mapping of each leaf group name to its size update_parent_sizes : bool Update all parent group sizes based on the leaf group if True If False, only set the leaf group sizes. """ for id, size in group_sizes.items(): if not self.tree.contains(id): raise AssertionError( f"Process group {id} is not in this process group config" ) node = self.tree.get_node(id) if not node.is_leaf(): raise AssertionError(f"Process group {id} is not a leaf group") node.data.size = size if update_parent_sizes: self.update_parent_sizes() def _tree_product_reduction(tree, node_id, verbose=False): """ Function to traverse a tree and compute the product reduction of the sub-tree for each node starting from `node_id` """ children = tree.children(node_id) node = tree.get_node(node_id) if not children: if node.data.size is None: raise AssertionError("Leaf nodes should have a valid size set") return node.data.size product = 1 for child in children: product *= _tree_product_reduction(tree, child.identifier) if node.data.size != product: if verbose: print( "Updating size of node " f"{node.data.name} from {node.data.size} to {product}" ) node.data.size = product return product