File size: 8,733 Bytes
c3d0544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from physicsnemo.utils.profiling import profile
from physicsnemo.utils.version_check import check_module_requirements

check_module_requirements("physicsnemo.distributed.shard_tensor")

from torch.distributed.tensor.placement_types import (  # noqa: E402
    Shard,
)

from physicsnemo.distributed import ShardTensor  # noqa: E402
from physicsnemo.distributed.shard_utils.patch_core import (  # noqa: E402
    MissingShardPatch,
)


def compute_local_padding_and_output_shape(
    input_tensor_shape: tuple[int, ...],
    pad: tuple[int, ...],
    mesh_coords: tuple[int, ...],
    mesh_sizes: tuple[int, ...],
    tensor_sharding_map: dict[int, int],
) -> tuple[tuple[int, ...], tuple[int, ...]]:
    """
    Compute the local padding and output shape for a given input tensor shape, pad, mode, and value.

    Args:

        input_tensor_shape: The shape of the input tensor
        pad: The padding size(s)
        mesh_coords: The coordinates of the tensor in the mesh
        mesh_sizes: The sizes of the mesh
        tensor_sharding_map: A map from tensor dimension to mesh dimension

    Returns:
        tuple of the local padding and output shape
    """

    tensor_rank = len(input_tensor_shape)

    pad_dims = len(pad) // 2

    output_padding = []

    local_output_shape = list(input_tensor_shape)

    # We have to loop over this backwards:
    for dim_from_last in range(pad_dims):
        tensor_dim = tensor_rank - 1 - dim_from_last

        left = pad[2 * dim_from_last]
        right = pad[2 * dim_from_last + 1]

        # If this axis of the tensor is not sharded, we keep these as they are:
        if tensor_dim not in tensor_sharding_map.keys():
            output_padding.append(left)
            output_padding.append(right)
            local_output_shape[tensor_dim] += left + right
        else:
            # The tensor is sharded on this dim.
            # So, determine if this is an edge or not.

            # four cases here.
            # - is left and not is right
            # - not is left and is right
            # - is left AND is right:
            # - not is left and not is right

            mesh_dim = tensor_sharding_map[tensor_dim]
            is_left = mesh_coords[mesh_dim] == 0
            is_right = mesh_coords[mesh_dim] == mesh_sizes[mesh_dim] - 1

            if is_left and not is_right:
                output_padding.append(left)
                output_padding.append(0)
                local_output_shape[tensor_dim] += left
            elif not is_left and is_right:
                output_padding.append(0)
                output_padding.append(right)
                local_output_shape[tensor_dim] += right
            elif is_left and is_right:
                output_padding.append(left)
                output_padding.append(right)
                local_output_shape[tensor_dim] += left + right
            else:
                output_padding.append(0)
                output_padding.append(0)

    return tuple(local_output_shape), tuple(output_padding)


def generic_pad_nd_wrapper(func: callable, types: tuple, args: tuple, kwargs: dict):
    """Wrapper function for N-dimensional padding operations supporting shardtensors.

    Args:
        func: The padding function to be wrapped
        types: tuple of input types (unused)
        args: Positional arguments to the padding function
        kwargs: Keyword arguments to the padding function

    Returns:
        The result of the padding operation

    """

    # Padding is a no-communication operation unless it's circular padding
    # Circular padding is not implemented yet, and probably won't get implemented
    # until it's requested.

    inputs, pad, mode, value = repackage_pad_args(*args, **kwargs)

    if mode == "circular":
        raise MissingShardPatch(
            "Circular padding is not implemented yet.  Please open an issue at https://github.com/NVIDIA/PhysicsNemo/issues if you need this functionality."
        )

    # Now, get the local tensor:
    local_input = inputs.to_local()

    # We have to update the padding values based on where this tensor is, and if
    # it is on the edge or not.
    #
    # The only way to do that is to loop over the paddings to determine
    # the tensor axes, and then loop over the mesh / spec to see if that axis is
    # sharded.
    #
    # Further, because we don't want to communicate across GPUs unless it's needed,
    # We need to compute this for all tensors in the shard spec.

    # Sanity checks
    if len(pad) % 2 != 0:
        raise ValueError("Sharded Padding requires len(pad) to be divisible by 2.")

    pad_dims = len(pad) // 2

    if pad_dims > len(inputs.shape):
        raise ValueError(
            f"Sharded Padding specified for {pad_dims} but tensor has only {len(inputs.shape)} dimensions."
        )

    # By default, all output tensors are unsharded
    # This maps tensor dim to mesh dim but ONLY if it's sharded
    tensor_sharding_map = {}
    mesh_sizes = []
    spec = inputs._spec

    # Loop over the mesh spec and extract sharding vs tensor dim:
    for mesh_dim, placement in enumerate(spec.placements):
        if isinstance(placement, Shard):
            tensor_sharding_map[placement.dim] = mesh_dim
        mesh_sizes.append(spec.mesh.size(mesh_dim))

    # If the tensor_shard_map is all False, still (so no sharding)
    # We can just use a local computation and be done.
    if len(tensor_sharding_map) == 0:
        local_output = func(local_input, pad, mode, value)
        return ShardTensor.from_local(local_output, spec.mesh, spec.placements)

    # at this point, at least one dimension is sharded.  Maybe more.
    # So, loop over the mesh sharding shapes and compute the local output
    # shape and padding for that chunk:
    output_shapes = {}
    self_mesh_coords = [spec.mesh.get_local_rank(m) for m in range(spec.mesh.ndim)]
    self_padding = None
    for mesh_dim, sharding_shapes in spec.sharding_shapes().items():
        output_shapes[mesh_dim] = []
        for i, local_shape in enumerate(sharding_shapes):
            # Update the mesh sharding coords:
            mesh_coords = list(self_mesh_coords)
            mesh_coords[mesh_dim] = i
            output_shape, local_padding = compute_local_padding_and_output_shape(
                local_input.shape, pad, mesh_coords, mesh_sizes, tensor_sharding_map
            )

            # Catch and cache the one that applies to this rank:
            if mesh_coords == self_mesh_coords:
                self_padding = local_padding

            output_shapes[mesh_dim].append(output_shape)

    # From here, apply the local padding to this tensor:
    local_output = func(local_input, self_padding, mode, value)

    # Now, convert back to shard tensor.
    # We already have all the output shapes
    return ShardTensor.from_local(
        local_output, spec.mesh, spec.placements, sharding_shapes=output_shapes
    )


@profile
def repackage_pad_args(
    inputs: ShardTensor,
    pad: int | tuple[int, ...] = 0,
    mode: str = "constant",
    value: float | None = None,
    *args,
    **kwargs,
) -> tuple[
    ShardTensor,
    tuple[int, ...],
    str,
    float,
    dict,
]:
    """Repackages pad arguments into standard format.

    Takes the full set of arguments that could be passed to a pad operation
    and separates them into core tensor inputs (inputs, pad, mode, value) and
    configuration parameters packaged as a kwargs dict.

    Args:
        inputs: Input tensor to convolve
        pad: Padding size(s)
        mode: Padding mode
        value: Padding value
        bias: Optional bias tensor
        *args: Additional positional args (unused)
        **kwargs: Additional keyword args (unused)
    Returns:
        tuple containing:
        - Input tensor
        - Padding size(s)
        - Padding mode
        - Padding value
    """

    return inputs, pad, mode, value


ShardTensor.register_function_handler(torch.nn.functional.pad, generic_pad_nd_wrapper)