ViT_Autoencoder / objectives.py
detectivejoewest's picture
Upload 7 files
582b238 verified
import torch.nn as nn
from torchvision.models import vgg16, VGG16_Weights
class Discriminator(nn.Module):
def __init__(self, img_shape, filters=[256,512]):
super().__init__()
module_list = [nn.Conv2d(img_shape[0], filters[0], kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(filters[0]),
nn.LeakyReLU(0.2)]
for i in range(1,len(filters)):
module_list += [nn.Conv2d(filters[i-1], filters[i], kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(filters[i]),
nn.LeakyReLU(0.2)]
self.convs = nn.Sequential(*module_list)
self.mlp = nn.Sequential(nn.Conv2d(filters[-1], 1, kernel_size=1, stride=1, padding=0))
def forward(self, x):
x = self.convs(x)
x = self.mlp(x)
return x
class vgg_builder(nn.Module):
def __init__(self):
super(vgg_builder, self).__init__()
convs = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
self.N_slices = 5
self.slices = nn.ModuleList(list(nn.Sequential() for _ in range(self.N_slices)))
for x in range(4):
self.slices[0].add_module(str(x), convs[x])
for x in range(4, 9):
self.slices[1].add_module(str(x), convs[x])
for x in range(9, 16):
self.slices[2].add_module(str(x), convs[x])
for x in range(16, 23):
self.slices[3].add_module(str(x), convs[x])
for x in range(23, 30):
self.slices[4].add_module(str(x), convs[x])
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
feat_map = []
x = (x+1)/2
x = self.slices[0](x)
feat_map.append(x)
x = self.slices[1](x)
feat_map.append(x)
x = self.slices[2](x)
feat_map.append(x)
x = self.slices[3](x)
feat_map.append(x)
x = self.slices[4](x)
feat_map.append(x)
return feat_map