Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.utils import make_grid | |
| import matplotlib.pyplot as plt | |
| def get_noise(n_samples, z_dim, device='cpu'): | |
| return torch.randn((n_samples, z_dim), device=device) | |
| def get_random_labels(n_samples, device='cpu'): | |
| return torch.randint(0, 10, (n_samples,), device=device).type(torch.long) | |
| def get_generator_block(input_dim, output_dim): | |
| return nn.Sequential( | |
| nn.Linear(input_dim, output_dim), | |
| nn.BatchNorm1d(output_dim), | |
| nn.ReLU(inplace=True) | |
| ) | |
| class Generator(nn.Module): | |
| def __init__(self, z_dim=10, im_dim=784, hidden_dim=128): | |
| super(Generator, self).__init__() | |
| # input is of shape (batch_size, z_dim + 10) | |
| self.gen = nn.Sequential( | |
| get_generator_block(z_dim + 10, hidden_dim), # 128 | |
| get_generator_block(hidden_dim, hidden_dim*2), # 256 | |
| get_generator_block(hidden_dim*2, hidden_dim*4), # 512 | |
| get_generator_block(hidden_dim*4, hidden_dim*8), # 1024 | |
| nn.Linear(hidden_dim*8, im_dim), # 784 | |
| nn.Sigmoid(), # output between 0 and 1 | |
| ) | |
| def forward(self, noise, classes): | |
| ''' | |
| noise (batch_size, z_dim) noise vector for each image in a batch | |
| classes:long (batch_size) condition class for each image in a batch | |
| ''' | |
| # classes = classes.type(torch.long) | |
| # one-hot encode condition_class e.g. 3 -> [0,0,0,1,0,0,0,0,0,0] | |
| one_hot_vec = F.one_hot(classes, num_classes=10).type(torch.float32) # (batch_size, 10) | |
| conditioned_noise = torch.concat((noise, one_hot_vec), dim=1) # (batch_size, z_dim + 10) | |
| return self.gen(conditioned_noise) | |
| def get_discriminator_block(input_dim, output_dim): | |
| return nn.Sequential( | |
| nn.Linear(input_dim, output_dim), | |
| nn.LeakyReLU(0.2, inplace=True) | |
| ) | |
| class Discriminator(nn.Module): | |
| def __init__(self, im_dim=784, hidden_dim=128): | |
| super(Discriminator, self).__init__() | |
| self.disc = nn.Sequential( | |
| get_discriminator_block(im_dim + 10, hidden_dim*4), # 512 | |
| get_discriminator_block(hidden_dim * 4, hidden_dim * 2), # 256 | |
| get_discriminator_block(hidden_dim * 2, hidden_dim), # 128 | |
| nn.Linear(hidden_dim, 1), | |
| # nn.Sigmoid(), | |
| # using a sigmoid followed by BCE is less numerically stable than BCEWithLogitsLoss alone | |
| # https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss:~:text=This%20loss%20combines%20a%20Sigmoid%20layer%20and%20the%20BCELoss%20in%20one%20single%20class.%20This%20version%20is%20more%20numerically%20stable%20than%20using%20a%20plain%20Sigmoid%20followed%20by%20a%20BCELoss%20as%2C%20by%20combining%20the%20operations%20into%20one%20layer%2C%20we%20take%20advantage%20of%20the%20log%2Dsum%2Dexp%20trick%20for%20numerical%20stability. | |
| ) | |
| def forward(self, image_batch): | |
| '''image_batch (batch_size, 784+10)''' | |
| return self.disc(image_batch) |