Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import numpy as np | |
| import torch | |
| import pickle | |
| import types | |
| from huggingface_hub import hf_hub_url, cached_download | |
| TOKEN = os.environ['TOKEN'] | |
| with open(cached_download(hf_hub_url('mfrashad/stylegan2_emoji_512', 'stylegan2_emoji_512.pkl'), use_auth_token=TOKEN), 'rb') as f: | |
| G = pickle.load(f)['G_ema']# torch.nn.Module | |
| device = torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| G = G.to(device) | |
| else: | |
| _old_forward = G.forward | |
| def _new_forward(self, *args, **kwargs): | |
| kwargs["force_fp32"] = True | |
| return _old_forward(*args, **kwargs) | |
| G.forward = types.MethodType(_new_forward, G) | |
| _old_synthesis_forward = G.synthesis.forward | |
| def _new_synthesis_forward(self, *args, **kwargs): | |
| kwargs["force_fp32"] = True | |
| return _old_synthesis_forward(*args, **kwargs) | |
| G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis) | |
| def generate(num_images, interpolate): | |
| if interpolate: | |
| z1 = torch.randn([1, G.z_dim])# latent codes | |
| z2 = torch.randn([1, G.z_dim])# latent codes | |
| zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0) | |
| else: | |
| zs = torch.randn([num_images, G.z_dim])# latent codes | |
| with torch.no_grad(): | |
| zs = zs.to(device) | |
| img = G(zs, None, force_fp32=True, noise_mode='const') | |
| img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
| return img.cpu().numpy() | |
| demo = gr.Blocks() | |
| def infer(num_images, interpolate): | |
| img = generate(round(num_images), interpolate) | |
| imgs = list(img) | |
| return imgs | |
| with demo: | |
| gr.Markdown( | |
| """ | |
| # EmojiGAN | |
| Generate Emojis with AI (StyleGAN2-ADA). Made by [mfrashad](https://github.com/mfrashad) | |
| """) | |
| images_num = gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=16, step=1) | |
| interpolate = gr.inputs.Checkbox(default=False, label="Interpolate") | |
| submit = gr.Button("Generate") | |
| out = gr.Gallery() | |
| submit.click(fn=infer, | |
| inputs=[images_num, interpolate], | |
| outputs=out) | |
| demo.launch() |