Yash Nagraj commited on
Commit ·
35839a1
1
Parent(s): 2c4de69
Add train files
Browse files
model.ipynb
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 20,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import torch\n",
|
| 10 |
+
"import torch.nn as nn\n",
|
| 11 |
+
"import numpy as np\n",
|
| 12 |
+
"from torchvision.utils import save_image, make_grid\n",
|
| 13 |
+
"import matplotlib.pyplot as plt\n",
|
| 14 |
+
"from matplotlib.animation import FuncAnimation, PillowWriter\n",
|
| 15 |
+
"import os\n",
|
| 16 |
+
"import torchvision.transforms as transforms\n",
|
| 17 |
+
"from torch.utils.data import Dataset\n",
|
| 18 |
+
"from PIL import Image\n",
|
| 19 |
+
"from torch.utils.data import DataLoader\n",
|
| 20 |
+
"from tqdm.auto import tqdm\n",
|
| 21 |
+
"import torch.nn.functional as F"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": 3,
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"class ResidualBlock(nn.Module):\n",
|
| 31 |
+
" def __init__(self, in_channels: int, out_channels: int,is_res: bool = False) -> None:\n",
|
| 32 |
+
" super(ResidualBlock,self).__init__()\n",
|
| 33 |
+
"\n",
|
| 34 |
+
" self.same_channesls = in_channels == out_channels\n",
|
| 35 |
+
"\n",
|
| 36 |
+
" self.is_res = is_res\n",
|
| 37 |
+
"\n",
|
| 38 |
+
" self.conv1 = nn.Sequential(\n",
|
| 39 |
+
" nn.Conv2d(in_channels,out_channels,3,1,1),\n",
|
| 40 |
+
" nn.BatchNorm2d(out_channels),\n",
|
| 41 |
+
" nn.GELU(),\n",
|
| 42 |
+
" )\n",
|
| 43 |
+
"\n",
|
| 44 |
+
" self.conv2 = nn.Sequential(\n",
|
| 45 |
+
" nn.Conv2d(out_channels,out_channels,3,1,1),\n",
|
| 46 |
+
" nn.BatchNorm2d(out_channels),\n",
|
| 47 |
+
" nn.GELU(),\n",
|
| 48 |
+
" )\n",
|
| 49 |
+
"\n",
|
| 50 |
+
" def forward(self,x): \n",
|
| 51 |
+
" if self.is_res:\n",
|
| 52 |
+
" x1 = self.conv1(x)\n",
|
| 53 |
+
"\n",
|
| 54 |
+
" x2 = self.conv2(x1)\n",
|
| 55 |
+
"\n",
|
| 56 |
+
" if self.same_channesls:\n",
|
| 57 |
+
" out = x1 + x2\n",
|
| 58 |
+
" else:\n",
|
| 59 |
+
" shortcut = nn.Conv2d(x.shape[1],x2.shape[1],1,1,0).to(x.device)\n",
|
| 60 |
+
" out = shortcut(x) + x2\n",
|
| 61 |
+
"\n",
|
| 62 |
+
" return out / 1.414\n",
|
| 63 |
+
" \n",
|
| 64 |
+
" else:\n",
|
| 65 |
+
" x1 = self.conv1(x)\n",
|
| 66 |
+
" x2 = self.conv2(x1)\n",
|
| 67 |
+
" return x2\n"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"cell_type": "code",
|
| 72 |
+
"execution_count": 4,
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"outputs": [],
|
| 75 |
+
"source": [
|
| 76 |
+
"class UnetUp(nn.Module):\n",
|
| 77 |
+
" def __init__(self, in_channels, out_channels) -> None:\n",
|
| 78 |
+
" super(UnetUp,self).__init__()\n",
|
| 79 |
+
"\n",
|
| 80 |
+
" self.model = nn.Sequential(\n",
|
| 81 |
+
" nn.ConvTranspose2d(in_channels,out_channels,2,2),\n",
|
| 82 |
+
" ResidualBlock(out_channels,out_channels),\n",
|
| 83 |
+
" ResidualBlock(out_channels,out_channels),\n",
|
| 84 |
+
" )\n",
|
| 85 |
+
"\n",
|
| 86 |
+
" def forward(self, x, skip):\n",
|
| 87 |
+
" x = torch.cat([x,skip],1)\n",
|
| 88 |
+
"\n",
|
| 89 |
+
" x = self.model(x)\n",
|
| 90 |
+
" return x\n",
|
| 91 |
+
" \n",
|
| 92 |
+
"class UnetDown(nn.Module):\n",
|
| 93 |
+
" def __init__(self, input_channels, out_channels) -> None:\n",
|
| 94 |
+
" super(UnetDown,self).__init__()\n",
|
| 95 |
+
"\n",
|
| 96 |
+
" self.model = nn.Sequential(\n",
|
| 97 |
+
" ResidualBlock(input_channels,out_channels),\n",
|
| 98 |
+
" ResidualBlock(out_channels,out_channels),\n",
|
| 99 |
+
" nn.MaxPool2d(2)\n",
|
| 100 |
+
" )\n",
|
| 101 |
+
"\n",
|
| 102 |
+
" def forward(self,x):\n",
|
| 103 |
+
" return self.model(x)\n",
|
| 104 |
+
" \n",
|
| 105 |
+
"\n",
|
| 106 |
+
"class EmbedFC(nn.Module):\n",
|
| 107 |
+
" def __init__(self, input_dim,embed_dm) -> None:\n",
|
| 108 |
+
" super(EmbedFC,self).__init__()\n",
|
| 109 |
+
"\n",
|
| 110 |
+
" self.input_dim = input_dim\n",
|
| 111 |
+
" \n",
|
| 112 |
+
" self.model = nn.Sequential(\n",
|
| 113 |
+
" nn.Linear(input_dim,embed_dm),\n",
|
| 114 |
+
" nn.GELU(),\n",
|
| 115 |
+
" nn.Linear(embed_dm,embed_dm),\n",
|
| 116 |
+
" )\n",
|
| 117 |
+
"\n",
|
| 118 |
+
" def forward(self,x):\n",
|
| 119 |
+
" x = x.view(-1,self.input_dim)\n",
|
| 120 |
+
" return self.model(x)\n"
|
| 121 |
+
]
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"cell_type": "code",
|
| 125 |
+
"execution_count": 5,
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"outputs": [],
|
| 128 |
+
"source": [
|
| 129 |
+
"def unorm(x):\n",
|
| 130 |
+
" # unity norm. results in range of [0,1]\n",
|
| 131 |
+
" # assume x (h,w,3)\n",
|
| 132 |
+
" xmax = x.max((0,1))\n",
|
| 133 |
+
" xmin = x.min((0,1))\n",
|
| 134 |
+
" return(x - xmin)/(xmax - xmin)\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"def norm_all(store, n_t, n_s):\n",
|
| 137 |
+
" # runs unity norm on all timesteps of all samples\n",
|
| 138 |
+
" nstore = np.zeros_like(store)\n",
|
| 139 |
+
" for t in range(n_t):\n",
|
| 140 |
+
" for s in range(n_s):\n",
|
| 141 |
+
" nstore[t,s] = unorm(store[t,s])\n",
|
| 142 |
+
" return nstore\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"def norm_torch(x_all):\n",
|
| 145 |
+
" # runs unity norm on all timesteps of all samples\n",
|
| 146 |
+
" # input is (n_samples, 3,h,w), the torch image format\n",
|
| 147 |
+
" x = x_all.cpu().numpy()\n",
|
| 148 |
+
" xmax = x.max((2,3))\n",
|
| 149 |
+
" xmin = x.min((2,3))\n",
|
| 150 |
+
" xmax = np.expand_dims(xmax,(2,3)) \n",
|
| 151 |
+
" xmin = np.expand_dims(xmin,(2,3))\n",
|
| 152 |
+
" nstore = (x - xmin)/(xmax - xmin)\n",
|
| 153 |
+
" return torch.from_numpy(nstore)\n"
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"cell_type": "code",
|
| 158 |
+
"execution_count": 6,
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [],
|
| 161 |
+
"source": [
|
| 162 |
+
"def plot_grid(x,n_sample,n_rows,save_dir,w):\n",
|
| 163 |
+
" # x:(n_sample, 3, h, w)\n",
|
| 164 |
+
" ncols = n_sample//n_rows\n",
|
| 165 |
+
" grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row.\n",
|
| 166 |
+
" save_image(grid, save_dir + f\"run_image_w{w}.png\")\n",
|
| 167 |
+
" print('saved image at ' + save_dir + f\"run_image_w{w}.png\")\n",
|
| 168 |
+
" return grid\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False):\n",
|
| 171 |
+
" ncols = n_sample//nrows\n",
|
| 172 |
+
" sx_gen_store = np.moveaxis(x_gen_store,2,4) \n",
|
| 173 |
+
" nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) \n",
|
| 174 |
+
" fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows))\n",
|
| 175 |
+
" def animate_diff(i, store):\n",
|
| 176 |
+
" print(f'gif animating frame {i} of {store.shape[0]}', end='\\r')\n",
|
| 177 |
+
" plots = []\n",
|
| 178 |
+
" for row in range(nrows):\n",
|
| 179 |
+
" for col in range(ncols):\n",
|
| 180 |
+
" axs[row, col].clear()\n",
|
| 181 |
+
" axs[row, col].set_xticks([])\n",
|
| 182 |
+
" axs[row, col].set_yticks([])\n",
|
| 183 |
+
" plots.append(axs[row, col].imshow(store[i,(row*ncols)+col]))\n",
|
| 184 |
+
" return plots\n",
|
| 185 |
+
" ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0]) \n",
|
| 186 |
+
" plt.close()\n",
|
| 187 |
+
" if save:\n",
|
| 188 |
+
" ani.save(save_dir + f\"{fn}_w{w}.gif\", dpi=100, writer=PillowWriter(fps=5))\n",
|
| 189 |
+
" print('saved gif at ' + save_dir + f\"{fn}_w{w}.gif\")\n",
|
| 190 |
+
" return ani\n"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"cell_type": "code",
|
| 195 |
+
"execution_count": 7,
|
| 196 |
+
"metadata": {},
|
| 197 |
+
"outputs": [],
|
| 198 |
+
"source": [
|
| 199 |
+
"transform = transforms.Compose([\n",
|
| 200 |
+
" transforms.ToTensor(), # from [0,255] to range [0.0,1.0]\n",
|
| 201 |
+
" transforms.Normalize((0.5,), (0.5,)) # range [-1,1]\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"])\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"class CustomDataset(Dataset):\n",
|
| 206 |
+
" def __init__(self, sfilename, lfilename, transform, null_context=False):\n",
|
| 207 |
+
" self.sprites = np.load(sfilename)\n",
|
| 208 |
+
" self.slabels = np.load(lfilename)\n",
|
| 209 |
+
" print(f\"sprite shape: {self.sprites.shape}\")\n",
|
| 210 |
+
" print(f\"labels shape: {self.slabels.shape}\")\n",
|
| 211 |
+
" self.transform = transform\n",
|
| 212 |
+
" self.null_context = null_context\n",
|
| 213 |
+
" self.sprites_shape = self.sprites.shape\n",
|
| 214 |
+
" self.slabel_shape = self.slabels.shape\n",
|
| 215 |
+
" \n",
|
| 216 |
+
" def __len__(self):\n",
|
| 217 |
+
" return len(self.sprites)\n",
|
| 218 |
+
" \n",
|
| 219 |
+
" def __getitem__(self, idx):\n",
|
| 220 |
+
" if self.transform:\n",
|
| 221 |
+
" image = self.transform(self.sprites[idx])\n",
|
| 222 |
+
" if self.null_context:\n",
|
| 223 |
+
" label = torch.tensor(0).to(torch.int64)\n",
|
| 224 |
+
" else:\n",
|
| 225 |
+
" label = torch.tensor(self.slabels[idx]).to(torch.int64)\n",
|
| 226 |
+
" return (image, label)\n"
|
| 227 |
+
]
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"cell_type": "code",
|
| 231 |
+
"execution_count": 8,
|
| 232 |
+
"metadata": {},
|
| 233 |
+
"outputs": [],
|
| 234 |
+
"source": [
|
| 235 |
+
"class ContextUnet(nn.Module):\n",
|
| 236 |
+
" def __init__(self,in_channels, n_feat = 256,n_cfeat = 10, height = 28) -> None:\n",
|
| 237 |
+
" super(ContextUnet,self).__init__()\n",
|
| 238 |
+
"\n",
|
| 239 |
+
" self.in_channels = in_channels\n",
|
| 240 |
+
" self.n_feat = n_feat\n",
|
| 241 |
+
" self.n_cfeat = n_cfeat\n",
|
| 242 |
+
" self.h = height\n",
|
| 243 |
+
"\n",
|
| 244 |
+
" self.init_conv = ResidualBlock(in_channels,n_feat,is_res=True)\n",
|
| 245 |
+
"\n",
|
| 246 |
+
" self.down1 = UnetDown(n_feat,n_feat)\n",
|
| 247 |
+
" self.down2 = UnetDown(n_feat,n_feat * 2)\n",
|
| 248 |
+
"\n",
|
| 249 |
+
" self.to_vec = nn.Sequential(nn.AvgPool2d((4)),nn.GELU())\n",
|
| 250 |
+
"\n",
|
| 251 |
+
" self.timeembed1 = EmbedFC(1, 2 *n_feat)\n",
|
| 252 |
+
" self.timeembed2 = EmbedFC(1,n_feat)\n",
|
| 253 |
+
" self.contextembed1 = EmbedFC(n_cfeat,2 * n_feat)\n",
|
| 254 |
+
" self.contextembed2 = EmbedFC(n_cfeat,n_feat)\n",
|
| 255 |
+
"\n",
|
| 256 |
+
" self.up0 = nn.Sequential(\n",
|
| 257 |
+
" nn.ConvTranspose2d(2 * n_feat,2*n_feat,self.h // 4,self.h // 4),\n",
|
| 258 |
+
" nn.GroupNorm(8, 2*n_feat),\n",
|
| 259 |
+
" nn.ReLU(),\n",
|
| 260 |
+
" )\n",
|
| 261 |
+
"\n",
|
| 262 |
+
" self.up1 = UnetUp(4 * n_feat,n_feat)\n",
|
| 263 |
+
" self.up2 = UnetUp(2 * n_feat,n_feat)\n",
|
| 264 |
+
"\n",
|
| 265 |
+
" self.out = nn.Sequential(\n",
|
| 266 |
+
" nn.Conv2d(2 * n_feat, n_feat,3,1,1),\n",
|
| 267 |
+
" nn.GroupNorm(8,n_feat),\n",
|
| 268 |
+
" nn.ReLU(),\n",
|
| 269 |
+
" nn.Conv2d(n_feat,self.in_channels,3,1,1)\n",
|
| 270 |
+
" )\n",
|
| 271 |
+
"\n",
|
| 272 |
+
" def forward(self,x,t,c=None):\n",
|
| 273 |
+
" x = self.init_conv(x)\n",
|
| 274 |
+
"\n",
|
| 275 |
+
" down1 = self.down1(x)\n",
|
| 276 |
+
" down2 = self.down2(down1)\n",
|
| 277 |
+
"\n",
|
| 278 |
+
" hidden_vec = self.to_vec(down2)\n",
|
| 279 |
+
"\n",
|
| 280 |
+
" if c is None:\n",
|
| 281 |
+
" c = torch.zeros(x.shape[0],self.n_cfeat).to(x)\n",
|
| 282 |
+
" \n",
|
| 283 |
+
" cemb1 = self.contextembed1(c).view(-1,self.n_cfeat*2,1,1)\n",
|
| 284 |
+
" temb1 = self.timeembed1(t).view(-1,self.n_cfeat * 2,1,1)\n",
|
| 285 |
+
" cemb2 = self.contextembed2(c).view(-1,self.n_cfeat,1,1)\n",
|
| 286 |
+
" temb2 = self.timeembed2(t).view(-1,self.n_cfeat,1,1)\n",
|
| 287 |
+
"\n",
|
| 288 |
+
" up0 = self.up0(hidden_vec)\n",
|
| 289 |
+
" up1 =self.up1(up0*cemb1 + temb1,down2)\n",
|
| 290 |
+
" up2 = self.up2(up1*cemb2+temb2,down1)\n",
|
| 291 |
+
"\n",
|
| 292 |
+
" out = self.out(torch.cat((up2,x),1))\n",
|
| 293 |
+
"\n",
|
| 294 |
+
" return out"
|
| 295 |
+
]
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"cell_type": "code",
|
| 299 |
+
"execution_count": 14,
|
| 300 |
+
"metadata": {},
|
| 301 |
+
"outputs": [],
|
| 302 |
+
"source": [
|
| 303 |
+
"# Hyperparameters\n",
|
| 304 |
+
"\n",
|
| 305 |
+
"timesteps = 500\n",
|
| 306 |
+
"beta1 = 1e-4\n",
|
| 307 |
+
"beta2 = 0.02\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"device = \"cuda\"\n",
|
| 310 |
+
"n_feat = 64\n",
|
| 311 |
+
"n_cfeat = 5\n",
|
| 312 |
+
"height = 16\n",
|
| 313 |
+
"save_dir=\"./checkpoints\"\n",
|
| 314 |
+
"\n",
|
| 315 |
+
"batch_size = 100\n",
|
| 316 |
+
"n_epoch = 40\n",
|
| 317 |
+
"lrate = 1e-3"
|
| 318 |
+
]
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"cell_type": "code",
|
| 322 |
+
"execution_count": 12,
|
| 323 |
+
"metadata": {},
|
| 324 |
+
"outputs": [
|
| 325 |
+
{
|
| 326 |
+
"name": "stdout",
|
| 327 |
+
"output_type": "stream",
|
| 328 |
+
"text": [
|
| 329 |
+
"torch.Size([501])\n",
|
| 330 |
+
"torch.Size([501])\n",
|
| 331 |
+
"torch.Size([501])\n"
|
| 332 |
+
]
|
| 333 |
+
}
|
| 334 |
+
],
|
| 335 |
+
"source": [
|
| 336 |
+
"b_t = (beta2 - beta1) * torch.linspace(0,1,timesteps+1,device=device) + beta1\n",
|
| 337 |
+
"a_t = 1 - b_t\n",
|
| 338 |
+
"a_bt = torch.cumsum(a_t.log(),0).exp()\n",
|
| 339 |
+
"a_bt[0] = 1"
|
| 340 |
+
]
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"cell_type": "code",
|
| 344 |
+
"execution_count": null,
|
| 345 |
+
"metadata": {},
|
| 346 |
+
"outputs": [],
|
| 347 |
+
"source": [
|
| 348 |
+
"dataset = CustomDataset(\"./sprites_1788_16x16.npy\", \"./sprite_labels_nc_1788_16x16.npy\", transform, null_context=False)\n",
|
| 349 |
+
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)"
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"cell_type": "code",
|
| 354 |
+
"execution_count": 17,
|
| 355 |
+
"metadata": {},
|
| 356 |
+
"outputs": [],
|
| 357 |
+
"source": [
|
| 358 |
+
"nn_model = ContextUnet(3,n_feat,n_cfeat,height)\n",
|
| 359 |
+
"optim = torch.optim.Adam(nn_model.parameters(),lrate)"
|
| 360 |
+
]
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"cell_type": "code",
|
| 364 |
+
"execution_count": 16,
|
| 365 |
+
"metadata": {},
|
| 366 |
+
"outputs": [],
|
| 367 |
+
"source": [
|
| 368 |
+
"def perturb_input(x, t, noise):\n",
|
| 369 |
+
" return a_bt.sqrt()[t, None, None, None] * x + (1 - a_bt[t, None, None, None]) * noise"
|
| 370 |
+
]
|
| 371 |
+
},
|
| 372 |
+
{
|
| 373 |
+
"cell_type": "code",
|
| 374 |
+
"execution_count": null,
|
| 375 |
+
"metadata": {},
|
| 376 |
+
"outputs": [],
|
| 377 |
+
"source": [
|
| 378 |
+
"nn_model.train()\n",
|
| 379 |
+
"\n",
|
| 380 |
+
"for epoch in range(n_epoch):\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" optim.param_groups[0]['lr'] = lrate * (1-epoch/n_epoch)\n",
|
| 383 |
+
" for x,_ in tqdm(dataloader):\n",
|
| 384 |
+
" optim.zero_grad()\n",
|
| 385 |
+
"\n",
|
| 386 |
+
" x = x.to(device)\n",
|
| 387 |
+
"\n",
|
| 388 |
+
" t = torch.randint(1,timesteps+1,x.shape[0]).to(device)\n",
|
| 389 |
+
" noise = torch.randn_like(x)\n",
|
| 390 |
+
" x_pert = perturb_input(x,t,noise)\n",
|
| 391 |
+
"\n",
|
| 392 |
+
" pred = nn_model(x_pert,t / timesteps)\n",
|
| 393 |
+
"\n",
|
| 394 |
+
" loss = F.mse_loss(pred,noise)\n",
|
| 395 |
+
" loss.backward()\n",
|
| 396 |
+
" optim.step()\n",
|
| 397 |
+
"\n",
|
| 398 |
+
" if epoch % 1 == 0 and epoch >0:\n",
|
| 399 |
+
" if not os.path.exists(save_dir):\n",
|
| 400 |
+
" os.mkdir(save_dir)\n",
|
| 401 |
+
" torch.save(nn_model,save_dir + f\"model_Epoch{epoch}.pth\")\n",
|
| 402 |
+
" print(\"Saved model\")\n"
|
| 403 |
+
]
|
| 404 |
+
},
|
| 405 |
+
{
|
| 406 |
+
"cell_type": "code",
|
| 407 |
+
"execution_count": 22,
|
| 408 |
+
"metadata": {},
|
| 409 |
+
"outputs": [],
|
| 410 |
+
"source": [
|
| 411 |
+
"def denoise_add_noise(x,t,pred_noise,z=None):\n",
|
| 412 |
+
" if z is None:\n",
|
| 413 |
+
" z = torch.randn_like(x)\n",
|
| 414 |
+
" noise = b_t.sqrt()[t]\n",
|
| 415 |
+
" mean = x - (pred_noise * ((1-a_t[t]) / (1-a_bt[t]).sqrt())) / a_t[t].sqrt()\n",
|
| 416 |
+
" return mean + noise\n"
|
| 417 |
+
]
|
| 418 |
+
},
|
| 419 |
+
{
|
| 420 |
+
"cell_type": "code",
|
| 421 |
+
"execution_count": null,
|
| 422 |
+
"metadata": {},
|
| 423 |
+
"outputs": [],
|
| 424 |
+
"source": [
|
| 425 |
+
"@torch.no_grad()\n",
|
| 426 |
+
"def sample_ddpm(n_sample,save_rate=20):\n",
|
| 427 |
+
" # x_T ~ N(0, 1), sample initial noise\n",
|
| 428 |
+
" samples = torch.randn(n_sample,3,height,height)\n",
|
| 429 |
+
"\n",
|
| 430 |
+
" intermediate = []\n",
|
| 431 |
+
" for i in range(timesteps,0,-1):\n",
|
| 432 |
+
" print(f\"Sampling timestep: {i}\")\n",
|
| 433 |
+
"\n",
|
| 434 |
+
" t = torch.tensor([i/timesteps])[:,None,None,None].to(device)\n",
|
| 435 |
+
"\n",
|
| 436 |
+
" z = torch.randn_like(samples)\n",
|
| 437 |
+
"\n",
|
| 438 |
+
" pred = nn_model(samples,t)\n",
|
| 439 |
+
" samples = denoise_add_noise(samples,t,pred,z)\n",
|
| 440 |
+
" if i % save_rate ==0 or i==timesteps or i<8:\n",
|
| 441 |
+
" intermediate.append(samples.detach().cpu().numpy())\n",
|
| 442 |
+
"\n",
|
| 443 |
+
" intermediate = np.stack(intermediate)\n",
|
| 444 |
+
" return samples,intermediate\n"
|
| 445 |
+
]
|
| 446 |
+
},
|
| 447 |
+
{
|
| 448 |
+
"cell_type": "code",
|
| 449 |
+
"execution_count": null,
|
| 450 |
+
"metadata": {},
|
| 451 |
+
"outputs": [],
|
| 452 |
+
"source": [
|
| 453 |
+
"model = torch.load(f\"{save_dir}/model_Epoch_35\")\n",
|
| 454 |
+
"model.eval()\n",
|
| 455 |
+
"print(\"Loaded model\")"
|
| 456 |
+
]
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"cell_type": "code",
|
| 460 |
+
"execution_count": null,
|
| 461 |
+
"metadata": {},
|
| 462 |
+
"outputs": [],
|
| 463 |
+
"source": [
|
| 464 |
+
"plt.clf()\n",
|
| 465 |
+
"samples, intermediate = sample_ddpm(32)\n"
|
| 466 |
+
]
|
| 467 |
+
}
|
| 468 |
+
],
|
| 469 |
+
"metadata": {
|
| 470 |
+
"kernelspec": {
|
| 471 |
+
"display_name": "Python 3",
|
| 472 |
+
"language": "python",
|
| 473 |
+
"name": "python3"
|
| 474 |
+
},
|
| 475 |
+
"language_info": {
|
| 476 |
+
"codemirror_mode": {
|
| 477 |
+
"name": "ipython",
|
| 478 |
+
"version": 3
|
| 479 |
+
},
|
| 480 |
+
"file_extension": ".py",
|
| 481 |
+
"mimetype": "text/x-python",
|
| 482 |
+
"name": "python",
|
| 483 |
+
"nbconvert_exporter": "python",
|
| 484 |
+
"pygments_lexer": "ipython3",
|
| 485 |
+
"version": "3.12.3"
|
| 486 |
+
}
|
| 487 |
+
},
|
| 488 |
+
"nbformat": 4,
|
| 489 |
+
"nbformat_minor": 2
|
| 490 |
+
}
|
models.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class ResidualBlock(nn.Module):
|
| 5 |
+
def __init__(self, in_channels: int, out_channels: int,is_res: bool = False) -> None:
|
| 6 |
+
super(ResidualBlock,self).__init__()
|
| 7 |
+
|
| 8 |
+
self.same_channesls = in_channels == out_channels
|
| 9 |
+
|
| 10 |
+
self.is_res = is_res
|
| 11 |
+
|
| 12 |
+
self.conv1 = nn.Sequential(
|
| 13 |
+
nn.Conv2d(in_channels,out_channels,3,1,1),
|
| 14 |
+
nn.BatchNorm2d(out_channels),
|
| 15 |
+
nn.GELU(),
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
self.conv2 = nn.Sequential(
|
| 19 |
+
nn.Conv2d(out_channels,out_channels,3,1,1),
|
| 20 |
+
nn.BatchNorm2d(out_channels),
|
| 21 |
+
nn.GELU(),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self,x):
|
| 25 |
+
if self.is_res:
|
| 26 |
+
x1 = self.conv1(x)
|
| 27 |
+
|
| 28 |
+
x2 = self.conv2(x1)
|
| 29 |
+
|
| 30 |
+
if self.same_channesls:
|
| 31 |
+
out = x1 + x2
|
| 32 |
+
else:
|
| 33 |
+
shortcut = nn.Conv2d(x.shape[1],x2.shape[1],1,1,0).to(x.device)
|
| 34 |
+
out = shortcut(x) + x2
|
| 35 |
+
|
| 36 |
+
return out / 1.414
|
| 37 |
+
|
| 38 |
+
else:
|
| 39 |
+
x1 = self.conv1(x)
|
| 40 |
+
x2 = self.conv2(x1)
|
| 41 |
+
return x2
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class UnetUp(nn.Module):
|
| 46 |
+
def __init__(self, in_channels, out_channels) -> None:
|
| 47 |
+
super(UnetUp,self).__init__()
|
| 48 |
+
|
| 49 |
+
self.model = nn.Sequential(
|
| 50 |
+
nn.ConvTranspose2d(in_channels,out_channels,2,2),
|
| 51 |
+
ResidualBlock(out_channels,out_channels),
|
| 52 |
+
ResidualBlock(out_channels,out_channels),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x, skip):
|
| 56 |
+
x = torch.cat([x,skip],1)
|
| 57 |
+
|
| 58 |
+
x = self.model(x)
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
class UnetDown(nn.Module):
|
| 62 |
+
def __init__(self, input_channels, out_channels) -> None:
|
| 63 |
+
super(UnetDown,self).__init__()
|
| 64 |
+
|
| 65 |
+
self.model = nn.Sequential(
|
| 66 |
+
ResidualBlock(input_channels,out_channels),
|
| 67 |
+
ResidualBlock(out_channels,out_channels),
|
| 68 |
+
nn.MaxPool2d(2)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self,x):
|
| 72 |
+
return self.model(x)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class EmbedFC(nn.Module):
|
| 76 |
+
def __init__(self, input_dim,embed_dm) -> None:
|
| 77 |
+
super(EmbedFC,self).__init__()
|
| 78 |
+
|
| 79 |
+
self.input_dim = input_dim
|
| 80 |
+
|
| 81 |
+
self.model = nn.Sequential(
|
| 82 |
+
nn.Linear(input_dim,embed_dm),
|
| 83 |
+
nn.GELU(),
|
| 84 |
+
nn.Linear(embed_dm,embed_dm),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def forward(self,x):
|
| 88 |
+
x = x.view(-1,self.input_dim)
|
| 89 |
+
return self.model(x)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ContextUnet(nn.Module):
|
| 93 |
+
def __init__(self,in_channels, n_feat = 256,n_cfeat = 10, height = 28) -> None:
|
| 94 |
+
super(ContextUnet,self).__init__()
|
| 95 |
+
|
| 96 |
+
self.in_channels = in_channels
|
| 97 |
+
self.n_feat = n_feat
|
| 98 |
+
self.n_cfeat = n_cfeat
|
| 99 |
+
self.h = height
|
| 100 |
+
|
| 101 |
+
self.init_conv = ResidualBlock(in_channels,n_feat,is_res=True)
|
| 102 |
+
|
| 103 |
+
self.down1 = UnetDown(n_feat,n_feat)
|
| 104 |
+
self.down2 = UnetDown(n_feat,n_feat * 2)
|
| 105 |
+
|
| 106 |
+
self.to_vec = nn.Sequential(nn.AvgPool2d((4)),nn.GELU())
|
| 107 |
+
|
| 108 |
+
self.timeembed1 = EmbedFC(1, 2 *n_feat)
|
| 109 |
+
self.timeembed2 = EmbedFC(1,n_feat)
|
| 110 |
+
self.contextembed1 = EmbedFC(n_cfeat,2 * n_feat)
|
| 111 |
+
self.contextembed2 = EmbedFC(n_cfeat,n_feat)
|
| 112 |
+
|
| 113 |
+
self.up0 = nn.Sequential(
|
| 114 |
+
nn.ConvTranspose2d(2 * n_feat,2*n_feat,self.h // 4,self.h // 4),
|
| 115 |
+
nn.GroupNorm(8, 2*n_feat),
|
| 116 |
+
nn.ReLU(),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.up1 = UnetUp(4 * n_feat,n_feat)
|
| 120 |
+
self.up2 = UnetUp(2 * n_feat,n_feat)
|
| 121 |
+
|
| 122 |
+
self.out = nn.Sequential(
|
| 123 |
+
nn.Conv2d(2 * n_feat, n_feat,3,1,1),
|
| 124 |
+
nn.GroupNorm(8,n_feat),
|
| 125 |
+
nn.ReLU(),
|
| 126 |
+
nn.Conv2d(n_feat,self.in_channels,3,1,1)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward(self,x,t,c=None):
|
| 130 |
+
x = self.init_conv(x)
|
| 131 |
+
|
| 132 |
+
down1 = self.down1(x)
|
| 133 |
+
down2 = self.down2(down1)
|
| 134 |
+
|
| 135 |
+
hidden_vec = self.to_vec(down2)
|
| 136 |
+
|
| 137 |
+
if c is None:
|
| 138 |
+
c = torch.zeros(x.shape[0],self.n_cfeat).to(x)
|
| 139 |
+
|
| 140 |
+
cemb1 = self.contextembed1(c).view(-1,self.n_cfeat*2,1,1)
|
| 141 |
+
temb1 = self.timeembed1(t).view(-1,self.n_cfeat * 2,1,1)
|
| 142 |
+
cemb2 = self.contextembed2(c).view(-1,self.n_cfeat,1,1)
|
| 143 |
+
temb2 = self.timeembed2(t).view(-1,self.n_cfeat,1,1)
|
| 144 |
+
|
| 145 |
+
up0 = self.up0(hidden_vec)
|
| 146 |
+
up1 =self.up1(up0*cemb1 + temb1,down2)
|
| 147 |
+
up2 = self.up2(up1*cemb2+temb2,down1)
|
| 148 |
+
|
| 149 |
+
out = self.out(torch.cat((up2,x),1))
|
| 150 |
+
|
| 151 |
+
return out
|
train.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from utils import *
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from models import *
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
|
| 7 |
+
timesteps = 500
|
| 8 |
+
beta1 = 1e-4
|
| 9 |
+
beta2 = 0.02
|
| 10 |
+
|
| 11 |
+
device = "cuda"
|
| 12 |
+
n_feat = 64
|
| 13 |
+
n_cfeat = 5
|
| 14 |
+
height = 16
|
| 15 |
+
save_dir="./checkpoints"
|
| 16 |
+
|
| 17 |
+
batch_size = 100
|
| 18 |
+
n_epoch = 40
|
| 19 |
+
lrate = 1e-3
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
b_t = (beta2 - beta1) * torch.linspace(0,1,timesteps+1,device=device) + beta1
|
| 23 |
+
a_t = 1 - b_t
|
| 24 |
+
a_bt = torch.cumsum(a_t.log(),0).exp()
|
| 25 |
+
a_bt[0] = 1
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
|
| 29 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
nn_model = ContextUnet(3,n_feat,n_cfeat,height)
|
| 33 |
+
optim = torch.optim.Adam(nn_model.parameters(),lrate)
|
| 34 |
+
|
| 35 |
+
def perturb_input(x, t, noise):
|
| 36 |
+
return a_bt.sqrt()[t, None, None, None] * x + (1 - a_bt[t, None, None, None]) * noise
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
nn_model.train()
|
| 40 |
+
|
| 41 |
+
for epoch in range(n_epoch):
|
| 42 |
+
|
| 43 |
+
optim.param_groups[0]['lr'] = lrate * (1-epoch/n_epoch)
|
| 44 |
+
for x,_ in tqdm(dataloader):
|
| 45 |
+
optim.zero_grad()
|
| 46 |
+
|
| 47 |
+
x = x.to(device)
|
| 48 |
+
|
| 49 |
+
t = torch.randint(1,timesteps+1,x.shape[0]).to(device)
|
| 50 |
+
noise = torch.randn_like(x)
|
| 51 |
+
x_pert = perturb_input(x,t,noise)
|
| 52 |
+
|
| 53 |
+
pred = nn_model(x_pert,t / timesteps)
|
| 54 |
+
|
| 55 |
+
loss = F.mse_loss(pred,noise)
|
| 56 |
+
loss.backward()
|
| 57 |
+
optim.step()
|
| 58 |
+
|
| 59 |
+
if epoch % 1 == 0 and epoch >0:
|
| 60 |
+
if not os.path.exists(save_dir):
|
| 61 |
+
os.mkdir(save_dir)
|
| 62 |
+
torch.save(nn_model,save_dir + f"model_Epoch{epoch}.pth")
|
| 63 |
+
print("Saved model")
|
utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision.utils import save_image, make_grid
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from matplotlib.animation import FuncAnimation, PillowWriter
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def unorm(x):
|
| 12 |
+
# unity norm. results in range of [0,1]
|
| 13 |
+
# assume x (h,w,3)
|
| 14 |
+
xmax = x.max((0,1))
|
| 15 |
+
xmin = x.min((0,1))
|
| 16 |
+
return(x - xmin)/(xmax - xmin)
|
| 17 |
+
|
| 18 |
+
def norm_all(store, n_t, n_s):
|
| 19 |
+
# runs unity norm on all timesteps of all samples
|
| 20 |
+
nstore = np.zeros_like(store)
|
| 21 |
+
for t in range(n_t):
|
| 22 |
+
for s in range(n_s):
|
| 23 |
+
nstore[t,s] = unorm(store[t,s])
|
| 24 |
+
return nstore
|
| 25 |
+
|
| 26 |
+
def norm_torch(x_all):
|
| 27 |
+
# runs unity norm on all timesteps of all samples
|
| 28 |
+
# input is (n_samples, 3,h,w), the torch image format
|
| 29 |
+
x = x_all.cpu().numpy()
|
| 30 |
+
xmax = x.max((2,3))
|
| 31 |
+
xmin = x.min((2,3))
|
| 32 |
+
xmax = np.expand_dims(xmax,(2,3))
|
| 33 |
+
xmin = np.expand_dims(xmin,(2,3))
|
| 34 |
+
nstore = (x - xmin)/(xmax - xmin)
|
| 35 |
+
return torch.from_numpy(nstore)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def plot_grid(x,n_sample,n_rows,save_dir,w):
|
| 39 |
+
# x:(n_sample, 3, h, w)
|
| 40 |
+
ncols = n_sample//n_rows
|
| 41 |
+
grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row.
|
| 42 |
+
save_image(grid, save_dir + f"run_image_w{w}.png")
|
| 43 |
+
print('saved image at ' + save_dir + f"run_image_w{w}.png")
|
| 44 |
+
return grid
|
| 45 |
+
|
| 46 |
+
def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False):
|
| 47 |
+
ncols = n_sample//nrows
|
| 48 |
+
sx_gen_store = np.moveaxis(x_gen_store,2,4)
|
| 49 |
+
nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample)
|
| 50 |
+
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows))
|
| 51 |
+
def animate_diff(i, store):
|
| 52 |
+
print(f'gif animating frame {i} of {store.shape[0]}', end='\r')
|
| 53 |
+
plots = []
|
| 54 |
+
for row in range(nrows):
|
| 55 |
+
for col in range(ncols):
|
| 56 |
+
axs[row, col].clear()
|
| 57 |
+
axs[row, col].set_xticks([])
|
| 58 |
+
axs[row, col].set_yticks([])
|
| 59 |
+
plots.append(axs[row, col].imshow(store[i,(row*ncols)+col]))
|
| 60 |
+
return plots
|
| 61 |
+
ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0])
|
| 62 |
+
plt.close()
|
| 63 |
+
if save:
|
| 64 |
+
ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
|
| 65 |
+
print('saved gif at ' + save_dir + f"{fn}_w{w}.gif")
|
| 66 |
+
return ani
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
transform = transforms.Compose([
|
| 70 |
+
transforms.ToTensor(), # from [0,255] to range [0.0,1.0]
|
| 71 |
+
transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
|
| 72 |
+
|
| 73 |
+
])
|
| 74 |
+
|
| 75 |
+
class CustomDataset(Dataset):
|
| 76 |
+
def __init__(self, sfilename, lfilename, transform, null_context=False):
|
| 77 |
+
self.sprites = np.load(sfilename)
|
| 78 |
+
self.slabels = np.load(lfilename)
|
| 79 |
+
print(f"sprite shape: {self.sprites.shape}")
|
| 80 |
+
print(f"labels shape: {self.slabels.shape}")
|
| 81 |
+
self.transform = transform
|
| 82 |
+
self.null_context = null_context
|
| 83 |
+
self.sprites_shape = self.sprites.shape
|
| 84 |
+
self.slabel_shape = self.slabels.shape
|
| 85 |
+
|
| 86 |
+
def __len__(self):
|
| 87 |
+
return len(self.sprites)
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, idx):
|
| 90 |
+
if self.transform:
|
| 91 |
+
image = self.transform(self.sprites[idx])
|
| 92 |
+
if self.null_context:
|
| 93 |
+
label = torch.tensor(0).to(torch.int64)
|
| 94 |
+
else:
|
| 95 |
+
label = torch.tensor(self.slabels[idx]).to(torch.int64)
|
| 96 |
+
return (image, label)
|