Work commited on
Commit
b8f5ef0
·
1 Parent(s): c4b1a94

refactor app to use (newly added) autoencoder class

Browse files
Files changed (2) hide show
  1. app.py +4 -3
  2. autoencoder.py +252 -0
app.py CHANGED
@@ -3,11 +3,12 @@ import gradio as gr
3
  import torch
4
  from torchvision.transforms import Resize, ToTensor
5
 
6
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
7
 
8
- model = torch.nn.Module()
9
 
10
- model.load_state_dict(torch.load('model.pt', map_location=device))
 
11
 
12
  resize = Resize((224))
13
  to_tensor = ToTensor()
 
3
  import torch
4
  from torchvision.transforms import Resize, ToTensor
5
 
6
+ from autoencoder import Autoencoder
7
 
8
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9
 
10
+ model = Autoencoder()
11
+ model.load_state_dict('model.pt', map_location=device)
12
 
13
  resize = Resize((224))
14
  to_tensor = ToTensor()
autoencoder.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from
2
+ # https://github.com/arnaghosh/Auto-Encoder/blob/master/resnet.py
3
+
4
+ import torch
5
+ from torch.autograd import Variable
6
+ import torchvision
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torchvision import datasets, models,transforms
10
+ import torch.optim as optim
11
+ from torch.optim import lr_scheduler
12
+ import numpy as np
13
+ import os
14
+ import matplotlib.pyplot as plt
15
+ from torch.autograd import Function
16
+ from collections import OrderedDict
17
+ import torch.nn as nn
18
+ import math
19
+
20
+ import torchvision.models as models
21
+
22
+ zsize=48
23
+
24
+ def conv3x3(in_planes, out_planes, stride=1):
25
+ """3x3 convolution with padding"""
26
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27
+ padding=1, bias=False)
28
+
29
+ class BasicBlock(nn.Module):
30
+ expansion = 1
31
+
32
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
33
+ super(BasicBlock, self).__init__()
34
+ self.conv1 = conv3x3(inplanes, planes, stride)
35
+ self.bn1 = nn.BatchNorm2d(planes)
36
+ self.relu = nn.ReLU(inplace=True)
37
+ self.conv2 = conv3x3(planes, planes)
38
+ self.bn2 = nn.BatchNorm2d(planes)
39
+ self.downsample = downsample
40
+ self.stride = stride
41
+
42
+ def forward(self, x):
43
+ residual = x
44
+
45
+ out = self.conv1(x)
46
+ out = self.bn1(out)
47
+ out = self.relu(out)
48
+
49
+ out = self.conv2(out)
50
+ out = self.bn2(out)
51
+
52
+ if self.downsample is not None:
53
+ residual = self.downsample(x)
54
+
55
+ out += residual
56
+ out = self.relu(out)
57
+
58
+ return out
59
+
60
+
61
+ class Bottleneck(nn.Module):
62
+ expansion = 4
63
+
64
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
65
+ super(Bottleneck, self).__init__()
66
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
67
+ self.bn1 = nn.BatchNorm2d(planes)
68
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
69
+ padding=1, bias=False)
70
+ self.bn2 = nn.BatchNorm2d(planes)
71
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
72
+ self.bn3 = nn.BatchNorm2d(planes * 4)
73
+ self.relu = nn.ReLU(inplace=True)
74
+ self.downsample = downsample
75
+ self.stride = stride
76
+
77
+ def forward(self, x):
78
+ residual = x
79
+
80
+ out = self.conv1(x)
81
+
82
+ out = self.bn1(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv2(out)
86
+ out = self.bn2(out)
87
+ out = self.relu(out)
88
+
89
+ out = self.conv3(out)
90
+ out = self.bn3(out)
91
+
92
+ if self.downsample is not None:
93
+ residual = self.downsample(x)
94
+
95
+ out += residual
96
+ out = self.relu(out)
97
+
98
+ return out
99
+
100
+
101
+ class Encoder(nn.Module):
102
+
103
+ def __init__(self, block, layers, num_classes=23):
104
+ self.inplanes = 64
105
+ super (Encoder, self).__init__()
106
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
107
+ bias=False)
108
+ self.bn1 = nn.BatchNorm2d(64)
109
+ self.relu = nn.ReLU(inplace=True)
110
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)#, return_indices = True)
111
+ self.layer1 = self._make_layer(block, 64, layers[0])
112
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
113
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
114
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
115
+ self.avgpool = nn.AvgPool2d(7, stride=1)
116
+ self.fc = nn.Linear(512 * block.expansion, 1000)
117
+ #self.fc = nn.Linear(num_classes,16)
118
+ for m in self.modules():
119
+ if isinstance(m, nn.Conv2d):
120
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
121
+ m.weight.data.normal_(0, math.sqrt(2. / n))
122
+ elif isinstance(m, nn.BatchNorm2d):
123
+ m.weight.data.fill_(1)
124
+ m.bias.data.zero_()
125
+
126
+ def _make_layer(self, block, planes, blocks, stride=1):
127
+ downsample = None
128
+ if stride != 1 or self.inplanes != planes * block.expansion:
129
+ downsample = nn.Sequential(
130
+ nn.Conv2d(self.inplanes, planes * block.expansion,
131
+ kernel_size=1, stride=stride, bias=False),
132
+ nn.BatchNorm2d(planes * block.expansion),
133
+ )
134
+
135
+ layers = []
136
+ layers.append(block(self.inplanes, planes, stride, downsample))
137
+ self.inplanes = planes * block.expansion
138
+ for i in range(1, blocks):
139
+ layers.append(block(self.inplanes, planes))
140
+
141
+ return nn.Sequential(*layers)
142
+
143
+ def forward(self, x):
144
+ x = self.conv1(x)
145
+
146
+ x = self.bn1(x)
147
+ x = self.relu(x)
148
+
149
+ x = self.maxpool(x)
150
+
151
+ x = self.layer1(x)
152
+ x = self.layer2(x)
153
+ x = self.layer3(x)
154
+ x = self.layer4(x)
155
+
156
+ x = self.avgpool(x)
157
+ x = x.view(x.size(0), -1)
158
+ x = self.fc(x)
159
+
160
+ return x
161
+
162
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
163
+
164
+ encoder = Encoder(Bottleneck, [3, 4, 6, 3])
165
+ encoder_state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth')
166
+ encoder.load_state_dict(encoder_state_dict)
167
+ encoder.fc = nn.Linear(2048, 48)
168
+ encoder=encoder.to(device)
169
+
170
+
171
+ class Binary(Function):
172
+
173
+ @staticmethod
174
+ def forward(ctx, input):
175
+ return F.relu(Variable(input.sign())).data
176
+
177
+ @staticmethod
178
+ def backward(ctx, grad_output):
179
+ return grad_output
180
+
181
+
182
+ class Decoder(nn.Module):
183
+ def __init__(self):
184
+ super(Decoder,self).__init__()
185
+ self.dfc3 = nn.Linear(zsize, 4096)
186
+ self.bn3 = nn.BatchNorm1d(4096)
187
+ self.dfc2 = nn.Linear(4096, 4096)
188
+ self.bn2 = nn.BatchNorm1d(4096)
189
+ self.dfc1 = nn.Linear(4096,256 * 6 * 6)
190
+ self.bn1 = nn.BatchNorm1d(256*6*6)
191
+ self.upsample1=nn.Upsample(scale_factor=2)
192
+ self.dconv5 = nn.ConvTranspose2d(256, 256, 3, padding = 0)
193
+ self.dconv4 = nn.ConvTranspose2d(256, 384, 3, padding = 1)
194
+ self.dconv3 = nn.ConvTranspose2d(384, 192, 3, padding = 1)
195
+ self.dconv2 = nn.ConvTranspose2d(192, 64, 5, padding = 2)
196
+ self.dconv1 = nn.ConvTranspose2d(64, 3, 12, stride = 4, padding = 4)
197
+
198
+ def forward(self,x):#,i1,i2,i3):
199
+
200
+ x = self.dfc3(x)
201
+ #x = F.relu(x)
202
+ x = F.relu(self.bn3(x))
203
+
204
+ x = self.dfc2(x)
205
+ x = F.relu(self.bn2(x))
206
+ #x = F.relu(x)
207
+ x = self.dfc1(x)
208
+ x = F.relu(self.bn1(x))
209
+ #x = F.relu(x)
210
+ #print(x.size())
211
+ x = x.view(x.shape[0],256,6,6)
212
+ #print (x.size())
213
+ x=self.upsample1(x)
214
+ #print x.size()
215
+ x = self.dconv5(x)
216
+ #print x.size()
217
+ x = F.relu(x)
218
+ #print x.size()
219
+ x = F.relu(self.dconv4(x))
220
+ #print x.size()
221
+ x = F.relu(self.dconv3(x))
222
+ #print x.size()
223
+ x=self.upsample1(x)
224
+ #print x.size()
225
+ x = self.dconv2(x)
226
+ #print x.size()
227
+ x = F.relu(x)
228
+ x=self.upsample1(x)
229
+ #print x.size()
230
+ x = self.dconv1(x)
231
+ #print x.size()
232
+ x = torch.sigmoid(x)
233
+ #print x
234
+ return x
235
+
236
+
237
+ class Autoencoder(nn.Module):
238
+ def __init__(self):
239
+ super(Autoencoder,self).__init__()
240
+ self.encoder = encoder
241
+ self.binary = Binary()
242
+ self.decoder = Decoder()
243
+
244
+ def forward(self,x):
245
+ #x=Encoder(x)
246
+ x = self.encoder(x)
247
+ x = self.binary.apply(x)
248
+ #print x
249
+ #x,i2,i1 = self.binary(x)
250
+ #x=Variable(x)
251
+ x = self.decoder(x)
252
+ return x