# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import ( Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union, ) import torch from physicsnemo.utils.version_check import check_module_requirements check_module_requirements("physicsnemo.distributed.shard_tensor") from torch.distributed.tensor.placement_types import ( # noqa: E402 Partial, Shard, ) # noqa: E402 from physicsnemo.distributed.shard_tensor import ShardTensor # noqa: E402 aten = torch.ops.aten # Type variable for dimension parameter DimT = TypeVar("DimT", None, int, Iterable[int]) def normalize_dim( dim: DimT, tensor_ndim: int, as_set: bool = False, handle_negatives: bool = True ) -> Union[Optional[Tuple[int, ...]], Set[int]]: """ Normalize dimension argument to a consistent form. Args: dim: The dimension(s) to normalize. Can be None, int, or iterable of ints. tensor_ndim: Number of dimensions in the tensor. as_set: If True, return a set of dimensions instead of a tuple. handle_negatives: If True, convert negative dimensions to positive ones. Returns: - None if dim is None and as_set is False - A set of all dimensions if dim is None and as_set is True - A tuple of dimensions (or set if as_set is True) """ if dim is None: if as_set: return set(range(tensor_ndim)) return None # Convert to tuple if iterable if isinstance(dim, Iterable) and not isinstance(dim, torch.Tensor): dims = tuple(dim) else: dims = (dim,) # Handle negative dimensions if handle_negatives: dims = tuple(d % tensor_ndim for d in dims) # Return as set or tuple based on as_set flag if as_set: return set(dims) return dims def is_full_reduction(dim: DimT, tensor_ndim: int) -> bool: """ Determine if this is a full reduction. Args: dim: The dimension(s) to check. Can be None, int, or iterable of ints. tensor_ndim: Number of dimensions in the tensor. Returns: bool: True if all dimensions are being reduced, False otherwise. """ if dim is None: return True if isinstance(dim, Iterable) and len(dim) == tensor_ndim: return True return False def compute_result_placements( tensor: ShardTensor, dim: DimT, reduction_name: str, keepdim: bool = False ) -> List[Union[Partial, Shard]]: """ Compute placement info for reduction result. Args: tensor: The input ShardTensor being reduced. dim: The dimension(s) to reduce. Can be None, int, or iterable of ints. reduction_name: Type of reduction operation ("sum", "avg", etc.). keepdim: Whether to preserve reduced dimensions with size 1. Returns: List[Union[Partial, Shard]]: Placement specifications for the result tensor. """ if is_full_reduction(dim, tensor.ndim): return [ p if p.is_replicate() else Partial("sum" if reduction_name != "avg" else "avg") for p in tensor._spec.placements ] # Use enhanced normalize_dim to get dimensions as a set dims = normalize_dim(dim, tensor.ndim, as_set=True) placements = [] for p in tensor._spec.placements: if isinstance(p, Shard): shard_dim = p.dim # Count how many reduction dims are less than this shard dim num_lower = sum(1 for d in dims if d < shard_dim) # If this sharded dim is being reduced, it becomes Partial if shard_dim in dims: placements.append(Partial(reduction_name)) else: # If keepdim is False, dims to the left are removed, so shift left new_dim = shard_dim - num_lower if not keepdim else shard_dim placements.append(Shard(new_dim)) else: placements.append(p) return placements def reduction_shape( S: torch.Size, dim: DimT = None, keepdim: bool = False ) -> torch.Size: """ Calculate the resulting shape after a reduction operation. Args: S: Original shape of the tensor. dim: The dimension(s) to reduce. Can be None, int, or iterable of ints. keepdim: Whether to preserve reduced dimensions with size 1. Returns: torch.Size: The shape after reduction. """ shape = list(S) if dim is None: return torch.Size([1] * len(shape)) if keepdim else torch.Size([]) # Use enhanced normalize_dim to handle iterable and negative dims dim = normalize_dim(dim, len(shape), handle_negatives=True) if keepdim: for d in dim: shape[d] = 1 else: for d in sorted(dim, reverse=True): del shape[d] return torch.Size(shape) def compute_result_sharding_shapes( tensor: ShardTensor, dim: DimT, keepdim: bool ) -> Dict[int, List[torch.Size]]: """ Compute sharding sizes for the result of a reduction operation. Args: tensor: The input ShardTensor being reduced. dim: The dimension(s) to reduce. Can be None, int, or iterable of ints. keepdim: Whether to preserve reduced dimensions with size 1. Returns: Dict[int, List[torch.Size]]: Mapping of mesh dimensions to sharding shapes. """ if is_full_reduction(dim, tensor.ndim): return {} else: # Create a dictionary to store sharding sizes for dimensions that remain in the output result_sharding_shapes = {} # Get the original sharding sizes original_sharding_shapes = tensor._spec.sharding_shapes() # Use normalize_dim directly normalized_dim = normalize_dim(dim, tensor.ndim) for mesh_dim, sharding_shapes in original_sharding_shapes.items(): result_sharding_shapes[mesh_dim] = [ reduction_shape(shape, normalized_dim, keepdim) for shape in sharding_shapes ] return result_sharding_shapes def create_sharded_grad_input( local_grad_input: torch.Tensor, original_spec: Any ) -> ShardTensor: """ Create a ShardTensor from local gradient input. Args: local_grad_input: The local gradient tensor. original_spec: The original ShardTensor's spec to use for placement. Returns: ShardTensor: A distributed tensor with the same sharding as the original input. """ return ShardTensor.from_local( local_grad_input, device_mesh=original_spec.mesh, placements=original_spec.placements, sharding_shapes=original_spec.sharding_shapes(), ) # Base class for sharded reductions class ShardedReductionBase(torch.autograd.Function): """Base class for implementing custom autograd functions for sharded tensor reductions.""" @staticmethod def setup_ctx( ctx: Any, tensor: ShardTensor, dim: DimT, keepdim: bool ) -> Tuple[Optional[Tuple[int, ...]], bool]: """ Save common context information for backward pass. Args: ctx: The autograd context object. tensor: The input ShardTensor being reduced. dim: The dimension(s) to reduce. keepdim: Whether to preserve reduced dimensions with size 1. Returns: Tuple[Optional[Tuple[int, ...]], bool]: Normalized dimension and keepdim flag. """ ctx.original_spec = tensor._spec ctx.output_requires_grad = tensor.requires_grad # Normalize dim to tuple form dim = normalize_dim(dim, tensor.ndim) # Ensure keepdim is a boolean keepdim = bool(keepdim) ctx.dim = dim ctx.keepdim = keepdim ctx.is_full_reduction = is_full_reduction(dim, tensor.ndim) # Save the shape of the local tensor ctx.local_grad_shape = tensor._local_tensor.shape return dim, keepdim # Specific reduction implementations class ShardedSum(ShardedReductionBase): """ Custom autograd function for sum reduction of sharded tensors. Handles both forward and backward passes with proper gradient computation. """ @staticmethod def forward( ctx: Any, tensor: ShardTensor, dim: DimT = None, keepdim: bool = False, dtype: Optional[torch.dtype] = None, ) -> ShardTensor: """ Forward pass for sum reduction on ShardTensor. Args: ctx: The autograd context object. tensor: The input ShardTensor to be reduced. dim: The dimension(s) to reduce. keepdim: Whether to preserve reduced dimensions with size 1. dtype: Output data type (optional). Returns: ShardTensor: The result of sum reduction. """ dim, keepdim = ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim) # Get local tensor local_tensor = tensor._local_tensor # Perform local sum local_result = aten.sum(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype) # Compute placements for the result placements = compute_result_placements(tensor, dim, "sum") output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) # Create result ShardTensor result = ShardTensor.from_local( local_result, tensor.device_mesh, placements, sharding_shapes=output_sharding_shapes, ) return result @staticmethod def backward( ctx: Any, grad_output: ShardTensor ) -> Tuple[ShardTensor, None, None, None]: """ Backward pass for sum reduction. Args: ctx: The autograd context object. grad_output: Gradient of the loss with respect to the output. Returns: Tuple containing gradients for each input in the forward pass. """ original_spec = ctx.original_spec dim = ctx.dim is_full_reduction = ctx.is_full_reduction keepdim = ctx.keepdim local_grad_shape = ctx.local_grad_shape # Get local grad output local_grad_output = grad_output._local_tensor if is_full_reduction: # For full reduction, broadcast to original size grad_input = local_grad_output.expand(local_grad_shape) else: # For dimension-specific reduction if keepdim: # Just expand along reduced dimensions expand_shape = list(local_grad_shape) grad_input = local_grad_output.expand(expand_shape) else: # Need to unsqueeze first grad_shape = list(local_grad_output.shape) for d in sorted(dim): if d < 0: d += original_spec.tensor_meta.ndim grad_shape.insert(d, 1) grad_expanded = local_grad_output.reshape(grad_shape) expand_shape = list(local_grad_shape) grad_input = grad_expanded.expand(expand_shape) # Create ShardTensor from local grad grad_input = create_sharded_grad_input(grad_input, original_spec) # Return gradients for all inputs return grad_input, None, None, None class ShardedMean(ShardedReductionBase): """ Custom autograd function for mean reduction of sharded tensors. Handles both forward and backward passes with proper gradient computation and scaling. """ @staticmethod def forward( ctx: Any, tensor: ShardTensor, dim: DimT = None, keepdim: bool = False, dtype: Optional[torch.dtype] = None, ) -> ShardTensor: """ Forward pass for mean reduction on ShardTensor. Args: ctx: The autograd context object. tensor: The input ShardTensor to be reduced. dim: The dimension(s) to reduce. keepdim: Whether to preserve reduced dimensions with size 1. dtype: Output data type (optional). Returns: ShardTensor: The result of mean reduction. """ dim, keepdim = ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim) # Get local tensor local_tensor = tensor._local_tensor # Compute proper weighting for mean weight = 1.0 # Normalize dimensions for consistent handling if is_full_reduction(dim, tensor.ndim): # For full reduction, use all dimensions reduction_dims = set(range(tensor.ndim)) else: # Only use the normalized dimensions for partial reduction reduction_dims = dim # Calculate weight based on local vs global shape ratio for reduction dimensions local_shape = local_tensor.shape global_shape = tensor.shape for d in reduction_dims: weight *= local_shape[d] / global_shape[d] # Perform local mean local_result = aten.mean(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype) # Apply weighting local_result = local_result * weight placements = compute_result_placements(tensor, dim, "sum") output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) # Create result ShardTensor result = ShardTensor.from_local( local_result, tensor.device_mesh, placements, sharding_shapes=output_sharding_shapes, ) return result @staticmethod def backward( ctx: Any, grad_output: ShardTensor ) -> Tuple[ShardTensor, None, None, None]: """ Backward pass for mean reduction. Args: ctx: The autograd context object. grad_output: Gradient of the loss with respect to the output. Returns: Tuple containing gradients for each input in the forward pass. """ original_spec = ctx.original_spec dim = ctx.dim is_full_reduction = ctx.is_full_reduction keepdim = ctx.keepdim local_grad_shape = ctx.local_grad_shape global_shape = original_spec.tensor_meta.shape # Get local grad output local_grad_output = grad_output._local_tensor if is_full_reduction: # For full reduction, broadcast to original size with scaling factor = 1.0 / torch.prod(torch.tensor(global_shape)) grad_input = local_grad_output.expand(local_grad_shape) * factor else: # For dimension-specific reduction if keepdim: # Just expand along reduced dimensions expand_shape = list(local_grad_shape) grad_input = local_grad_output.expand(expand_shape) else: # Need to unsqueeze first grad_shape = list(local_grad_output.shape) for d in sorted(dim): if d < 0: d += original_spec.tensor_meta.ndim grad_shape.insert(d, 1) grad_expanded = local_grad_output.reshape(grad_shape) expand_shape = list(local_grad_shape) grad_input = grad_expanded.expand(expand_shape) # Apply scaling factor for mean factor = 1.0 for d in dim: if d < 0: d += original_spec.tensor_meta.ndim factor /= global_shape[d] grad_input = grad_input * factor # Create ShardTensor from local grad grad_input = create_sharded_grad_input(grad_input, original_spec) # Return gradients for all inputs return grad_input, None, None, None def sum_wrapper( func: Callable, types: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> ShardTensor: """ Wrapper function for ShardTensor sum reduction. In Args and kwargs: tensor: Input ShardTensor to reduce. dim: The dimension(s) to reduce. keepdim: Whether to preserve reduced dimensions with size 1. *args: Additional positional arguments. **kwargs: Additional keyword arguments. Returns: ShardTensor: Result of sum reduction. """ tensor, dim, keepdim, extra_args, extra_kwargs = unpack_args(*args, **kwargs) return ShardedSum.apply(tensor, dim, keepdim, *extra_args, **extra_kwargs) # TODO - accept func, types, args, kwargs instead def mean_wrapper( func: Callable, types: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> ShardTensor: """ Wrapper function for ShardTensor mean reduction. Args: tensor: Input ShardTensor to reduce. dim: The dimension(s) to reduce. keepdim: Whether to preserve reduced dimensions with size 1. *args: Additional positional arguments. **kwargs: Additional keyword arguments. Returns: ShardTensor: Result of mean reduction. """ tensor, dim, keepdim, extra_args, extra_kwargs = unpack_args(*args, **kwargs) return ShardedMean.apply(tensor, dim, keepdim, *extra_args, **extra_kwargs) def unpack_args( tensor: ShardTensor, dim: DimT = None, keepdim: bool = False, *args: Any, **kwargs: Any, ) -> Tuple[ShardTensor, DimT, bool, Tuple[Any, ...], Dict[str, Any]]: """ Unpack arguments for reduction functions. Maps default args from torch. Returns: tensor: Input ShardTensor to reduce. dim: The dimension(s) to reduce. """ return tensor, dim, keepdim, args, kwargs # Map the reduction ops to their handlers reduction_mapping: Dict[str, Callable] = { "sum": sum_wrapper, "avg": mean_wrapper, } # Register handlers for standalone functions and methods ShardTensor.register_function_handler(torch.mean, mean_wrapper) ShardTensor.register_function_handler(torch.Tensor.mean, mean_wrapper) ShardTensor.register_function_handler(torch.sum, sum_wrapper) ShardTensor.register_function_handler(torch.Tensor.sum, sum_wrapper)