gibiee's picture
apply ZeroGPU
9f26e5b
import spaces
import os
import subprocess
import time
def install_cuda_toolkit():
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run"
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
os.environ["CUDA_HOME"] = "/usr/local/cuda"
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
os.environ["CUDA_HOME"],
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
)
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
install_time = time.time()
print('Install CUDA toolkit...')
install_cuda_toolkit()
print(f'CUDA toolkit installed in {time.time() - install_time:.2f} seconds.')
import random
import gradio as gr
from PIL import Image
import numpy as np
import torch
print(f'torch version : {torch.__version__}')
print(f'cuda available : {torch.cuda.is_available()}')
print(f'cuda version : {torch.version.cuda}')
print(f'nvcc version : {os.system("nvcc --version")}')
import sys
sys.path.append('./stylegan-xl')
import dnnlib, legacy
from torch_utils import gen_utils
NETWORK = "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon1024.pkl"
CLASS_NAMES = ["Normal", "Fire", "Water", "Grass", "Electric", "Ice", "Fighting", "Poison", "Ground", "Flying", "Psychic", "Bug", "Rock", "Ghost", "Dragon", "Dark", "Steel", "Fairy"]
CLASS_FEATURES = {cname: np.load(f'class_vectors/{cname}.npy') for cname in CLASS_NAMES}
device = "cuda" if torch.cuda.is_available() else "cpu"
with dnnlib.util.open_url(NETWORK) as f:
G = legacy.load_network_pkl(f)['G_ema']
G = G.eval().requires_grad_(False)
G = G.to(device)
def reset_values(*args) :
return [0 for _ in range(len(args))]
@spaces.GPU
def generate_base_pokemon() :
random_seed = random.randint(0, 2**32 - 1)
base_w = gen_utils.get_w_from_seed(G, batch_sz=1, device=device, seed=random_seed)
output = gen_utils.w_to_img(G, base_w, to_np=True)[0]
output_img = Image.fromarray(output, 'RGB')
base_w = base_w.to('cpu')
return (base_w, output_img), output_img, [output_img]
init_time = time.time()
print(f'Inital run to generate a base Pokémon... (device is {device})')
generate_base_pokemon()
print(f'Inital run completed in {time.time() - init_time:.2f} seconds.')
@spaces.GPU
def apply_class_feature(base_state, base_ratio, normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy) :
if base_state is None :
raise gr.Error('Base Pokémon is required.')
class_values = [normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy]
if sum(class_values) == 0 :
return base_state[-1]
target_feature = None
for i, class_value in enumerate(class_values) :
if class_value == 0 : continue
class_name = CLASS_NAMES[i]
class_feature = CLASS_FEATURES[class_name]
if target_feature is None :
target_feature = class_feature * class_value
else :
target_feature += class_feature * class_value
target_feature = torch.from_numpy(target_feature).to(device)
target_feature = target_feature.repeat(1, G.num_ws, 1)
base_w = base_state[0].to(device)
edit_w = (base_w * base_ratio) + (target_feature * (1 - base_ratio))
output_edit = gen_utils.w_to_img(G, edit_w, to_np=True)[0]
return Image.fromarray(output_edit, 'RGB')
def add_image_to_gallery(gallery, img) :
gallery.append(img)
return gallery
with gr.Blocks() as demo :
gr.Markdown('# Pokémon Generator')
gr.Markdown('Generate a base Pokémon. Then, apply class features to edit the base Pokémon.')
with gr.Row() :
with gr.Column() :
btn_generate = gr.Button('Generate a base Pokémon', variant='primary')
base_state = gr.State()
with gr.Row() :
normal = gr.Slider(label='Normal', minimum=-1, maximum=1, value=0, step=0.01)
fire = gr.Slider(label='Fire', minimum=-1, maximum=1, value=0, step=0.01)
water = gr.Slider(label='Water', minimum=-1, maximum=1, value=0, step=0.01)
grass = gr.Slider(label='Grass', minimum=-1, maximum=1, value=0, step=0.01)
electric = gr.Slider(label='Electric', minimum=-1, maximum=1, value=0, step=0.01)
ice = gr.Slider(label='Ice', minimum=-1, maximum=1, value=0, step=0.01)
fighting = gr.Slider(label='Fighting', minimum=-1, maximum=1, value=0, step=0.01)
poison = gr.Slider(label='Poison', minimum=-1, maximum=1, value=0, step=0.01)
ground = gr.Slider(label='Ground', minimum=-1, maximum=1, value=0, step=0.01)
flying = gr.Slider(label='Flying', minimum=-1, maximum=1, value=0, step=0.01)
psychic = gr.Slider(label='Psychic', minimum=-1, maximum=1, value=0, step=0.01)
bug = gr.Slider(label='Bug', minimum=-1, maximum=1, value=0, step=0.01)
rock = gr.Slider(label='Rock', minimum=-1, maximum=1, value=0, step=0.01)
ghost = gr.Slider(label='Ghost', minimum=-1, maximum=1, value=0, step=0.01)
dragon = gr.Slider(label='Dragon', minimum=-1, maximum=1, value=0, step=0.01)
dark = gr.Slider(label='Dark', minimum=-1, maximum=1, value=0, step=0.01)
steel = gr.Slider(label='Steel', minimum=-1, maximum=1, value=0, step=0.01)
fairy = gr.Slider(label='Fairy', minimum=-1, maximum=1, value=0, step=0.01)
with gr.Row() :
btn_zero = gr.Button('Reset to 0', variant='secondary')
btn_edit = gr.Button('Apply class feature', variant='primary')
base_ratio = gr.Slider(label='Base ratio', minimum=0, maximum=1, value=0.5, step=0.01,
info="The parameter that determines how much the base Pokémon's form is retained.")
with gr.Column() :
output_img = gr.Image(label='Output', image_mode='RGB', type='pil', interactive=False)
output_gallery = gr.Gallery(label='Gallery', columns=5, interactive=False)
btn_zero.click(
fn=reset_values,
inputs=[normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy],
outputs=[normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy]
)
btn_generate.click(
fn=generate_base_pokemon,
inputs=None,
outputs=[base_state, output_img, output_gallery],
concurrency_id='gpu')
btn_edit.click(
fn=apply_class_feature,
inputs=[base_state, base_ratio, normal, fire, water, grass, electric, ice, fighting, poison, ground, flying, psychic, bug, rock, ghost, dragon, dark, steel, fairy],
outputs=output_img,
concurrency_id='gpu'
).success(
fn=add_image_to_gallery,
inputs=[output_gallery, output_img],
outputs=output_gallery
)
demo.title = 'Pokémon Generator'
demo.launch(show_api=False)