Baaz / code /model.py
SrinivasMudiraj's picture
Upload 34 files
9c7aa86 verified
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from collections import OrderedDict
class NetG(nn.Module):
def __init__(self, ngf=64, nz=100):
super(NetG, self).__init__()
self.ngf = ngf
self.fc = nn.Linear(nz, ngf*8*4*4)
self.block0 = G_Block(ngf * 8, ngf * 8)#4x4
self.block1 = G_Block(ngf * 8, ngf * 8)#8x8
self.block2 = G_Block(ngf * 8, ngf * 8)#16x16
self.block3 = G_Block(ngf * 8, ngf * 8)#32x32
self.block4 = G_Block(ngf * 8, ngf * 4)#64x64
self.block5 = G_Block(ngf * 4, ngf * 2)#128x128
self.block6 = G_Block(ngf * 2, ngf * 1)#256x256
self.conv_img = nn.Sequential(
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(ngf, 3, 3, 1, 1),
nn.Tanh(),
)
def forward(self, x, c):
out = self.fc(x)
out = out.view(x.size(0), 8*self.ngf, 4, 4)
out = self.block0(out,c)
out = F.interpolate(out, scale_factor=2)
out = self.block1(out,c)
out = F.interpolate(out, scale_factor=2)
out = self.block2(out,c)
out = F.interpolate(out, scale_factor=2)
out = self.block3(out,c)
out = F.interpolate(out, scale_factor=2)
out = self.block4(out,c)
out = F.interpolate(out, scale_factor=2)
out = self.block5(out,c)
out = F.interpolate(out, scale_factor=2)
out = self.block6(out,c)
out = self.conv_img(out)
return out
class G_Block(nn.Module):
def __init__(self, in_ch, out_ch):
super(G_Block, self).__init__()
self.learnable_sc = in_ch != out_ch
self.c1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
self.c2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
self.affine0 = affine(in_ch)
self.affine1 = affine(in_ch)
self.affine2 = affine(out_ch)
self.affine3 = affine(out_ch)
self.gamma = nn.Parameter(torch.zeros(1))
if self.learnable_sc:
self.c_sc = nn.Conv2d(in_ch,out_ch, 1, stride=1, padding=0)
def forward(self, x, y=None):
return self.shortcut(x) + self.gamma * self.residual(x, y)
def shortcut(self, x):
if self.learnable_sc:
x = self.c_sc(x)
return x
def residual(self, x, y=None):
h = self.affine0(x, y)
h = nn.LeakyReLU(0.2,inplace=True)(h)
h = self.affine1(h, y)
h = nn.LeakyReLU(0.2,inplace=True)(h)
h = self.c1(h)
h = self.affine2(h, y)
h = nn.LeakyReLU(0.2,inplace=True)(h)
h = self.affine3(h, y)
h = nn.LeakyReLU(0.2,inplace=True)(h)
return self.c2(h)
class affine(nn.Module):
def __init__(self, num_features):
super(affine, self).__init__()
self.fc_gamma = nn.Sequential(OrderedDict([
('linear1',nn.Linear(256, 256)),
('relu1',nn.ReLU(inplace=True)),
('linear2',nn.Linear(256, num_features)),
]))
self.fc_beta = nn.Sequential(OrderedDict([
('linear1',nn.Linear(256, 256)),
('relu1',nn.ReLU(inplace=True)),
('linear2',nn.Linear(256, num_features)),
]))
self._initialize()
def _initialize(self):
nn.init.zeros_(self.fc_gamma.linear2.weight.data)
nn.init.ones_(self.fc_gamma.linear2.bias.data)
nn.init.zeros_(self.fc_beta.linear2.weight.data)
nn.init.zeros_(self.fc_beta.linear2.bias.data)
def forward(self, x, y=None):
weight = self.fc_gamma(y)
bias = self.fc_beta(y)
if weight.dim() == 1:
weight = weight.unsqueeze(0)
if bias.dim() == 1:
bias = bias.unsqueeze(0)
size = x.size()
weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
return weight * x + bias
class D_GET_LOGITS(nn.Module):
def __init__(self, ndf):
super(D_GET_LOGITS, self).__init__()
self.df_dim = ndf
self.joint_conv = nn.Sequential(
nn.Conv2d(ndf * 16+256, ndf * 2, 3, 1, 1, bias=False),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(ndf * 2, 1, 4, 1, 0, bias=False),
)
def forward(self, out, y):
y = y.view(-1, 256, 1, 1)
y = y.repeat(1, 1, 4, 4)
h_c_code = torch.cat((out, y), 1)
out = self.joint_conv(h_c_code)
return out
class NetD(nn.Module):
def __init__(self, ndf):
super(NetD, self).__init__()
self.conv_img = nn.Conv2d(3, ndf, 3, 1, 1)#256*256s
self.block0 = resD(ndf * 1, ndf * 2)#128*128
self.block1 = resD(ndf * 2, ndf * 4)#64*64
self.block2 = resD(ndf * 4, ndf * 8)#32*32
self.block3 = resD(ndf * 8, ndf * 16)#16*16
self.block4 = resD(ndf * 16, ndf * 16)#8*8
self.block5 = resD(ndf * 16, ndf * 16)#4*4
self.COND_DNET = D_GET_LOGITS(ndf)
def forward(self,x):
out = self.conv_img(x)
out = self.block0(out)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.block4(out)
out = self.block5(out)
return out
class resD(nn.Module):
def __init__(self, fin, fout, downsample=True):
super().__init__()
self.downsample = downsample
self.learned_shortcut = (fin != fout)
self.conv_r = nn.Sequential(
nn.Conv2d(fin, fout, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(fout, fout, 3, 1, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv_s = nn.Conv2d(fin,fout, 1, stride=1, padding=0)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x, c=None):
return self.shortcut(x) + self.gamma * self.residual(x)
def shortcut(self, x):
if self.learned_shortcut:
x = self.conv_s(x)
if self.downsample:
return F.avg_pool2d(x, 2)
return x
def residual(self, x):
return self.conv_r(x)