Efesasa0 commited on
Commit
b6b6742
·
1 Parent(s): 10bce0b
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Sprite Generation
3
- emoji: 🌖
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
@@ -10,4 +10,4 @@ pinned: false
10
  short_description: generation of game characther sprites from trained weights
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Sprite Generation
3
+ emoji: 👾
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
 
10
  short_description: generation of game characther sprites from trained weights
11
  ---
12
 
13
+ Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ from src import *
6
+
7
+ # device setup
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # diffusion constants
11
+ T = 3000
12
+ beta_end = 0.02
13
+ beta_start = 1e-3
14
+ betas = (beta_end - beta_start) * torch.linspace(0, 1, T+1, device=device) + beta_start
15
+ alphas = 1 - betas
16
+ alphas_hat = torch.cumsum(alphas.log(), dim=0).exp()
17
+ alphas_hat[0] = 1
18
+
19
+ # -----------------------------
20
+ # Diffusion model wrapper
21
+ # -----------------------------
22
+ class Diffusion:
23
+ def __init__(self, weights_path):
24
+ context_features = 5
25
+ features = 256
26
+ self.image_size = (16, 16)
27
+ self.model = ContextUnet(
28
+ in_channels=3,
29
+ features=features,
30
+ context_features=context_features,
31
+ image_size=self.image_size
32
+ ).to(device)
33
+ self.model.load_state_dict(torch.load(weights_path, map_location=device))
34
+ self.model.eval()
35
+
36
+ def denoise_add_noise(self, x, t, pred_noise, z=None):
37
+ if z is None:
38
+ z = torch.randn_like(x)
39
+ noise = betas.sqrt()[t] * z
40
+ mean = (x - pred_noise * ((1 - alphas[t]) / (1 - alphas_hat[t]).sqrt())) / alphas[t].sqrt()
41
+ return mean + noise
42
+
43
+ @torch.no_grad()
44
+ def sample_ddpm(self, n_sample, context):
45
+ samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device)
46
+ for i in range(T, 0, -1):
47
+ t = torch.tensor([i / T])[:, None, None, None].to(device)
48
+ z = torch.randn_like(samples) if i > 1 else 0
49
+ eps = self.model(samples, t, c=context)
50
+ samples = self.denoise_add_noise(samples, i, eps, z)
51
+ return samples
52
+
53
+ def denoise_ddim(self, x, t, t_prev, pred_noise):
54
+ ab = alphas_hat[t]
55
+ ab_prev = alphas_hat[t_prev]
56
+ x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
57
+ dir_xt = (1 - ab_prev).sqrt() * pred_noise
58
+ return x0_pred + dir_xt
59
+
60
+ @torch.no_grad()
61
+ def sample_ddim(self, n_sample, context, n=20):
62
+ samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device)
63
+ step_size = T // n
64
+ for i in range(T, 0, -step_size):
65
+ t = torch.tensor([i / T])[:, None, None, None].to(device)
66
+ eps = self.model(samples, t, c=context)
67
+ prev_i = max(i - step_size, 1)
68
+ samples = self.denoise_ddim(samples, i, prev_i, eps)
69
+ return samples
70
+
71
+ def generate(self, context, mode="ddim"):
72
+ ctx = torch.tensor(context).float().unsqueeze(0).to(device)
73
+ if mode == "ddpm":
74
+ return self.sample_ddpm(1, ctx)
75
+ else:
76
+ return self.sample_ddim(1, ctx, n=25)
77
+
78
+ # -----------------------------
79
+ # Gradio Interface
80
+ # -----------------------------
81
+ # list weights in folder
82
+ weights_folder = "weights"
83
+ os.makedirs(weights_folder, exist_ok=True)
84
+ available_weights = [f for f in os.listdir(weights_folder) if f.endswith(".pth")]
85
+
86
+ import torch.nn.functional as F
87
+
88
+ def run_inference(weights_name, mode, context_choice):
89
+ weights_path = os.path.join(weights_folder, weights_name)
90
+ diffusion = Diffusion(weights_path)
91
+
92
+ context_map = {
93
+ "hero": [1,0,0,0,0],
94
+ "non-hero": [0,1,0,0,0],
95
+ "food": [0,0,1,0,0],
96
+ "spell": [0,0,0,1,0],
97
+ "side-facing": [0,0,0,0,1],
98
+ }
99
+ context = context_map[context_choice]
100
+
101
+ samples = diffusion.generate(context=context, mode=mode)
102
+
103
+ # take the [0]th sample
104
+ img = samples[0].unsqueeze(0) # shape (1, 3, 16, 16)
105
+
106
+ # upscale to 256×256 (use 'nearest' to keep blocky pixel-art style)
107
+ img_up = F.interpolate(img, size=(256, 256), mode="nearest")
108
+
109
+ img_np = img_up[0].detach().cpu().numpy()
110
+ img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) # normalize [0,1]
111
+ img_np = np.transpose(img_np, (1,2,0)) # (H,W,C) for display
112
+
113
+ return img_np
114
+
115
+
116
+ with gr.Blocks() as demo:
117
+ gr.Markdown("## Sprite Diffusion Generator 👾")
118
+ gr.Markdown("Note: DDPM algorihm may take around 1-2 minutes.")
119
+
120
+ with gr.Row():
121
+ weights_name = gr.Dropdown(available_weights, label="Select weights file")
122
+ mode = gr.Radio(["ddpm", "ddim"], value="ddim", label="Generation Mode")
123
+ context_choice = gr.Dropdown(["hero","non-hero","food","spell","side-facing"], value="hero", label="Context")
124
+
125
+ run_btn = gr.Button("Generate")
126
+ output = gr.Image(label="Generated Image")
127
+
128
+ run_btn.click(run_inference, inputs=[weights_name, mode, context_choice], outputs=output)
129
+
130
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ matplotlib
5
+ gradio
6
+ torch.nn
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .custom_dataset import SpritesDataset, sprites_transform
2
+ from .model import ContextUnet
3
+ from .model import *
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (334 Bytes). View file
 
src/__pycache__/custom_dataset.cpython-312.pyc ADDED
Binary file (2.39 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (4.15 kB). View file
 
src/__pycache__/model_parts.cpython-312.pyc ADDED
Binary file (4.48 kB). View file
 
src/custom_dataset.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ import torchvision.transforms as transforms
5
+
6
+ class SpritesDataset(Dataset):
7
+ def __init__(self, images_path, labels_path, transform, null_context):
8
+ self.images = np.load(images_path, allow_pickle=False)
9
+ self.labels = np.load(labels_path, allow_pickle=False)
10
+
11
+ self.images_shape = self.images.shape
12
+ self.labels_shape = self.labels.shape
13
+
14
+ self.transform = transform
15
+ self.null_context = null_context
16
+
17
+ def __len__(self):
18
+ return len(self.images)
19
+
20
+ def __getitem__(self, idx):
21
+ image = self.transform(self.images[idx])
22
+
23
+ if self.null_context:
24
+ label = torch.tensor(0).to(torch.int64)
25
+ else:
26
+ label = torch.tensor(self.labels[idx]).to(torch.int64)
27
+
28
+ return image, label
29
+
30
+ def __getshape__(self):
31
+ return self.images_shape, self.labels_shape
32
+
33
+ sprites_transform = transforms.Compose([
34
+ transforms.ToTensor(),
35
+ transforms.Normalize((0.5,0.5,0.5),
36
+ (0.5,0.5,0.5))
37
+ ])
38
+
src/generators.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ def denoise_add_noise(x, t, pred_noise, z=None):
2
+ if z is None:
3
+ z = torch.randn_like(x)
4
+ noise = betas.sqrt()[t] * z
5
+ mean = (x - pred_noise * ((1 - alphas[t]) / (1 - alphas_hat[t]).sqrt())) / alphas[t].sqrt()
6
+ return mean + noise
7
+
src/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.model_parts import ResidualDoubleConv, UpSample, DownSample, EmbedFC
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class ContextUnet(nn.Module):
6
+
7
+ def __init__(self, in_channels, features=256, context_features=10, image_size=(16, 16)):
8
+ super(ContextUnet, self).__init__()
9
+
10
+ self.in_channels = in_channels
11
+ self.features = features
12
+ self.context_features = context_features
13
+ self.height, self.width = image_size
14
+
15
+ self.init_conv = ResidualDoubleConv(in_channels, features, is_residual=True)
16
+
17
+ self.down1 = DownSample(features, features)
18
+ self.down2 = DownSample(features, 2*features)
19
+
20
+ self.to_vec = nn.Sequential(
21
+ nn.AvgPool2d((4)),
22
+ nn.GELU(),
23
+ )
24
+
25
+ self.timeembed1 = EmbedFC(1, 2*features)
26
+ self.timeembed2 = EmbedFC(1, 1*features)
27
+ self.contextembed1 = EmbedFC(context_features, 2*features)
28
+ self.contextembed2 = EmbedFC(context_features, 1*features)
29
+
30
+ self.up0 = nn.Sequential(
31
+ nn.ConvTranspose2d(2*features, 2*features, self.height//4, self.height//4),
32
+ nn.GroupNorm(8, 2*features),
33
+ nn.ReLU(),
34
+ )
35
+ self.up1 = UpSample(4*features, features)
36
+ self.up2 = UpSample(2*features, features)
37
+
38
+ self.out = nn.Sequential(
39
+ nn.Conv2d(2*features, features, 3, 1, 1),
40
+ nn.GroupNorm(8, features),
41
+ nn.ReLU(),
42
+ nn.Conv2d(features, self.in_channels, 3, 1, 1),
43
+ )
44
+
45
+ def forward(self, x, t, c=None):
46
+
47
+ x = self.init_conv(x)
48
+ down1 = self.down1(x)
49
+ down2 = self.down2(down1)
50
+
51
+ hiddenvec = self.to_vec(down2)
52
+
53
+ if c is None:
54
+ c = torch.zeros(x.shape[0], self.context_features).to(x)
55
+
56
+ cemb1 = self.contextembed1(c).view(-1, self.features*2, 1, 1)
57
+ temb1 = self.timeembed1(t).view(-1, self.features*2, 1, 1)
58
+ cemb2 = self.contextembed2(c).view(-1, self.features, 1, 1)
59
+ temb2 = self.timeembed2(t).view(-1, self.features, 1, 1)
60
+
61
+ up1 = self.up0(hiddenvec)
62
+ up2 = self.up1(cemb1*up1 + temb1, down2)
63
+ up3 = self.up2(cemb2*up2 + temb2, down1)
64
+ out = self.out(torch.cat((up3, x), 1))
65
+ return out
66
+
src/model_parts.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class ResidualDoubleConv(nn.Module):
6
+ def __init__(self, in_channels, out_channels, is_residual=False):
7
+ super().__init__()
8
+
9
+ self.conv = nn.Sequential(
10
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
11
+ nn.BatchNorm2d(out_channels),
12
+ nn.GELU(),
13
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
14
+ nn.BatchNorm2d(out_channels),
15
+ nn.GELU(),
16
+ )
17
+
18
+ self.is_same_channels = in_channels == out_channels
19
+ self.is_residual = is_residual
20
+
21
+ if is_residual and not self.is_same_channels:
22
+ self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
23
+ else:
24
+ self.shortcut = None
25
+
26
+ def forward(self, x):
27
+
28
+ out = self.conv(x)
29
+
30
+ if not self.is_residual:
31
+ return out
32
+
33
+ if self.is_same_channels:
34
+ out += x
35
+ else:
36
+ out += self.shortcut(x)
37
+
38
+ return out / np.sqrt(2) # Normalizing residual flow
39
+
40
+ class UpSample(nn.Module):
41
+ def __init__(self, in_channels, out_channels):
42
+ super(UpSample, self).__init__()
43
+
44
+ self.conv = nn.Sequential(
45
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
46
+ ResidualDoubleConv(out_channels, out_channels),
47
+ ResidualDoubleConv(out_channels, out_channels),
48
+ )
49
+
50
+ def forward(self, x, skip):
51
+
52
+ x = torch.cat((x, skip), 1)
53
+ x = self.conv(x)
54
+
55
+ return x
56
+
57
+ class DownSample(nn.Module):
58
+ def __init__(self, in_channels, out_channels):
59
+ super(DownSample, self).__init__()
60
+
61
+ # Diffusion nets handle residual connections inside DoubleConv
62
+ self.conv = nn.Sequential(
63
+ ResidualDoubleConv(in_channels, out_channels),
64
+ ResidualDoubleConv(out_channels, out_channels),
65
+ nn.MaxPool2d(2),
66
+ )
67
+
68
+ def forward(self, x):
69
+
70
+ return self.conv(x)
71
+
72
+ class EmbedFC(nn.Module):
73
+ def __init__(self, input_dim, embed_dim):
74
+ super(EmbedFC, self).__init__()
75
+
76
+ self.input_dim = input_dim
77
+
78
+ self.fc = nn.Sequential(
79
+ nn.Linear(input_dim, embed_dim),
80
+ nn.GELU(),
81
+ nn.Linear(embed_dim, embed_dim),
82
+ )
83
+
84
+ def forward(self, x):
85
+
86
+ x = x.view(-1, self.input_dim)
87
+ x = self.fc(x)
88
+
89
+ return x
weights/sprites_model_100.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4955a5d39b625a60f0bd825b15fd2e9fae44b4643211054fa40418395e9a3cd
3
+ size 94376581
weights/sprites_model_150.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db83ef40e458b5079179f4d681e8b18c69b08cd07fca817fc06d88dfbd231349
3
+ size 94376581
weights/sprites_model_199.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19bdd26e7e07e95bcc230298b26564828d6cc3fb2f4ef4fa379b5bdde5a12347
3
+ size 94376581
weights/sprites_model_50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49cb4b55fb0371b40df841b1be6e2294ccdcf8c2489770b5aed08a0caa2aaf3d
3
+ size 94376422