Spaces:
Sleeping
Sleeping
File size: 8,160 Bytes
c3d0544 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union
from treelib import Tree
class ProcessGroupNode:
"""
Class to store the attributes of a distributed process group
Attributes
----------
name : str
Name of the process group
size : Optional[int]
Optional, number of processes in the process group
"""
def __init__(
self,
name: str,
size: Optional[int] = None,
):
"""
Constructor for the ProcessGroupNode class
Parameters
----------
name : str
Name of the process group
size : Optional[int]
Optional, size of the process group
"""
self.name = name
self.size = size
def __str__(self):
"""
String representation of the process group node
Returns
-------
str
String representation of the process group node
"""
return f"ProcessGroupNode(name={self.name}, size={self.size}, "
def __repr__(self):
"""
String representation of the process group node
Returns
-------
str
String representation of the process group node
"""
return self.__str__()
class ProcessGroupConfig:
"""
Class to define the configuration of a model's parallel process group structure as a
tree. Each node of the tree is of type `ProcessGroupNode`.
Once the process group config structure (i.e, the tree structure) is set, it is
sufficient to set only the sizes for each leaf process group. Then, the size of
every parent group can be automatically computed as the product reduction of the
sub-tree of that parent group node.
Examples
--------
>>> from physicsnemo.distributed import ProcessGroupNode, ProcessGroupConfig
>>>
>>> # Create world group that contains all processes that are part of this job
>>> world = ProcessGroupNode("world")
>>>
>>> # Create the process group config with the highest level process group
>>> config = ProcessGroupConfig(world)
>>>
>>> # Create model and data parallel sub-groups
>>> # Sub-groups of a single node are guaranteed to be orthogonal by construction
>>> # Nodes can be added with either the name of the node or the node itself
>>> config.add_node(ProcessGroupNode("model_parallel"), parent=world)
>>> config.add_node(ProcessGroupNode("data_parallel"), parent="world")
>>>
>>> # Create spatial and channel parallel sub-groups
>>> config.add_node(ProcessGroupNode("spatial_parallel"), parent="model_parallel")
>>> config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel")
>>>
>>> config.leaf_groups()
['data_parallel', 'spatial_parallel', 'channel_parallel']
>>>
>>> # Set leaf group sizes
>>> # Note: product of all leaf-node sizes should be the world size
>>> group_sizes = {"channel_parallel": 3, "spatial_parallel": 2, "data_parallel": 4}
>>> config.set_leaf_group_sizes(group_sizes) # Update all parent group sizes too
>>> config.get_node("model_parallel").size
6
"""
def __init__(self, node: ProcessGroupNode):
"""
Constructor to the ProcessGroupConfig class
Parameters
----------
node : ProcessGroupNode
Root node of the tree, typically would be 'world'
Note, it is generally recommended to set the child groups for 'world'
to 'model_parallel' and 'data_parallel' to aid with distributed
data parallel training unless there is a specific reason to choose a
different structure
"""
self.root = node
self.root_id = node.name
self.tree = Tree()
self.tree.create_node(node.name, node.name, data=node)
def add_node(self, node: ProcessGroupNode, parent=Union[str, ProcessGroupNode]):
"""
Add a node to the process group config
Parameters
----------
node : ProcessGroupNode
The new node to be added to the config
parent : Union[str, ProcessGroupNode]
Parent node of the node to be added. Should already be in the config.
If str, it is the name of the parent node. Otherwise, the parent
ProcessGroupNode itself.
"""
if isinstance(parent, ProcessGroupNode):
parent = parent.name
self.tree.create_node(node.name, node.name, data=node, parent=parent)
def get_node(self, name: str) -> ProcessGroupNode:
"""
Method to get the node given the name of the node
Parameters
----------
name : str
Name of the node to retrieve
Returns
-------
ProcessGroupNode
Node with the given name from the config
"""
return self.tree.get_node(name).data
def update_parent_sizes(self, verbose: bool = False) -> int:
"""
Method to update parent node sizes after setting the sizes for each leaf node
Parameters
----------
verbose : bool
If True, print a message each time a parent node size was updated
Returns
-------
int
Size of the root node
"""
return _tree_product_reduction(self.tree, self.root_id, verbose=verbose)
def leaf_groups(self) -> List[str]:
"""
Get a list of all leaf group names
Returns
-------
List[str]
List of all leaf node names
"""
return [n.identifier for n in self.tree.leaves()]
def set_leaf_group_sizes(
self, group_sizes: Dict[str, int], update_parent_sizes: bool = True
):
"""
Set process group sizes for all leaf groups
Parameters
----------
group_sizes : Dict[str, int]
Dictionary with a mapping of each leaf group name to its size
update_parent_sizes : bool
Update all parent group sizes based on the leaf group if True
If False, only set the leaf group sizes.
"""
for id, size in group_sizes.items():
if not self.tree.contains(id):
raise AssertionError(
f"Process group {id} is not in this process group config"
)
node = self.tree.get_node(id)
if not node.is_leaf():
raise AssertionError(f"Process group {id} is not a leaf group")
node.data.size = size
if update_parent_sizes:
self.update_parent_sizes()
def _tree_product_reduction(tree, node_id, verbose=False):
"""
Function to traverse a tree and compute the product reduction of
the sub-tree for each node starting from `node_id`
"""
children = tree.children(node_id)
node = tree.get_node(node_id)
if not children:
if node.data.size is None:
raise AssertionError("Leaf nodes should have a valid size set")
return node.data.size
product = 1
for child in children:
product *= _tree_product_reduction(tree, child.identifier)
if node.data.size != product:
if verbose:
print(
"Updating size of node "
f"{node.data.name} from {node.data.size} to {product}"
)
node.data.size = product
return product
|