Yash Nagraj commited on
Commit
35839a1
·
1 Parent(s): 2c4de69

Add train files

Browse files
Files changed (4) hide show
  1. model.ipynb +490 -0
  2. models.py +151 -0
  3. train.py +63 -0
  4. utils.py +96 -0
model.ipynb ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 20,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torch.nn as nn\n",
11
+ "import numpy as np\n",
12
+ "from torchvision.utils import save_image, make_grid\n",
13
+ "import matplotlib.pyplot as plt\n",
14
+ "from matplotlib.animation import FuncAnimation, PillowWriter\n",
15
+ "import os\n",
16
+ "import torchvision.transforms as transforms\n",
17
+ "from torch.utils.data import Dataset\n",
18
+ "from PIL import Image\n",
19
+ "from torch.utils.data import DataLoader\n",
20
+ "from tqdm.auto import tqdm\n",
21
+ "import torch.nn.functional as F"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 3,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "class ResidualBlock(nn.Module):\n",
31
+ " def __init__(self, in_channels: int, out_channels: int,is_res: bool = False) -> None:\n",
32
+ " super(ResidualBlock,self).__init__()\n",
33
+ "\n",
34
+ " self.same_channesls = in_channels == out_channels\n",
35
+ "\n",
36
+ " self.is_res = is_res\n",
37
+ "\n",
38
+ " self.conv1 = nn.Sequential(\n",
39
+ " nn.Conv2d(in_channels,out_channels,3,1,1),\n",
40
+ " nn.BatchNorm2d(out_channels),\n",
41
+ " nn.GELU(),\n",
42
+ " )\n",
43
+ "\n",
44
+ " self.conv2 = nn.Sequential(\n",
45
+ " nn.Conv2d(out_channels,out_channels,3,1,1),\n",
46
+ " nn.BatchNorm2d(out_channels),\n",
47
+ " nn.GELU(),\n",
48
+ " )\n",
49
+ "\n",
50
+ " def forward(self,x): \n",
51
+ " if self.is_res:\n",
52
+ " x1 = self.conv1(x)\n",
53
+ "\n",
54
+ " x2 = self.conv2(x1)\n",
55
+ "\n",
56
+ " if self.same_channesls:\n",
57
+ " out = x1 + x2\n",
58
+ " else:\n",
59
+ " shortcut = nn.Conv2d(x.shape[1],x2.shape[1],1,1,0).to(x.device)\n",
60
+ " out = shortcut(x) + x2\n",
61
+ "\n",
62
+ " return out / 1.414\n",
63
+ " \n",
64
+ " else:\n",
65
+ " x1 = self.conv1(x)\n",
66
+ " x2 = self.conv2(x1)\n",
67
+ " return x2\n"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 4,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "class UnetUp(nn.Module):\n",
77
+ " def __init__(self, in_channels, out_channels) -> None:\n",
78
+ " super(UnetUp,self).__init__()\n",
79
+ "\n",
80
+ " self.model = nn.Sequential(\n",
81
+ " nn.ConvTranspose2d(in_channels,out_channels,2,2),\n",
82
+ " ResidualBlock(out_channels,out_channels),\n",
83
+ " ResidualBlock(out_channels,out_channels),\n",
84
+ " )\n",
85
+ "\n",
86
+ " def forward(self, x, skip):\n",
87
+ " x = torch.cat([x,skip],1)\n",
88
+ "\n",
89
+ " x = self.model(x)\n",
90
+ " return x\n",
91
+ " \n",
92
+ "class UnetDown(nn.Module):\n",
93
+ " def __init__(self, input_channels, out_channels) -> None:\n",
94
+ " super(UnetDown,self).__init__()\n",
95
+ "\n",
96
+ " self.model = nn.Sequential(\n",
97
+ " ResidualBlock(input_channels,out_channels),\n",
98
+ " ResidualBlock(out_channels,out_channels),\n",
99
+ " nn.MaxPool2d(2)\n",
100
+ " )\n",
101
+ "\n",
102
+ " def forward(self,x):\n",
103
+ " return self.model(x)\n",
104
+ " \n",
105
+ "\n",
106
+ "class EmbedFC(nn.Module):\n",
107
+ " def __init__(self, input_dim,embed_dm) -> None:\n",
108
+ " super(EmbedFC,self).__init__()\n",
109
+ "\n",
110
+ " self.input_dim = input_dim\n",
111
+ " \n",
112
+ " self.model = nn.Sequential(\n",
113
+ " nn.Linear(input_dim,embed_dm),\n",
114
+ " nn.GELU(),\n",
115
+ " nn.Linear(embed_dm,embed_dm),\n",
116
+ " )\n",
117
+ "\n",
118
+ " def forward(self,x):\n",
119
+ " x = x.view(-1,self.input_dim)\n",
120
+ " return self.model(x)\n"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 5,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "def unorm(x):\n",
130
+ " # unity norm. results in range of [0,1]\n",
131
+ " # assume x (h,w,3)\n",
132
+ " xmax = x.max((0,1))\n",
133
+ " xmin = x.min((0,1))\n",
134
+ " return(x - xmin)/(xmax - xmin)\n",
135
+ "\n",
136
+ "def norm_all(store, n_t, n_s):\n",
137
+ " # runs unity norm on all timesteps of all samples\n",
138
+ " nstore = np.zeros_like(store)\n",
139
+ " for t in range(n_t):\n",
140
+ " for s in range(n_s):\n",
141
+ " nstore[t,s] = unorm(store[t,s])\n",
142
+ " return nstore\n",
143
+ "\n",
144
+ "def norm_torch(x_all):\n",
145
+ " # runs unity norm on all timesteps of all samples\n",
146
+ " # input is (n_samples, 3,h,w), the torch image format\n",
147
+ " x = x_all.cpu().numpy()\n",
148
+ " xmax = x.max((2,3))\n",
149
+ " xmin = x.min((2,3))\n",
150
+ " xmax = np.expand_dims(xmax,(2,3)) \n",
151
+ " xmin = np.expand_dims(xmin,(2,3))\n",
152
+ " nstore = (x - xmin)/(xmax - xmin)\n",
153
+ " return torch.from_numpy(nstore)\n"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 6,
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "def plot_grid(x,n_sample,n_rows,save_dir,w):\n",
163
+ " # x:(n_sample, 3, h, w)\n",
164
+ " ncols = n_sample//n_rows\n",
165
+ " grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row.\n",
166
+ " save_image(grid, save_dir + f\"run_image_w{w}.png\")\n",
167
+ " print('saved image at ' + save_dir + f\"run_image_w{w}.png\")\n",
168
+ " return grid\n",
169
+ "\n",
170
+ "def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False):\n",
171
+ " ncols = n_sample//nrows\n",
172
+ " sx_gen_store = np.moveaxis(x_gen_store,2,4) \n",
173
+ " nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) \n",
174
+ " fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows))\n",
175
+ " def animate_diff(i, store):\n",
176
+ " print(f'gif animating frame {i} of {store.shape[0]}', end='\\r')\n",
177
+ " plots = []\n",
178
+ " for row in range(nrows):\n",
179
+ " for col in range(ncols):\n",
180
+ " axs[row, col].clear()\n",
181
+ " axs[row, col].set_xticks([])\n",
182
+ " axs[row, col].set_yticks([])\n",
183
+ " plots.append(axs[row, col].imshow(store[i,(row*ncols)+col]))\n",
184
+ " return plots\n",
185
+ " ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0]) \n",
186
+ " plt.close()\n",
187
+ " if save:\n",
188
+ " ani.save(save_dir + f\"{fn}_w{w}.gif\", dpi=100, writer=PillowWriter(fps=5))\n",
189
+ " print('saved gif at ' + save_dir + f\"{fn}_w{w}.gif\")\n",
190
+ " return ani\n"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": 7,
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "transform = transforms.Compose([\n",
200
+ " transforms.ToTensor(), # from [0,255] to range [0.0,1.0]\n",
201
+ " transforms.Normalize((0.5,), (0.5,)) # range [-1,1]\n",
202
+ "\n",
203
+ "])\n",
204
+ "\n",
205
+ "class CustomDataset(Dataset):\n",
206
+ " def __init__(self, sfilename, lfilename, transform, null_context=False):\n",
207
+ " self.sprites = np.load(sfilename)\n",
208
+ " self.slabels = np.load(lfilename)\n",
209
+ " print(f\"sprite shape: {self.sprites.shape}\")\n",
210
+ " print(f\"labels shape: {self.slabels.shape}\")\n",
211
+ " self.transform = transform\n",
212
+ " self.null_context = null_context\n",
213
+ " self.sprites_shape = self.sprites.shape\n",
214
+ " self.slabel_shape = self.slabels.shape\n",
215
+ " \n",
216
+ " def __len__(self):\n",
217
+ " return len(self.sprites)\n",
218
+ " \n",
219
+ " def __getitem__(self, idx):\n",
220
+ " if self.transform:\n",
221
+ " image = self.transform(self.sprites[idx])\n",
222
+ " if self.null_context:\n",
223
+ " label = torch.tensor(0).to(torch.int64)\n",
224
+ " else:\n",
225
+ " label = torch.tensor(self.slabels[idx]).to(torch.int64)\n",
226
+ " return (image, label)\n"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": 8,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "class ContextUnet(nn.Module):\n",
236
+ " def __init__(self,in_channels, n_feat = 256,n_cfeat = 10, height = 28) -> None:\n",
237
+ " super(ContextUnet,self).__init__()\n",
238
+ "\n",
239
+ " self.in_channels = in_channels\n",
240
+ " self.n_feat = n_feat\n",
241
+ " self.n_cfeat = n_cfeat\n",
242
+ " self.h = height\n",
243
+ "\n",
244
+ " self.init_conv = ResidualBlock(in_channels,n_feat,is_res=True)\n",
245
+ "\n",
246
+ " self.down1 = UnetDown(n_feat,n_feat)\n",
247
+ " self.down2 = UnetDown(n_feat,n_feat * 2)\n",
248
+ "\n",
249
+ " self.to_vec = nn.Sequential(nn.AvgPool2d((4)),nn.GELU())\n",
250
+ "\n",
251
+ " self.timeembed1 = EmbedFC(1, 2 *n_feat)\n",
252
+ " self.timeembed2 = EmbedFC(1,n_feat)\n",
253
+ " self.contextembed1 = EmbedFC(n_cfeat,2 * n_feat)\n",
254
+ " self.contextembed2 = EmbedFC(n_cfeat,n_feat)\n",
255
+ "\n",
256
+ " self.up0 = nn.Sequential(\n",
257
+ " nn.ConvTranspose2d(2 * n_feat,2*n_feat,self.h // 4,self.h // 4),\n",
258
+ " nn.GroupNorm(8, 2*n_feat),\n",
259
+ " nn.ReLU(),\n",
260
+ " )\n",
261
+ "\n",
262
+ " self.up1 = UnetUp(4 * n_feat,n_feat)\n",
263
+ " self.up2 = UnetUp(2 * n_feat,n_feat)\n",
264
+ "\n",
265
+ " self.out = nn.Sequential(\n",
266
+ " nn.Conv2d(2 * n_feat, n_feat,3,1,1),\n",
267
+ " nn.GroupNorm(8,n_feat),\n",
268
+ " nn.ReLU(),\n",
269
+ " nn.Conv2d(n_feat,self.in_channels,3,1,1)\n",
270
+ " )\n",
271
+ "\n",
272
+ " def forward(self,x,t,c=None):\n",
273
+ " x = self.init_conv(x)\n",
274
+ "\n",
275
+ " down1 = self.down1(x)\n",
276
+ " down2 = self.down2(down1)\n",
277
+ "\n",
278
+ " hidden_vec = self.to_vec(down2)\n",
279
+ "\n",
280
+ " if c is None:\n",
281
+ " c = torch.zeros(x.shape[0],self.n_cfeat).to(x)\n",
282
+ " \n",
283
+ " cemb1 = self.contextembed1(c).view(-1,self.n_cfeat*2,1,1)\n",
284
+ " temb1 = self.timeembed1(t).view(-1,self.n_cfeat * 2,1,1)\n",
285
+ " cemb2 = self.contextembed2(c).view(-1,self.n_cfeat,1,1)\n",
286
+ " temb2 = self.timeembed2(t).view(-1,self.n_cfeat,1,1)\n",
287
+ "\n",
288
+ " up0 = self.up0(hidden_vec)\n",
289
+ " up1 =self.up1(up0*cemb1 + temb1,down2)\n",
290
+ " up2 = self.up2(up1*cemb2+temb2,down1)\n",
291
+ "\n",
292
+ " out = self.out(torch.cat((up2,x),1))\n",
293
+ "\n",
294
+ " return out"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": 14,
300
+ "metadata": {},
301
+ "outputs": [],
302
+ "source": [
303
+ "# Hyperparameters\n",
304
+ "\n",
305
+ "timesteps = 500\n",
306
+ "beta1 = 1e-4\n",
307
+ "beta2 = 0.02\n",
308
+ "\n",
309
+ "device = \"cuda\"\n",
310
+ "n_feat = 64\n",
311
+ "n_cfeat = 5\n",
312
+ "height = 16\n",
313
+ "save_dir=\"./checkpoints\"\n",
314
+ "\n",
315
+ "batch_size = 100\n",
316
+ "n_epoch = 40\n",
317
+ "lrate = 1e-3"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": 12,
323
+ "metadata": {},
324
+ "outputs": [
325
+ {
326
+ "name": "stdout",
327
+ "output_type": "stream",
328
+ "text": [
329
+ "torch.Size([501])\n",
330
+ "torch.Size([501])\n",
331
+ "torch.Size([501])\n"
332
+ ]
333
+ }
334
+ ],
335
+ "source": [
336
+ "b_t = (beta2 - beta1) * torch.linspace(0,1,timesteps+1,device=device) + beta1\n",
337
+ "a_t = 1 - b_t\n",
338
+ "a_bt = torch.cumsum(a_t.log(),0).exp()\n",
339
+ "a_bt[0] = 1"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "dataset = CustomDataset(\"./sprites_1788_16x16.npy\", \"./sprite_labels_nc_1788_16x16.npy\", transform, null_context=False)\n",
349
+ "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": 17,
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": [
358
+ "nn_model = ContextUnet(3,n_feat,n_cfeat,height)\n",
359
+ "optim = torch.optim.Adam(nn_model.parameters(),lrate)"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": 16,
365
+ "metadata": {},
366
+ "outputs": [],
367
+ "source": [
368
+ "def perturb_input(x, t, noise):\n",
369
+ " return a_bt.sqrt()[t, None, None, None] * x + (1 - a_bt[t, None, None, None]) * noise"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "code",
374
+ "execution_count": null,
375
+ "metadata": {},
376
+ "outputs": [],
377
+ "source": [
378
+ "nn_model.train()\n",
379
+ "\n",
380
+ "for epoch in range(n_epoch):\n",
381
+ "\n",
382
+ " optim.param_groups[0]['lr'] = lrate * (1-epoch/n_epoch)\n",
383
+ " for x,_ in tqdm(dataloader):\n",
384
+ " optim.zero_grad()\n",
385
+ "\n",
386
+ " x = x.to(device)\n",
387
+ "\n",
388
+ " t = torch.randint(1,timesteps+1,x.shape[0]).to(device)\n",
389
+ " noise = torch.randn_like(x)\n",
390
+ " x_pert = perturb_input(x,t,noise)\n",
391
+ "\n",
392
+ " pred = nn_model(x_pert,t / timesteps)\n",
393
+ "\n",
394
+ " loss = F.mse_loss(pred,noise)\n",
395
+ " loss.backward()\n",
396
+ " optim.step()\n",
397
+ "\n",
398
+ " if epoch % 1 == 0 and epoch >0:\n",
399
+ " if not os.path.exists(save_dir):\n",
400
+ " os.mkdir(save_dir)\n",
401
+ " torch.save(nn_model,save_dir + f\"model_Epoch{epoch}.pth\")\n",
402
+ " print(\"Saved model\")\n"
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": 22,
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "def denoise_add_noise(x,t,pred_noise,z=None):\n",
412
+ " if z is None:\n",
413
+ " z = torch.randn_like(x)\n",
414
+ " noise = b_t.sqrt()[t]\n",
415
+ " mean = x - (pred_noise * ((1-a_t[t]) / (1-a_bt[t]).sqrt())) / a_t[t].sqrt()\n",
416
+ " return mean + noise\n"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": null,
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "@torch.no_grad()\n",
426
+ "def sample_ddpm(n_sample,save_rate=20):\n",
427
+ " # x_T ~ N(0, 1), sample initial noise\n",
428
+ " samples = torch.randn(n_sample,3,height,height)\n",
429
+ "\n",
430
+ " intermediate = []\n",
431
+ " for i in range(timesteps,0,-1):\n",
432
+ " print(f\"Sampling timestep: {i}\")\n",
433
+ "\n",
434
+ " t = torch.tensor([i/timesteps])[:,None,None,None].to(device)\n",
435
+ "\n",
436
+ " z = torch.randn_like(samples)\n",
437
+ "\n",
438
+ " pred = nn_model(samples,t)\n",
439
+ " samples = denoise_add_noise(samples,t,pred,z)\n",
440
+ " if i % save_rate ==0 or i==timesteps or i<8:\n",
441
+ " intermediate.append(samples.detach().cpu().numpy())\n",
442
+ "\n",
443
+ " intermediate = np.stack(intermediate)\n",
444
+ " return samples,intermediate\n"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "metadata": {},
451
+ "outputs": [],
452
+ "source": [
453
+ "model = torch.load(f\"{save_dir}/model_Epoch_35\")\n",
454
+ "model.eval()\n",
455
+ "print(\"Loaded model\")"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "execution_count": null,
461
+ "metadata": {},
462
+ "outputs": [],
463
+ "source": [
464
+ "plt.clf()\n",
465
+ "samples, intermediate = sample_ddpm(32)\n"
466
+ ]
467
+ }
468
+ ],
469
+ "metadata": {
470
+ "kernelspec": {
471
+ "display_name": "Python 3",
472
+ "language": "python",
473
+ "name": "python3"
474
+ },
475
+ "language_info": {
476
+ "codemirror_mode": {
477
+ "name": "ipython",
478
+ "version": 3
479
+ },
480
+ "file_extension": ".py",
481
+ "mimetype": "text/x-python",
482
+ "name": "python",
483
+ "nbconvert_exporter": "python",
484
+ "pygments_lexer": "ipython3",
485
+ "version": "3.12.3"
486
+ }
487
+ },
488
+ "nbformat": 4,
489
+ "nbformat_minor": 2
490
+ }
models.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class ResidualBlock(nn.Module):
5
+ def __init__(self, in_channels: int, out_channels: int,is_res: bool = False) -> None:
6
+ super(ResidualBlock,self).__init__()
7
+
8
+ self.same_channesls = in_channels == out_channels
9
+
10
+ self.is_res = is_res
11
+
12
+ self.conv1 = nn.Sequential(
13
+ nn.Conv2d(in_channels,out_channels,3,1,1),
14
+ nn.BatchNorm2d(out_channels),
15
+ nn.GELU(),
16
+ )
17
+
18
+ self.conv2 = nn.Sequential(
19
+ nn.Conv2d(out_channels,out_channels,3,1,1),
20
+ nn.BatchNorm2d(out_channels),
21
+ nn.GELU(),
22
+ )
23
+
24
+ def forward(self,x):
25
+ if self.is_res:
26
+ x1 = self.conv1(x)
27
+
28
+ x2 = self.conv2(x1)
29
+
30
+ if self.same_channesls:
31
+ out = x1 + x2
32
+ else:
33
+ shortcut = nn.Conv2d(x.shape[1],x2.shape[1],1,1,0).to(x.device)
34
+ out = shortcut(x) + x2
35
+
36
+ return out / 1.414
37
+
38
+ else:
39
+ x1 = self.conv1(x)
40
+ x2 = self.conv2(x1)
41
+ return x2
42
+
43
+
44
+
45
+ class UnetUp(nn.Module):
46
+ def __init__(self, in_channels, out_channels) -> None:
47
+ super(UnetUp,self).__init__()
48
+
49
+ self.model = nn.Sequential(
50
+ nn.ConvTranspose2d(in_channels,out_channels,2,2),
51
+ ResidualBlock(out_channels,out_channels),
52
+ ResidualBlock(out_channels,out_channels),
53
+ )
54
+
55
+ def forward(self, x, skip):
56
+ x = torch.cat([x,skip],1)
57
+
58
+ x = self.model(x)
59
+ return x
60
+
61
+ class UnetDown(nn.Module):
62
+ def __init__(self, input_channels, out_channels) -> None:
63
+ super(UnetDown,self).__init__()
64
+
65
+ self.model = nn.Sequential(
66
+ ResidualBlock(input_channels,out_channels),
67
+ ResidualBlock(out_channels,out_channels),
68
+ nn.MaxPool2d(2)
69
+ )
70
+
71
+ def forward(self,x):
72
+ return self.model(x)
73
+
74
+
75
+ class EmbedFC(nn.Module):
76
+ def __init__(self, input_dim,embed_dm) -> None:
77
+ super(EmbedFC,self).__init__()
78
+
79
+ self.input_dim = input_dim
80
+
81
+ self.model = nn.Sequential(
82
+ nn.Linear(input_dim,embed_dm),
83
+ nn.GELU(),
84
+ nn.Linear(embed_dm,embed_dm),
85
+ )
86
+
87
+ def forward(self,x):
88
+ x = x.view(-1,self.input_dim)
89
+ return self.model(x)
90
+
91
+
92
+ class ContextUnet(nn.Module):
93
+ def __init__(self,in_channels, n_feat = 256,n_cfeat = 10, height = 28) -> None:
94
+ super(ContextUnet,self).__init__()
95
+
96
+ self.in_channels = in_channels
97
+ self.n_feat = n_feat
98
+ self.n_cfeat = n_cfeat
99
+ self.h = height
100
+
101
+ self.init_conv = ResidualBlock(in_channels,n_feat,is_res=True)
102
+
103
+ self.down1 = UnetDown(n_feat,n_feat)
104
+ self.down2 = UnetDown(n_feat,n_feat * 2)
105
+
106
+ self.to_vec = nn.Sequential(nn.AvgPool2d((4)),nn.GELU())
107
+
108
+ self.timeembed1 = EmbedFC(1, 2 *n_feat)
109
+ self.timeembed2 = EmbedFC(1,n_feat)
110
+ self.contextembed1 = EmbedFC(n_cfeat,2 * n_feat)
111
+ self.contextembed2 = EmbedFC(n_cfeat,n_feat)
112
+
113
+ self.up0 = nn.Sequential(
114
+ nn.ConvTranspose2d(2 * n_feat,2*n_feat,self.h // 4,self.h // 4),
115
+ nn.GroupNorm(8, 2*n_feat),
116
+ nn.ReLU(),
117
+ )
118
+
119
+ self.up1 = UnetUp(4 * n_feat,n_feat)
120
+ self.up2 = UnetUp(2 * n_feat,n_feat)
121
+
122
+ self.out = nn.Sequential(
123
+ nn.Conv2d(2 * n_feat, n_feat,3,1,1),
124
+ nn.GroupNorm(8,n_feat),
125
+ nn.ReLU(),
126
+ nn.Conv2d(n_feat,self.in_channels,3,1,1)
127
+ )
128
+
129
+ def forward(self,x,t,c=None):
130
+ x = self.init_conv(x)
131
+
132
+ down1 = self.down1(x)
133
+ down2 = self.down2(down1)
134
+
135
+ hidden_vec = self.to_vec(down2)
136
+
137
+ if c is None:
138
+ c = torch.zeros(x.shape[0],self.n_cfeat).to(x)
139
+
140
+ cemb1 = self.contextembed1(c).view(-1,self.n_cfeat*2,1,1)
141
+ temb1 = self.timeembed1(t).view(-1,self.n_cfeat * 2,1,1)
142
+ cemb2 = self.contextembed2(c).view(-1,self.n_cfeat,1,1)
143
+ temb2 = self.timeembed2(t).view(-1,self.n_cfeat,1,1)
144
+
145
+ up0 = self.up0(hidden_vec)
146
+ up1 =self.up1(up0*cemb1 + temb1,down2)
147
+ up2 = self.up2(up1*cemb2+temb2,down1)
148
+
149
+ out = self.out(torch.cat((up2,x),1))
150
+
151
+ return out
train.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils import *
3
+ from torch.utils.data import DataLoader
4
+ from models import *
5
+ from tqdm.auto import tqdm
6
+
7
+ timesteps = 500
8
+ beta1 = 1e-4
9
+ beta2 = 0.02
10
+
11
+ device = "cuda"
12
+ n_feat = 64
13
+ n_cfeat = 5
14
+ height = 16
15
+ save_dir="./checkpoints"
16
+
17
+ batch_size = 100
18
+ n_epoch = 40
19
+ lrate = 1e-3
20
+
21
+
22
+ b_t = (beta2 - beta1) * torch.linspace(0,1,timesteps+1,device=device) + beta1
23
+ a_t = 1 - b_t
24
+ a_bt = torch.cumsum(a_t.log(),0).exp()
25
+ a_bt[0] = 1
26
+
27
+
28
+ dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
29
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
30
+
31
+
32
+ nn_model = ContextUnet(3,n_feat,n_cfeat,height)
33
+ optim = torch.optim.Adam(nn_model.parameters(),lrate)
34
+
35
+ def perturb_input(x, t, noise):
36
+ return a_bt.sqrt()[t, None, None, None] * x + (1 - a_bt[t, None, None, None]) * noise
37
+
38
+
39
+ nn_model.train()
40
+
41
+ for epoch in range(n_epoch):
42
+
43
+ optim.param_groups[0]['lr'] = lrate * (1-epoch/n_epoch)
44
+ for x,_ in tqdm(dataloader):
45
+ optim.zero_grad()
46
+
47
+ x = x.to(device)
48
+
49
+ t = torch.randint(1,timesteps+1,x.shape[0]).to(device)
50
+ noise = torch.randn_like(x)
51
+ x_pert = perturb_input(x,t,noise)
52
+
53
+ pred = nn_model(x_pert,t / timesteps)
54
+
55
+ loss = F.mse_loss(pred,noise)
56
+ loss.backward()
57
+ optim.step()
58
+
59
+ if epoch % 1 == 0 and epoch >0:
60
+ if not os.path.exists(save_dir):
61
+ os.mkdir(save_dir)
62
+ torch.save(nn_model,save_dir + f"model_Epoch{epoch}.pth")
63
+ print("Saved model")
utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torchvision.utils import save_image, make_grid
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.animation import FuncAnimation, PillowWriter
6
+ from torchvision import transforms
7
+ from torch.utils.data import Dataset
8
+
9
+
10
+
11
+ def unorm(x):
12
+ # unity norm. results in range of [0,1]
13
+ # assume x (h,w,3)
14
+ xmax = x.max((0,1))
15
+ xmin = x.min((0,1))
16
+ return(x - xmin)/(xmax - xmin)
17
+
18
+ def norm_all(store, n_t, n_s):
19
+ # runs unity norm on all timesteps of all samples
20
+ nstore = np.zeros_like(store)
21
+ for t in range(n_t):
22
+ for s in range(n_s):
23
+ nstore[t,s] = unorm(store[t,s])
24
+ return nstore
25
+
26
+ def norm_torch(x_all):
27
+ # runs unity norm on all timesteps of all samples
28
+ # input is (n_samples, 3,h,w), the torch image format
29
+ x = x_all.cpu().numpy()
30
+ xmax = x.max((2,3))
31
+ xmin = x.min((2,3))
32
+ xmax = np.expand_dims(xmax,(2,3))
33
+ xmin = np.expand_dims(xmin,(2,3))
34
+ nstore = (x - xmin)/(xmax - xmin)
35
+ return torch.from_numpy(nstore)
36
+
37
+
38
+ def plot_grid(x,n_sample,n_rows,save_dir,w):
39
+ # x:(n_sample, 3, h, w)
40
+ ncols = n_sample//n_rows
41
+ grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row.
42
+ save_image(grid, save_dir + f"run_image_w{w}.png")
43
+ print('saved image at ' + save_dir + f"run_image_w{w}.png")
44
+ return grid
45
+
46
+ def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False):
47
+ ncols = n_sample//nrows
48
+ sx_gen_store = np.moveaxis(x_gen_store,2,4)
49
+ nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample)
50
+ fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows))
51
+ def animate_diff(i, store):
52
+ print(f'gif animating frame {i} of {store.shape[0]}', end='\r')
53
+ plots = []
54
+ for row in range(nrows):
55
+ for col in range(ncols):
56
+ axs[row, col].clear()
57
+ axs[row, col].set_xticks([])
58
+ axs[row, col].set_yticks([])
59
+ plots.append(axs[row, col].imshow(store[i,(row*ncols)+col]))
60
+ return plots
61
+ ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0])
62
+ plt.close()
63
+ if save:
64
+ ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
65
+ print('saved gif at ' + save_dir + f"{fn}_w{w}.gif")
66
+ return ani
67
+
68
+
69
+ transform = transforms.Compose([
70
+ transforms.ToTensor(), # from [0,255] to range [0.0,1.0]
71
+ transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
72
+
73
+ ])
74
+
75
+ class CustomDataset(Dataset):
76
+ def __init__(self, sfilename, lfilename, transform, null_context=False):
77
+ self.sprites = np.load(sfilename)
78
+ self.slabels = np.load(lfilename)
79
+ print(f"sprite shape: {self.sprites.shape}")
80
+ print(f"labels shape: {self.slabels.shape}")
81
+ self.transform = transform
82
+ self.null_context = null_context
83
+ self.sprites_shape = self.sprites.shape
84
+ self.slabel_shape = self.slabels.shape
85
+
86
+ def __len__(self):
87
+ return len(self.sprites)
88
+
89
+ def __getitem__(self, idx):
90
+ if self.transform:
91
+ image = self.transform(self.sprites[idx])
92
+ if self.null_context:
93
+ label = torch.tensor(0).to(torch.int64)
94
+ else:
95
+ label = torch.tensor(self.slabels[idx]).to(torch.int64)
96
+ return (image, label)