Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| import warnings | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from physicsnemo.distributed.manager import DistributedManager | |
| from physicsnemo.distributed.mappings import ( | |
| copy_to_parallel_region, | |
| gather_from_parallel_region, | |
| reduce_from_parallel_region, | |
| scatter_to_parallel_region, | |
| ) | |
| from physicsnemo.distributed.utils import compute_split_shapes | |
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): | |
| # Cut & paste from PyTorch official master until it's in a few official releases | |
| # Method based on | |
| # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
| def norm_cdf(x): | |
| # Computes standard normal cumulative distribution function | |
| return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 | |
| if (mean < a - 2 * std) or (mean > b + 2 * std): | |
| warnings.warn( | |
| "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | |
| "The distribution of values may be incorrect.", | |
| stacklevel=2, | |
| ) | |
| with torch.no_grad(): | |
| # Values are generated by using a truncated uniform distribution and | |
| # then using the inverse CDF for the normal distribution. | |
| # Get upper and lower cdf values | |
| low = norm_cdf((a - mean) / std) | |
| up = norm_cdf((b - mean) / std) | |
| # Uniformly fill tensor with values from [low, up], then translate to | |
| # [2low-1, 2up-1]. | |
| tensor.uniform_(2 * low - 1, 2 * up - 1) | |
| # Use inverse cdf transform for normal distribution to get truncated | |
| # standard normal | |
| tensor.erfinv_() | |
| # Transform to proper mean, std | |
| tensor.mul_(std * math.sqrt(2.0)) | |
| tensor.add_(mean) | |
| # Clamp to ensure it's in the proper range | |
| tensor.clamp_(min=a, max=b) | |
| return tensor | |
| def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): | |
| r"""Fills the input Tensor with values drawn from a truncated | |
| normal distribution. The values are effectively drawn from the | |
| normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` | |
| with values outside :math:`[a, b]` redrawn until they are within | |
| the bounds. The method used for generating the random values works | |
| best when :math:`a \leq \text{mean} \leq b`. | |
| Args: | |
| tensor: an n-dimensional `torch.Tensor` | |
| mean: the mean of the normal distribution | |
| std: the standard deviation of the normal distribution | |
| a: the minimum cutoff value | |
| b: the maximum cutoff value | |
| Examples: | |
| >>> w = torch.empty(3, 5) | |
| >>> o = nn.init.trunc_normal_(w) | |
| """ | |
| return _no_grad_trunc_normal_(tensor, mean, std, a, b) | |
| def drop_path( | |
| x: torch.Tensor, drop_prob: float = 0.0, training: bool = False | |
| ) -> torch.Tensor: | |
| """ | |
| Drop paths (Stochastic Depth) per sample (when applied in main path of | |
| residual blocks). | |
| This is the same as the DropConnect implfor EfficientNet, etc networks, however, | |
| the original name is misleading as 'Drop Connect' is a different form of dropout in | |
| a separate paper. | |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 | |
| Opted for changing the layer and argument names to 'drop path' rather than mix | |
| DropConnect as a layer name and use 'survival rate' as the argument. | |
| """ | |
| if drop_prob == 0.0 or not training: | |
| return x | |
| keep_prob = 1.0 - drop_prob | |
| shape = (x.shape[0],) + (1,) * ( | |
| x.ndim - 1 | |
| ) # work with diff dim tensors, not just 2D ConvNets | |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | |
| random_tensor.floor_() # binarize | |
| output = x.div(keep_prob) * random_tensor | |
| return output | |
| class DropPath(nn.Module): | |
| """ | |
| Drop paths (Stochastic Depth) per sample (when applied in main path of | |
| residual blocks). | |
| """ | |
| def __init__(self, drop_prob=None): | |
| super(DropPath, self).__init__() | |
| self.drop_prob = drop_prob | |
| def forward(self, x): | |
| return drop_path(x, self.drop_prob, self.training) | |
| class DistributedMLP(nn.Module): | |
| def __init__( | |
| self, | |
| in_features, | |
| hidden_features=None, | |
| out_features=None, | |
| act_layer=nn.GELU, | |
| drop=0.0, | |
| input_is_matmul_parallel=False, | |
| output_is_matmul_parallel=False, | |
| ): | |
| super(DistributedMLP, self).__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.input_is_matmul_parallel = input_is_matmul_parallel | |
| self.output_is_matmul_parallel = output_is_matmul_parallel | |
| # get effective embedding size: | |
| comm_size = DistributedManager().group_size("model_parallel") | |
| if not (hidden_features % comm_size == 0): | |
| raise ValueError( | |
| "Error, hidden_features needs to be divisible by matmul_parallel_size" | |
| ) | |
| hidden_features_local = hidden_features // comm_size | |
| # first set of hp | |
| self.w1 = nn.Parameter(torch.ones(hidden_features_local, in_features, 1, 1)) | |
| self.b1 = nn.Parameter(torch.zeros(hidden_features_local)) | |
| # second set of hp | |
| self.w2 = nn.Parameter(torch.ones(out_features, hidden_features_local, 1, 1)) | |
| self.b2 = nn.Parameter(torch.zeros(out_features)) | |
| self.act = act_layer() | |
| self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity() | |
| if self.input_is_matmul_parallel: | |
| self.gather_shapes = compute_split_shapes( | |
| in_features, DistributedManager().group_size("model_parallel") | |
| ) | |
| # init weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| trunc_normal_(self.w1, std=0.02) | |
| nn.init.constant_(self.b1, 0.0) | |
| trunc_normal_(self.w2, std=0.02) | |
| nn.init.constant_(self.b2, 0.0) | |
| def forward(self, x): | |
| # gather if input is MP | |
| if self.input_is_matmul_parallel: | |
| x = gather_from_parallel_region( | |
| x, dim=1, shapes=self.gather_shapes, group="model_parallel" | |
| ) | |
| x = copy_to_parallel_region(x, group="model_parallel") | |
| x = F.conv2d(x, self.w1, bias=self.b1) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = F.conv2d(x, self.w2, bias=None) | |
| x = reduce_from_parallel_region(x, group="model_parallel") | |
| x = x + torch.reshape(self.b2, (1, -1, 1, 1)) | |
| x = self.drop(x) | |
| # scatter if output is MP | |
| if self.output_is_matmul_parallel: | |
| x = scatter_to_parallel_region(x, dim=1, group="model_parallel") | |
| return x | |
| class DistributedPatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| inp_shape=(224, 224), | |
| patch_size=(16, 16), | |
| in_chans=3, | |
| embed_dim=768, | |
| input_is_matmul_parallel=False, | |
| output_is_matmul_parallel=True, | |
| ): | |
| super(DistributedPatchEmbed, self).__init__() | |
| # store params | |
| self.input_parallel = input_is_matmul_parallel | |
| self.output_parallel = output_is_matmul_parallel | |
| # get comm sizes: | |
| matmul_comm_size = DistributedManager().group_size("model_parallel") | |
| # compute parameters | |
| num_patches = (inp_shape[1] // patch_size[1]) * (inp_shape[0] // patch_size[0]) | |
| self.inp_shape = (inp_shape[0], inp_shape[1]) | |
| self.patch_size = patch_size | |
| self.num_patches = num_patches | |
| if self.input_parallel: | |
| if not (in_chans % matmul_comm_size == 0): | |
| raise ValueError( | |
| "Error, the in_chans needs to be divisible by matmul_parallel_size" | |
| ) | |
| self.in_shapes = compute_split_shapes( | |
| in_chans, DistributedManager().group_size("model_parallel") | |
| ) | |
| # get effective embedding size: | |
| if self.output_parallel: | |
| if not (embed_dim % matmul_comm_size == 0): | |
| raise ValueError( | |
| "Error, the embed_dim needs to be divisible by matmul_parallel_size" | |
| ) | |
| out_chans_local = embed_dim // matmul_comm_size | |
| else: | |
| out_chans_local = embed_dim | |
| # the weights of this layer is shared across spatial parallel ranks | |
| self.proj = nn.Conv2d( | |
| in_chans, out_chans_local, kernel_size=patch_size, stride=patch_size | |
| ) | |
| # make sure we reduce them across rank | |
| self.proj.weight.is_shared_spatial = True | |
| self.proj.bias.is_shared_spatial = True | |
| def forward(self, x): | |
| if self.input_parallel: | |
| x = gather_from_parallel_region( | |
| x, dim=1, shapes=self.in_shapes, group="model_parallel" | |
| ) | |
| if self.output_parallel: | |
| x = copy_to_parallel_region(x, group="model_parallel") | |
| B, C, H, W = x.shape | |
| if not (H == self.inp_shape[0] and W == self.inp_shape[1]): | |
| raise ValueError( | |
| f"Input input size ({H}*{W}) doesn't match model ({self.inp_shape[0]}*{self.inp_shape[1]})." | |
| ) | |
| # new: B, C, H*W | |
| x = self.proj(x).flatten(2) | |
| return x | |
| def compl_mul_add_fwd( | |
| a: torch.Tensor, b: torch.Tensor, c: torch.Tensor | |
| ) -> torch.Tensor: | |
| tmp = torch.einsum("bkixys,kiot->stbkoxy", a, b) | |
| res = ( | |
| torch.stack( | |
| [tmp[0, 0, ...] - tmp[1, 1, ...], tmp[1, 0, ...] + tmp[0, 1, ...]], dim=-1 | |
| ) | |
| + c | |
| ) | |
| return res | |
| def compl_mul_add_fwd_c( | |
| a: torch.Tensor, b: torch.Tensor, c: torch.Tensor | |
| ) -> torch.Tensor: | |
| ac = torch.view_as_complex(a) | |
| bc = torch.view_as_complex(b) | |
| cc = torch.view_as_complex(c) | |
| tmp = torch.einsum("bkixy,kio->bkoxy", ac, bc) | |
| res = tmp + cc | |
| return torch.view_as_real(res) | |
| class DistributedAFNO2D(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size, | |
| num_blocks=8, | |
| sparsity_threshold=0.01, | |
| hard_thresholding_fraction=1, | |
| hidden_size_factor=1, | |
| input_is_matmul_parallel=False, | |
| output_is_matmul_parallel=False, | |
| ): | |
| super(DistributedAFNO2D, self).__init__() | |
| if not (hidden_size % num_blocks == 0): | |
| raise ValueError( | |
| f"hidden_size {hidden_size} should be divisible by num_blocks {num_blocks}" | |
| ) | |
| # get comm sizes: | |
| matmul_comm_size = DistributedManager().group_size("model_parallel") | |
| self.fft_handle = torch.fft.rfft2 | |
| self.ifft_handle = torch.fft.irfft2 | |
| self.hidden_size = hidden_size | |
| self.sparsity_threshold = sparsity_threshold | |
| self.num_blocks = num_blocks | |
| if not (self.num_blocks % matmul_comm_size == 0): | |
| raise ValueError( | |
| "Error, num_blocks needs to be divisible by matmul_parallel_size" | |
| ) | |
| self.num_blocks_local = self.num_blocks // matmul_comm_size | |
| self.block_size = self.hidden_size // self.num_blocks | |
| self.hard_thresholding_fraction = hard_thresholding_fraction | |
| self.hidden_size_factor = hidden_size_factor | |
| self.scale = 0.02 | |
| use_complex_mult = False | |
| self.mult_handle = ( | |
| compl_mul_add_fwd_c if use_complex_mult else compl_mul_add_fwd | |
| ) | |
| # model parallelism | |
| self.input_is_matmul_parallel = input_is_matmul_parallel | |
| self.output_is_matmul_parallel = output_is_matmul_parallel | |
| # new | |
| # these weights need to be synced across all spatial ranks! | |
| self.w1 = nn.Parameter( | |
| self.scale | |
| * torch.randn( | |
| self.num_blocks_local, | |
| self.block_size, | |
| self.block_size * self.hidden_size_factor, | |
| 2, | |
| ) | |
| ) | |
| self.b1 = nn.Parameter( | |
| self.scale | |
| * torch.randn( | |
| self.num_blocks_local, | |
| self.block_size * self.hidden_size_factor, | |
| 1, | |
| 1, | |
| 2, | |
| ) | |
| ) | |
| self.w2 = nn.Parameter( | |
| self.scale | |
| * torch.randn( | |
| self.num_blocks_local, | |
| self.block_size * self.hidden_size_factor, | |
| self.block_size, | |
| 2, | |
| ) | |
| ) | |
| self.b2 = nn.Parameter( | |
| self.scale * torch.randn(self.num_blocks_local, self.block_size, 1, 1, 2) | |
| ) | |
| # make sure we reduce them across rank | |
| self.w1.is_shared_spatial = True | |
| self.b1.is_shared_spatial = True | |
| self.w2.is_shared_spatial = True | |
| self.b2.is_shared_spatial = True | |
| def forward(self, x): | |
| if not self.input_is_matmul_parallel: | |
| # distribute data | |
| num_chans = x.shape[1] | |
| x = scatter_to_parallel_region(x, dim=1, group="model_parallel") | |
| # bias | |
| bias = x | |
| dtype = x.dtype | |
| x = x.float() | |
| B, C, H, W = x.shape | |
| total_modes = H // 2 + 1 | |
| kept_modes = int(total_modes * self.hard_thresholding_fraction) | |
| x = self.fft_handle(x, (H, W), (-2, -1), "ortho") | |
| x = x.view(B, self.num_blocks_local, self.block_size, H, W // 2 + 1) | |
| # new | |
| x = torch.view_as_real(x) | |
| o2 = torch.zeros(x.shape, device=x.device) | |
| o1 = F.relu( | |
| self.mult_handle( | |
| x[ | |
| :, | |
| :, | |
| :, | |
| total_modes - kept_modes : total_modes + kept_modes, | |
| :kept_modes, | |
| :, | |
| ], | |
| self.w1, | |
| self.b1, | |
| ) | |
| ) | |
| o2[ | |
| :, :, :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, : | |
| ] = self.mult_handle(o1, self.w2, self.b2) | |
| # finalize | |
| x = F.softshrink(o2, lambd=self.sparsity_threshold) | |
| x = torch.view_as_complex(x) | |
| x = x.reshape(B, C, H, W // 2 + 1) | |
| x = self.ifft_handle(x, (H, W), (-2, -1), "ortho") | |
| x = x.type(dtype) + bias | |
| # gather | |
| if not self.output_is_matmul_parallel: | |
| gather_shapes = compute_split_shapes( | |
| num_chans, DistributedManager().group_size("model_parallel") | |
| ) | |
| x = gather_from_parallel_region( | |
| x, dim=1, shapes=gather_shapes, group="model_parallel" | |
| ) | |
| return x | |