dongjoaquin commited on
Commit
6d55688
·
verified ·
1 Parent(s): 5c85af8

Upload modules.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modules.py +225 -0
modules.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+
5
+ def convrelu(in_channels, out_channels, kernel, padding, pool):
6
+ return nn.Sequential(
7
+ nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
8
+ #In conv, the dimension of the output, if the input is H,W, is
9
+ # H+2*padding-kernel +1
10
+ nn.ReLU(inplace=True),
11
+ nn.MaxPool2d(pool, stride=pool, padding=0, dilation=1, return_indices=False, ceil_mode=False)
12
+ #pooling takes Height H and width W to (H-pool)/pool+1 = H/pool, and floor. Same for W.
13
+ #altogether, the output size is (H+2*padding-kernel +1)/pool.
14
+ )
15
+
16
+ def convreluT(in_channels, out_channels, kernel, padding):
17
+ return nn.Sequential(
18
+ nn.ConvTranspose2d(in_channels, out_channels, kernel, stride=2, padding=padding),
19
+ nn.ReLU(inplace=True)
20
+ #input is H X W, output is (H-1)*2 - 2*padding + kernel
21
+ )
22
+
23
+
24
+
25
+ class RadioWNet(nn.Module):
26
+
27
+ def __init__(self,inputs=2,phase="firstU"):
28
+ super().__init__()
29
+
30
+ self.inputs=inputs
31
+ self.phase=phase
32
+
33
+ if inputs<=3:
34
+ self.layer00 = convrelu(inputs, 6, 3, 1,1)
35
+ self.layer0 = convrelu(6, 40, 5, 2,2)
36
+ else:
37
+ self.layer00 = convrelu(inputs, 10, 3, 1,1)
38
+ self.layer0 = convrelu(10, 40, 5, 2,2)
39
+
40
+ self.layer1 = convrelu(40, 50, 5, 2,2)
41
+ self.layer10 = convrelu(50, 60, 5, 2,1)
42
+ self.layer2 = convrelu(60, 100, 5, 2,2)
43
+ self.layer20 = convrelu(100, 100, 3, 1,1)
44
+ self.layer3 = convrelu(100, 150, 5, 2,2)
45
+ self.layer4 =convrelu(150, 300, 5, 2,2)
46
+ self.layer5 =convrelu(300, 500, 5, 2,2)
47
+
48
+ self.conv_up5 =convreluT(500, 300, 4, 1)
49
+ self.conv_up4 = convreluT(300+300, 150, 4, 1)
50
+ self.conv_up3 = convreluT(150 + 150, 100, 4, 1)
51
+ self.conv_up20 = convrelu(100 + 100, 100, 3, 1, 1)
52
+ self.conv_up2 = convreluT(100 + 100, 60, 6, 2)
53
+ self.conv_up10 = convrelu(60 + 60, 50, 5, 2, 1)
54
+ self.conv_up1 = convreluT(50 + 50, 40, 6, 2)
55
+ self.conv_up0 = convreluT(40 + 40, 20, 6, 2)
56
+ if inputs<=3:
57
+ self.conv_up00 = convrelu(20+6+inputs, 20, 5, 2,1)
58
+
59
+ else:
60
+ self.conv_up00 = convrelu(20+10+inputs, 20, 5, 2,1)
61
+
62
+ self.conv_up000 = convrelu(20+inputs, 1, 5, 2,1)
63
+
64
+ self.Wlayer00 = convrelu(inputs+1, 20, 3, 1,1)
65
+ self.Wlayer0 = convrelu(20, 30, 5, 2,2)
66
+ self.Wlayer1 = convrelu(30, 40, 5, 2,2)
67
+ self.Wlayer10 = convrelu(40, 50, 5, 2,1)
68
+ self.Wlayer2 = convrelu(50, 60, 5, 2,2)
69
+ self.Wlayer20 = convrelu(60, 70, 3, 1,1)
70
+ self.Wlayer3 = convrelu(70, 90, 5, 2,2)
71
+ self.Wlayer4 =convrelu(90, 110, 5, 2,2)
72
+ self.Wlayer5 =convrelu(110, 150, 5, 2,2)
73
+
74
+ self.Wconv_up5 =convreluT(150, 110, 4, 1)
75
+ self.Wconv_up4 = convreluT(110+110, 90, 4, 1)
76
+ self.Wconv_up3 = convreluT(90 + 90, 70, 4, 1)
77
+ self.Wconv_up20 = convrelu(70 + 70, 60, 3, 1, 1)
78
+ self.Wconv_up2 = convreluT(60 + 60, 50, 6, 2)
79
+ self.Wconv_up10 = convrelu(50 + 50, 40, 5, 2, 1)
80
+ self.Wconv_up1 = convreluT(40 + 40, 30, 6, 2)
81
+ self.Wconv_up0 = convreluT(30 + 30, 20, 6, 2)
82
+ self.Wconv_up00 = convrelu(20+20+inputs+1, 20, 5, 2,1)
83
+ self.Wconv_up000 = convrelu(20+inputs+1, 1, 5, 2,1)
84
+
85
+ def forward(self, input):
86
+
87
+ input0=input[:,0:self.inputs,:,:]
88
+
89
+ if self.phase=="firstU":
90
+ layer00 = self.layer00(input0)
91
+ layer0 = self.layer0(layer00)
92
+ layer1 = self.layer1(layer0)
93
+ layer10 = self.layer10(layer1)
94
+ layer2 = self.layer2(layer10)
95
+ layer20 = self.layer20(layer2)
96
+ layer3 = self.layer3(layer20)
97
+ layer4 = self.layer4(layer3)
98
+ layer5 = self.layer5(layer4)
99
+
100
+ layer4u = self.conv_up5(layer5)
101
+ layer4u = torch.cat([layer4u, layer4], dim=1)
102
+ layer3u = self.conv_up4(layer4u)
103
+ layer3u = torch.cat([layer3u, layer3], dim=1)
104
+ layer20u = self.conv_up3(layer3u)
105
+ layer20u = torch.cat([layer20u, layer20], dim=1)
106
+ layer2u = self.conv_up20(layer20u)
107
+ layer2u = torch.cat([layer2u, layer2], dim=1)
108
+ layer10u = self.conv_up2(layer2u)
109
+ layer10u = torch.cat([layer10u, layer10], dim=1)
110
+ layer1u = self.conv_up10(layer10u)
111
+ layer1u = torch.cat([layer1u, layer1], dim=1)
112
+ layer0u = self.conv_up1(layer1u)
113
+ layer0u = torch.cat([layer0u, layer0], dim=1)
114
+ layer00u = self.conv_up0(layer0u)
115
+ layer00u = torch.cat([layer00u, layer00], dim=1)
116
+ layer00u = torch.cat([layer00u,input0], dim=1)
117
+ layer000u = self.conv_up00(layer00u)
118
+ layer000u = torch.cat([layer000u,input0], dim=1)
119
+ output1 = self.conv_up000(layer000u)
120
+
121
+ Winput=torch.cat([output1, input], dim=1).detach()
122
+
123
+ Wlayer00 = self.Wlayer00(Winput).detach()
124
+ Wlayer0 = self.Wlayer0(Wlayer00).detach()
125
+ Wlayer1 = self.Wlayer1(Wlayer0).detach()
126
+ Wlayer10 = self.Wlayer10(Wlayer1).detach()
127
+ Wlayer2 = self.Wlayer2(Wlayer10).detach()
128
+ Wlayer20 = self.Wlayer20(Wlayer2).detach()
129
+ Wlayer3 = self.Wlayer3(Wlayer20).detach()
130
+ Wlayer4 = self.Wlayer4(Wlayer3).detach()
131
+ Wlayer5 = self.Wlayer5(Wlayer4).detach()
132
+
133
+ Wlayer4u = self.Wconv_up5(Wlayer5).detach()
134
+ Wlayer4u = torch.cat([Wlayer4u, Wlayer4], dim=1).detach()
135
+ Wlayer3u = self.Wconv_up4(Wlayer4u).detach()
136
+ Wlayer3u = torch.cat([Wlayer3u, Wlayer3], dim=1).detach()
137
+ Wlayer20u = self.Wconv_up3(Wlayer3u).detach()
138
+ Wlayer20u = torch.cat([Wlayer20u, Wlayer20], dim=1).detach()
139
+ Wlayer2u = self.Wconv_up20(Wlayer20u).detach()
140
+ Wlayer2u = torch.cat([Wlayer2u, Wlayer2], dim=1).detach()
141
+ Wlayer10u = self.Wconv_up2(Wlayer2u).detach()
142
+ Wlayer10u = torch.cat([Wlayer10u, Wlayer10], dim=1).detach()
143
+ Wlayer1u = self.Wconv_up10(Wlayer10u).detach()
144
+ Wlayer1u = torch.cat([Wlayer1u, Wlayer1], dim=1).detach()
145
+ Wlayer0u = self.Wconv_up1(Wlayer1u).detach()
146
+ Wlayer0u = torch.cat([Wlayer0u, Wlayer0], dim=1).detach()
147
+ Wlayer00u = self.Wconv_up0(Wlayer0u).detach()
148
+ Wlayer00u = torch.cat([Wlayer00u, Wlayer00], dim=1).detach()
149
+ Wlayer00u = torch.cat([Wlayer00u,Winput], dim=1).detach()
150
+ Wlayer000u = self.Wconv_up00(Wlayer00u).detach()
151
+ Wlayer000u = torch.cat([Wlayer000u,Winput], dim=1).detach()
152
+ output2 = self.Wconv_up000(Wlayer000u).detach()
153
+
154
+ else:
155
+ layer00 = self.layer00(input0).detach()
156
+ layer0 = self.layer0(layer00).detach()
157
+ layer1 = self.layer1(layer0).detach()
158
+ layer10 = self.layer10(layer1).detach()
159
+ layer2 = self.layer2(layer10).detach()
160
+ layer20 = self.layer20(layer2).detach()
161
+ layer3 = self.layer3(layer20).detach()
162
+ layer4 = self.layer4(layer3).detach()
163
+ layer5 = self.layer5(layer4).detach()
164
+
165
+ layer4u = self.conv_up5(layer5).detach()
166
+ layer4u = torch.cat([layer4u, layer4], dim=1).detach()
167
+ layer3u = self.conv_up4(layer4u).detach()
168
+ layer3u = torch.cat([layer3u, layer3], dim=1).detach()
169
+ layer20u = self.conv_up3(layer3u).detach()
170
+ layer20u = torch.cat([layer20u, layer20], dim=1).detach()
171
+ layer2u = self.conv_up20(layer20u).detach()
172
+ layer2u = torch.cat([layer2u, layer2], dim=1).detach()
173
+ layer10u = self.conv_up2(layer2u).detach()
174
+ layer10u = torch.cat([layer10u, layer10], dim=1).detach()
175
+ layer1u = self.conv_up10(layer10u).detach()
176
+ layer1u = torch.cat([layer1u, layer1], dim=1).detach()
177
+ layer0u = self.conv_up1(layer1u).detach()
178
+ layer0u = torch.cat([layer0u, layer0], dim=1).detach()
179
+ layer00u = self.conv_up0(layer0u).detach()
180
+ layer00u = torch.cat([layer00u, layer00], dim=1).detach()
181
+ layer00u = torch.cat([layer00u,input0], dim=1).detach()
182
+ layer000u = self.conv_up00(layer00u).detach()
183
+ layer000u = torch.cat([layer000u,input0], dim=1).detach()
184
+ output1 = self.conv_up000(layer000u).detach()
185
+
186
+ Winput=torch.cat([output1, input], dim=1).detach()
187
+
188
+ Wlayer00 = self.Wlayer00(Winput)
189
+ Wlayer0 = self.Wlayer0(Wlayer00)
190
+ Wlayer1 = self.Wlayer1(Wlayer0)
191
+ Wlayer10 = self.Wlayer10(Wlayer1)
192
+ Wlayer2 = self.Wlayer2(Wlayer10)
193
+ Wlayer20 = self.Wlayer20(Wlayer2)
194
+ Wlayer3 = self.Wlayer3(Wlayer20)
195
+ Wlayer4 = self.Wlayer4(Wlayer3)
196
+ Wlayer5 = self.Wlayer5(Wlayer4)
197
+
198
+ Wlayer4u = self.Wconv_up5(Wlayer5)
199
+ Wlayer4u = torch.cat([Wlayer4u, Wlayer4], dim=1)
200
+ Wlayer3u = self.Wconv_up4(Wlayer4u)
201
+ Wlayer3u = torch.cat([Wlayer3u, Wlayer3], dim=1)
202
+ Wlayer20u = self.Wconv_up3(Wlayer3u)
203
+ Wlayer20u = torch.cat([Wlayer20u, Wlayer20], dim=1)
204
+ Wlayer2u = self.Wconv_up20(Wlayer20u)
205
+ Wlayer2u = torch.cat([Wlayer2u, Wlayer2], dim=1)
206
+ Wlayer10u = self.Wconv_up2(Wlayer2u)
207
+ Wlayer10u = torch.cat([Wlayer10u, Wlayer10], dim=1)
208
+ Wlayer1u = self.Wconv_up10(Wlayer10u)
209
+ Wlayer1u = torch.cat([Wlayer1u, Wlayer1], dim=1)
210
+ Wlayer0u = self.Wconv_up1(Wlayer1u)
211
+ Wlayer0u = torch.cat([Wlayer0u, Wlayer0], dim=1)
212
+ Wlayer00u = self.Wconv_up0(Wlayer0u)
213
+ Wlayer00u = torch.cat([Wlayer00u, Wlayer00], dim=1)
214
+ Wlayer00u = torch.cat([Wlayer00u,Winput], dim=1)
215
+ Wlayer000u = self.Wconv_up00(Wlayer00u)
216
+ Wlayer000u = torch.cat([Wlayer000u,Winput], dim=1)
217
+ output2 = self.Wconv_up000(Wlayer000u)
218
+
219
+ return [output1,output2]
220
+
221
+
222
+
223
+
224
+
225
+