Efesasa0's picture
f52937d
# DISCLAIMER: Code written by AI
import gradio as gr
import torch
import numpy as np
import os
from src import *
# device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# diffusion constants
T = 3000
beta_end = 0.02
beta_start = 1e-3
betas = (beta_end - beta_start) * torch.linspace(0, 1, T+1, device=device) + beta_start
alphas = 1 - betas
alphas_hat = torch.cumsum(alphas.log(), dim=0).exp()
alphas_hat[0] = 1
# -----------------------------
# Diffusion model wrapper
# -----------------------------
class Diffusion:
def __init__(self, weights_path):
context_features = 5
features = 256
self.image_size = (16, 16)
self.model = ContextUnet(
in_channels=3,
features=features,
context_features=context_features,
image_size=self.image_size
).to(device)
self.model.load_state_dict(torch.load(weights_path, map_location=device))
self.model.eval()
def denoise_add_noise(self, x, t, pred_noise, z=None):
if z is None:
z = torch.randn_like(x)
noise = betas.sqrt()[t] * z
mean = (x - pred_noise * ((1 - alphas[t]) / (1 - alphas_hat[t]).sqrt())) / alphas[t].sqrt()
return mean + noise
@torch.no_grad()
def sample_ddpm(self, n_sample, context):
samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device)
for i in range(T, 0, -1):
t = torch.tensor([i / T])[:, None, None, None].to(device)
z = torch.randn_like(samples) if i > 1 else 0
eps = self.model(samples, t, c=context)
samples = self.denoise_add_noise(samples, i, eps, z)
return samples
def denoise_ddim(self, x, t, t_prev, pred_noise):
ab = alphas_hat[t]
ab_prev = alphas_hat[t_prev]
x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
dir_xt = (1 - ab_prev).sqrt() * pred_noise
return x0_pred + dir_xt
@torch.no_grad()
def sample_ddim(self, n_sample, context, n=20):
samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device)
step_size = T // n
for i in range(T, 0, -step_size):
t = torch.tensor([i / T])[:, None, None, None].to(device)
eps = self.model(samples, t, c=context)
prev_i = max(i - step_size, 1)
samples = self.denoise_ddim(samples, i, prev_i, eps)
return samples
def generate(self, context, mode="ddim"):
ctx = torch.tensor(context).float().unsqueeze(0).to(device)
if mode == "ddpm":
return self.sample_ddpm(1, ctx)
else:
return self.sample_ddim(1, ctx, n=25)
# -----------------------------
# Gradio Interface
# -----------------------------
# list weights in folder
weights_folder = "weights"
os.makedirs(weights_folder, exist_ok=True)
available_weights = [f for f in os.listdir(weights_folder) if f.endswith(".pth")]
import torch.nn.functional as F
def run_inference(weights_name, mode, context_choice):
weights_path = os.path.join(weights_folder, weights_name)
diffusion = Diffusion(weights_path)
context_map = {
"hero": [1,0,0,0,0],
"non-hero": [0,1,0,0,0],
"food": [0,0,1,0,0],
"spell": [0,0,0,1,0],
"side-facing": [0,0,0,0,1],
}
context = context_map[context_choice]
samples = diffusion.generate(context=context, mode=mode)
# take the [0]th sample
img = samples[0].unsqueeze(0) # shape (1, 3, 16, 16)
# upscale to 256×256 (use 'nearest' to keep blocky pixel-art style)
img_up = F.interpolate(img, size=(256, 256), mode="nearest")
img_np = img_up[0].detach().cpu().numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) # normalize [0,1]
img_np = np.transpose(img_np, (1,2,0)) # (H,W,C) for display
return img_np
with gr.Blocks() as demo:
gr.Markdown("## Sprite Diffusion Generator 👾")
gr.Markdown("Note: DDPM algorihm may take around 1-2 minutes.")
with gr.Row():
weights_name = gr.Dropdown(available_weights, label="Select weights file")
mode = gr.Radio(["ddpm", "ddim"], value="ddim", label="Generation Mode")
context_choice = gr.Dropdown(["hero","non-hero","food","spell","side-facing"], value="hero", label="Context")
run_btn = gr.Button("Generate")
output = gr.Image(label="Generated Image")
run_btn.click(run_inference, inputs=[weights_name, mode, context_choice], outputs=output)
demo.launch()