| import functools |
| import pickle |
| import random |
| from typing import List |
|
|
| import numpy as np |
| import streamlit as st |
| import torch |
|
|
| from huggingface_hub import hf_hub_url, cached_download |
|
|
| ICON_CLASS_MAPPING = { |
| "Fire": 8, |
| "Magic": 7, |
| "Nature": 6, |
| "Lightning": 5, |
| "Ice": 4, |
| "Shadow": 3, |
| "Unholy": 2, |
| "Battle": 1, |
| "Holy": 0, |
| } |
|
|
| MAX_SEED = 100000000 |
|
|
| st.title("RPG Icon Generator") |
|
|
| with open( |
| cached_download(hf_hub_url("gylleus/rpg-icongen", "icongen-model.pkl")), "rb" |
| ) as f: |
| G = pickle.load(f)["G_ema"] |
|
|
| device = torch.device("cpu") |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| G = G.to(device) |
| else: |
| G.forward = functools.partial(G.forward, force_fp32=True) |
|
|
|
|
| random_seed = 0 |
|
|
|
|
| def randomize_seed() -> int: |
| global random_seed |
| random_seed = random.randint(0, MAX_SEED) |
|
|
|
|
| randomize_seed() |
|
|
|
|
| def get_class_id(class_name: str): |
| if class_name in ICON_CLASS_MAPPING: |
| return ICON_CLASS_MAPPING[class_name] |
| return ICON_CLASS_MAPPING["Fire"] |
|
|
|
|
| def generate(seed: int, class_name: str) -> np.ndarray: |
| label = torch.zeros([1, G.c_dim], device=device) |
| |
| label[:, get_class_id(class_name)] = 1 |
| truncation_psi = 1 |
| noise_mode = "const" |
|
|
| z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) |
| img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) |
| img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
| return img.cpu().numpy() |
|
|
|
|
| def generate_images(seed: int, amount: int, class_name: str) -> List[np.ndarray]: |
| return [generate(i, class_name) for i in range(seed, seed + amount)] |
|
|
|
|
| st.button("Generate", on_click=randomize_seed()) |
|
|
| chosen_class = st.selectbox("Choose icon type", tuple(ICON_CLASS_MAPPING.keys())) |
|
|
| image_amount = st.slider("Images to generate", 1, 9, 3) |
|
|
| columns = st.columns(3) |
|
|
| column_index = 0 |
| for img in generate_images(random_seed, image_amount, chosen_class): |
| column = columns[column_index % len(columns)] |
| column.image(img) |
| column_index += 1 |
|
|