File size: 6,142 Bytes
f4cade0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
from fnmatch import fnmatch
from typing import Optional, Union

import torch
import torch.nn as nn
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
from torch.distributed.tensor.parallel.style import ParallelStyle


__all__ = ["parallelize_module"]


def parallelize_module(  # type: ignore[return]

    module: nn.Module,

    device_mesh: Optional[DeviceMesh] = None,

    parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None,

    *,

    src_data_rank: Optional[int] = 0,

) -> nn.Module:
    """

    Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.



    We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains

    :class:`ParallelStyle`, which indicates how user wants the module or sub_module

    to be parallelized.



    User can also specify different parallel style per module fully qualified name (FQN).



    Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,

    slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)



    Args:

        module (:class:`nn.Module`):

            Module to be parallelized.

        device_mesh (:class:`DeviceMesh`, optional):

            Object which describes the mesh topology of devices for the DTensor.

            If not specified, the call must be under a DeviceMesh context.

        parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]], optional):

            The plan used to parallelize the module. It can be either a

            :class:`ParallelStyle` object which contains how we prepare

            input/output for Tensor Parallelism or it can be a dict of module

            FQN and its corresponding :class:`ParallelStyle` object. If not

            specified, the call will do nothing at the moment.

    Keyword args:

        src_data_rank (int, optional): the rank of the source data for the logical/global tensor, it is used by

            :meth:`distribute_tensor` to scatter/broadcast the shards/replicas to other ranks. By default,

            we use ``group_rank=0`` on each DeviceMesh dimension as the source data to preserve the single-device

            semantic. If passing ``None`` explicitly, :meth:`parallelize_module` simply uses its local data instead

            of trying to preserve the single-device semantic via scatter/broadcast. Default: 0

    Return:

        A :class:`nn.Module` object parallelized.



    Example::

        >>> # xdoctest: +SKIP("distributed")

        >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel

        >>> from torch.distributed.device_mesh import init_device_mesh

        >>>

        >>> # Define the module.

        >>> m = Model(...)

        >>> tp_mesh = init_device_mesh("cuda", (8,))

        >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})

        >>>



    .. note:: For complex module architecture like Attention, MLP layers, we recommend composing

        different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass

        as a parallelize_plan, to achieves the desired sharding computation.

    """
    torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")

    device_mesh = device_mesh or _mesh_resources.get_current_mesh()
    _validate_tp_mesh_dim(device_mesh)

    if parallelize_plan is None:
        warnings.warn(
            "No parallelize_plan is provided and auto-parallel is not supported "
            "at the moment, so this parallelize_module call will do nothing."
        )
        return module

    # note: The RNG tracker will be initialized in distribute_tensor() call if it hasn't
    # been initialized.

    if isinstance(parallelize_plan, ParallelStyle):
        parallelize_plan.src_data_rank = src_data_rank
        return parallelize_plan._apply(module, device_mesh)
    elif isinstance(parallelize_plan, dict):
        for module_path, parallelize_style in parallelize_plan.items():
            path_splits = module_path.split(".")
            if len(path_splits) == 0:
                raise ValueError(
                    "Expect module path to be non-empty, but got empty string!"
                )
            while path_splits:
                atom = path_splits.pop(0)
                matched_children = filter(
                    # `t[0]` is child name
                    lambda t: fnmatch(t[0], atom),
                    module.named_children(),
                )
                # apply the plan to all matched submodules
                for _, submodule in matched_children:
                    if path_splits:
                        # we haven't reached the leaf, apply in dict style
                        leaf_path = ".".join(
                            path_splits
                        )  # rest of the path after `atom`
                        parallelize_module(
                            submodule,
                            device_mesh,
                            {leaf_path: parallelize_style},
                            src_data_rank=src_data_rank,
                        )
                    else:
                        # otherwise, directly apply style to this submodule
                        parallelize_module(
                            submodule,
                            device_mesh,
                            parallelize_style,
                            src_data_rank=src_data_rank,
                        )
        return module
    else:
        raise TypeError(  # pyre-ignore[7]
            "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
            f" parallelize_plan, {type(parallelize_plan)} found!"
        )