# Adapted from # https://github.com/arnaghosh/Auto-Encoder/blob/master/resnet.py import torch from torch.autograd import Variable import torchvision import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, models,transforms import torch.optim as optim from torch.optim import lr_scheduler import numpy as np import os import matplotlib.pyplot as plt from torch.autograd import Function from collections import OrderedDict import torch.nn as nn import math import torchvision.models as models zsize=48 def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Encoder(nn.Module): def __init__(self, block, layers, num_classes=23): self.inplanes = 64 super (Encoder, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)#, return_indices = True) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.avgpool = nn.AvgPool2d(7, stride=1) self.fc = nn.Linear(512 * block.expansion, 1000) #self.fc = nn.Linear(num_classes,16) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') encoder = Encoder(Bottleneck, [3, 4, 6, 3]) encoder_state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') encoder.load_state_dict(encoder_state_dict) encoder.fc = nn.Linear(2048, 48) encoder=encoder.to(device) class Binary(Function): @staticmethod def forward(ctx, input): return F.relu(Variable(input.sign())).data @staticmethod def backward(ctx, grad_output): return grad_output class Decoder(nn.Module): def __init__(self): super(Decoder,self).__init__() self.dfc3 = nn.Linear(zsize, 4096) self.bn3 = nn.BatchNorm1d(4096) self.dfc2 = nn.Linear(4096, 4096) self.bn2 = nn.BatchNorm1d(4096) self.dfc1 = nn.Linear(4096,256 * 6 * 6) self.bn1 = nn.BatchNorm1d(256*6*6) self.upsample1=nn.Upsample(scale_factor=2) self.dconv5 = nn.ConvTranspose2d(256, 256, 3, padding = 0) self.dconv4 = nn.ConvTranspose2d(256, 384, 3, padding = 1) self.dconv3 = nn.ConvTranspose2d(384, 192, 3, padding = 1) self.dconv2 = nn.ConvTranspose2d(192, 64, 5, padding = 2) self.dconv1 = nn.ConvTranspose2d(64, 3, 12, stride = 4, padding = 4) def forward(self,x):#,i1,i2,i3): x = self.dfc3(x) #x = F.relu(x) x = F.relu(self.bn3(x)) x = self.dfc2(x) x = F.relu(self.bn2(x)) #x = F.relu(x) x = self.dfc1(x) x = F.relu(self.bn1(x)) #x = F.relu(x) #print(x.size()) x = x.view(x.shape[0],256,6,6) #print (x.size()) x=self.upsample1(x) #print x.size() x = self.dconv5(x) #print x.size() x = F.relu(x) #print x.size() x = F.relu(self.dconv4(x)) #print x.size() x = F.relu(self.dconv3(x)) #print x.size() x=self.upsample1(x) #print x.size() x = self.dconv2(x) #print x.size() x = F.relu(x) x=self.upsample1(x) #print x.size() x = self.dconv1(x) #print x.size() x = torch.sigmoid(x) #print x return x class Autoencoder(nn.Module): def __init__(self): super(Autoencoder,self).__init__() self.encoder = encoder self.binary = Binary() self.decoder = Decoder() def forward(self,x): #x=Encoder(x) x = self.encoder(x) x = self.binary.apply(x) #print x #x,i2,i1 = self.binary(x) #x=Variable(x) x = self.decoder(x) return x