Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| import base64 | |
| import torch | |
| from PIL import Image | |
| import dnnlib | |
| import legacy | |
| def load_stylegan2(model_path, device): | |
| """ | |
| Loads the stylegan2 generator. | |
| Arguments: | |
| model_path (str): Path to model | |
| device (str): Device to load model on | |
| Returns: | |
| G (nn.Module): Stylegan generator | |
| w_avg (Tensor): The average style vector in W space | |
| """ | |
| with dnnlib.util.open_url(model_path) as f: | |
| G = legacy.load_network_pkl(f)["G_ema"] | |
| w_avg = G.mapping.w_avg.repeat(14, 1) | |
| w_avg = w_avg.to(device) | |
| G = G.to(device) | |
| return G, w_avg | |
| def tensor2im(var): | |
| """ | |
| Converts a tensor image to PIL Image. Taken from the stylegan2-ada-pytorch repo | |
| Arguments: | |
| var (Tensor): Tensor representing the input image | |
| Returns: | |
| image (PIL.Image): Image displayed | |
| """ | |
| var = (var.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) | |
| return Image.fromarray(var.cpu().numpy(), "RGB") | |