cgan-digit / cgan.py
Amogha Y A
Initial Streamlit cGAN app
deb8c06
import torch
import torch.nn as nn
num_classes = 10
z_dim = 100
img_size = 28
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
self.init_size = img_size // 4
self.fc = nn.Linear(z_dim + num_classes, 128 * self.init_size ** 2)
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, 1, 1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, 1, 1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 1, 3, 1, 1),
nn.Tanh()
)
def forward(self, z, labels):
label_input = self.label_emb(labels)
gen_input = torch.cat((z, label_input), dim=1)
out = self.fc(gen_input)
out = out.view(out.size(0), 128, self.init_size, self.init_size)
return self.conv_blocks(out)
def load_generator(device="cpu"):
model = Generator().to(device)
model.load_state_dict(torch.load("generator.pth", map_location=device))
model.eval()
return model