easylearning's picture
Upload 205 files
fa26901 verified
import torch
from torch import nn
import torch.nn.functional as F
import torch.fft
import numpy as np
import torch.optim as optimizer
from functools import partial
from collections import OrderedDict
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from torch.utils.checkpoint import checkpoint_sequential
from torch import nn
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False):
super(BasicConv2d, self).__init__()
self.act_norm=act_norm
if not transpose:
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
else:
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 )
self.norm = nn.GroupNorm(2, out_channels)
self.act = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
y = self.conv(x)
if self.act_norm:
y = self.act(self.norm(y))
return y
class ConvSC(nn.Module):
def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):
super(ConvSC, self).__init__()
if stride == 1:
transpose = False
self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride,
padding=1, transpose=transpose, act_norm=act_norm)
def forward(self, x):
y = self.conv(x)
return y
class GroupConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False):
super(GroupConv2d, self).__init__()
self.act_norm = act_norm
if in_channels % groups != 0:
groups = 1
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups)
self.norm = nn.GroupNorm(groups,out_channels)
self.activate = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
y = self.conv(x)
if self.act_norm:
y = self.activate(self.norm(y))
return y
class Inception(nn.Module):
def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8):
super(Inception, self).__init__()
self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
layers = []
for ker in incep_ker:
layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, groups=groups, act_norm=True))
self.layers = nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
y = 0
for layer in self.layers:
y += layer(x)
return y
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super(Mlp, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.fc3 = nn.AdaptiveAvgPool1d(out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc3(x)
x = self.drop(x)
return x
class AdativeFourierNeuralOperator(nn.Module):
def __init__(self, dim, h=16, w=16, is_fno_bias=True):
super(AdativeFourierNeuralOperator, self).__init__()
self.hidden_size = dim
self.h = h
self.w = w
self.num_blocks = 2
self.block_size = self.hidden_size // self.num_blocks
assert self.hidden_size % self.num_blocks == 0
self.scale = 0.02
self.w1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size))
self.b1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
self.w2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size))
self.b2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
self.relu = nn.ReLU()
self.is_fno_bias = is_fno_bias
if self.is_fno_bias:
self.bias = nn.Conv1d(self.hidden_size, self.hidden_size, 1)
else:
self.bias = None
self.softshrink = 0.00
def multiply(self, input, weights):
return torch.einsum('...bd, bdk->...bk', input, weights)
def forward(self, x):
B, N, C = x.shape
if self.bias:
bias = self.bias(x.permute(0, 2, 1)).permute(0, 2, 1)
else:
bias = torch.zeros(x.shape, device=x.device)
x = x.reshape(B, self.h, self.w, C)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
x = x.reshape(B, x.shape[1], x.shape[2], self.num_blocks, self.block_size)
x_real = F.relu(self.multiply(x.real, self.w1[0]) - self.multiply(x.imag, self.w1[1]) + self.b1[0],
inplace=True)
x_imag = F.relu(self.multiply(x.real, self.w1[1]) + self.multiply(x.imag, self.w1[0]) + self.b1[1],
inplace=True)
x_real = self.multiply(x_real, self.w2[0]) - self.multiply(x_imag, self.w2[1]) + self.b2[0]
x_imag = self.multiply(x_real, self.w2[1]) + self.multiply(x_imag, self.w2[0]) + self.b2[1]
x = torch.stack([x_real, x_imag], dim=-1)
x = F.softshrink(x, lambd=self.softshrink) if self.softshrink else x
x = torch.view_as_complex(x)
x = x.reshape(B, x.shape[1], x.shape[2], self.hidden_size)
x = torch.fft.irfft2(x, s=(self.h, self.w), dim=(1, 2), norm='ortho')
x = x.reshape(B, N, C)
return x + bias
class FourierNetBlock(nn.Module):
def __init__(self,
dim,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
h=16,
w=16):
super(FourierNetBlock, self).__init__()
self.normlayer1 = norm_layer(dim)
self.filter = AdativeFourierNeuralOperator(dim, h=h, w=w)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.normlayer2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
self.double_skip = True
def forward(self, x):
x = x + self.drop_path(self.filter(self.normlayer1(x)))
x = x + self.drop_path(self.mlp(self.normlayer2(x)))
return x