File size: 4,571 Bytes
f52937d
b6b6742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# DISCLAIMER: Code written by AI
import gradio as gr
import torch
import numpy as np
import os
from src import *

# device setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# diffusion constants
T = 3000
beta_end = 0.02
beta_start = 1e-3
betas = (beta_end - beta_start) * torch.linspace(0, 1, T+1, device=device) + beta_start
alphas = 1 - betas
alphas_hat = torch.cumsum(alphas.log(), dim=0).exp()
alphas_hat[0] = 1

# -----------------------------
# Diffusion model wrapper
# -----------------------------
class Diffusion:
    def __init__(self, weights_path):
        context_features = 5
        features = 256
        self.image_size = (16, 16)
        self.model = ContextUnet(
            in_channels=3,
            features=features,
            context_features=context_features,
            image_size=self.image_size
        ).to(device)
        self.model.load_state_dict(torch.load(weights_path, map_location=device))
        self.model.eval()

    def denoise_add_noise(self, x, t, pred_noise, z=None):
        if z is None:
            z = torch.randn_like(x)
        noise = betas.sqrt()[t] * z
        mean = (x - pred_noise * ((1 - alphas[t]) / (1 - alphas_hat[t]).sqrt())) / alphas[t].sqrt()
        return mean + noise

    @torch.no_grad()
    def sample_ddpm(self, n_sample, context):
        samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device)
        for i in range(T, 0, -1):
            t = torch.tensor([i / T])[:, None, None, None].to(device)
            z = torch.randn_like(samples) if i > 1 else 0
            eps = self.model(samples, t, c=context)
            samples = self.denoise_add_noise(samples, i, eps, z)
        return samples

    def denoise_ddim(self, x, t, t_prev, pred_noise):
        ab = alphas_hat[t]
        ab_prev = alphas_hat[t_prev]
        x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
        dir_xt = (1 - ab_prev).sqrt() * pred_noise
        return x0_pred + dir_xt

    @torch.no_grad()
    def sample_ddim(self, n_sample, context, n=20):
        samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device)
        step_size = T // n
        for i in range(T, 0, -step_size):
            t = torch.tensor([i / T])[:, None, None, None].to(device)
            eps = self.model(samples, t, c=context)
            prev_i = max(i - step_size, 1)
            samples = self.denoise_ddim(samples, i, prev_i, eps)
        return samples

    def generate(self, context, mode="ddim"):
        ctx = torch.tensor(context).float().unsqueeze(0).to(device)
        if mode == "ddpm":
            return self.sample_ddpm(1, ctx)
        else:
            return self.sample_ddim(1, ctx, n=25)

# -----------------------------
# Gradio Interface
# -----------------------------
# list weights in folder
weights_folder = "weights"
os.makedirs(weights_folder, exist_ok=True)
available_weights = [f for f in os.listdir(weights_folder) if f.endswith(".pth")]

import torch.nn.functional as F

def run_inference(weights_name, mode, context_choice):
    weights_path = os.path.join(weights_folder, weights_name)
    diffusion = Diffusion(weights_path)

    context_map = {
        "hero": [1,0,0,0,0],
        "non-hero": [0,1,0,0,0],
        "food": [0,0,1,0,0],
        "spell": [0,0,0,1,0],
        "side-facing": [0,0,0,0,1],
    }
    context = context_map[context_choice]

    samples = diffusion.generate(context=context, mode=mode)

    # take the [0]th sample
    img = samples[0].unsqueeze(0)  # shape (1, 3, 16, 16)

    # upscale to 256×256 (use 'nearest' to keep blocky pixel-art style)
    img_up = F.interpolate(img, size=(256, 256), mode="nearest")

    img_np = img_up[0].detach().cpu().numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())  # normalize [0,1]
    img_np = np.transpose(img_np, (1,2,0))  # (H,W,C) for display

    return img_np


with gr.Blocks() as demo:
    gr.Markdown("## Sprite Diffusion Generator 👾")
    gr.Markdown("Note: DDPM algorihm may take around 1-2 minutes.")

    with gr.Row():
        weights_name = gr.Dropdown(available_weights, label="Select weights file")
        mode = gr.Radio(["ddpm", "ddim"], value="ddim", label="Generation Mode")
        context_choice = gr.Dropdown(["hero","non-hero","food","spell","side-facing"], value="hero", label="Context")

    run_btn = gr.Button("Generate")
    output = gr.Image(label="Generated Image")

    run_btn.click(run_inference, inputs=[weights_name, mode, context_choice], outputs=output)

demo.launch()