import torch import torch.nn as nn NOISE_DIM = 256 class Generator(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(NOISE_DIM, 4*4*512) self.net = nn.Sequential( nn.BatchNorm2d(512), nn.Upsample(scale_factor=2), nn.Conv2d(512, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(True), nn.Upsample(scale_factor=2), nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(True), nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(True), nn.Upsample(scale_factor=2), nn.Conv2d(64, 3, 3, padding=1), nn.Tanh() ) def forward(self, noise): x = self.fc(noise) x = x.view(-1, 512, 4, 4) return self.net(x)