import spaces import os import subprocess import time def install_cuda_toolkit(): CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run" CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) os.environ["CUDA_HOME"] = "/usr/local/cuda" os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( os.environ["CUDA_HOME"], "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], ) # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" install_time = time.time() print('Install CUDA toolkit...') install_cuda_toolkit() print(f'CUDA toolkit installed in {time.time() - install_time:.2f} seconds.') import random import gradio as gr from PIL import Image import numpy as np import torch print(f'torch version : {torch.__version__}') print(f'cuda available : {torch.cuda.is_available()}') print(f'cuda version : {torch.version.cuda}') print(f'nvcc version : {os.system("nvcc --version")}') import sys sys.path.append('./stylegan-xl') import dnnlib, legacy from torch_utils import gen_utils NETWORK = "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon1024.pkl" CLASS_NAMES = ["Normal", "Fire", "Water", "Grass", "Electric", "Ice", "Fighting", "Poison", "Ground", "Flying", "Psychic", "Bug", "Rock", "Ghost", "Dragon", "Dark", "Steel", "Fairy"] CLASS_FEATURES = {cname: np.load(f'class_vectors/{cname}.npy') for cname in CLASS_NAMES} device = "cuda" if torch.cuda.is_available() else "cpu" with dnnlib.util.open_url(NETWORK) as f: G = legacy.load_network_pkl(f)['G_ema'] G = G.eval().requires_grad_(False) G = G.to(device) def reset_values(*args) : return [0 for _ in range(len(args))] @spaces.GPU def generate_base_pokemon() : random_seed = random.randint(0, 2**32 - 1) base_w = gen_utils.get_w_from_seed(G, batch_sz=1, device=device, seed=random_seed) output = gen_utils.w_to_img(G, base_w, to_np=True)[0] output_img = Image.fromarray(output, 'RGB') base_w = base_w.to('cpu') return (base_w, output_img), output_img, [output_img] init_time = time.time() print(f'Inital run to generate a base Pokémon... (device is {device})') generate_base_pokemon() print(f'Inital run completed in {time.time() - init_time:.2f} seconds.') @spaces.GPU def apply_class_feature(base_state, base_ratio, normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy) : if base_state is None : raise gr.Error('Base Pokémon is required.') class_values = [normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy] if sum(class_values) == 0 : return base_state[-1] target_feature = None for i, class_value in enumerate(class_values) : if class_value == 0 : continue class_name = CLASS_NAMES[i] class_feature = CLASS_FEATURES[class_name] if target_feature is None : target_feature = class_feature * class_value else : target_feature += class_feature * class_value target_feature = torch.from_numpy(target_feature).to(device) target_feature = target_feature.repeat(1, G.num_ws, 1) base_w = base_state[0].to(device) edit_w = (base_w * base_ratio) + (target_feature * (1 - base_ratio)) output_edit = gen_utils.w_to_img(G, edit_w, to_np=True)[0] return Image.fromarray(output_edit, 'RGB') def add_image_to_gallery(gallery, img) : gallery.append(img) return gallery with gr.Blocks() as demo : gr.Markdown('# Pokémon Generator') gr.Markdown('Generate a base Pokémon. Then, apply class features to edit the base Pokémon.') with gr.Row() : with gr.Column() : btn_generate = gr.Button('Generate a base Pokémon', variant='primary') base_state = gr.State() with gr.Row() : normal = gr.Slider(label='Normal', minimum=-1, maximum=1, value=0, step=0.01) fire = gr.Slider(label='Fire', minimum=-1, maximum=1, value=0, step=0.01) water = gr.Slider(label='Water', minimum=-1, maximum=1, value=0, step=0.01) grass = gr.Slider(label='Grass', minimum=-1, maximum=1, value=0, step=0.01) electric = gr.Slider(label='Electric', minimum=-1, maximum=1, value=0, step=0.01) ice = gr.Slider(label='Ice', minimum=-1, maximum=1, value=0, step=0.01) fighting = gr.Slider(label='Fighting', minimum=-1, maximum=1, value=0, step=0.01) poison = gr.Slider(label='Poison', minimum=-1, maximum=1, value=0, step=0.01) ground = gr.Slider(label='Ground', minimum=-1, maximum=1, value=0, step=0.01) flying = gr.Slider(label='Flying', minimum=-1, maximum=1, value=0, step=0.01) psychic = gr.Slider(label='Psychic', minimum=-1, maximum=1, value=0, step=0.01) bug = gr.Slider(label='Bug', minimum=-1, maximum=1, value=0, step=0.01) rock = gr.Slider(label='Rock', minimum=-1, maximum=1, value=0, step=0.01) ghost = gr.Slider(label='Ghost', minimum=-1, maximum=1, value=0, step=0.01) dragon = gr.Slider(label='Dragon', minimum=-1, maximum=1, value=0, step=0.01) dark = gr.Slider(label='Dark', minimum=-1, maximum=1, value=0, step=0.01) steel = gr.Slider(label='Steel', minimum=-1, maximum=1, value=0, step=0.01) fairy = gr.Slider(label='Fairy', minimum=-1, maximum=1, value=0, step=0.01) with gr.Row() : btn_zero = gr.Button('Reset to 0', variant='secondary') btn_edit = gr.Button('Apply class feature', variant='primary') base_ratio = gr.Slider(label='Base ratio', minimum=0, maximum=1, value=0.5, step=0.01, info="The parameter that determines how much the base Pokémon's form is retained.") with gr.Column() : output_img = gr.Image(label='Output', image_mode='RGB', type='pil', interactive=False) output_gallery = gr.Gallery(label='Gallery', columns=5, interactive=False) btn_zero.click( fn=reset_values, inputs=[normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy], outputs=[normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy] ) btn_generate.click( fn=generate_base_pokemon, inputs=None, outputs=[base_state, output_img, output_gallery], concurrency_id='gpu') btn_edit.click( fn=apply_class_feature, inputs=[base_state, base_ratio, normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy], outputs=output_img, concurrency_id='gpu' ).success( fn=add_image_to_gallery, inputs=[output_gallery, output_img], outputs=output_gallery ) demo.title = 'Pokémon Generator' demo.launch(show_api=False)