srijaydeshpande's picture
Upload 28 files
e7610f7 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
try:
nn.init.xavier_uniform_(m.weight.data)
m.bias.data.fill_(0)
except AttributeError:
print("Skipping initialization of ", classname)
class GatedActivation(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x, y = x.chunk(2, dim=1)
return F.tanh(x) * F.sigmoid(y)
class GatedMaskedConv2d(nn.Module):
def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10):
super().__init__()
assert kernel % 2 == 1, print("Kernel size must be odd")
self.mask_type = mask_type
self.residual = residual
self.class_cond_embedding = nn.Embedding(
n_classes, 2 * dim
)
kernel_shp = (kernel // 2 + 1, kernel) # (ceil(n/2), n)
padding_shp = (kernel // 2, kernel // 2)
self.vert_stack = nn.Conv2d(
dim, dim * 2,
kernel_shp, 1, padding_shp
)
self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
kernel_shp = (1, kernel // 2 + 1)
padding_shp = (0, kernel // 2)
self.horiz_stack = nn.Conv2d(
dim, dim * 2,
kernel_shp, 1, padding_shp
)
self.horiz_resid = nn.Conv2d(dim, dim, 1)
self.gate = GatedActivation()
def make_causal(self):
self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
def forward(self, x_v, x_h, h):
if self.mask_type == 'A':
self.make_causal()
h = self.class_cond_embedding(h)
h_vert = self.vert_stack(x_v)
h_vert = h_vert[:, :, :x_v.size(-1), :]
out_v = self.gate(h_vert + h[:, :, None, None])
h_horiz = self.horiz_stack(x_h)
h_horiz = h_horiz[:, :, :, :x_h.size(-2)]
v2h = self.vert_to_horiz(h_vert)
out = self.gate(v2h + h_horiz + h[:, :, None, None])
if self.residual:
out_h = self.horiz_resid(out) + x_h
else:
out_h = self.horiz_resid(out)
return out_v, out_h
class GatedPixelCNN(nn.Module):
def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10):
super().__init__()
self.dim = dim
# Create embedding layer to embed input
self.embedding = nn.Embedding(input_dim, dim)
# Building the PixelCNN layer by layer
self.layers = nn.ModuleList()
# Initial block with Mask-A convolution
# Rest with Mask-B convolutions
for i in range(n_layers):
mask_type = 'A' if i == 0 else 'B'
kernel = 7 if i == 0 else 3
residual = False if i == 0 else True
self.layers.append(
GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes)
)
# Add the output layer
self.output_conv = nn.Sequential(
nn.Conv2d(dim, 512, 1),
nn.ReLU(True),
nn.Conv2d(512, input_dim, 1)
)
self.apply(weights_init)
def forward(self, x, label):
shp = x.size() + (-1, )
x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
x = x.permute(0, 3, 1, 2) # (B, C, W, H)
x_v, x_h = (x, x)
for i, layer in enumerate(self.layers):
x_v, x_h = layer(x_v, x_h, label)
return self.output_conv(x_h)
def generate(self, label, shape=(8, 8), batch_size=64):
param = next(self.parameters())
x = torch.zeros(
(batch_size, *shape),
dtype=torch.int64, device=param.device
)
for i in range(shape[0]):
for j in range(shape[1]):
logits = self.forward(x, label)
probs = F.softmax(logits[:, :, i, j], -1)
x.data[:, i, j].copy_(
probs.multinomial(1).squeeze().data
)
return x