Spaces:
Running
Running
| # DISCLAIMER: Code written by AI | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import os | |
| from src import * | |
| # device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # diffusion constants | |
| T = 3000 | |
| beta_end = 0.02 | |
| beta_start = 1e-3 | |
| betas = (beta_end - beta_start) * torch.linspace(0, 1, T+1, device=device) + beta_start | |
| alphas = 1 - betas | |
| alphas_hat = torch.cumsum(alphas.log(), dim=0).exp() | |
| alphas_hat[0] = 1 | |
| # ----------------------------- | |
| # Diffusion model wrapper | |
| # ----------------------------- | |
| class Diffusion: | |
| def __init__(self, weights_path): | |
| context_features = 5 | |
| features = 256 | |
| self.image_size = (16, 16) | |
| self.model = ContextUnet( | |
| in_channels=3, | |
| features=features, | |
| context_features=context_features, | |
| image_size=self.image_size | |
| ).to(device) | |
| self.model.load_state_dict(torch.load(weights_path, map_location=device)) | |
| self.model.eval() | |
| def denoise_add_noise(self, x, t, pred_noise, z=None): | |
| if z is None: | |
| z = torch.randn_like(x) | |
| noise = betas.sqrt()[t] * z | |
| mean = (x - pred_noise * ((1 - alphas[t]) / (1 - alphas_hat[t]).sqrt())) / alphas[t].sqrt() | |
| return mean + noise | |
| def sample_ddpm(self, n_sample, context): | |
| samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device) | |
| for i in range(T, 0, -1): | |
| t = torch.tensor([i / T])[:, None, None, None].to(device) | |
| z = torch.randn_like(samples) if i > 1 else 0 | |
| eps = self.model(samples, t, c=context) | |
| samples = self.denoise_add_noise(samples, i, eps, z) | |
| return samples | |
| def denoise_ddim(self, x, t, t_prev, pred_noise): | |
| ab = alphas_hat[t] | |
| ab_prev = alphas_hat[t_prev] | |
| x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise) | |
| dir_xt = (1 - ab_prev).sqrt() * pred_noise | |
| return x0_pred + dir_xt | |
| def sample_ddim(self, n_sample, context, n=20): | |
| samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device) | |
| step_size = T // n | |
| for i in range(T, 0, -step_size): | |
| t = torch.tensor([i / T])[:, None, None, None].to(device) | |
| eps = self.model(samples, t, c=context) | |
| prev_i = max(i - step_size, 1) | |
| samples = self.denoise_ddim(samples, i, prev_i, eps) | |
| return samples | |
| def generate(self, context, mode="ddim"): | |
| ctx = torch.tensor(context).float().unsqueeze(0).to(device) | |
| if mode == "ddpm": | |
| return self.sample_ddpm(1, ctx) | |
| else: | |
| return self.sample_ddim(1, ctx, n=25) | |
| # ----------------------------- | |
| # Gradio Interface | |
| # ----------------------------- | |
| # list weights in folder | |
| weights_folder = "weights" | |
| os.makedirs(weights_folder, exist_ok=True) | |
| available_weights = [f for f in os.listdir(weights_folder) if f.endswith(".pth")] | |
| import torch.nn.functional as F | |
| def run_inference(weights_name, mode, context_choice): | |
| weights_path = os.path.join(weights_folder, weights_name) | |
| diffusion = Diffusion(weights_path) | |
| context_map = { | |
| "hero": [1,0,0,0,0], | |
| "non-hero": [0,1,0,0,0], | |
| "food": [0,0,1,0,0], | |
| "spell": [0,0,0,1,0], | |
| "side-facing": [0,0,0,0,1], | |
| } | |
| context = context_map[context_choice] | |
| samples = diffusion.generate(context=context, mode=mode) | |
| # take the [0]th sample | |
| img = samples[0].unsqueeze(0) # shape (1, 3, 16, 16) | |
| # upscale to 256×256 (use 'nearest' to keep blocky pixel-art style) | |
| img_up = F.interpolate(img, size=(256, 256), mode="nearest") | |
| img_np = img_up[0].detach().cpu().numpy() | |
| img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) # normalize [0,1] | |
| img_np = np.transpose(img_np, (1,2,0)) # (H,W,C) for display | |
| return img_np | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Sprite Diffusion Generator 👾") | |
| gr.Markdown("Note: DDPM algorihm may take around 1-2 minutes.") | |
| with gr.Row(): | |
| weights_name = gr.Dropdown(available_weights, label="Select weights file") | |
| mode = gr.Radio(["ddpm", "ddim"], value="ddim", label="Generation Mode") | |
| context_choice = gr.Dropdown(["hero","non-hero","food","spell","side-facing"], value="hero", label="Context") | |
| run_btn = gr.Button("Generate") | |
| output = gr.Image(label="Generated Image") | |
| run_btn.click(run_inference, inputs=[weights_name, mode, context_choice], outputs=output) | |
| demo.launch() | |