Spaces:
Sleeping
Sleeping
| # Adapted from | |
| # https://github.com/arnaghosh/Auto-Encoder/blob/master/resnet.py | |
| import torch | |
| from torch.autograd import Variable | |
| import torchvision | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import datasets, models,transforms | |
| import torch.optim as optim | |
| from torch.optim import lr_scheduler | |
| import numpy as np | |
| import os | |
| import matplotlib.pyplot as plt | |
| from torch.autograd import Function | |
| from collections import OrderedDict | |
| import torch.nn as nn | |
| import math | |
| import torchvision.models as models | |
| zsize=48 | |
| def conv3x3(in_planes, out_planes, stride=1): | |
| """3x3 convolution with padding""" | |
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |
| padding=1, bias=False) | |
| class BasicBlock(nn.Module): | |
| expansion = 1 | |
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |
| super(BasicBlock, self).__init__() | |
| self.conv1 = conv3x3(inplanes, planes, stride) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv2 = conv3x3(planes, planes) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| self.downsample = downsample | |
| self.stride = stride | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| if self.downsample is not None: | |
| residual = self.downsample(x) | |
| out += residual | |
| out = self.relu(out) | |
| return out | |
| class Bottleneck(nn.Module): | |
| expansion = 4 | |
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |
| super(Bottleneck, self).__init__() | |
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, | |
| padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | |
| self.bn3 = nn.BatchNorm2d(planes * 4) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.downsample = downsample | |
| self.stride = stride | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out = self.bn3(out) | |
| if self.downsample is not None: | |
| residual = self.downsample(x) | |
| out += residual | |
| out = self.relu(out) | |
| return out | |
| class Encoder(nn.Module): | |
| def __init__(self, block, layers, num_classes=23): | |
| self.inplanes = 64 | |
| super (Encoder, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, | |
| bias=False) | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)#, return_indices = True) | |
| self.layer1 = self._make_layer(block, 64, layers[0]) | |
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |
| self.avgpool = nn.AvgPool2d(7, stride=1) | |
| self.fc = nn.Linear(512 * block.expansion, 1000) | |
| #self.fc = nn.Linear(num_classes,16) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
| m.weight.data.normal_(0, math.sqrt(2. / n)) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| def _make_layer(self, block, planes, blocks, stride=1): | |
| downsample = None | |
| if stride != 1 or self.inplanes != planes * block.expansion: | |
| downsample = nn.Sequential( | |
| nn.Conv2d(self.inplanes, planes * block.expansion, | |
| kernel_size=1, stride=stride, bias=False), | |
| nn.BatchNorm2d(planes * block.expansion), | |
| ) | |
| layers = [] | |
| layers.append(block(self.inplanes, planes, stride, downsample)) | |
| self.inplanes = planes * block.expansion | |
| for i in range(1, blocks): | |
| layers.append(block(self.inplanes, planes)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.avgpool(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.fc(x) | |
| return x | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| encoder = Encoder(Bottleneck, [3, 4, 6, 3]) | |
| encoder_state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') | |
| encoder.load_state_dict(encoder_state_dict) | |
| encoder.fc = nn.Linear(2048, 48) | |
| encoder=encoder.to(device) | |
| class Binary(Function): | |
| def forward(ctx, input): | |
| return F.relu(Variable(input.sign())).data | |
| def backward(ctx, grad_output): | |
| return grad_output | |
| class Decoder(nn.Module): | |
| def __init__(self): | |
| super(Decoder,self).__init__() | |
| self.dfc3 = nn.Linear(zsize, 4096) | |
| self.bn3 = nn.BatchNorm1d(4096) | |
| self.dfc2 = nn.Linear(4096, 4096) | |
| self.bn2 = nn.BatchNorm1d(4096) | |
| self.dfc1 = nn.Linear(4096,256 * 6 * 6) | |
| self.bn1 = nn.BatchNorm1d(256*6*6) | |
| self.upsample1=nn.Upsample(scale_factor=2) | |
| self.dconv5 = nn.ConvTranspose2d(256, 256, 3, padding = 0) | |
| self.dconv4 = nn.ConvTranspose2d(256, 384, 3, padding = 1) | |
| self.dconv3 = nn.ConvTranspose2d(384, 192, 3, padding = 1) | |
| self.dconv2 = nn.ConvTranspose2d(192, 64, 5, padding = 2) | |
| self.dconv1 = nn.ConvTranspose2d(64, 3, 12, stride = 4, padding = 4) | |
| def forward(self,x):#,i1,i2,i3): | |
| x = self.dfc3(x) | |
| #x = F.relu(x) | |
| x = F.relu(self.bn3(x)) | |
| x = self.dfc2(x) | |
| x = F.relu(self.bn2(x)) | |
| #x = F.relu(x) | |
| x = self.dfc1(x) | |
| x = F.relu(self.bn1(x)) | |
| #x = F.relu(x) | |
| #print(x.size()) | |
| x = x.view(x.shape[0],256,6,6) | |
| #print (x.size()) | |
| x=self.upsample1(x) | |
| #print x.size() | |
| x = self.dconv5(x) | |
| #print x.size() | |
| x = F.relu(x) | |
| #print x.size() | |
| x = F.relu(self.dconv4(x)) | |
| #print x.size() | |
| x = F.relu(self.dconv3(x)) | |
| #print x.size() | |
| x=self.upsample1(x) | |
| #print x.size() | |
| x = self.dconv2(x) | |
| #print x.size() | |
| x = F.relu(x) | |
| x=self.upsample1(x) | |
| #print x.size() | |
| x = self.dconv1(x) | |
| #print x.size() | |
| x = torch.sigmoid(x) | |
| #print x | |
| return x | |
| class Autoencoder(nn.Module): | |
| def __init__(self): | |
| super(Autoencoder,self).__init__() | |
| self.encoder = encoder | |
| self.binary = Binary() | |
| self.decoder = Decoder() | |
| def forward(self,x): | |
| #x=Encoder(x) | |
| x = self.encoder(x) | |
| x = self.binary.apply(x) | |
| #print x | |
| #x,i2,i1 = self.binary(x) | |
| #x=Variable(x) | |
| x = self.decoder(x) | |
| return x | |