Spaces:
Running
Running
- README.md +2 -2
- app.py +130 -0
- requirements.txt +6 -0
- src/__init__.py +3 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/custom_dataset.cpython-312.pyc +0 -0
- src/__pycache__/model.cpython-312.pyc +0 -0
- src/__pycache__/model_parts.cpython-312.pyc +0 -0
- src/custom_dataset.py +38 -0
- src/generators.py +7 -0
- src/model.py +66 -0
- src/model_parts.py +89 -0
- weights/sprites_model_100.pth +3 -0
- weights/sprites_model_150.pth +3 -0
- weights/sprites_model_199.pth +3 -0
- weights/sprites_model_50.pth +3 -0
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
title: Sprite Generation
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
|
@@ -10,4 +10,4 @@ pinned: false
|
|
| 10 |
short_description: generation of game characther sprites from trained weights
|
| 11 |
---
|
| 12 |
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: Sprite Generation
|
| 3 |
+
emoji: 👾
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
|
|
|
| 10 |
short_description: generation of game characther sprites from trained weights
|
| 11 |
---
|
| 12 |
|
| 13 |
+
Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
|
app.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
from src import *
|
| 6 |
+
|
| 7 |
+
# device setup
|
| 8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
+
|
| 10 |
+
# diffusion constants
|
| 11 |
+
T = 3000
|
| 12 |
+
beta_end = 0.02
|
| 13 |
+
beta_start = 1e-3
|
| 14 |
+
betas = (beta_end - beta_start) * torch.linspace(0, 1, T+1, device=device) + beta_start
|
| 15 |
+
alphas = 1 - betas
|
| 16 |
+
alphas_hat = torch.cumsum(alphas.log(), dim=0).exp()
|
| 17 |
+
alphas_hat[0] = 1
|
| 18 |
+
|
| 19 |
+
# -----------------------------
|
| 20 |
+
# Diffusion model wrapper
|
| 21 |
+
# -----------------------------
|
| 22 |
+
class Diffusion:
|
| 23 |
+
def __init__(self, weights_path):
|
| 24 |
+
context_features = 5
|
| 25 |
+
features = 256
|
| 26 |
+
self.image_size = (16, 16)
|
| 27 |
+
self.model = ContextUnet(
|
| 28 |
+
in_channels=3,
|
| 29 |
+
features=features,
|
| 30 |
+
context_features=context_features,
|
| 31 |
+
image_size=self.image_size
|
| 32 |
+
).to(device)
|
| 33 |
+
self.model.load_state_dict(torch.load(weights_path, map_location=device))
|
| 34 |
+
self.model.eval()
|
| 35 |
+
|
| 36 |
+
def denoise_add_noise(self, x, t, pred_noise, z=None):
|
| 37 |
+
if z is None:
|
| 38 |
+
z = torch.randn_like(x)
|
| 39 |
+
noise = betas.sqrt()[t] * z
|
| 40 |
+
mean = (x - pred_noise * ((1 - alphas[t]) / (1 - alphas_hat[t]).sqrt())) / alphas[t].sqrt()
|
| 41 |
+
return mean + noise
|
| 42 |
+
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def sample_ddpm(self, n_sample, context):
|
| 45 |
+
samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device)
|
| 46 |
+
for i in range(T, 0, -1):
|
| 47 |
+
t = torch.tensor([i / T])[:, None, None, None].to(device)
|
| 48 |
+
z = torch.randn_like(samples) if i > 1 else 0
|
| 49 |
+
eps = self.model(samples, t, c=context)
|
| 50 |
+
samples = self.denoise_add_noise(samples, i, eps, z)
|
| 51 |
+
return samples
|
| 52 |
+
|
| 53 |
+
def denoise_ddim(self, x, t, t_prev, pred_noise):
|
| 54 |
+
ab = alphas_hat[t]
|
| 55 |
+
ab_prev = alphas_hat[t_prev]
|
| 56 |
+
x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
|
| 57 |
+
dir_xt = (1 - ab_prev).sqrt() * pred_noise
|
| 58 |
+
return x0_pred + dir_xt
|
| 59 |
+
|
| 60 |
+
@torch.no_grad()
|
| 61 |
+
def sample_ddim(self, n_sample, context, n=20):
|
| 62 |
+
samples = torch.randn(n_sample, 3, self.image_size[0], self.image_size[1]).to(device)
|
| 63 |
+
step_size = T // n
|
| 64 |
+
for i in range(T, 0, -step_size):
|
| 65 |
+
t = torch.tensor([i / T])[:, None, None, None].to(device)
|
| 66 |
+
eps = self.model(samples, t, c=context)
|
| 67 |
+
prev_i = max(i - step_size, 1)
|
| 68 |
+
samples = self.denoise_ddim(samples, i, prev_i, eps)
|
| 69 |
+
return samples
|
| 70 |
+
|
| 71 |
+
def generate(self, context, mode="ddim"):
|
| 72 |
+
ctx = torch.tensor(context).float().unsqueeze(0).to(device)
|
| 73 |
+
if mode == "ddpm":
|
| 74 |
+
return self.sample_ddpm(1, ctx)
|
| 75 |
+
else:
|
| 76 |
+
return self.sample_ddim(1, ctx, n=25)
|
| 77 |
+
|
| 78 |
+
# -----------------------------
|
| 79 |
+
# Gradio Interface
|
| 80 |
+
# -----------------------------
|
| 81 |
+
# list weights in folder
|
| 82 |
+
weights_folder = "weights"
|
| 83 |
+
os.makedirs(weights_folder, exist_ok=True)
|
| 84 |
+
available_weights = [f for f in os.listdir(weights_folder) if f.endswith(".pth")]
|
| 85 |
+
|
| 86 |
+
import torch.nn.functional as F
|
| 87 |
+
|
| 88 |
+
def run_inference(weights_name, mode, context_choice):
|
| 89 |
+
weights_path = os.path.join(weights_folder, weights_name)
|
| 90 |
+
diffusion = Diffusion(weights_path)
|
| 91 |
+
|
| 92 |
+
context_map = {
|
| 93 |
+
"hero": [1,0,0,0,0],
|
| 94 |
+
"non-hero": [0,1,0,0,0],
|
| 95 |
+
"food": [0,0,1,0,0],
|
| 96 |
+
"spell": [0,0,0,1,0],
|
| 97 |
+
"side-facing": [0,0,0,0,1],
|
| 98 |
+
}
|
| 99 |
+
context = context_map[context_choice]
|
| 100 |
+
|
| 101 |
+
samples = diffusion.generate(context=context, mode=mode)
|
| 102 |
+
|
| 103 |
+
# take the [0]th sample
|
| 104 |
+
img = samples[0].unsqueeze(0) # shape (1, 3, 16, 16)
|
| 105 |
+
|
| 106 |
+
# upscale to 256×256 (use 'nearest' to keep blocky pixel-art style)
|
| 107 |
+
img_up = F.interpolate(img, size=(256, 256), mode="nearest")
|
| 108 |
+
|
| 109 |
+
img_np = img_up[0].detach().cpu().numpy()
|
| 110 |
+
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) # normalize [0,1]
|
| 111 |
+
img_np = np.transpose(img_np, (1,2,0)) # (H,W,C) for display
|
| 112 |
+
|
| 113 |
+
return img_np
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
with gr.Blocks() as demo:
|
| 117 |
+
gr.Markdown("## Sprite Diffusion Generator 👾")
|
| 118 |
+
gr.Markdown("Note: DDPM algorihm may take around 1-2 minutes.")
|
| 119 |
+
|
| 120 |
+
with gr.Row():
|
| 121 |
+
weights_name = gr.Dropdown(available_weights, label="Select weights file")
|
| 122 |
+
mode = gr.Radio(["ddpm", "ddim"], value="ddim", label="Generation Mode")
|
| 123 |
+
context_choice = gr.Dropdown(["hero","non-hero","food","spell","side-facing"], value="hero", label="Context")
|
| 124 |
+
|
| 125 |
+
run_btn = gr.Button("Generate")
|
| 126 |
+
output = gr.Image(label="Generated Image")
|
| 127 |
+
|
| 128 |
+
run_btn.click(run_inference, inputs=[weights_name, mode, context_choice], outputs=output)
|
| 129 |
+
|
| 130 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
numpy
|
| 4 |
+
matplotlib
|
| 5 |
+
gradio
|
| 6 |
+
torch.nn
|
src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .custom_dataset import SpritesDataset, sprites_transform
|
| 2 |
+
from .model import ContextUnet
|
| 3 |
+
from .model import *
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (334 Bytes). View file
|
|
|
src/__pycache__/custom_dataset.cpython-312.pyc
ADDED
|
Binary file (2.39 kB). View file
|
|
|
src/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (4.15 kB). View file
|
|
|
src/__pycache__/model_parts.cpython-312.pyc
ADDED
|
Binary file (4.48 kB). View file
|
|
|
src/custom_dataset.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
|
| 6 |
+
class SpritesDataset(Dataset):
|
| 7 |
+
def __init__(self, images_path, labels_path, transform, null_context):
|
| 8 |
+
self.images = np.load(images_path, allow_pickle=False)
|
| 9 |
+
self.labels = np.load(labels_path, allow_pickle=False)
|
| 10 |
+
|
| 11 |
+
self.images_shape = self.images.shape
|
| 12 |
+
self.labels_shape = self.labels.shape
|
| 13 |
+
|
| 14 |
+
self.transform = transform
|
| 15 |
+
self.null_context = null_context
|
| 16 |
+
|
| 17 |
+
def __len__(self):
|
| 18 |
+
return len(self.images)
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, idx):
|
| 21 |
+
image = self.transform(self.images[idx])
|
| 22 |
+
|
| 23 |
+
if self.null_context:
|
| 24 |
+
label = torch.tensor(0).to(torch.int64)
|
| 25 |
+
else:
|
| 26 |
+
label = torch.tensor(self.labels[idx]).to(torch.int64)
|
| 27 |
+
|
| 28 |
+
return image, label
|
| 29 |
+
|
| 30 |
+
def __getshape__(self):
|
| 31 |
+
return self.images_shape, self.labels_shape
|
| 32 |
+
|
| 33 |
+
sprites_transform = transforms.Compose([
|
| 34 |
+
transforms.ToTensor(),
|
| 35 |
+
transforms.Normalize((0.5,0.5,0.5),
|
| 36 |
+
(0.5,0.5,0.5))
|
| 37 |
+
])
|
| 38 |
+
|
src/generators.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def denoise_add_noise(x, t, pred_noise, z=None):
|
| 2 |
+
if z is None:
|
| 3 |
+
z = torch.randn_like(x)
|
| 4 |
+
noise = betas.sqrt()[t] * z
|
| 5 |
+
mean = (x - pred_noise * ((1 - alphas[t]) / (1 - alphas_hat[t]).sqrt())) / alphas[t].sqrt()
|
| 6 |
+
return mean + noise
|
| 7 |
+
|
src/model.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.model_parts import ResidualDoubleConv, UpSample, DownSample, EmbedFC
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class ContextUnet(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, in_channels, features=256, context_features=10, image_size=(16, 16)):
|
| 8 |
+
super(ContextUnet, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.in_channels = in_channels
|
| 11 |
+
self.features = features
|
| 12 |
+
self.context_features = context_features
|
| 13 |
+
self.height, self.width = image_size
|
| 14 |
+
|
| 15 |
+
self.init_conv = ResidualDoubleConv(in_channels, features, is_residual=True)
|
| 16 |
+
|
| 17 |
+
self.down1 = DownSample(features, features)
|
| 18 |
+
self.down2 = DownSample(features, 2*features)
|
| 19 |
+
|
| 20 |
+
self.to_vec = nn.Sequential(
|
| 21 |
+
nn.AvgPool2d((4)),
|
| 22 |
+
nn.GELU(),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
self.timeembed1 = EmbedFC(1, 2*features)
|
| 26 |
+
self.timeembed2 = EmbedFC(1, 1*features)
|
| 27 |
+
self.contextembed1 = EmbedFC(context_features, 2*features)
|
| 28 |
+
self.contextembed2 = EmbedFC(context_features, 1*features)
|
| 29 |
+
|
| 30 |
+
self.up0 = nn.Sequential(
|
| 31 |
+
nn.ConvTranspose2d(2*features, 2*features, self.height//4, self.height//4),
|
| 32 |
+
nn.GroupNorm(8, 2*features),
|
| 33 |
+
nn.ReLU(),
|
| 34 |
+
)
|
| 35 |
+
self.up1 = UpSample(4*features, features)
|
| 36 |
+
self.up2 = UpSample(2*features, features)
|
| 37 |
+
|
| 38 |
+
self.out = nn.Sequential(
|
| 39 |
+
nn.Conv2d(2*features, features, 3, 1, 1),
|
| 40 |
+
nn.GroupNorm(8, features),
|
| 41 |
+
nn.ReLU(),
|
| 42 |
+
nn.Conv2d(features, self.in_channels, 3, 1, 1),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(self, x, t, c=None):
|
| 46 |
+
|
| 47 |
+
x = self.init_conv(x)
|
| 48 |
+
down1 = self.down1(x)
|
| 49 |
+
down2 = self.down2(down1)
|
| 50 |
+
|
| 51 |
+
hiddenvec = self.to_vec(down2)
|
| 52 |
+
|
| 53 |
+
if c is None:
|
| 54 |
+
c = torch.zeros(x.shape[0], self.context_features).to(x)
|
| 55 |
+
|
| 56 |
+
cemb1 = self.contextembed1(c).view(-1, self.features*2, 1, 1)
|
| 57 |
+
temb1 = self.timeembed1(t).view(-1, self.features*2, 1, 1)
|
| 58 |
+
cemb2 = self.contextembed2(c).view(-1, self.features, 1, 1)
|
| 59 |
+
temb2 = self.timeembed2(t).view(-1, self.features, 1, 1)
|
| 60 |
+
|
| 61 |
+
up1 = self.up0(hiddenvec)
|
| 62 |
+
up2 = self.up1(cemb1*up1 + temb1, down2)
|
| 63 |
+
up3 = self.up2(cemb2*up2 + temb2, down1)
|
| 64 |
+
out = self.out(torch.cat((up3, x), 1))
|
| 65 |
+
return out
|
| 66 |
+
|
src/model_parts.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class ResidualDoubleConv(nn.Module):
|
| 6 |
+
def __init__(self, in_channels, out_channels, is_residual=False):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.conv = nn.Sequential(
|
| 10 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
|
| 11 |
+
nn.BatchNorm2d(out_channels),
|
| 12 |
+
nn.GELU(),
|
| 13 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
|
| 14 |
+
nn.BatchNorm2d(out_channels),
|
| 15 |
+
nn.GELU(),
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
self.is_same_channels = in_channels == out_channels
|
| 19 |
+
self.is_residual = is_residual
|
| 20 |
+
|
| 21 |
+
if is_residual and not self.is_same_channels:
|
| 22 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
| 23 |
+
else:
|
| 24 |
+
self.shortcut = None
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
|
| 28 |
+
out = self.conv(x)
|
| 29 |
+
|
| 30 |
+
if not self.is_residual:
|
| 31 |
+
return out
|
| 32 |
+
|
| 33 |
+
if self.is_same_channels:
|
| 34 |
+
out += x
|
| 35 |
+
else:
|
| 36 |
+
out += self.shortcut(x)
|
| 37 |
+
|
| 38 |
+
return out / np.sqrt(2) # Normalizing residual flow
|
| 39 |
+
|
| 40 |
+
class UpSample(nn.Module):
|
| 41 |
+
def __init__(self, in_channels, out_channels):
|
| 42 |
+
super(UpSample, self).__init__()
|
| 43 |
+
|
| 44 |
+
self.conv = nn.Sequential(
|
| 45 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
|
| 46 |
+
ResidualDoubleConv(out_channels, out_channels),
|
| 47 |
+
ResidualDoubleConv(out_channels, out_channels),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def forward(self, x, skip):
|
| 51 |
+
|
| 52 |
+
x = torch.cat((x, skip), 1)
|
| 53 |
+
x = self.conv(x)
|
| 54 |
+
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
class DownSample(nn.Module):
|
| 58 |
+
def __init__(self, in_channels, out_channels):
|
| 59 |
+
super(DownSample, self).__init__()
|
| 60 |
+
|
| 61 |
+
# Diffusion nets handle residual connections inside DoubleConv
|
| 62 |
+
self.conv = nn.Sequential(
|
| 63 |
+
ResidualDoubleConv(in_channels, out_channels),
|
| 64 |
+
ResidualDoubleConv(out_channels, out_channels),
|
| 65 |
+
nn.MaxPool2d(2),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
|
| 70 |
+
return self.conv(x)
|
| 71 |
+
|
| 72 |
+
class EmbedFC(nn.Module):
|
| 73 |
+
def __init__(self, input_dim, embed_dim):
|
| 74 |
+
super(EmbedFC, self).__init__()
|
| 75 |
+
|
| 76 |
+
self.input_dim = input_dim
|
| 77 |
+
|
| 78 |
+
self.fc = nn.Sequential(
|
| 79 |
+
nn.Linear(input_dim, embed_dim),
|
| 80 |
+
nn.GELU(),
|
| 81 |
+
nn.Linear(embed_dim, embed_dim),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
|
| 86 |
+
x = x.view(-1, self.input_dim)
|
| 87 |
+
x = self.fc(x)
|
| 88 |
+
|
| 89 |
+
return x
|
weights/sprites_model_100.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f4955a5d39b625a60f0bd825b15fd2e9fae44b4643211054fa40418395e9a3cd
|
| 3 |
+
size 94376581
|
weights/sprites_model_150.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db83ef40e458b5079179f4d681e8b18c69b08cd07fca817fc06d88dfbd231349
|
| 3 |
+
size 94376581
|
weights/sprites_model_199.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:19bdd26e7e07e95bcc230298b26564828d6cc3fb2f4ef4fa379b5bdde5a12347
|
| 3 |
+
size 94376581
|
weights/sprites_model_50.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:49cb4b55fb0371b40df841b1be6e2294ccdcf8c2489770b5aed08a0caa2aaf3d
|
| 3 |
+
size 94376422
|