import os import sys ROOT = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(ROOT, "stylegan2")) import torch import numpy as np import gradio as gr from PIL import Image from stylegan2 import legacy # ---------------- CONFIG ---------------- MODEL_PATH = "network-snapshot-000700.pkl" DEVICE = "cpu" # ---------------- LOAD MODEL ---------------- with open(MODEL_PATH, "rb") as f: G = legacy.load_network_pkl(f)["G_ema"].to(DEVICE) G.eval() # ---------------- GENERATION ---------------- @torch.no_grad() def generate_image(seed, truncation): seed = int(seed) truncation = float(truncation) print(f"Generating image for seed {seed} | trunc={truncation}") z = torch.from_numpy( np.random.RandomState(seed).randn(1, G.z_dim) ).to(DEVICE) img = G( z, None, truncation_psi=truncation, noise_mode="const" ) img = (img.permute(0, 2, 3, 1) * 127.5 + 128) img = img.clamp(0, 255).to(torch.uint8) return Image.fromarray(img[0].cpu().numpy(), "RGB") # ---------------- GRADIO UI ---------------- with gr.Blocks() as demo: gr.Markdown("## STYLEGAN2 Anime Image Generator") seed = gr.Slider( minimum=0, maximum=10000, value=0, step=1, label="Seed" ) truncation = gr.Slider( minimum=0.3, maximum=1.0, value=0.7, step=0.05, label="Truncation (ψ)" ) generate_btn = gr.Button("Generate Image") output = gr.Image(type="pil", label="Generated Image") generate_btn.click( fn=generate_image, inputs=[seed, truncation], outputs=output ) if __name__ == "__main__": demo.launch()