import torch from torch import nn import torch.nn.functional as F import torch.fft import numpy as np import torch.optim as optimizer from functools import partial from collections import OrderedDict from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from torch.utils.checkpoint import checkpoint_sequential from torch import nn class BasicConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False): super(BasicConv2d, self).__init__() self.act_norm=act_norm if not transpose: self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) else: self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 ) self.norm = nn.GroupNorm(2, out_channels) self.act = nn.LeakyReLU(0.2, inplace=True) def forward(self, x): y = self.conv(x) if self.act_norm: y = self.act(self.norm(y)) return y class ConvSC(nn.Module): def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True): super(ConvSC, self).__init__() if stride == 1: transpose = False self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride, padding=1, transpose=transpose, act_norm=act_norm) def forward(self, x): y = self.conv(x) return y class GroupConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False): super(GroupConv2d, self).__init__() self.act_norm = act_norm if in_channels % groups != 0: groups = 1 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups) self.norm = nn.GroupNorm(groups,out_channels) self.activate = nn.LeakyReLU(0.2, inplace=True) def forward(self, x): y = self.conv(x) if self.act_norm: y = self.activate(self.norm(y)) return y class Inception(nn.Module): def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8): super(Inception, self).__init__() self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0) layers = [] for ker in incep_ker: layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, groups=groups, act_norm=True)) self.layers = nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) y = 0 for layer in self.layers: y += layer(x) return y class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super(Mlp, self).__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.fc3 = nn.AdaptiveAvgPool1d(out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc3(x) x = self.drop(x) return x class AdativeFourierNeuralOperator(nn.Module): def __init__(self, dim, h=16, w=16, is_fno_bias=True): super(AdativeFourierNeuralOperator, self).__init__() self.hidden_size = dim self.h = h self.w = w self.num_blocks = 2 self.block_size = self.hidden_size // self.num_blocks assert self.hidden_size % self.num_blocks == 0 self.scale = 0.02 self.w1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size)) self.b1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) self.w2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size)) self.b2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) self.relu = nn.ReLU() self.is_fno_bias = is_fno_bias if self.is_fno_bias: self.bias = nn.Conv1d(self.hidden_size, self.hidden_size, 1) else: self.bias = None self.softshrink = 0.00 def multiply(self, input, weights): return torch.einsum('...bd, bdk->...bk', input, weights) def forward(self, x): B, N, C = x.shape if self.bias: bias = self.bias(x.permute(0, 2, 1)).permute(0, 2, 1) else: bias = torch.zeros(x.shape, device=x.device) x = x.reshape(B, self.h, self.w, C) x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') x = x.reshape(B, x.shape[1], x.shape[2], self.num_blocks, self.block_size) x_real = F.relu(self.multiply(x.real, self.w1[0]) - self.multiply(x.imag, self.w1[1]) + self.b1[0], inplace=True) x_imag = F.relu(self.multiply(x.real, self.w1[1]) + self.multiply(x.imag, self.w1[0]) + self.b1[1], inplace=True) x_real = self.multiply(x_real, self.w2[0]) - self.multiply(x_imag, self.w2[1]) + self.b2[0] x_imag = self.multiply(x_real, self.w2[1]) + self.multiply(x_imag, self.w2[0]) + self.b2[1] x = torch.stack([x_real, x_imag], dim=-1) x = F.softshrink(x, lambd=self.softshrink) if self.softshrink else x x = torch.view_as_complex(x) x = x.reshape(B, x.shape[1], x.shape[2], self.hidden_size) x = torch.fft.irfft2(x, s=(self.h, self.w), dim=(1, 2), norm='ortho') x = x.reshape(B, N, C) return x + bias class FourierNetBlock(nn.Module): def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=16, w=16): super(FourierNetBlock, self).__init__() self.normlayer1 = norm_layer(dim) self.filter = AdativeFourierNeuralOperator(dim, h=h, w=w) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.normlayer2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.double_skip = True def forward(self, x): x = x + self.drop_path(self.filter(self.normlayer1(x))) x = x + self.drop_path(self.mlp(self.normlayer2(x))) return x