Ref-Band / src /model /ewt.py
Traver's picture
Upload 180 files
2147e2e verified
from model import common
# import common
import torch
import torch.nn as nn
import torch.nn.functional as F
# import scipy.io as sio
# from model.masknet import MaskBlock
# from masknet import MaskBlock
from model.MFAM import MFAM
# from MFAM import MFAM
# from model.newarch import NEWARCH
# from newarch import NEWARCH
# from MFAM import MFAM
# from model.swinir import SwinIR
# from swinir import SwinIR
# from model.uformer import Uformer
# from uformer import Uformer
# from model.mwt import MWT
# from mwt import MWT
# from model.restormer import Restormer
# from restormer import Restormer
import os
def make_model(args, parent=False):
return EWT(args)
class EWT(nn.Module):
# def __init__(self, args, conv=common.default_conv):
def __init__(self, conv=common.default_conv):
super(EWT, self).__init__()
print("EWT")
self.scale_idx = 0
self.DWT = common.DWT()
self.IWT = common.IWT()
# gray-4
# self.trans = MFAM(upscale=1, img_size=(32, 32), in_chans=12,
# window_size=8, img_range=1., depths=[2, 2, 4],
# embed_dim=96, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='')
# gray-1
# self.trans = MFAM(upscale=1, img_size=(32, 32), in_chans=12,
# window_size=8, img_range=1., depths=[2, 2, 4],
# embed_dim=48, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='')
# gray-5
self.trans = MFAM(upscale=1, img_size=(32, 32), in_chans=24,
window_size=8, img_range=1., depths=[2, 2, 4],
embed_dim=96, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='')
# self.trans = NEWARCH(upscale=1, img_size=(32, 32), in_chans=12,
# window_size=8, img_range=1., depths=[4, 4, 4, 4, 4],
# embed_dim=180, num_heads=[6, 6, 6, 6, 6], mlp_ratio=2, upsamplermpler='')
# self.trans = SwinIR(upscale=1, img_size=(8, 8),
# window_size=8, img_range=1., depths=[6, 6, 6],
# embed_dim=180, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='')
# self.trans = Uformer(img_size=[64, 64], embed_dim=16, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2],
# win_size=8, mlp_ratio=4., token_projection='linear', token_mlp='leff', modulator=True,
# shift_flag=False)
# self.trans = Restormer()
# self.trans = MWT()
# self.trans = MaskBlock()
def _padding(self, x, scale):
delta_H = 0
delta_W = 0
if x.shape[2] % scale != 0:
delta_H = scale - x.shape[2] % scale
x = F.pad(x, (0, 0, 0, delta_H), 'reflect')
if x.shape[3] % scale != 0:
delta_W = scale - x.shape[3] % scale
x = F.pad(x, (0, delta_W, 0, 0), 'reflect')
return x, delta_H, delta_W
def _padding_2(self, x):
_, _, H, W = x.shape
delta = abs(H-W)
if H < W:
x = F.pad(x, (0, 0, 0, delta), 'reflect')
elif H > W:
x = F.pad(x, (0, delta, 0, 0), 'reflect')
return x
def forward(self, x):
_, _, H, W = x.shape
# x = self._padding_2(x)
x, delta_H, delta_W = self._padding(x, 2)
# print(x.shape)
x = self.DWT(x)
# x = self.DWT(x)
x = self.trans(x)
# x = self.IWT(x)
# x = self.IWT(x)
x = self.IWT(x)
x = x[:, :, :H, :W]
return x
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
# input size: [batch_size, C, N], where C is number of dimension, N is the number of mesh.
x = torch.rand(2, 3, 64, 64)
x = x.cuda()
# x = x.cuda()
model = EWT()
# model = model.cuda()
# y = model(x)
# print(y.shape)
def get_parameter_number(model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
print(get_parameter_number(model))