File size: 4,109 Bytes
2147e2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from model import common
# import common
import torch
import torch.nn as nn
import torch.nn.functional as F
# import scipy.io as sio
# from model.masknet import MaskBlock
# from masknet import MaskBlock
from model.MFAM import MFAM
# from MFAM import MFAM
# from model.newarch import NEWARCH
# from newarch import NEWARCH
# from MFAM import MFAM
# from model.swinir import SwinIR
# from swinir import SwinIR
# from model.uformer import Uformer
# from uformer import Uformer
# from model.mwt import MWT
# from mwt import MWT
# from model.restormer import Restormer
# from restormer import Restormer
import os

def make_model(args, parent=False):
    return EWT(args)

class EWT(nn.Module):
    # def __init__(self, args, conv=common.default_conv):
    def __init__(self, conv=common.default_conv):
        super(EWT, self).__init__()
        print("EWT")
        self.scale_idx = 0

        self.DWT = common.DWT()
        self.IWT = common.IWT()
        # gray-4
        # self.trans = MFAM(upscale=1, img_size=(32, 32), in_chans=12,
        #              window_size=8, img_range=1., depths=[2, 2, 4],
        #              embed_dim=96, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='')

        # gray-1
        # self.trans = MFAM(upscale=1, img_size=(32, 32), in_chans=12,
        #                   window_size=8, img_range=1., depths=[2, 2, 4],
        #                   embed_dim=48, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='')

        # gray-5
        self.trans = MFAM(upscale=1, img_size=(32, 32), in_chans=24,
                          window_size=8, img_range=1., depths=[2, 2, 4],
                          embed_dim=96, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='')

        # self.trans = NEWARCH(upscale=1, img_size=(32, 32), in_chans=12,
        #              window_size=8, img_range=1., depths=[4, 4, 4, 4, 4],
        #              embed_dim=180, num_heads=[6, 6, 6, 6, 6], mlp_ratio=2, upsamplermpler='')
        # self.trans = SwinIR(upscale=1, img_size=(8, 8),
        #              window_size=8, img_range=1., depths=[6, 6, 6],
        #              embed_dim=180, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='')
        # self.trans = Uformer(img_size=[64, 64], embed_dim=16, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2],
        #         win_size=8, mlp_ratio=4., token_projection='linear', token_mlp='leff', modulator=True,
        #         shift_flag=False)
        # self.trans = Restormer()
        # self.trans = MWT()
        # self.trans = MaskBlock()


    def _padding(self, x, scale):
        delta_H = 0
        delta_W = 0
        if x.shape[2] % scale != 0:
            delta_H = scale - x.shape[2] % scale
            x = F.pad(x, (0, 0, 0, delta_H), 'reflect')
        if x.shape[3] % scale != 0:
            delta_W = scale - x.shape[3] % scale
            x = F.pad(x, (0, delta_W, 0, 0), 'reflect')
        return x, delta_H, delta_W

    def _padding_2(self, x):
        _, _, H, W = x.shape
        delta = abs(H-W)
        if H < W:
            x = F.pad(x, (0, 0, 0, delta), 'reflect')
        elif H > W:
            x = F.pad(x, (0, delta, 0, 0), 'reflect')
        return x


    def forward(self, x):
        _, _, H, W = x.shape
        # x = self._padding_2(x)
        x, delta_H, delta_W = self._padding(x, 2)
        # print(x.shape)
        x = self.DWT(x)
        # x = self.DWT(x)
        x = self.trans(x)
        # x = self.IWT(x)
        # x = self.IWT(x)
        x = self.IWT(x)

        x = x[:, :, :H, :W]

        return x

if __name__ == "__main__":

    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    # input size: [batch_size, C, N], where C is number of dimension, N is the number of mesh.
    x = torch.rand(2, 3, 64, 64)
    x = x.cuda()
    # x = x.cuda()
    model = EWT()
    # model = model.cuda()
    # y = model(x)
    # print(y.shape)

    def get_parameter_number(model):
        total_num = sum(p.numel() for p in model.parameters())
        trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
        return {'Total': total_num, 'Trainable': trainable_num}

    print(get_parameter_number(model))