Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| class SpectralConv1d(nn.Module): | |
| """1D Fourier layer. It does FFT, linear transform, and Inverse FFT. | |
| Parameters | |
| ---------- | |
| in_channels : int | |
| Number of input channels | |
| out_channels : int | |
| Number of output channels | |
| modes1 : int | |
| Number of Fourier modes to multiply, at most floor(N/2) + 1 | |
| """ | |
| def __init__(self, in_channels: int, out_channels: int, modes1: int): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.modes1 = ( | |
| modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 | |
| ) | |
| self.scale = 1 / (in_channels * out_channels) | |
| self.weights1 = nn.Parameter( | |
| torch.empty(in_channels, out_channels, self.modes1, 2) | |
| ) | |
| self.reset_parameters() | |
| def compl_mul1d( | |
| self, | |
| input: Tensor, | |
| weights: Tensor, | |
| ) -> Tensor: | |
| """Complex multiplication | |
| Parameters | |
| ---------- | |
| input : Tensor | |
| Input tensor | |
| weights : Tensor | |
| Weights tensor | |
| Returns | |
| ------- | |
| Tensor | |
| Product of complex multiplication | |
| """ | |
| # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) | |
| cweights = torch.view_as_complex(weights) | |
| return torch.einsum("bix,iox->box", input, cweights) | |
| def forward(self, x: Tensor) -> Tensor: | |
| bsize = x.shape[0] | |
| # Compute Fourier coeffcients up to factor of e^(- something constant) | |
| x_ft = torch.fft.rfft(x) | |
| # Multiply relevant Fourier modes | |
| out_ft = torch.zeros( | |
| bsize, | |
| self.out_channels, | |
| x.size(-1) // 2 + 1, | |
| device=x.device, | |
| dtype=torch.cfloat, | |
| ) | |
| out_ft[:, :, : self.modes1] = self.compl_mul1d( | |
| x_ft[:, :, : self.modes1], | |
| self.weights1, | |
| ) | |
| # Return to physical space | |
| x = torch.fft.irfft(out_ft, n=x.size(-1)) | |
| return x | |
| def reset_parameters(self): | |
| """Reset spectral weights with distribution scale*U(0,1)""" | |
| self.weights1.data = self.scale * torch.rand(self.weights1.data.shape) | |
| class SpectralConv2d(nn.Module): | |
| """2D Fourier layer. It does FFT, linear transform, and Inverse FFT. | |
| Parameters | |
| ---------- | |
| in_channels : int | |
| Number of input channels | |
| out_channels : int | |
| Number of output channels | |
| modes1 : int | |
| Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1 | |
| modes2 : int | |
| Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1 | |
| """ | |
| def __init__(self, in_channels: int, out_channels: int, modes1: int, modes2: int): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.modes1 = ( | |
| modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 | |
| ) | |
| self.modes2 = modes2 | |
| self.scale = 1 / (in_channels * out_channels) | |
| self.weights1 = nn.Parameter( | |
| torch.empty(in_channels, out_channels, self.modes1, self.modes2, 2) | |
| ) | |
| self.weights2 = nn.Parameter( | |
| torch.empty(in_channels, out_channels, self.modes1, self.modes2, 2) | |
| ) | |
| self.reset_parameters() | |
| def compl_mul2d(self, input: Tensor, weights: Tensor) -> Tensor: | |
| """Complex multiplication | |
| Parameters | |
| ---------- | |
| input : Tensor | |
| Input tensor | |
| weights : Tensor | |
| Weights tensor | |
| Returns | |
| ------- | |
| Tensor | |
| Product of complex multiplication | |
| """ | |
| # (batch, in_channel, x, y), (in_channel, out_channel, x, y) -> (batch, out_channel, x, y) | |
| cweights = torch.view_as_complex(weights) | |
| return torch.einsum("bixy,ioxy->boxy", input, cweights) | |
| def forward(self, x: Tensor) -> Tensor: | |
| batchsize = x.shape[0] | |
| # Compute Fourier coeffcients up to factor of e^(- something constant) | |
| x_ft = torch.fft.rfft2(x) | |
| # Multiply relevant Fourier modes | |
| out_ft = torch.zeros( | |
| batchsize, | |
| self.out_channels, | |
| x.size(-2), | |
| x.size(-1) // 2 + 1, | |
| dtype=torch.cfloat, | |
| device=x.device, | |
| ) | |
| out_ft[:, :, : self.modes1, : self.modes2] = self.compl_mul2d( | |
| x_ft[:, :, : self.modes1, : self.modes2], | |
| self.weights1, | |
| ) | |
| out_ft[:, :, -self.modes1 :, : self.modes2] = self.compl_mul2d( | |
| x_ft[:, :, -self.modes1 :, : self.modes2], | |
| self.weights2, | |
| ) | |
| # Return to physical space | |
| x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) | |
| return x | |
| def reset_parameters(self): | |
| """Reset spectral weights with distribution scale*U(0,1)""" | |
| self.weights1.data = self.scale * torch.rand(self.weights1.data.shape) | |
| self.weights2.data = self.scale * torch.rand(self.weights2.data.shape) | |
| class SpectralConv3d(nn.Module): | |
| """3D Fourier layer. It does FFT, linear transform, and Inverse FFT. | |
| Parameters | |
| ---------- | |
| in_channels : int | |
| Number of input channels | |
| out_channels : int | |
| Number of output channels | |
| modes1 : int | |
| Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1 | |
| modes2 : int | |
| Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1 | |
| modes3 : int | |
| Number of Fourier modes to multiply in third dimension, at most floor(N/2) + 1 | |
| """ | |
| def __init__( | |
| self, in_channels: int, out_channels: int, modes1: int, modes2: int, modes3: int | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.modes1 = ( | |
| modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 | |
| ) | |
| self.modes2 = modes2 | |
| self.modes3 = modes3 | |
| self.scale = 1 / (in_channels * out_channels) | |
| self.weights1 = nn.Parameter( | |
| torch.empty( | |
| in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2 | |
| ) | |
| ) | |
| self.weights2 = nn.Parameter( | |
| torch.empty( | |
| in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2 | |
| ) | |
| ) | |
| self.weights3 = nn.Parameter( | |
| torch.empty( | |
| in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2 | |
| ) | |
| ) | |
| self.weights4 = nn.Parameter( | |
| torch.empty( | |
| in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2 | |
| ) | |
| ) | |
| self.reset_parameters() | |
| def compl_mul3d( | |
| self, | |
| input: Tensor, | |
| weights: Tensor, | |
| ) -> Tensor: | |
| """Complex multiplication | |
| Parameters | |
| ---------- | |
| input : Tensor | |
| Input tensor | |
| weights : Tensor | |
| Weights tensor | |
| Returns | |
| ------- | |
| Tensor | |
| Product of complex multiplication | |
| """ | |
| # (batch, in_channel, x, y, z), (in_channel, out_channel, x, y, z) -> (batch, out_channel, x, y, z) | |
| cweights = torch.view_as_complex(weights) | |
| return torch.einsum("bixyz,ioxyz->boxyz", input, cweights) | |
| def forward(self, x: Tensor) -> Tensor: | |
| batchsize = x.shape[0] | |
| # Compute Fourier coeffcients up to factor of e^(- something constant) | |
| x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1]) | |
| # Multiply relevant Fourier modes | |
| out_ft = torch.zeros( | |
| batchsize, | |
| self.out_channels, | |
| x.size(-3), | |
| x.size(-2), | |
| x.size(-1) // 2 + 1, | |
| dtype=torch.cfloat, | |
| device=x.device, | |
| ) | |
| out_ft[:, :, : self.modes1, : self.modes2, : self.modes3] = self.compl_mul3d( | |
| x_ft[:, :, : self.modes1, : self.modes2, : self.modes3], self.weights1 | |
| ) | |
| out_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3] = self.compl_mul3d( | |
| x_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3], self.weights2 | |
| ) | |
| out_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3] = self.compl_mul3d( | |
| x_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3], self.weights3 | |
| ) | |
| out_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3] = self.compl_mul3d( | |
| x_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3], self.weights4 | |
| ) | |
| # Return to physical space | |
| x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) | |
| return x | |
| def reset_parameters(self): | |
| """Reset spectral weights with distribution scale*U(0,1)""" | |
| self.weights1.data = self.scale * torch.rand(self.weights1.data.shape) | |
| self.weights2.data = self.scale * torch.rand(self.weights2.data.shape) | |
| self.weights3.data = self.scale * torch.rand(self.weights3.data.shape) | |
| self.weights4.data = self.scale * torch.rand(self.weights4.data.shape) | |
| class SpectralConv4d(nn.Module): | |
| """4D Fourier layer. It does FFT, linear transform, and Inverse FFT. | |
| Parameters | |
| ---------- | |
| in_channels : int | |
| Number of input channels | |
| out_channels : int | |
| Number of output channels | |
| modes1 : int | |
| Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1 | |
| modes2 : int | |
| Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1 | |
| modes3 : int | |
| Number of Fourier modes to multiply in third dimension, at most floor(N/2) + 1 | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| modes1: int, | |
| modes2: int, | |
| modes3: int, | |
| modes4: int, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| # Number of Fourier modes to multiply, at most floor(N/2) + 1 | |
| self.modes1 = modes1 | |
| self.modes2 = modes2 | |
| self.modes3 = modes3 | |
| self.modes4 = modes4 | |
| self.scale = 1 / (in_channels * out_channels) | |
| self.weights1 = nn.Parameter( | |
| torch.empty( | |
| in_channels, | |
| out_channels, | |
| self.modes1, | |
| self.modes2, | |
| self.modes3, | |
| self.modes4, | |
| 2, | |
| ) | |
| ) | |
| self.weights2 = nn.Parameter( | |
| torch.empty( | |
| in_channels, | |
| out_channels, | |
| self.modes1, | |
| self.modes2, | |
| self.modes3, | |
| self.modes4, | |
| 2, | |
| ) | |
| ) | |
| self.weights3 = nn.Parameter( | |
| torch.empty( | |
| in_channels, | |
| out_channels, | |
| self.modes1, | |
| self.modes2, | |
| self.modes3, | |
| self.modes4, | |
| 2, | |
| ) | |
| ) | |
| self.weights4 = nn.Parameter( | |
| torch.empty( | |
| in_channels, | |
| out_channels, | |
| self.modes1, | |
| self.modes2, | |
| self.modes3, | |
| self.modes4, | |
| 2, | |
| ) | |
| ) | |
| self.weights5 = nn.Parameter( | |
| torch.empty( | |
| in_channels, | |
| out_channels, | |
| self.modes1, | |
| self.modes2, | |
| self.modes3, | |
| self.modes4, | |
| 2, | |
| ) | |
| ) | |
| self.weights6 = nn.Parameter( | |
| torch.empty( | |
| in_channels, | |
| out_channels, | |
| self.modes1, | |
| self.modes2, | |
| self.modes3, | |
| self.modes4, | |
| 2, | |
| ) | |
| ) | |
| self.weights7 = nn.Parameter( | |
| torch.empty( | |
| in_channels, | |
| out_channels, | |
| self.modes1, | |
| self.modes2, | |
| self.modes3, | |
| self.modes4, | |
| 2, | |
| ) | |
| ) | |
| self.weights8 = nn.Parameter( | |
| torch.empty( | |
| in_channels, | |
| out_channels, | |
| self.modes1, | |
| self.modes2, | |
| self.modes3, | |
| self.modes4, | |
| 2, | |
| ) | |
| ) | |
| self.reset_parameters() | |
| def compl_mul4d( | |
| self, | |
| input: Tensor, | |
| weights: Tensor, | |
| ) -> Tensor: | |
| """Complex multiplication | |
| Parameters | |
| ---------- | |
| input : Tensor | |
| Input tensor | |
| weights : Tensor | |
| Weights tensor | |
| Returns | |
| ------- | |
| Tensor | |
| Product of complex multiplication | |
| """ | |
| # (batch, in_channel, x, y, z), (in_channel, out_channel, x, y, z) -> (batch, out_channel, x, y, z) | |
| cweights = torch.view_as_complex(weights) | |
| return torch.einsum("bixyzt,ioxyzt->boxyzt", input, cweights) | |
| def forward(self, x: Tensor) -> Tensor: | |
| batchsize = x.shape[0] | |
| # Compute Fourier coeffcients up to factor of e^(- something constant) | |
| x_ft = torch.fft.rfftn(x, dim=[-4, -3, -2, -1]) | |
| # Multiply relevant Fourier modes | |
| out_ft = torch.zeros( | |
| batchsize, | |
| self.out_channels, | |
| x.size(-4), | |
| x.size(-3), | |
| x.size(-2), | |
| x.size(-1) // 2 + 1, | |
| dtype=torch.cfloat, | |
| device=x.device, | |
| ) | |
| # print(f'mod: size x: {x_ft.size()}, out: {out_ft.size()}') | |
| # print(f'mod: x_ft[weight4]: {x_ft[:, :, self.modes1 :, self.modes2 :, : -self.modes3, :self.modes4].size()} weight4: {self.weights4.size()}') | |
| out_ft[:, :, : self.modes1, : self.modes2, : self.modes3, : self.modes4] = ( | |
| self.compl_mul4d( | |
| x_ft[:, :, : self.modes1, : self.modes2, : self.modes3, : self.modes4], | |
| self.weights1, | |
| ) | |
| ) | |
| out_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3, : self.modes4] = ( | |
| self.compl_mul4d( | |
| x_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3, : self.modes4], | |
| self.weights2, | |
| ) | |
| ) | |
| out_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3, : self.modes4] = ( | |
| self.compl_mul4d( | |
| x_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3, : self.modes4], | |
| self.weights3, | |
| ) | |
| ) | |
| out_ft[:, :, : self.modes1, : self.modes2, -self.modes3 :, : self.modes4] = ( | |
| self.compl_mul4d( | |
| x_ft[:, :, : self.modes1, : self.modes2, -self.modes3 :, : self.modes4], | |
| self.weights4, | |
| ) | |
| ) | |
| out_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3, : self.modes4] = ( | |
| self.compl_mul4d( | |
| x_ft[ | |
| :, :, -self.modes1 :, -self.modes2 :, : self.modes3, : self.modes4 | |
| ], | |
| self.weights5, | |
| ) | |
| ) | |
| out_ft[:, :, -self.modes1 :, : self.modes2, -self.modes3 :, : self.modes4] = ( | |
| self.compl_mul4d( | |
| x_ft[ | |
| :, :, -self.modes1 :, : self.modes2, -self.modes3 :, : self.modes4 | |
| ], | |
| self.weights6, | |
| ) | |
| ) | |
| out_ft[:, :, : self.modes1, -self.modes2 :, -self.modes3 :, : self.modes4] = ( | |
| self.compl_mul4d( | |
| x_ft[ | |
| :, :, : self.modes1, -self.modes2 :, -self.modes3 :, : self.modes4 | |
| ], | |
| self.weights7, | |
| ) | |
| ) | |
| out_ft[:, :, -self.modes1 :, -self.modes2 :, -self.modes3 :, : self.modes4] = ( | |
| self.compl_mul4d( | |
| x_ft[ | |
| :, :, -self.modes1 :, -self.modes2 :, -self.modes3 :, : self.modes4 | |
| ], | |
| self.weights8, | |
| ) | |
| ) | |
| # Return to physical space | |
| x = torch.fft.irfftn(out_ft, s=(x.size(-4), x.size(-3), x.size(-2), x.size(-1))) | |
| return x | |
| def reset_parameters(self): | |
| """Reset spectral weights with distribution scale*U(0,1)""" | |
| self.weights1.data = self.scale * torch.rand(self.weights1.data.shape) | |
| self.weights2.data = self.scale * torch.rand(self.weights2.data.shape) | |
| self.weights3.data = self.scale * torch.rand(self.weights3.data.shape) | |
| self.weights4.data = self.scale * torch.rand(self.weights4.data.shape) | |
| self.weights5.data = self.scale * torch.rand(self.weights5.data.shape) | |
| self.weights6.data = self.scale * torch.rand(self.weights6.data.shape) | |
| self.weights7.data = self.scale * torch.rand(self.weights7.data.shape) | |
| self.weights8.data = self.scale * torch.rand(self.weights8.data.shape) | |
| # ========================================== | |
| # Utils for PINO exact gradients | |
| # ========================================== | |
| def fourier_derivatives(x: Tensor, ell: List[float]) -> Tuple[Tensor, Tensor]: | |
| """ | |
| Fourier derivative function for PINO | |
| """ | |
| # check that input shape maches domain length | |
| if len(x.shape) - 2 != len(ell): | |
| raise ValueError("input shape doesn't match domain dims") | |
| # set pi from numpy | |
| pi = float(np.pi) | |
| # get needed dims | |
| n = x.shape[2:] | |
| dim = len(ell) | |
| # get device | |
| device = x.device | |
| # compute fourier transform | |
| x_h = torch.fft.fftn(x, dim=list(range(2, dim + 2))) | |
| # make wavenumbers | |
| k_x = [] | |
| for i, nx in enumerate(n): | |
| k_x.append( | |
| torch.cat( | |
| ( | |
| torch.arange(start=0, end=nx // 2, step=1, device=device), | |
| torch.arange(start=-nx // 2, end=0, step=1, device=device), | |
| ), | |
| 0, | |
| ).reshape((i + 2) * [1] + [nx] + (dim - i - 1) * [1]) | |
| ) | |
| # compute laplacian in fourier space | |
| j = torch.complex( | |
| torch.tensor([0.0], device=device), torch.tensor([1.0], device=device) | |
| ) # Cuda graphs does not work here | |
| wx_h = [j * k_x_i * x_h * (2 * pi / ell[i]) for i, k_x_i in enumerate(k_x)] | |
| wxx_h = [ | |
| j * k_x_i * wx_h_i * (2 * pi / ell[i]) | |
| for i, (wx_h_i, k_x_i) in enumerate(zip(wx_h, k_x)) | |
| ] | |
| # inverse fourier transform out | |
| wx = torch.cat( | |
| [torch.fft.ifftn(wx_h_i, dim=list(range(2, dim + 2))).real for wx_h_i in wx_h], | |
| dim=1, | |
| ) | |
| wxx = torch.cat( | |
| [ | |
| torch.fft.ifftn(wxx_h_i, dim=list(range(2, dim + 2))).real | |
| for wxx_h_i in wxx_h | |
| ], | |
| dim=1, | |
| ) | |
| return (wx, wxx) | |
| def calc_latent_derivatives( | |
| x: Tensor, domain_length: List[int] = 2 | |
| ) -> Tuple[List[Tensor], List[Tensor]]: | |
| """ | |
| Compute first and second order derivatives of latent variables | |
| """ | |
| dim = len(x.shape) - 2 | |
| # Compute derivatives of latent variables via fourier methods | |
| # Padd domain by factor of 2 for non-periodic domains | |
| padd = [(i - 1) // 2 for i in list(x.shape[2:])] | |
| # Scale domain length by padding amount | |
| domain_length = [ | |
| domain_length[i] * (2 * padd[i] + x.shape[i + 2]) / x.shape[i + 2] | |
| for i in range(dim) | |
| ] | |
| padding = padd + padd | |
| x_p = F.pad(x, padding, mode="replicate") | |
| dx, ddx = fourier_derivatives(x_p, domain_length) | |
| # Trim padded domain | |
| if len(x.shape) == 3: | |
| dx = dx[..., padd[0] : -padd[0]] | |
| ddx = ddx[..., padd[0] : -padd[0]] | |
| dx_list = torch.split(dx, x.shape[1], dim=1) | |
| ddx_list = torch.split(ddx, x.shape[1], dim=1) | |
| elif len(x.shape) == 4: | |
| dx = dx[..., padd[0] : -padd[0], padd[1] : -padd[1]] | |
| ddx = ddx[..., padd[0] : -padd[0], padd[1] : -padd[1]] | |
| dx_list = torch.split(dx, x.shape[1], dim=1) | |
| ddx_list = torch.split(ddx, x.shape[1], dim=1) | |
| else: | |
| dx = dx[..., padd[0] : -padd[0], padd[1] : -padd[1], padd[2] : -padd[2]] | |
| ddx = ddx[..., padd[0] : -padd[0], padd[1] : -padd[1], padd[2] : -padd[2]] | |
| dx_list = torch.split(dx, x.shape[1], dim=1) | |
| ddx_list = torch.split(ddx, x.shape[1], dim=1) | |
| return dx_list, ddx_list | |
| def first_order_pino_grads( | |
| u: Tensor, | |
| ux: List[Tensor], | |
| weights_1: Tensor, | |
| weights_2: Tensor, | |
| bias_1: Tensor, | |
| ) -> Tuple[Tensor]: # pragma: no cover | |
| """ | |
| Compute first order derivatives of output variables | |
| """ | |
| # dim for einsum | |
| dim = len(u.shape) - 2 | |
| dim_str = "xyz"[:dim] | |
| # compute first order derivatives of input | |
| # compute first layer | |
| if dim == 1: | |
| u_hidden = F.conv1d(u, weights_1, bias_1) | |
| elif dim == 2: | |
| weights_1 = weights_1.unsqueeze(-1) | |
| weights_2 = weights_2.unsqueeze(-1) | |
| u_hidden = F.conv2d(u, weights_1, bias_1) | |
| elif dim == 3: | |
| weights_1 = weights_1.unsqueeze(-1).unsqueeze(-1) | |
| weights_2 = weights_2.unsqueeze(-1).unsqueeze(-1) | |
| u_hidden = F.conv3d(u, weights_1, bias_1) | |
| # compute derivative hidden layer | |
| diff_tanh = 1 / torch.cosh(u_hidden) ** 2 | |
| # compute diff(f(g)) | |
| diff_fg = torch.einsum( | |
| "mi" + dim_str + ",bm" + dim_str + ",km" + dim_str + "->bi" + dim_str, | |
| weights_1, | |
| diff_tanh, | |
| weights_2, | |
| ) | |
| # compute diff(f(g)) * diff(g) | |
| vx = [ | |
| torch.einsum("bi" + dim_str + ",bi" + dim_str + "->b" + dim_str, diff_fg, w) | |
| for w in ux | |
| ] | |
| vx = [torch.unsqueeze(w, dim=1) for w in vx] | |
| return vx | |
| def second_order_pino_grads( | |
| u: Tensor, | |
| ux: Tensor, | |
| uxx: Tensor, | |
| weights_1: Tensor, | |
| weights_2: Tensor, | |
| bias_1: Tensor, | |
| ) -> Tuple[Tensor]: # pragma: no cover | |
| """ | |
| Compute second order derivatives of output variables | |
| """ | |
| # dim for einsum | |
| dim = len(u.shape) - 2 | |
| dim_str = "xyz"[:dim] | |
| # compute first order derivatives of input | |
| # compute first layer | |
| if dim == 1: | |
| u_hidden = F.conv1d(u, weights_1, bias_1) | |
| elif dim == 2: | |
| weights_1 = weights_1.unsqueeze(-1) | |
| weights_2 = weights_2.unsqueeze(-1) | |
| u_hidden = F.conv2d(u, weights_1, bias_1) | |
| elif dim == 3: | |
| weights_1 = weights_1.unsqueeze(-1).unsqueeze(-1) | |
| weights_2 = weights_2.unsqueeze(-1).unsqueeze(-1) | |
| u_hidden = F.conv3d(u, weights_1, bias_1) | |
| # compute derivative hidden layer | |
| diff_tanh = 1 / torch.cosh(u_hidden) ** 2 | |
| # compute diff(f(g)) | |
| diff_fg = torch.einsum( | |
| "mi" + dim_str + ",bm" + dim_str + ",km" + dim_str + "->bi" + dim_str, | |
| weights_1, | |
| diff_tanh, | |
| weights_2, | |
| ) | |
| # compute diagonal of hessian | |
| # double derivative of hidden layer | |
| diff_diff_tanh = -2 * diff_tanh * torch.tanh(u_hidden) | |
| # compute diff(g) * hessian(f) * diff(g) | |
| vxx1 = [ | |
| torch.einsum( | |
| "bi" | |
| + dim_str | |
| + ",mi" | |
| + dim_str | |
| + ",bm" | |
| + dim_str | |
| + ",mj" | |
| + dim_str | |
| + ",bj" | |
| + dim_str | |
| + "->b" | |
| + dim_str, | |
| w, | |
| weights_1, | |
| weights_2 * diff_diff_tanh, | |
| weights_1, | |
| w, | |
| ) | |
| for w in ux | |
| ] # (b,x,y,t) | |
| # compute diff(f) * hessian(g) | |
| vxx2 = [ | |
| torch.einsum("bi" + dim_str + ",bi" + dim_str + "->b" + dim_str, diff_fg, w) | |
| for w in uxx | |
| ] | |
| vxx = [torch.unsqueeze(a + b, dim=1) for a, b in zip(vxx1, vxx2)] | |
| return vxx | |