# DISCLAIMER: Code written by AI import gradio as gr import torch import numpy as np import os from src import * # device setup device = "cuda" if torch.cuda.is_available() else "cpu" # diffusion constants T = 3000 beta_end = 0.02 beta_start = 1e-3 betas = (beta_end - beta_start) * torch.linspace(0, 1, T+1, device=device) + beta_start alphas = 1 - betas alphas_hat = torch.cumsum(alphas.log(), dim=0).exp() alphas_hat[0] = 1 # ----------------------------- # Diffusion model wrapper # ----------------------------- class Diffusion: def __init__(self, weights_path): context_features = 5 features = 256 self.image_size = (16, 16) self.model = ContextUnet( in_channels=3, features=features, context_features=context_features, image_size=self.image_size ).to(device) self.model.load_state_dict(torch.load(weights_path, map_location=device)) self.model.eval() def denoise_add_noise(self, x, t, pred_noise, z=None): if z is None: z = torch.randn_like(x) noise = betas.sqrt()[t] * z mean = (x - pred_noise * ((1 - alphas[t]) / (1 - alphas_hat[t]).sqrt())) / alphas[t].sqrt() return mean + noise @torch.no_grad() def sample_ddpm(self, n_sample, context): samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device) for i in range(T, 0, -1): t = torch.tensor([i / T])[:, None, None, None].to(device) z = torch.randn_like(samples) if i > 1 else 0 eps = self.model(samples, t, c=context) samples = self.denoise_add_noise(samples, i, eps, z) return samples def denoise_ddim(self, x, t, t_prev, pred_noise): ab = alphas_hat[t] ab_prev = alphas_hat[t_prev] x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise) dir_xt = (1 - ab_prev).sqrt() * pred_noise return x0_pred + dir_xt @torch.no_grad() def sample_ddim(self, n_sample, context, n=20): samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device) step_size = T // n for i in range(T, 0, -step_size): t = torch.tensor([i / T])[:, None, None, None].to(device) eps = self.model(samples, t, c=context) prev_i = max(i - step_size, 1) samples = self.denoise_ddim(samples, i, prev_i, eps) return samples def generate(self, context, mode="ddim"): ctx = torch.tensor(context).float().unsqueeze(0).to(device) if mode == "ddpm": return self.sample_ddpm(1, ctx) else: return self.sample_ddim(1, ctx, n=25) # ----------------------------- # Gradio Interface # ----------------------------- # list weights in folder weights_folder = "weights" os.makedirs(weights_folder, exist_ok=True) available_weights = [f for f in os.listdir(weights_folder) if f.endswith(".pth")] import torch.nn.functional as F def run_inference(weights_name, mode, context_choice): weights_path = os.path.join(weights_folder, weights_name) diffusion = Diffusion(weights_path) context_map = { "hero": [1,0,0,0,0], "non-hero": [0,1,0,0,0], "food": [0,0,1,0,0], "spell": [0,0,0,1,0], "side-facing": [0,0,0,0,1], } context = context_map[context_choice] samples = diffusion.generate(context=context, mode=mode) # take the [0]th sample img = samples[0].unsqueeze(0) # shape (1, 3, 16, 16) # upscale to 256×256 (use 'nearest' to keep blocky pixel-art style) img_up = F.interpolate(img, size=(256, 256), mode="nearest") img_np = img_up[0].detach().cpu().numpy() img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) # normalize [0,1] img_np = np.transpose(img_np, (1,2,0)) # (H,W,C) for display return img_np with gr.Blocks() as demo: gr.Markdown("## Sprite Diffusion Generator 👾") gr.Markdown("Note: DDPM algorihm may take around 1-2 minutes.") with gr.Row(): weights_name = gr.Dropdown(available_weights, label="Select weights file") mode = gr.Radio(["ddpm", "ddim"], value="ddim", label="Generation Mode") context_choice = gr.Dropdown(["hero","non-hero","food","spell","side-facing"], value="hero", label="Context") run_btn = gr.Button("Generate") output = gr.Image(label="Generated Image") run_btn.click(run_inference, inputs=[weights_name, mode, context_choice], outputs=output) demo.launch()