Spaces:
Runtime error
Runtime error
| import argparse | |
| import torch | |
| from torchvision import utils | |
| from model import Generator | |
| from tqdm import tqdm | |
| def generate(args, g_ema, device, mean_latent): | |
| with torch.no_grad(): | |
| g_ema.eval() | |
| for i in tqdm(range(args.pics)): | |
| sample_z = torch.randn(args.sample, args.latent, device=device) | |
| sample, _ = g_ema([sample_z], truncation=args.truncation, truncation_latent=mean_latent) | |
| utils.save_image( | |
| sample, | |
| f'sample/{str(i).zfill(6)}.png', | |
| nrow=1, | |
| normalize=True, | |
| range=(-1, 1), | |
| ) | |
| if __name__ == '__main__': | |
| device = 'cuda' | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--size', type=int, default=1024) | |
| parser.add_argument('--sample', type=int, default=1) | |
| parser.add_argument('--pics', type=int, default=20) | |
| parser.add_argument('--truncation', type=float, default=1) | |
| parser.add_argument('--truncation_mean', type=int, default=4096) | |
| parser.add_argument('--ckpt', type=str, default="stylegan2-ffhq-config-f.pt") | |
| parser.add_argument('--channel_multiplier', type=int, default=2) | |
| args = parser.parse_args() | |
| args.latent = 512 | |
| args.n_mlp = 8 | |
| g_ema = Generator( | |
| args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier | |
| ).to(device) | |
| checkpoint = torch.load(args.ckpt) | |
| g_ema.load_state_dict(checkpoint['g_ema']) | |
| if args.truncation < 1: | |
| with torch.no_grad(): | |
| mean_latent = g_ema.mean_latent(args.truncation_mean) | |
| else: | |
| mean_latent = None | |
| generate(args, g_ema, device, mean_latent) | |