File size: 7,657 Bytes
9f26e5b
d5363ee
b9551eb
a850437
b9551eb
 
a850437
b9551eb
 
 
 
 
 
 
 
 
 
 
 
 
 
a850437
 
b9551eb
a850437
d5363ee
1beb7b0
7dea3f2
1beb7b0
7dea3f2
 
b9551eb
 
 
e3ab213
7dea3f2
1beb7b0
 
 
 
 
 
 
 
7dea3f2
1beb7b0
 
 
5e3044e
 
1beb7b0
 
 
 
47715ee
1beb7b0
 
 
 
 
721b0ce
 
1beb7b0
 
9f26e5b
 
 
 
9a6586b
47715ee
1beb7b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721b0ce
1beb7b0
 
 
 
 
 
 
 
 
 
9a6586b
1beb7b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f26e5b
 
1beb7b0
 
 
 
9f26e5b
 
1beb7b0
 
 
 
7dea3f2
 
1beb7b0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import spaces
import os
import subprocess
import time

def install_cuda_toolkit():
    CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run"
    CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
    subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
    subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
    subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
    
    os.environ["CUDA_HOME"] = "/usr/local/cuda"
    os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
    os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
        os.environ["CUDA_HOME"],
        "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
    )
    # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
    os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"

install_time = time.time()
print('Install CUDA toolkit...')
install_cuda_toolkit()
print(f'CUDA toolkit installed in {time.time() - install_time:.2f} seconds.')

import random
import gradio as gr
from PIL import Image
import numpy as np
import torch
print(f'torch version : {torch.__version__}')
print(f'cuda available : {torch.cuda.is_available()}')
print(f'cuda version : {torch.version.cuda}')
print(f'nvcc version : {os.system("nvcc --version")}')

import sys
sys.path.append('./stylegan-xl')
import dnnlib, legacy
from torch_utils import gen_utils

NETWORK = "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon1024.pkl"
CLASS_NAMES = ["Normal", "Fire", "Water", "Grass", "Electric", "Ice", "Fighting", "Poison", "Ground", "Flying", "Psychic", "Bug", "Rock", "Ghost", "Dragon", "Dark", "Steel", "Fairy"]
CLASS_FEATURES = {cname: np.load(f'class_vectors/{cname}.npy') for cname in CLASS_NAMES}
device = "cuda" if torch.cuda.is_available() else "cpu"

with dnnlib.util.open_url(NETWORK) as f:
    G = legacy.load_network_pkl(f)['G_ema']
    G = G.eval().requires_grad_(False)
    G = G.to(device)

def reset_values(*args) :
    return [0 for _ in range(len(args))]

@spaces.GPU
def generate_base_pokemon() :
    random_seed = random.randint(0, 2**32 - 1)
    base_w = gen_utils.get_w_from_seed(G, batch_sz=1, device=device, seed=random_seed)
    output = gen_utils.w_to_img(G, base_w, to_np=True)[0]
    output_img = Image.fromarray(output, 'RGB')

    base_w = base_w.to('cpu')
    return (base_w, output_img), output_img, [output_img]

init_time = time.time()
print(f'Inital run to generate a base Pokémon... (device is {device})')
generate_base_pokemon()
print(f'Inital run completed in {time.time() - init_time:.2f} seconds.')

@spaces.GPU
def apply_class_feature(base_state, base_ratio, normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy) :
    if base_state is None :
        raise gr.Error('Base Pokémon is required.')
    
    class_values = [normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy]
    if sum(class_values) == 0 :
        return base_state[-1]
    
    target_feature = None
    for i, class_value in enumerate(class_values) :
        if class_value == 0 : continue

        class_name = CLASS_NAMES[i]
        class_feature = CLASS_FEATURES[class_name]

        if target_feature is None :
            target_feature = class_feature * class_value
        else :
            target_feature += class_feature * class_value
    
    target_feature = torch.from_numpy(target_feature).to(device)
    target_feature = target_feature.repeat(1, G.num_ws, 1)

    base_w = base_state[0].to(device)
    edit_w = (base_w * base_ratio) + (target_feature * (1 - base_ratio))
    output_edit = gen_utils.w_to_img(G, edit_w, to_np=True)[0]
    return Image.fromarray(output_edit, 'RGB')

def add_image_to_gallery(gallery, img) :
    gallery.append(img)
    return gallery

with gr.Blocks() as demo :
    gr.Markdown('# Pokémon Generator')
    gr.Markdown('Generate a base Pokémon. Then, apply class features to edit the base Pokémon.')
    with gr.Row() :
        with gr.Column() :
            btn_generate = gr.Button('Generate a base Pokémon', variant='primary')
            base_state = gr.State()

            with gr.Row() :
                normal = gr.Slider(label='Normal', minimum=-1, maximum=1, value=0, step=0.01)
                fire = gr.Slider(label='Fire', minimum=-1, maximum=1, value=0, step=0.01)
                water = gr.Slider(label='Water', minimum=-1, maximum=1, value=0, step=0.01)
                grass = gr.Slider(label='Grass', minimum=-1, maximum=1, value=0, step=0.01)
                electric = gr.Slider(label='Electric', minimum=-1, maximum=1, value=0, step=0.01)
                ice = gr.Slider(label='Ice', minimum=-1, maximum=1, value=0, step=0.01)
                fighting = gr.Slider(label='Fighting', minimum=-1, maximum=1, value=0, step=0.01)
                poison = gr.Slider(label='Poison', minimum=-1, maximum=1, value=0, step=0.01)
                ground = gr.Slider(label='Ground', minimum=-1, maximum=1, value=0, step=0.01)
                flying = gr.Slider(label='Flying', minimum=-1, maximum=1, value=0, step=0.01)
                psychic = gr.Slider(label='Psychic', minimum=-1, maximum=1, value=0, step=0.01)
                bug = gr.Slider(label='Bug', minimum=-1, maximum=1, value=0, step=0.01)
                rock = gr.Slider(label='Rock', minimum=-1, maximum=1, value=0, step=0.01)
                ghost = gr.Slider(label='Ghost', minimum=-1, maximum=1, value=0, step=0.01)
                dragon = gr.Slider(label='Dragon', minimum=-1, maximum=1, value=0, step=0.01)
                dark = gr.Slider(label='Dark', minimum=-1, maximum=1, value=0, step=0.01)
                steel = gr.Slider(label='Steel', minimum=-1, maximum=1, value=0, step=0.01)
                fairy = gr.Slider(label='Fairy', minimum=-1, maximum=1, value=0, step=0.01)

            with gr.Row() :
                btn_zero = gr.Button('Reset to 0', variant='secondary')
                btn_edit = gr.Button('Apply class feature', variant='primary')
            
            base_ratio = gr.Slider(label='Base ratio', minimum=0, maximum=1, value=0.5, step=0.01,
                                   info="The parameter that determines how much the base Pokémon's form is retained.")
        
        with gr.Column() :
            output_img = gr.Image(label='Output', image_mode='RGB', type='pil', interactive=False)
            output_gallery = gr.Gallery(label='Gallery', columns=5, interactive=False)

    btn_zero.click(
        fn=reset_values,
        inputs=[normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy],
        outputs=[normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy]
    )

    btn_generate.click(
        fn=generate_base_pokemon,
        inputs=None,
        outputs=[base_state, output_img, output_gallery],
        concurrency_id='gpu')
    
    btn_edit.click(
        fn=apply_class_feature,
        inputs=[base_state, base_ratio, normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy],
        outputs=output_img,
        concurrency_id='gpu'
    ).success(
        fn=add_image_to_gallery,
        inputs=[output_gallery, output_img],
        outputs=output_gallery
    )

demo.title = 'Pokémon Generator'
demo.launch(show_api=False)