File size: 6,924 Bytes
fa26901 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | import torch
from torch import nn
import torch.nn.functional as F
import torch.fft
import numpy as np
import torch.optim as optimizer
from functools import partial
from collections import OrderedDict
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from torch.utils.checkpoint import checkpoint_sequential
from torch import nn
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False):
super(BasicConv2d, self).__init__()
self.act_norm=act_norm
if not transpose:
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
else:
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 )
self.norm = nn.GroupNorm(2, out_channels)
self.act = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
y = self.conv(x)
if self.act_norm:
y = self.act(self.norm(y))
return y
class ConvSC(nn.Module):
def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):
super(ConvSC, self).__init__()
if stride == 1:
transpose = False
self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride,
padding=1, transpose=transpose, act_norm=act_norm)
def forward(self, x):
y = self.conv(x)
return y
class GroupConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False):
super(GroupConv2d, self).__init__()
self.act_norm = act_norm
if in_channels % groups != 0:
groups = 1
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups)
self.norm = nn.GroupNorm(groups,out_channels)
self.activate = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
y = self.conv(x)
if self.act_norm:
y = self.activate(self.norm(y))
return y
class Inception(nn.Module):
def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8):
super(Inception, self).__init__()
self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
layers = []
for ker in incep_ker:
layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, groups=groups, act_norm=True))
self.layers = nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
y = 0
for layer in self.layers:
y += layer(x)
return y
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super(Mlp, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.fc3 = nn.AdaptiveAvgPool1d(out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc3(x)
x = self.drop(x)
return x
class AdativeFourierNeuralOperator(nn.Module):
def __init__(self, dim, h=16, w=16, is_fno_bias=True):
super(AdativeFourierNeuralOperator, self).__init__()
self.hidden_size = dim
self.h = h
self.w = w
self.num_blocks = 2
self.block_size = self.hidden_size // self.num_blocks
assert self.hidden_size % self.num_blocks == 0
self.scale = 0.02
self.w1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size))
self.b1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
self.w2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size))
self.b2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
self.relu = nn.ReLU()
self.is_fno_bias = is_fno_bias
if self.is_fno_bias:
self.bias = nn.Conv1d(self.hidden_size, self.hidden_size, 1)
else:
self.bias = None
self.softshrink = 0.00
def multiply(self, input, weights):
return torch.einsum('...bd, bdk->...bk', input, weights)
def forward(self, x):
B, N, C = x.shape
if self.bias:
bias = self.bias(x.permute(0, 2, 1)).permute(0, 2, 1)
else:
bias = torch.zeros(x.shape, device=x.device)
x = x.reshape(B, self.h, self.w, C)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
x = x.reshape(B, x.shape[1], x.shape[2], self.num_blocks, self.block_size)
x_real = F.relu(self.multiply(x.real, self.w1[0]) - self.multiply(x.imag, self.w1[1]) + self.b1[0],
inplace=True)
x_imag = F.relu(self.multiply(x.real, self.w1[1]) + self.multiply(x.imag, self.w1[0]) + self.b1[1],
inplace=True)
x_real = self.multiply(x_real, self.w2[0]) - self.multiply(x_imag, self.w2[1]) + self.b2[0]
x_imag = self.multiply(x_real, self.w2[1]) + self.multiply(x_imag, self.w2[0]) + self.b2[1]
x = torch.stack([x_real, x_imag], dim=-1)
x = F.softshrink(x, lambd=self.softshrink) if self.softshrink else x
x = torch.view_as_complex(x)
x = x.reshape(B, x.shape[1], x.shape[2], self.hidden_size)
x = torch.fft.irfft2(x, s=(self.h, self.w), dim=(1, 2), norm='ortho')
x = x.reshape(B, N, C)
return x + bias
class FourierNetBlock(nn.Module):
def __init__(self,
dim,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
h=16,
w=16):
super(FourierNetBlock, self).__init__()
self.normlayer1 = norm_layer(dim)
self.filter = AdativeFourierNeuralOperator(dim, h=h, w=w)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.normlayer2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
self.double_skip = True
def forward(self, x):
x = x + self.drop_path(self.filter(self.normlayer1(x)))
x = x + self.drop_path(self.mlp(self.normlayer2(x)))
return x
|