potato commited on
Commit
9dfc87e
Β·
1 Parent(s): 93461e3

Add model.py

Browse files
Files changed (1) hide show
  1. model.py +279 -0
model.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import functools
5
+ from torchvision import models
6
+ from torch.autograd import Variable
7
+ import numpy as np
8
+ import math
9
+
10
+ norm_layer = nn.InstanceNorm2d
11
+
12
+ class ResidualBlock(nn.Module):
13
+ def __init__(self, in_features):
14
+ super(ResidualBlock, self).__init__()
15
+
16
+ conv_block = [ nn.ReflectionPad2d(1),
17
+ nn.Conv2d(in_features, in_features, 3),
18
+ norm_layer(in_features),
19
+ nn.ReLU(inplace=True),
20
+ nn.ReflectionPad2d(1),
21
+ nn.Conv2d(in_features, in_features, 3),
22
+ norm_layer(in_features)
23
+ ]
24
+
25
+ self.conv_block = nn.Sequential(*conv_block)
26
+
27
+ def forward(self, x):
28
+ return x + self.conv_block(x)
29
+
30
+
31
+ class Generator(nn.Module):
32
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
33
+ super(Generator, self).__init__()
34
+
35
+ # Initial convolution block
36
+ model0 = [ nn.ReflectionPad2d(3),
37
+ nn.Conv2d(input_nc, 64, 7),
38
+ norm_layer(64),
39
+ nn.ReLU(inplace=True) ]
40
+ self.model0 = nn.Sequential(*model0)
41
+
42
+ # Downsampling
43
+ model1 = []
44
+ in_features = 64
45
+ out_features = in_features*2
46
+ for _ in range(2):
47
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
48
+ norm_layer(out_features),
49
+ nn.ReLU(inplace=True) ]
50
+ in_features = out_features
51
+ out_features = in_features*2
52
+ self.model1 = nn.Sequential(*model1)
53
+
54
+ model2 = []
55
+ # Residual blocks
56
+ for _ in range(n_residual_blocks):
57
+ model2 += [ResidualBlock(in_features)]
58
+ self.model2 = nn.Sequential(*model2)
59
+
60
+ # Upsampling
61
+ model3 = []
62
+ out_features = in_features//2
63
+ for _ in range(2):
64
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
65
+ norm_layer(out_features),
66
+ nn.ReLU(inplace=True) ]
67
+ in_features = out_features
68
+ out_features = in_features//2
69
+ self.model3 = nn.Sequential(*model3)
70
+
71
+ # Output layer
72
+ model4 = [ nn.ReflectionPad2d(3),
73
+ nn.Conv2d(64, output_nc, 7)]
74
+ if sigmoid:
75
+ model4 += [nn.Sigmoid()]
76
+
77
+ self.model4 = nn.Sequential(*model4)
78
+
79
+ def forward(self, x, cond=None):
80
+ out = self.model0(x)
81
+ out = self.model1(out)
82
+ out = self.model2(out)
83
+ out = self.model3(out)
84
+ out = self.model4(out)
85
+
86
+ return out
87
+
88
+ # Define a resnet block
89
+ class ResnetBlock(nn.Module):
90
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
91
+ super(ResnetBlock, self).__init__()
92
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
93
+
94
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
95
+ conv_block = []
96
+ p = 0
97
+ if padding_type == 'reflect':
98
+ conv_block += [nn.ReflectionPad2d(1)]
99
+ elif padding_type == 'replicate':
100
+ conv_block += [nn.ReplicationPad2d(1)]
101
+ elif padding_type == 'zero':
102
+ p = 1
103
+ else:
104
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
105
+
106
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
107
+ norm_layer(dim),
108
+ activation]
109
+ if use_dropout:
110
+ conv_block += [nn.Dropout(0.5)]
111
+
112
+ p = 0
113
+ if padding_type == 'reflect':
114
+ conv_block += [nn.ReflectionPad2d(1)]
115
+ elif padding_type == 'replicate':
116
+ conv_block += [nn.ReplicationPad2d(1)]
117
+ elif padding_type == 'zero':
118
+ p = 1
119
+ else:
120
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
121
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
122
+ norm_layer(dim)]
123
+
124
+ return nn.Sequential(*conv_block)
125
+
126
+ def forward(self, x):
127
+ out = x + self.conv_block(x)
128
+ return out
129
+
130
+ class GlobalGenerator2(nn.Module):
131
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
132
+ padding_type='reflect', use_sig=False, n_UPsampling=0):
133
+ assert(n_blocks >= 0)
134
+ super(GlobalGenerator2, self).__init__()
135
+ activation = nn.ReLU(True)
136
+
137
+ mult = 8
138
+ model = [nn.ReflectionPad2d(4), nn.Conv2d(input_nc, ngf*mult, kernel_size=7, padding=0), norm_layer(ngf*mult), activation]
139
+
140
+ ### downsample
141
+ for i in range(n_downsampling):
142
+ model += [nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, kernel_size=4, stride=2, padding=1),
143
+ norm_layer(ngf * mult // 2), activation]
144
+ mult = mult // 2
145
+
146
+ if n_UPsampling <= 0:
147
+ n_UPsampling = n_downsampling
148
+
149
+ ### resnet blocks
150
+ for i in range(n_blocks):
151
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
152
+
153
+ ### upsample
154
+ for i in range(n_UPsampling):
155
+ next_mult = mult // 2
156
+ if next_mult == 0:
157
+ next_mult = 1
158
+ mult = 1
159
+
160
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * next_mult), kernel_size=3, stride=2, padding=1, output_padding=1),
161
+ norm_layer(int(ngf * next_mult)), activation]
162
+ mult = next_mult
163
+
164
+ if use_sig:
165
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Sigmoid()]
166
+ else:
167
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
168
+ self.model = nn.Sequential(*model)
169
+
170
+ def forward(self, input, cond=None):
171
+ return self.model(input)
172
+
173
+
174
+ class InceptionV3(nn.Module): #avg pool
175
+ def __init__(self, num_classes, isTrain, use_aux=True, pretrain=False, freeze=True, every_feat=False):
176
+ super(InceptionV3, self).__init__()
177
+ """ Inception v3 expects (299,299) sized images for training and has auxiliary output
178
+ """
179
+
180
+ self.every_feat = every_feat
181
+
182
+ self.model_ft = models.inception_v3(pretrained=pretrain)
183
+ stop = 0
184
+ if freeze and pretrain:
185
+ for child in self.model_ft.children():
186
+ if stop < 17:
187
+ for param in child.parameters():
188
+ param.requires_grad = False
189
+ stop += 1
190
+
191
+ num_ftrs = self.model_ft.AuxLogits.fc.in_features #768
192
+ self.model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
193
+
194
+ # Handle the primary net
195
+ num_ftrs = self.model_ft.fc.in_features #2048
196
+ self.model_ft.fc = nn.Linear(num_ftrs,num_classes)
197
+
198
+ self.model_ft.input_size = 299
199
+
200
+ self.isTrain = isTrain
201
+ self.use_aux = use_aux
202
+
203
+ if self.isTrain:
204
+ self.model_ft.train()
205
+ else:
206
+ self.model_ft.eval()
207
+
208
+
209
+ def forward(self, x, cond=None, catch_gates=False):
210
+ # N x 3 x 299 x 299
211
+ x = self.model_ft.Conv2d_1a_3x3(x)
212
+
213
+ # N x 32 x 149 x 149
214
+ x = self.model_ft.Conv2d_2a_3x3(x)
215
+ # N x 32 x 147 x 147
216
+ x = self.model_ft.Conv2d_2b_3x3(x)
217
+ # N x 64 x 147 x 147
218
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
219
+ # N x 64 x 73 x 73
220
+ x = self.model_ft.Conv2d_3b_1x1(x)
221
+ # N x 80 x 73 x 73
222
+ x = self.model_ft.Conv2d_4a_3x3(x)
223
+
224
+ # N x 192 x 71 x 71
225
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
226
+ # N x 192 x 35 x 35
227
+ x = self.model_ft.Mixed_5b(x)
228
+ feat1 = x
229
+ # N x 256 x 35 x 35
230
+ x = self.model_ft.Mixed_5c(x)
231
+ feat11 = x
232
+ # N x 288 x 35 x 35
233
+ x = self.model_ft.Mixed_5d(x)
234
+ feat12 = x
235
+ # N x 288 x 35 x 35
236
+ x = self.model_ft.Mixed_6a(x)
237
+ feat2 = x
238
+ # N x 768 x 17 x 17
239
+ x = self.model_ft.Mixed_6b(x)
240
+ feat21 = x
241
+ # N x 768 x 17 x 17
242
+ x = self.model_ft.Mixed_6c(x)
243
+ feat22 = x
244
+ # N x 768 x 17 x 17
245
+ x = self.model_ft.Mixed_6d(x)
246
+ feat23 = x
247
+ # N x 768 x 17 x 17
248
+ x = self.model_ft.Mixed_6e(x)
249
+
250
+ feat3 = x
251
+
252
+ # N x 768 x 17 x 17
253
+ aux_defined = self.isTrain and self.use_aux
254
+ if aux_defined:
255
+ aux = self.model_ft.AuxLogits(x)
256
+ else:
257
+ aux = None
258
+ # N x 768 x 17 x 17
259
+ x = self.model_ft.Mixed_7a(x)
260
+ # N x 1280 x 8 x 8
261
+ x = self.model_ft.Mixed_7b(x)
262
+ # N x 2048 x 8 x 8
263
+ x = self.model_ft.Mixed_7c(x)
264
+ # N x 2048 x 8 x 8
265
+ # Adaptive average pooling
266
+ x = F.adaptive_avg_pool2d(x, (1, 1))
267
+ # N x 2048 x 1 x 1
268
+ feats = F.dropout(x, training=self.isTrain)
269
+ # N x 2048 x 1 x 1
270
+ x = torch.flatten(feats, 1)
271
+ # N x 2048
272
+ x = self.model_ft.fc(x)
273
+ # N x 1000 (num_classes)
274
+
275
+ if self.every_feat:
276
+ # return feat21, feats, x
277
+ return x, feat21
278
+
279
+ return x, aux