Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| num_classes = 10 | |
| z_dim = 100 | |
| img_size = 28 | |
| class Generator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.label_emb = nn.Embedding(num_classes, num_classes) | |
| self.init_size = img_size // 4 | |
| self.fc = nn.Linear(z_dim + num_classes, 128 * self.init_size ** 2) | |
| self.conv_blocks = nn.Sequential( | |
| nn.BatchNorm2d(128), | |
| nn.Upsample(scale_factor=2), | |
| nn.Conv2d(128, 128, 3, 1, 1), | |
| nn.BatchNorm2d(128, 0.8), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Upsample(scale_factor=2), | |
| nn.Conv2d(128, 64, 3, 1, 1), | |
| nn.BatchNorm2d(64, 0.8), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(64, 1, 3, 1, 1), | |
| nn.Tanh() | |
| ) | |
| def forward(self, z, labels): | |
| label_input = self.label_emb(labels) | |
| gen_input = torch.cat((z, label_input), dim=1) | |
| out = self.fc(gen_input) | |
| out = out.view(out.size(0), 128, self.init_size, self.init_size) | |
| return self.conv_blocks(out) | |
| def load_generator(device="cpu"): | |
| model = Generator().to(device) | |
| model.load_state_dict(torch.load("generator.pth", map_location=device)) | |
| model.eval() | |
| return model | |