File size: 8,160 Bytes
c3d0544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Union

from treelib import Tree


class ProcessGroupNode:
    """
    Class to store the attributes of a distributed process group

    Attributes
    ----------
    name : str
        Name of the process group
    size : Optional[int]
        Optional, number of processes in the process group
    """

    def __init__(
        self,
        name: str,
        size: Optional[int] = None,
    ):
        """
        Constructor for the ProcessGroupNode class

        Parameters
        ----------
        name : str
            Name of the process group
        size : Optional[int]
            Optional, size of the process group
        """
        self.name = name
        self.size = size

    def __str__(self):
        """
        String representation of the process group node

        Returns
        -------
        str
            String representation of the process group node
        """
        return f"ProcessGroupNode(name={self.name}, size={self.size}, "

    def __repr__(self):
        """
        String representation of the process group node

        Returns
        -------
        str
            String representation of the process group node
        """
        return self.__str__()


class ProcessGroupConfig:
    """
    Class to define the configuration of a model's parallel process group structure as a
    tree. Each node of the tree is of type `ProcessGroupNode`.

    Once the process group config structure (i.e, the tree structure) is set, it is
    sufficient to set only the sizes for each leaf process group. Then, the size of
    every parent group can be automatically computed as the product reduction of the
    sub-tree of that parent group node.

    Examples
    --------
    >>> from physicsnemo.distributed import ProcessGroupNode, ProcessGroupConfig
    >>>
    >>> # Create world group that contains all processes that are part of this job
    >>> world = ProcessGroupNode("world")
    >>>
    >>> # Create the process group config with the highest level process group
    >>> config = ProcessGroupConfig(world)
    >>>
    >>> # Create model and data parallel sub-groups
    >>> # Sub-groups of a single node are guaranteed to be orthogonal by construction
    >>> # Nodes can be added with either the name of the node or the node itself
    >>> config.add_node(ProcessGroupNode("model_parallel"), parent=world)
    >>> config.add_node(ProcessGroupNode("data_parallel"), parent="world")
    >>>
    >>> # Create spatial and channel parallel sub-groups
    >>> config.add_node(ProcessGroupNode("spatial_parallel"), parent="model_parallel")
    >>> config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel")
    >>>
    >>> config.leaf_groups()
    ['data_parallel', 'spatial_parallel', 'channel_parallel']
    >>>
    >>> # Set leaf group sizes
    >>> # Note: product of all leaf-node sizes should be the world size
    >>> group_sizes = {"channel_parallel": 3, "spatial_parallel": 2, "data_parallel": 4}
    >>> config.set_leaf_group_sizes(group_sizes)  # Update all parent group sizes too
    >>> config.get_node("model_parallel").size
    6
    """

    def __init__(self, node: ProcessGroupNode):
        """
        Constructor to the ProcessGroupConfig class

        Parameters
        ----------
        node : ProcessGroupNode
            Root node of the tree, typically would be 'world'
            Note, it is generally recommended to set the child groups for 'world'
            to 'model_parallel' and 'data_parallel' to aid with distributed
            data parallel training unless there is a specific reason to choose a
            different structure
        """
        self.root = node
        self.root_id = node.name
        self.tree = Tree()
        self.tree.create_node(node.name, node.name, data=node)

    def add_node(self, node: ProcessGroupNode, parent=Union[str, ProcessGroupNode]):
        """
        Add a node to the process group config

        Parameters
        ----------
        node : ProcessGroupNode
            The new node to be added to the config
        parent : Union[str, ProcessGroupNode]
            Parent node of the node to be added. Should already be in the config.
            If str, it is the name of the parent node. Otherwise, the parent
            ProcessGroupNode itself.
        """
        if isinstance(parent, ProcessGroupNode):
            parent = parent.name
        self.tree.create_node(node.name, node.name, data=node, parent=parent)

    def get_node(self, name: str) -> ProcessGroupNode:
        """
        Method to get the node given the name of the node

        Parameters
        ----------
        name : str
            Name of the node to retrieve

        Returns
        -------
        ProcessGroupNode
            Node with the given name from the config
        """
        return self.tree.get_node(name).data

    def update_parent_sizes(self, verbose: bool = False) -> int:
        """
        Method to update parent node sizes after setting the sizes for each leaf node

        Parameters
        ----------
        verbose : bool
            If True, print a message each time a parent node size was updated

        Returns
        -------
        int
            Size of the root node
        """
        return _tree_product_reduction(self.tree, self.root_id, verbose=verbose)

    def leaf_groups(self) -> List[str]:
        """
        Get a list of all leaf group names

        Returns
        -------
        List[str]
            List of all leaf node names
        """
        return [n.identifier for n in self.tree.leaves()]

    def set_leaf_group_sizes(
        self, group_sizes: Dict[str, int], update_parent_sizes: bool = True
    ):
        """
        Set process group sizes for all leaf groups

        Parameters
        ----------
        group_sizes : Dict[str, int]
            Dictionary with a mapping of each leaf group name to its size
        update_parent_sizes : bool
            Update all parent group sizes based on the leaf group if True
            If False, only set the leaf group sizes.
        """
        for id, size in group_sizes.items():
            if not self.tree.contains(id):
                raise AssertionError(
                    f"Process group {id} is not in this process group config"
                )
            node = self.tree.get_node(id)
            if not node.is_leaf():
                raise AssertionError(f"Process group {id} is not a leaf group")
            node.data.size = size

        if update_parent_sizes:
            self.update_parent_sizes()


def _tree_product_reduction(tree, node_id, verbose=False):
    """
    Function to traverse a tree and compute the product reduction of
    the sub-tree for each node starting from `node_id`
    """
    children = tree.children(node_id)
    node = tree.get_node(node_id)
    if not children:
        if node.data.size is None:
            raise AssertionError("Leaf nodes should have a valid size set")
        return node.data.size

    product = 1

    for child in children:
        product *= _tree_product_reduction(tree, child.identifier)

    if node.data.size != product:
        if verbose:
            print(
                "Updating size of node "
                f"{node.data.name} from {node.data.size} to {product}"
            )
        node.data.size = product

    return product