Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import re | |
| from typing import List, Optional, Tuple, Union | |
| import random | |
| sys.path.append('stylegan3-fun') # change this to the path where dnnlib is located | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| import streamlit as st | |
| import dnnlib | |
| import legacy | |
| def parse_range(s: Union[str, List]) -> List[int]: | |
| '''Parse a comma separated list of numbers or ranges and return a list of ints. | |
| Example: '1,2,5-10' returns [1, 2, 5, 6, 7] | |
| ''' | |
| if isinstance(s, list): return s | |
| ranges = [] | |
| range_re = re.compile(r'^(\d+)-(\d+)$') | |
| for p in s.split(','): | |
| m = range_re.match(p) | |
| if m: | |
| ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) | |
| else: | |
| ranges.append(int(p)) | |
| return ranges | |
| def make_transform(translate: Tuple[float,float], angle: float): | |
| m = np.eye(3) | |
| s = np.sin(angle/360.0*np.pi*2) | |
| c = np.cos(angle/360.0*np.pi*2) | |
| m[0][0] = c | |
| m[0][1] = s | |
| m[0][2] = translate[0] | |
| m[1][0] = -s | |
| m[1][1] = c | |
| m[1][2] = translate[1] | |
| return m | |
| def generate_image(network_pkl: str, seed: int, truncation_psi: float, noise_mode: str, translate: Tuple[float,float], rotate: float, class_idx: Optional[int]): | |
| print('Loading networks from "%s"...' % network_pkl) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| with open(network_pkl, 'rb') as f: | |
| G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore | |
| # Labels. | |
| label = torch.zeros([1, G.c_dim], device=device) | |
| if G.c_dim != 0: | |
| if class_idx is None: | |
| raise Exception('Must specify class label when using a conditional network') | |
| label[:, class_idx] = 1 | |
| z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) | |
| if hasattr(G.synthesis, 'input'): | |
| m = make_transform(translate, rotate) | |
| m = np.linalg.inv(m) | |
| G.synthesis.input.transform.copy_(torch.from_numpy(m)) | |
| img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) | |
| img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
| img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') | |
| return img | |
| def main(): | |
| st.title('Kpop Face Generator') | |
| st.write('Press the button below to generate a new image:') | |
| if st.button('Generate'): | |
| network_pkl = 'kpopGG.pkl' | |
| seed = random.randint(0, 99999) | |
| truncation_psi = 0.45 | |
| noise_mode = 'const' | |
| translate = (0.0, 0.0) | |
| rotate = 0.0 | |
| class_idx = None | |
| image = generate_image(network_pkl, seed, truncation_psi, noise_mode, translate, rotate, class_idx) | |
| st.image(image) | |
| if __name__ == "__main__": | |
| main() | |