CatGen-v2 / inference.py
LH-Tech-AI's picture
Create inference.py
c5d3855 verified
import torch
import torch.nn as nn
import torchvision.utils as vutils
import matplotlib.pyplot as plt
class Generator(nn.Module):
def __init__(self, z_dim, channels, features_g):
super(Generator, self).__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(z_dim, features_g * 16, 4, 1, 0, bias=False),
nn.BatchNorm2d(features_g * 16),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g),
nn.ReLU(True),
nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.net(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Z_DIM = 128
FEATURES_G = 256
model = Generator(Z_DIM, 3, FEATURES_G).to(device)
model.load_state_dict(torch.load("catgen_v2_generator_only.pth", map_location=device))
model.eval()
with torch.no_grad():
noise = torch.randn(16, Z_DIM, 1, 1, device=device)
fake_images = model(noise).detach().cpu()
vutils.save_image(fake_images, "generated_cats.png", normalize=True, nrow=4)
print("16 new cats generated in generated_cats.png!")