Spaces:
Sleeping
Sleeping
File size: 8,733 Bytes
c3d0544 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from physicsnemo.utils.profiling import profile
from physicsnemo.utils.version_check import check_module_requirements
check_module_requirements("physicsnemo.distributed.shard_tensor")
from torch.distributed.tensor.placement_types import ( # noqa: E402
Shard,
)
from physicsnemo.distributed import ShardTensor # noqa: E402
from physicsnemo.distributed.shard_utils.patch_core import ( # noqa: E402
MissingShardPatch,
)
def compute_local_padding_and_output_shape(
input_tensor_shape: tuple[int, ...],
pad: tuple[int, ...],
mesh_coords: tuple[int, ...],
mesh_sizes: tuple[int, ...],
tensor_sharding_map: dict[int, int],
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""
Compute the local padding and output shape for a given input tensor shape, pad, mode, and value.
Args:
input_tensor_shape: The shape of the input tensor
pad: The padding size(s)
mesh_coords: The coordinates of the tensor in the mesh
mesh_sizes: The sizes of the mesh
tensor_sharding_map: A map from tensor dimension to mesh dimension
Returns:
tuple of the local padding and output shape
"""
tensor_rank = len(input_tensor_shape)
pad_dims = len(pad) // 2
output_padding = []
local_output_shape = list(input_tensor_shape)
# We have to loop over this backwards:
for dim_from_last in range(pad_dims):
tensor_dim = tensor_rank - 1 - dim_from_last
left = pad[2 * dim_from_last]
right = pad[2 * dim_from_last + 1]
# If this axis of the tensor is not sharded, we keep these as they are:
if tensor_dim not in tensor_sharding_map.keys():
output_padding.append(left)
output_padding.append(right)
local_output_shape[tensor_dim] += left + right
else:
# The tensor is sharded on this dim.
# So, determine if this is an edge or not.
# four cases here.
# - is left and not is right
# - not is left and is right
# - is left AND is right:
# - not is left and not is right
mesh_dim = tensor_sharding_map[tensor_dim]
is_left = mesh_coords[mesh_dim] == 0
is_right = mesh_coords[mesh_dim] == mesh_sizes[mesh_dim] - 1
if is_left and not is_right:
output_padding.append(left)
output_padding.append(0)
local_output_shape[tensor_dim] += left
elif not is_left and is_right:
output_padding.append(0)
output_padding.append(right)
local_output_shape[tensor_dim] += right
elif is_left and is_right:
output_padding.append(left)
output_padding.append(right)
local_output_shape[tensor_dim] += left + right
else:
output_padding.append(0)
output_padding.append(0)
return tuple(local_output_shape), tuple(output_padding)
def generic_pad_nd_wrapper(func: callable, types: tuple, args: tuple, kwargs: dict):
"""Wrapper function for N-dimensional padding operations supporting shardtensors.
Args:
func: The padding function to be wrapped
types: tuple of input types (unused)
args: Positional arguments to the padding function
kwargs: Keyword arguments to the padding function
Returns:
The result of the padding operation
"""
# Padding is a no-communication operation unless it's circular padding
# Circular padding is not implemented yet, and probably won't get implemented
# until it's requested.
inputs, pad, mode, value = repackage_pad_args(*args, **kwargs)
if mode == "circular":
raise MissingShardPatch(
"Circular padding is not implemented yet. Please open an issue at https://github.com/NVIDIA/PhysicsNemo/issues if you need this functionality."
)
# Now, get the local tensor:
local_input = inputs.to_local()
# We have to update the padding values based on where this tensor is, and if
# it is on the edge or not.
#
# The only way to do that is to loop over the paddings to determine
# the tensor axes, and then loop over the mesh / spec to see if that axis is
# sharded.
#
# Further, because we don't want to communicate across GPUs unless it's needed,
# We need to compute this for all tensors in the shard spec.
# Sanity checks
if len(pad) % 2 != 0:
raise ValueError("Sharded Padding requires len(pad) to be divisible by 2.")
pad_dims = len(pad) // 2
if pad_dims > len(inputs.shape):
raise ValueError(
f"Sharded Padding specified for {pad_dims} but tensor has only {len(inputs.shape)} dimensions."
)
# By default, all output tensors are unsharded
# This maps tensor dim to mesh dim but ONLY if it's sharded
tensor_sharding_map = {}
mesh_sizes = []
spec = inputs._spec
# Loop over the mesh spec and extract sharding vs tensor dim:
for mesh_dim, placement in enumerate(spec.placements):
if isinstance(placement, Shard):
tensor_sharding_map[placement.dim] = mesh_dim
mesh_sizes.append(spec.mesh.size(mesh_dim))
# If the tensor_shard_map is all False, still (so no sharding)
# We can just use a local computation and be done.
if len(tensor_sharding_map) == 0:
local_output = func(local_input, pad, mode, value)
return ShardTensor.from_local(local_output, spec.mesh, spec.placements)
# at this point, at least one dimension is sharded. Maybe more.
# So, loop over the mesh sharding shapes and compute the local output
# shape and padding for that chunk:
output_shapes = {}
self_mesh_coords = [spec.mesh.get_local_rank(m) for m in range(spec.mesh.ndim)]
self_padding = None
for mesh_dim, sharding_shapes in spec.sharding_shapes().items():
output_shapes[mesh_dim] = []
for i, local_shape in enumerate(sharding_shapes):
# Update the mesh sharding coords:
mesh_coords = list(self_mesh_coords)
mesh_coords[mesh_dim] = i
output_shape, local_padding = compute_local_padding_and_output_shape(
local_input.shape, pad, mesh_coords, mesh_sizes, tensor_sharding_map
)
# Catch and cache the one that applies to this rank:
if mesh_coords == self_mesh_coords:
self_padding = local_padding
output_shapes[mesh_dim].append(output_shape)
# From here, apply the local padding to this tensor:
local_output = func(local_input, self_padding, mode, value)
# Now, convert back to shard tensor.
# We already have all the output shapes
return ShardTensor.from_local(
local_output, spec.mesh, spec.placements, sharding_shapes=output_shapes
)
@profile
def repackage_pad_args(
inputs: ShardTensor,
pad: int | tuple[int, ...] = 0,
mode: str = "constant",
value: float | None = None,
*args,
**kwargs,
) -> tuple[
ShardTensor,
tuple[int, ...],
str,
float,
dict,
]:
"""Repackages pad arguments into standard format.
Takes the full set of arguments that could be passed to a pad operation
and separates them into core tensor inputs (inputs, pad, mode, value) and
configuration parameters packaged as a kwargs dict.
Args:
inputs: Input tensor to convolve
pad: Padding size(s)
mode: Padding mode
value: Padding value
bias: Optional bias tensor
*args: Additional positional args (unused)
**kwargs: Additional keyword args (unused)
Returns:
tuple containing:
- Input tensor
- Padding size(s)
- Padding mode
- Padding value
"""
return inputs, pad, mode, value
ShardTensor.register_function_handler(torch.nn.functional.pad, generic_pad_nd_wrapper)
|