# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from physicsnemo.utils.profiling import profile from physicsnemo.utils.version_check import check_module_requirements check_module_requirements("physicsnemo.distributed.shard_tensor") from torch.distributed.tensor.placement_types import ( # noqa: E402 Shard, ) from physicsnemo.distributed import ShardTensor # noqa: E402 from physicsnemo.distributed.shard_utils.patch_core import ( # noqa: E402 MissingShardPatch, ) def compute_local_padding_and_output_shape( input_tensor_shape: tuple[int, ...], pad: tuple[int, ...], mesh_coords: tuple[int, ...], mesh_sizes: tuple[int, ...], tensor_sharding_map: dict[int, int], ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Compute the local padding and output shape for a given input tensor shape, pad, mode, and value. Args: input_tensor_shape: The shape of the input tensor pad: The padding size(s) mesh_coords: The coordinates of the tensor in the mesh mesh_sizes: The sizes of the mesh tensor_sharding_map: A map from tensor dimension to mesh dimension Returns: tuple of the local padding and output shape """ tensor_rank = len(input_tensor_shape) pad_dims = len(pad) // 2 output_padding = [] local_output_shape = list(input_tensor_shape) # We have to loop over this backwards: for dim_from_last in range(pad_dims): tensor_dim = tensor_rank - 1 - dim_from_last left = pad[2 * dim_from_last] right = pad[2 * dim_from_last + 1] # If this axis of the tensor is not sharded, we keep these as they are: if tensor_dim not in tensor_sharding_map.keys(): output_padding.append(left) output_padding.append(right) local_output_shape[tensor_dim] += left + right else: # The tensor is sharded on this dim. # So, determine if this is an edge or not. # four cases here. # - is left and not is right # - not is left and is right # - is left AND is right: # - not is left and not is right mesh_dim = tensor_sharding_map[tensor_dim] is_left = mesh_coords[mesh_dim] == 0 is_right = mesh_coords[mesh_dim] == mesh_sizes[mesh_dim] - 1 if is_left and not is_right: output_padding.append(left) output_padding.append(0) local_output_shape[tensor_dim] += left elif not is_left and is_right: output_padding.append(0) output_padding.append(right) local_output_shape[tensor_dim] += right elif is_left and is_right: output_padding.append(left) output_padding.append(right) local_output_shape[tensor_dim] += left + right else: output_padding.append(0) output_padding.append(0) return tuple(local_output_shape), tuple(output_padding) def generic_pad_nd_wrapper(func: callable, types: tuple, args: tuple, kwargs: dict): """Wrapper function for N-dimensional padding operations supporting shardtensors. Args: func: The padding function to be wrapped types: tuple of input types (unused) args: Positional arguments to the padding function kwargs: Keyword arguments to the padding function Returns: The result of the padding operation """ # Padding is a no-communication operation unless it's circular padding # Circular padding is not implemented yet, and probably won't get implemented # until it's requested. inputs, pad, mode, value = repackage_pad_args(*args, **kwargs) if mode == "circular": raise MissingShardPatch( "Circular padding is not implemented yet. Please open an issue at https://github.com/NVIDIA/PhysicsNemo/issues if you need this functionality." ) # Now, get the local tensor: local_input = inputs.to_local() # We have to update the padding values based on where this tensor is, and if # it is on the edge or not. # # The only way to do that is to loop over the paddings to determine # the tensor axes, and then loop over the mesh / spec to see if that axis is # sharded. # # Further, because we don't want to communicate across GPUs unless it's needed, # We need to compute this for all tensors in the shard spec. # Sanity checks if len(pad) % 2 != 0: raise ValueError("Sharded Padding requires len(pad) to be divisible by 2.") pad_dims = len(pad) // 2 if pad_dims > len(inputs.shape): raise ValueError( f"Sharded Padding specified for {pad_dims} but tensor has only {len(inputs.shape)} dimensions." ) # By default, all output tensors are unsharded # This maps tensor dim to mesh dim but ONLY if it's sharded tensor_sharding_map = {} mesh_sizes = [] spec = inputs._spec # Loop over the mesh spec and extract sharding vs tensor dim: for mesh_dim, placement in enumerate(spec.placements): if isinstance(placement, Shard): tensor_sharding_map[placement.dim] = mesh_dim mesh_sizes.append(spec.mesh.size(mesh_dim)) # If the tensor_shard_map is all False, still (so no sharding) # We can just use a local computation and be done. if len(tensor_sharding_map) == 0: local_output = func(local_input, pad, mode, value) return ShardTensor.from_local(local_output, spec.mesh, spec.placements) # at this point, at least one dimension is sharded. Maybe more. # So, loop over the mesh sharding shapes and compute the local output # shape and padding for that chunk: output_shapes = {} self_mesh_coords = [spec.mesh.get_local_rank(m) for m in range(spec.mesh.ndim)] self_padding = None for mesh_dim, sharding_shapes in spec.sharding_shapes().items(): output_shapes[mesh_dim] = [] for i, local_shape in enumerate(sharding_shapes): # Update the mesh sharding coords: mesh_coords = list(self_mesh_coords) mesh_coords[mesh_dim] = i output_shape, local_padding = compute_local_padding_and_output_shape( local_input.shape, pad, mesh_coords, mesh_sizes, tensor_sharding_map ) # Catch and cache the one that applies to this rank: if mesh_coords == self_mesh_coords: self_padding = local_padding output_shapes[mesh_dim].append(output_shape) # From here, apply the local padding to this tensor: local_output = func(local_input, self_padding, mode, value) # Now, convert back to shard tensor. # We already have all the output shapes return ShardTensor.from_local( local_output, spec.mesh, spec.placements, sharding_shapes=output_shapes ) @profile def repackage_pad_args( inputs: ShardTensor, pad: int | tuple[int, ...] = 0, mode: str = "constant", value: float | None = None, *args, **kwargs, ) -> tuple[ ShardTensor, tuple[int, ...], str, float, dict, ]: """Repackages pad arguments into standard format. Takes the full set of arguments that could be passed to a pad operation and separates them into core tensor inputs (inputs, pad, mode, value) and configuration parameters packaged as a kwargs dict. Args: inputs: Input tensor to convolve pad: Padding size(s) mode: Padding mode value: Padding value bias: Optional bias tensor *args: Additional positional args (unused) **kwargs: Additional keyword args (unused) Returns: tuple containing: - Input tensor - Padding size(s) - Padding mode - Padding value """ return inputs, pad, mode, value ShardTensor.register_function_handler(torch.nn.functional.pad, generic_pad_nd_wrapper)