| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torch.fft | |
| import numpy as np | |
| import torch.optim as optimizer | |
| from functools import partial | |
| from collections import OrderedDict | |
| from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |
| from torch.utils.checkpoint import checkpoint_sequential | |
| from torch import nn | |
| class BasicConv2d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False): | |
| super(BasicConv2d, self).__init__() | |
| self.act_norm=act_norm | |
| if not transpose: | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) | |
| else: | |
| self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 ) | |
| self.norm = nn.GroupNorm(2, out_channels) | |
| self.act = nn.LeakyReLU(0.2, inplace=True) | |
| def forward(self, x): | |
| y = self.conv(x) | |
| if self.act_norm: | |
| y = self.act(self.norm(y)) | |
| return y | |
| class ConvSC(nn.Module): | |
| def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True): | |
| super(ConvSC, self).__init__() | |
| if stride == 1: | |
| transpose = False | |
| self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride, | |
| padding=1, transpose=transpose, act_norm=act_norm) | |
| def forward(self, x): | |
| y = self.conv(x) | |
| return y | |
| class GroupConv2d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False): | |
| super(GroupConv2d, self).__init__() | |
| self.act_norm = act_norm | |
| if in_channels % groups != 0: | |
| groups = 1 | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups) | |
| self.norm = nn.GroupNorm(groups,out_channels) | |
| self.activate = nn.LeakyReLU(0.2, inplace=True) | |
| def forward(self, x): | |
| y = self.conv(x) | |
| if self.act_norm: | |
| y = self.activate(self.norm(y)) | |
| return y | |
| class Inception(nn.Module): | |
| def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8): | |
| super(Inception, self).__init__() | |
| self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0) | |
| layers = [] | |
| for ker in incep_ker: | |
| layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, groups=groups, act_norm=True)) | |
| self.layers = nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| y = 0 | |
| for layer in self.layers: | |
| y += layer(x) | |
| return y | |
| class Mlp(nn.Module): | |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
| super(Mlp, self).__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.fc1 = nn.Linear(in_features, hidden_features) | |
| self.act = act_layer() | |
| self.fc2 = nn.Linear(hidden_features, out_features) | |
| self.fc3 = nn.AdaptiveAvgPool1d(out_features) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc3(x) | |
| x = self.drop(x) | |
| return x | |
| class AdativeFourierNeuralOperator(nn.Module): | |
| def __init__(self, dim, h=16, w=16, is_fno_bias=True): | |
| super(AdativeFourierNeuralOperator, self).__init__() | |
| self.hidden_size = dim | |
| self.h = h | |
| self.w = w | |
| self.num_blocks = 2 | |
| self.block_size = self.hidden_size // self.num_blocks | |
| assert self.hidden_size % self.num_blocks == 0 | |
| self.scale = 0.02 | |
| self.w1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size)) | |
| self.b1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) | |
| self.w2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size)) | |
| self.b2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) | |
| self.relu = nn.ReLU() | |
| self.is_fno_bias = is_fno_bias | |
| if self.is_fno_bias: | |
| self.bias = nn.Conv1d(self.hidden_size, self.hidden_size, 1) | |
| else: | |
| self.bias = None | |
| self.softshrink = 0.00 | |
| def multiply(self, input, weights): | |
| return torch.einsum('...bd, bdk->...bk', input, weights) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| if self.bias: | |
| bias = self.bias(x.permute(0, 2, 1)).permute(0, 2, 1) | |
| else: | |
| bias = torch.zeros(x.shape, device=x.device) | |
| x = x.reshape(B, self.h, self.w, C) | |
| x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') | |
| x = x.reshape(B, x.shape[1], x.shape[2], self.num_blocks, self.block_size) | |
| x_real = F.relu(self.multiply(x.real, self.w1[0]) - self.multiply(x.imag, self.w1[1]) + self.b1[0], | |
| inplace=True) | |
| x_imag = F.relu(self.multiply(x.real, self.w1[1]) + self.multiply(x.imag, self.w1[0]) + self.b1[1], | |
| inplace=True) | |
| x_real = self.multiply(x_real, self.w2[0]) - self.multiply(x_imag, self.w2[1]) + self.b2[0] | |
| x_imag = self.multiply(x_real, self.w2[1]) + self.multiply(x_imag, self.w2[0]) + self.b2[1] | |
| x = torch.stack([x_real, x_imag], dim=-1) | |
| x = F.softshrink(x, lambd=self.softshrink) if self.softshrink else x | |
| x = torch.view_as_complex(x) | |
| x = x.reshape(B, x.shape[1], x.shape[2], self.hidden_size) | |
| x = torch.fft.irfft2(x, s=(self.h, self.w), dim=(1, 2), norm='ortho') | |
| x = x.reshape(B, N, C) | |
| return x + bias | |
| class FourierNetBlock(nn.Module): | |
| def __init__(self, | |
| dim, | |
| mlp_ratio=4., | |
| drop=0., | |
| drop_path=0., | |
| act_layer=nn.GELU, | |
| norm_layer=nn.LayerNorm, | |
| h=16, | |
| w=16): | |
| super(FourierNetBlock, self).__init__() | |
| self.normlayer1 = norm_layer(dim) | |
| self.filter = AdativeFourierNeuralOperator(dim, h=h, w=w) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.normlayer2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = Mlp(in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop) | |
| self.double_skip = True | |
| def forward(self, x): | |
| x = x + self.drop_path(self.filter(self.normlayer1(x))) | |
| x = x + self.drop_path(self.mlp(self.normlayer2(x))) | |
| return x | |