Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import physicsnemo # noqa: F401 for docs | |
| import physicsnemo.models.layers.fft as fft | |
| from ..meta import ModelMetaData | |
| from ..module import Module | |
| Tensor = torch.Tensor | |
| class AFNOMlp(nn.Module): | |
| """Fully-connected Multi-layer perception used inside AFNO | |
| Parameters | |
| ---------- | |
| in_features : int | |
| Input feature size | |
| latent_features : int | |
| Latent feature size | |
| out_features : int | |
| Output feature size | |
| activation_fn : nn.Module, optional | |
| Activation function, by default nn.GELU | |
| drop : float, optional | |
| Drop out rate, by default 0.0 | |
| """ | |
| def __init__( | |
| self, | |
| in_features: int, | |
| latent_features: int, | |
| out_features: int, | |
| activation_fn: nn.Module = nn.GELU(), | |
| drop: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.fc1 = nn.Linear(in_features, latent_features) | |
| self.act = activation_fn | |
| self.fc2 = nn.Linear(latent_features, out_features) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| class AFNO2DLayer(nn.Module): | |
| """AFNO spectral convolution layer | |
| Parameters | |
| ---------- | |
| hidden_size : int | |
| Feature dimensionality | |
| num_blocks : int, optional | |
| Number of blocks used in the block diagonal weight matrix, by default 8 | |
| sparsity_threshold : float, optional | |
| Sparsity threshold (softshrink) of spectral features, by default 0.01 | |
| hard_thresholding_fraction : float, optional | |
| Threshold for limiting number of modes used [0,1], by default 1 | |
| hidden_size_factor : int, optional | |
| Factor to increase spectral features by after weight multiplication, by default 1 | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_blocks: int = 8, | |
| sparsity_threshold: float = 0.01, | |
| hard_thresholding_fraction: float = 1, | |
| hidden_size_factor: int = 1, | |
| ): | |
| super().__init__() | |
| if not (hidden_size % num_blocks == 0): | |
| raise ValueError( | |
| f"hidden_size {hidden_size} should be divisible by num_blocks {num_blocks}" | |
| ) | |
| self.hidden_size = hidden_size | |
| self.sparsity_threshold = sparsity_threshold | |
| self.num_blocks = num_blocks | |
| self.block_size = self.hidden_size // self.num_blocks | |
| self.hard_thresholding_fraction = hard_thresholding_fraction | |
| self.hidden_size_factor = hidden_size_factor | |
| self.scale = 0.02 | |
| self.w1 = nn.Parameter( | |
| self.scale | |
| * torch.randn( | |
| 2, | |
| self.num_blocks, | |
| self.block_size, | |
| self.block_size * self.hidden_size_factor, | |
| ) | |
| ) | |
| self.b1 = nn.Parameter( | |
| self.scale | |
| * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor) | |
| ) | |
| self.w2 = nn.Parameter( | |
| self.scale | |
| * torch.randn( | |
| 2, | |
| self.num_blocks, | |
| self.block_size * self.hidden_size_factor, | |
| self.block_size, | |
| ) | |
| ) | |
| self.b2 = nn.Parameter( | |
| self.scale * torch.randn(2, self.num_blocks, self.block_size) | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| bias = x | |
| dtype = x.dtype | |
| x = x.float() | |
| B, H, W, C = x.shape | |
| # Using ONNX friendly FFT functions | |
| x = fft.rfft2(x, dim=(1, 2), norm="ortho") | |
| x_real, x_imag = fft.real(x), fft.imag(x) | |
| x_real = x_real.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) | |
| x_imag = x_imag.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) | |
| o1_real = torch.zeros( | |
| [ | |
| B, | |
| H, | |
| W // 2 + 1, | |
| self.num_blocks, | |
| self.block_size * self.hidden_size_factor, | |
| ], | |
| device=x.device, | |
| ) | |
| o1_imag = torch.zeros( | |
| [ | |
| B, | |
| H, | |
| W // 2 + 1, | |
| self.num_blocks, | |
| self.block_size * self.hidden_size_factor, | |
| ], | |
| device=x.device, | |
| ) | |
| o2 = torch.zeros(x_real.shape + (2,), device=x.device) | |
| total_modes = H // 2 + 1 | |
| kept_modes = int(total_modes * self.hard_thresholding_fraction) | |
| o1_real[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = ( | |
| F.relu( | |
| torch.einsum( | |
| "nyxbi,bio->nyxbo", | |
| x_real[ | |
| :, | |
| total_modes - kept_modes : total_modes + kept_modes, | |
| :kept_modes, | |
| ], | |
| self.w1[0], | |
| ) | |
| - torch.einsum( | |
| "nyxbi,bio->nyxbo", | |
| x_imag[ | |
| :, | |
| total_modes - kept_modes : total_modes + kept_modes, | |
| :kept_modes, | |
| ], | |
| self.w1[1], | |
| ) | |
| + self.b1[0] | |
| ) | |
| ) | |
| o1_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = ( | |
| F.relu( | |
| torch.einsum( | |
| "nyxbi,bio->nyxbo", | |
| x_imag[ | |
| :, | |
| total_modes - kept_modes : total_modes + kept_modes, | |
| :kept_modes, | |
| ], | |
| self.w1[0], | |
| ) | |
| + torch.einsum( | |
| "nyxbi,bio->nyxbo", | |
| x_real[ | |
| :, | |
| total_modes - kept_modes : total_modes + kept_modes, | |
| :kept_modes, | |
| ], | |
| self.w1[1], | |
| ) | |
| + self.b1[1] | |
| ) | |
| ) | |
| o2[ | |
| :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 0 | |
| ] = ( | |
| torch.einsum( | |
| "nyxbi,bio->nyxbo", | |
| o1_real[ | |
| :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes | |
| ], | |
| self.w2[0], | |
| ) | |
| - torch.einsum( | |
| "nyxbi,bio->nyxbo", | |
| o1_imag[ | |
| :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes | |
| ], | |
| self.w2[1], | |
| ) | |
| + self.b2[0] | |
| ) | |
| o2[ | |
| :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 1 | |
| ] = ( | |
| torch.einsum( | |
| "nyxbi,bio->nyxbo", | |
| o1_imag[ | |
| :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes | |
| ], | |
| self.w2[0], | |
| ) | |
| + torch.einsum( | |
| "nyxbi,bio->nyxbo", | |
| o1_real[ | |
| :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes | |
| ], | |
| self.w2[1], | |
| ) | |
| + self.b2[1] | |
| ) | |
| x = F.softshrink(o2, lambd=self.sparsity_threshold) | |
| x = fft.view_as_complex(x) | |
| # TODO(akamenev): replace the following branching with | |
| # a one-liner, something like x.reshape(..., -1).squeeze(-1), | |
| # but this currently fails during ONNX export. | |
| if torch.onnx.is_in_onnx_export(): | |
| x = x.reshape(B, H, W // 2 + 1, C, 2) | |
| else: | |
| x = x.reshape(B, H, W // 2 + 1, C) | |
| # Using ONNX friendly FFT functions | |
| x = fft.irfft2(x, s=(H, W), dim=(1, 2), norm="ortho") | |
| x = x.type(dtype) | |
| return x + bias | |
| class Block(nn.Module): | |
| """AFNO block, spectral convolution and MLP | |
| Parameters | |
| ---------- | |
| embed_dim : int | |
| Embedded feature dimensionality | |
| num_blocks : int, optional | |
| Number of blocks used in the block diagonal weight matrix, by default 8 | |
| mlp_ratio : float, optional | |
| Ratio of MLP latent variable size to input feature size, by default 4.0 | |
| drop : float, optional | |
| Drop out rate in MLP, by default 0.0 | |
| activation_fn: nn.Module, optional | |
| Activation function used in MLP, by default nn.GELU | |
| norm_layer : nn.Module, optional | |
| Normalization function, by default nn.LayerNorm | |
| double_skip : bool, optional | |
| Residual, by default True | |
| sparsity_threshold : float, optional | |
| Sparsity threshold (softshrink) of spectral features, by default 0.01 | |
| hard_thresholding_fraction : float, optional | |
| Threshold for limiting number of modes used [0,1], by default 1 | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_blocks: int = 8, | |
| mlp_ratio: float = 4.0, | |
| drop: float = 0.0, | |
| activation_fn: nn.Module = nn.GELU(), | |
| norm_layer: nn.Module = nn.LayerNorm, | |
| double_skip: bool = True, | |
| sparsity_threshold: float = 0.01, | |
| hard_thresholding_fraction: float = 1.0, | |
| ): | |
| super().__init__() | |
| self.norm1 = norm_layer(embed_dim) | |
| self.filter = AFNO2DLayer( | |
| embed_dim, num_blocks, sparsity_threshold, hard_thresholding_fraction | |
| ) | |
| # self.drop_path = nn.Identity() | |
| self.norm2 = norm_layer(embed_dim) | |
| mlp_latent_dim = int(embed_dim * mlp_ratio) | |
| self.mlp = AFNOMlp( | |
| in_features=embed_dim, | |
| latent_features=mlp_latent_dim, | |
| out_features=embed_dim, | |
| activation_fn=activation_fn, | |
| drop=drop, | |
| ) | |
| self.double_skip = double_skip | |
| def forward(self, x: Tensor) -> Tensor: | |
| residual = x | |
| x = self.norm1(x) | |
| x = self.filter(x) | |
| if self.double_skip: | |
| x = x + residual | |
| residual = x | |
| x = self.norm2(x) | |
| x = self.mlp(x) | |
| x = x + residual | |
| return x | |
| class PatchEmbed(nn.Module): | |
| """Patch embedding layer | |
| Converts 2D patch into a 1D vector for input to AFNO | |
| Parameters | |
| ---------- | |
| inp_shape : List[int] | |
| Input image dimensions [height, width] | |
| in_channels : int | |
| Number of input channels | |
| patch_size : List[int], optional | |
| Size of image patches, by default [16, 16] | |
| embed_dim : int, optional | |
| Embedded channel size, by default 256 | |
| """ | |
| def __init__( | |
| self, | |
| inp_shape: List[int], | |
| in_channels: int, | |
| patch_size: List[int] = [16, 16], | |
| embed_dim: int = 256, | |
| ): | |
| super().__init__() | |
| if len(inp_shape) != 2: | |
| raise ValueError("inp_shape should be a list of length 2") | |
| if len(patch_size) != 2: | |
| raise ValueError("patch_size should be a list of length 2") | |
| num_patches = (inp_shape[1] // patch_size[1]) * (inp_shape[0] // patch_size[0]) | |
| self.inp_shape = inp_shape | |
| self.patch_size = patch_size | |
| self.num_patches = num_patches | |
| self.proj = nn.Conv2d( | |
| in_channels, embed_dim, kernel_size=patch_size, stride=patch_size | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| B, C, H, W = x.shape | |
| if not (H == self.inp_shape[0] and W == self.inp_shape[1]): | |
| raise ValueError( | |
| f"Input image size ({H}*{W}) doesn't match model ({self.inp_shape[0]}*{self.inp_shape[1]})." | |
| ) | |
| x = self.proj(x).flatten(2).transpose(1, 2) | |
| return x | |
| class MetaData(ModelMetaData): | |
| name: str = "AFNO" | |
| # Optimization | |
| jit: bool = False # ONNX Ops Conflict | |
| cuda_graphs: bool = True | |
| amp: bool = True | |
| # Inference | |
| onnx_cpu: bool = False # No FFT op on CPU | |
| onnx_gpu: bool = True | |
| onnx_runtime: bool = True | |
| # Physics informed | |
| var_dim: int = 1 | |
| func_torch: bool = False | |
| auto_grad: bool = False | |
| class AFNO(Module): | |
| """Adaptive Fourier neural operator (AFNO) model. | |
| Note | |
| ---- | |
| AFNO is a model that is designed for 2D images only. | |
| Parameters | |
| ---------- | |
| inp_shape : List[int] | |
| Input image dimensions [height, width] | |
| in_channels : int | |
| Number of input channels | |
| out_channels: int | |
| Number of output channels | |
| patch_size : List[int], optional | |
| Size of image patches, by default [16, 16] | |
| embed_dim : int, optional | |
| Embedded channel size, by default 256 | |
| depth : int, optional | |
| Number of AFNO layers, by default 4 | |
| mlp_ratio : float, optional | |
| Ratio of layer MLP latent variable size to input feature size, by default 4.0 | |
| drop_rate : float, optional | |
| Drop out rate in layer MLPs, by default 0.0 | |
| num_blocks : int, optional | |
| Number of blocks in the block-diag frequency weight matrices, by default 16 | |
| sparsity_threshold : float, optional | |
| Sparsity threshold (softshrink) of spectral features, by default 0.01 | |
| hard_thresholding_fraction : float, optional | |
| Threshold for limiting number of modes used [0,1], by default 1 | |
| Example | |
| ------- | |
| >>> model = physicsnemo.models.afno.AFNO( | |
| ... inp_shape=[32, 32], | |
| ... in_channels=2, | |
| ... out_channels=1, | |
| ... patch_size=(8, 8), | |
| ... embed_dim=16, | |
| ... depth=2, | |
| ... num_blocks=2, | |
| ... ) | |
| >>> input = torch.randn(32, 2, 32, 32) #(N, C, H, W) | |
| >>> output = model(input) | |
| >>> output.size() | |
| torch.Size([32, 1, 32, 32]) | |
| Note | |
| ---- | |
| Reference: Guibas, John, et al. "Adaptive fourier neural operators: | |
| Efficient token mixers for transformers." arXiv preprint arXiv:2111.13587 (2021). | |
| """ | |
| def __init__( | |
| self, | |
| inp_shape: List[int], | |
| in_channels: int, | |
| out_channels: int, | |
| patch_size: List[int] = [16, 16], | |
| embed_dim: int = 256, | |
| depth: int = 4, | |
| mlp_ratio: float = 4.0, | |
| drop_rate: float = 0.0, | |
| num_blocks: int = 16, | |
| sparsity_threshold: float = 0.01, | |
| hard_thresholding_fraction: float = 1.0, | |
| ) -> None: | |
| super().__init__(meta=MetaData()) | |
| if len(inp_shape) != 2: | |
| raise ValueError("inp_shape should be a list of length 2") | |
| if len(patch_size) != 2: | |
| raise ValueError("patch_size should be a list of length 2") | |
| if not ( | |
| inp_shape[0] % patch_size[0] == 0 and inp_shape[1] % patch_size[1] == 0 | |
| ): | |
| raise ValueError( | |
| f"input shape {inp_shape} should be divisible by patch_size {patch_size}" | |
| ) | |
| self.in_chans = in_channels | |
| self.out_chans = out_channels | |
| self.inp_shape = inp_shape | |
| self.patch_size = patch_size | |
| self.num_features = self.embed_dim = embed_dim | |
| self.num_blocks = num_blocks | |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
| self.patch_embed = PatchEmbed( | |
| inp_shape=inp_shape, | |
| in_channels=self.in_chans, | |
| patch_size=self.patch_size, | |
| embed_dim=embed_dim, | |
| ) | |
| num_patches = self.patch_embed.num_patches | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| self.h = inp_shape[0] // self.patch_size[0] | |
| self.w = inp_shape[1] // self.patch_size[1] | |
| self.blocks = nn.ModuleList( | |
| [ | |
| Block( | |
| embed_dim=embed_dim, | |
| num_blocks=self.num_blocks, | |
| mlp_ratio=mlp_ratio, | |
| drop=drop_rate, | |
| norm_layer=norm_layer, | |
| sparsity_threshold=sparsity_threshold, | |
| hard_thresholding_fraction=hard_thresholding_fraction, | |
| ) | |
| for i in range(depth) | |
| ] | |
| ) | |
| self.head = nn.Linear( | |
| embed_dim, | |
| self.out_chans * self.patch_size[0] * self.patch_size[1], | |
| bias=False, | |
| ) | |
| torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| """Init model weights""" | |
| if isinstance(m, nn.Linear): | |
| torch.nn.init.trunc_normal_(m.weight, std=0.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| # What is this for | |
| # @torch.jit.ignore | |
| # def no_weight_decay(self): | |
| # return {"pos_embed", "cls_token"} | |
| def forward_features(self, x: Tensor) -> Tensor: | |
| """Forward pass of core AFNO""" | |
| B = x.shape[0] | |
| x = self.patch_embed(x) | |
| x = x + self.pos_embed | |
| x = self.pos_drop(x) | |
| x = x.reshape(B, self.h, self.w, self.embed_dim) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| return x | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.forward_features(x) | |
| x = self.head(x) | |
| # Correct tensor shape back into [B, C, H, W] | |
| # [b h w (p1 p2 c_out)] | |
| out = x.view(list(x.shape[:-1]) + [self.patch_size[0], self.patch_size[1], -1]) | |
| # [b h w p1 p2 c_out] | |
| out = torch.permute(out, (0, 5, 1, 3, 2, 4)) | |
| # [b c_out, h, p1, w, p2] | |
| out = out.reshape(list(out.shape[:2]) + [self.inp_shape[0], self.inp_shape[1]]) | |
| # [b c_out, (h*p1), (w*p2)] | |
| return out | |