import torch import torch.nn as nn import torchvision.utils as vutils import matplotlib.pyplot as plt class Generator(nn.Module): def __init__(self, z_dim, channels, features_g): super(Generator, self).__init__() self.net = nn.Sequential( nn.ConvTranspose2d(z_dim, features_g * 16, 4, 1, 0, bias=False), nn.BatchNorm2d(features_g * 16), nn.ReLU(True), nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(features_g * 8), nn.ReLU(True), nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(features_g * 4), nn.ReLU(True), nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(features_g * 2), nn.ReLU(True), nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False), nn.BatchNorm2d(features_g), nn.ReLU(True), nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, x): return self.net(x) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") Z_DIM = 128 FEATURES_G = 256 model = Generator(Z_DIM, 3, FEATURES_G).to(device) model.load_state_dict(torch.load("catgen_v2_generator_only.pth", map_location=device)) model.eval() with torch.no_grad(): noise = torch.randn(16, Z_DIM, 1, 1, device=device) fake_images = model(noise).detach().cpu() vutils.save_image(fake_images, "generated_cats.png", normalize=True, nrow=4) print("16 new cats generated in generated_cats.png!")