| import torch |
| import torch.nn as nn |
| import torchvision.utils as vutils |
| import matplotlib.pyplot as plt |
|
|
| class Generator(nn.Module): |
| def __init__(self, z_dim, channels, features_g): |
| super(Generator, self).__init__() |
| self.net = nn.Sequential( |
| nn.ConvTranspose2d(z_dim, features_g * 16, 4, 1, 0, bias=False), |
| nn.BatchNorm2d(features_g * 16), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g * 8), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g * 4), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g * 2), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False), |
| nn.Tanh() |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| Z_DIM = 128 |
| FEATURES_G = 256 |
| model = Generator(Z_DIM, 3, FEATURES_G).to(device) |
|
|
| model.load_state_dict(torch.load("catgen_v2_generator_only.pth", map_location=device)) |
| model.eval() |
|
|
| with torch.no_grad(): |
| noise = torch.randn(16, Z_DIM, 1, 1, device=device) |
| fake_images = model(noise).detach().cpu() |
|
|
| vutils.save_image(fake_images, "generated_cats.png", normalize=True, nrow=4) |
| print("16 new cats generated in generated_cats.png!") |