luismidv's picture
New
373085f
import torch
import torch.nn as nn
class VGG16(torch.nn.Module):
def __init__(self, num_features, num_classes):
super(VGG16, self).__init__()
# calculate same padding:
# (w - k + 2*p)/s + 1 = o
# => p = (s(o-1) - w + k)/2
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=3,
out_channels=64,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),)
# (1(32-1)- 32 + 3)/2 = 1
self.block_1 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(in_channels=64,
out_channels=64,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
#nn.MaxPool2d(kernel_size=(2, 2),stride=(2, 2))
)
self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),)
self.block_2 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(in_channels=128,
out_channels=128,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
#nn.MaxPool2d(kernel_size=(2, 2),stride=(2, 2))
)
self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=128,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),)
self.block_3 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
#nn.MaxPool2d(kernel_size=(2, 2),stride=(2, 2))
)
self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=256,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),)
self.block_4 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
#nn.MaxPool2d(kernel_size=(2, 2),stride=(2, 2))
)
self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),)
self.block_5 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
nn.ReLU(),
#nn.MaxPool2d(kernel_size=(2, 2),stride=(2, 2))
)
self.classifier = nn.Sequential(
nn.Linear(512, 4096),
nn.ReLU(True),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Linear(4096, num_classes)
)
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
#n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
#m.weight.data.normal_(0, np.sqrt(2. / n))
m.weight.detach().normal_(0, 0.05)
if m.bias is not None:
m.bias.detach().zero_()
elif isinstance(m, torch.nn.Linear):
m.weight.detach().normal_(0, 0.05)
m.bias.detach().detach().zero_()
def forward(self, x):
x_conv = self.conv1_1(x)
x = self.block_1(x_conv)
x2_conv = self.conv2_1(x)
x_2 = self.block_2(x2_conv)
x3_conv = self.conv3_1(x_2)
x_3 = self.block_3(x3_conv)
x4_conv = self.conv4_1(x_3)
x_4 = self.block_4(x4_conv)
x5_conv = self.conv5_1(x_4)
x_5 = self.block_5(x5_conv)
result_dict = {
"style" : [x_conv, x2_conv,x3_conv, x4_conv, x5_conv],
"content" : [x,x_2,x_3,x_4,x_5]
}
#logits = self.classifier(x.view(-1, 512))
#probas = F.softmax(logits, dim=1)
return result_dict