| from model import common | |
| # import common | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # import scipy.io as sio | |
| # from model.masknet import MaskBlock | |
| # from masknet import MaskBlock | |
| from model.MFAM import MFAM | |
| # from MFAM import MFAM | |
| # from model.newarch import NEWARCH | |
| # from newarch import NEWARCH | |
| # from MFAM import MFAM | |
| # from model.swinir import SwinIR | |
| # from swinir import SwinIR | |
| # from model.uformer import Uformer | |
| # from uformer import Uformer | |
| # from model.mwt import MWT | |
| # from mwt import MWT | |
| # from model.restormer import Restormer | |
| # from restormer import Restormer | |
| import os | |
| def make_model(args, parent=False): | |
| return EWT(args) | |
| class EWT(nn.Module): | |
| # def __init__(self, args, conv=common.default_conv): | |
| def __init__(self, conv=common.default_conv): | |
| super(EWT, self).__init__() | |
| print("EWT") | |
| self.scale_idx = 0 | |
| self.DWT = common.DWT() | |
| self.IWT = common.IWT() | |
| # gray-4 | |
| # self.trans = MFAM(upscale=1, img_size=(32, 32), in_chans=12, | |
| # window_size=8, img_range=1., depths=[2, 2, 4], | |
| # embed_dim=96, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='') | |
| # gray-1 | |
| # self.trans = MFAM(upscale=1, img_size=(32, 32), in_chans=12, | |
| # window_size=8, img_range=1., depths=[2, 2, 4], | |
| # embed_dim=48, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='') | |
| # gray-5 | |
| self.trans = MFAM(upscale=1, img_size=(32, 32), in_chans=24, | |
| window_size=8, img_range=1., depths=[2, 2, 4], | |
| embed_dim=96, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='') | |
| # self.trans = NEWARCH(upscale=1, img_size=(32, 32), in_chans=12, | |
| # window_size=8, img_range=1., depths=[4, 4, 4, 4, 4], | |
| # embed_dim=180, num_heads=[6, 6, 6, 6, 6], mlp_ratio=2, upsamplermpler='') | |
| # self.trans = SwinIR(upscale=1, img_size=(8, 8), | |
| # window_size=8, img_range=1., depths=[6, 6, 6], | |
| # embed_dim=180, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='') | |
| # self.trans = Uformer(img_size=[64, 64], embed_dim=16, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], | |
| # win_size=8, mlp_ratio=4., token_projection='linear', token_mlp='leff', modulator=True, | |
| # shift_flag=False) | |
| # self.trans = Restormer() | |
| # self.trans = MWT() | |
| # self.trans = MaskBlock() | |
| def _padding(self, x, scale): | |
| delta_H = 0 | |
| delta_W = 0 | |
| if x.shape[2] % scale != 0: | |
| delta_H = scale - x.shape[2] % scale | |
| x = F.pad(x, (0, 0, 0, delta_H), 'reflect') | |
| if x.shape[3] % scale != 0: | |
| delta_W = scale - x.shape[3] % scale | |
| x = F.pad(x, (0, delta_W, 0, 0), 'reflect') | |
| return x, delta_H, delta_W | |
| def _padding_2(self, x): | |
| _, _, H, W = x.shape | |
| delta = abs(H-W) | |
| if H < W: | |
| x = F.pad(x, (0, 0, 0, delta), 'reflect') | |
| elif H > W: | |
| x = F.pad(x, (0, delta, 0, 0), 'reflect') | |
| return x | |
| def forward(self, x): | |
| _, _, H, W = x.shape | |
| # x = self._padding_2(x) | |
| x, delta_H, delta_W = self._padding(x, 2) | |
| # print(x.shape) | |
| x = self.DWT(x) | |
| # x = self.DWT(x) | |
| x = self.trans(x) | |
| # x = self.IWT(x) | |
| # x = self.IWT(x) | |
| x = self.IWT(x) | |
| x = x[:, :, :H, :W] | |
| return x | |
| if __name__ == "__main__": | |
| os.environ["CUDA_VISIBLE_DEVICES"] = '0' | |
| # input size: [batch_size, C, N], where C is number of dimension, N is the number of mesh. | |
| x = torch.rand(2, 3, 64, 64) | |
| x = x.cuda() | |
| # x = x.cuda() | |
| model = EWT() | |
| # model = model.cuda() | |
| # y = model(x) | |
| # print(y.shape) | |
| def get_parameter_number(model): | |
| total_num = sum(p.numel() for p in model.parameters()) | |
| trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| return {'Total': total_num, 'Trainable': trainable_num} | |
| print(get_parameter_number(model)) |