JBlitzar commited on
Commit
9f5a022
·
1 Parent(s): fc9acd0

ahahahaha it works

Browse files
__pycache__/bert_vectorize.cpython-311.pyc ADDED
Binary file (2.04 kB). View file
 
__pycache__/factories.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
__pycache__/logger.cpython-311.pyc ADDED
Binary file (2.14 kB). View file
 
__pycache__/wrapper.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
bert_vectorize.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel, DistilBertTokenizer, DistilBertModel
2
+ import torch
3
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
4
+ model = DistilBertModel.from_pretrained('distilbert-base-uncased', output_hidden_states=True)
5
+ model.eval()
6
+
7
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
8
+
9
+ model = model.to(device)
10
+ def vectorize_text_with_bert(text):# from hf docs
11
+
12
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
13
+ with torch.no_grad():
14
+ outputs = model(**inputs)
15
+ hidden_states = outputs.hidden_states
16
+ last_layer_hidden_states = hidden_states[-1]
17
+ text_representation = torch.mean(last_layer_hidden_states, dim=1).squeeze(0)
18
+
19
+ return text_representation
20
+
21
+ if __name__ == "__main__":
22
+ text = "A man walking down the street with a dog holding a balloon in one hand."
23
+ text_representation = vectorize_text_with_bert(text)
24
+
25
+
26
+ print("Vectorized representation:", text_representation)
27
+ print(text_representation.shape)
factories.py CHANGED
@@ -78,7 +78,6 @@ class CrossAttention(nn.Module):
78
 
79
  # Reshape and permute x for multi-head attention
80
  batch_size, channels, height, width = x.size()
81
-
82
  x = x.view(-1, self.channels, self.size * self.size).swapaxes(1,2)
83
  x_ln = self.ln(x)
84
 
@@ -124,7 +123,7 @@ class DoubleConv(nn.Module):
124
 
125
 
126
  class Down(nn.Module):
127
- def __init__(self, in_channels, out_channels, emb_dim=1024):
128
  super().__init__()
129
  self.maxpool_conv = nn.Sequential(
130
  nn.MaxPool2d(2),
@@ -147,7 +146,7 @@ class Down(nn.Module):
147
 
148
 
149
  class Up(nn.Module):
150
- def __init__(self, in_channels, out_channels, emb_dim=1024):
151
  super().__init__()
152
 
153
  self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
@@ -172,63 +171,30 @@ class Up(nn.Module):
172
  return x + emb
173
 
174
 
175
-
176
- class UNet_conditional_large(nn.Module):
177
- def __init__(self, c_in=3, c_out=3, time_dim=1024, num_classes=1024, context_dim=None, device="mps"):
178
  super().__init__()
179
-
180
- if context_dim is None:
181
- context_dim = num_classes
182
  self.device = device
183
  self.time_dim = time_dim
184
-
185
-
186
- start_depth = 128
187
-
188
-
189
- xa_amt_depth = 64 # dont change
190
-
191
- self.inc = DoubleConv(c_in, start_depth)
192
- self.down1 = Down(start_depth, start_depth * 2)
193
-
194
- self.xa1 = CrossAttention(start_depth * 2, xa_amt_depth // 2, context_dim)
195
-
196
- self.down2 = Down(start_depth * 2, start_depth * 4)
197
- self.xa2 = CrossAttention(start_depth * 4, xa_amt_depth // 4, context_dim)
198
-
199
- self.down3 = Down(start_depth * 4, start_depth * 8)
200
- self.xa3 = CrossAttention(start_depth * 8, xa_amt_depth // 8, context_dim)
201
-
202
- self.down4 = Down(start_depth * 8, start_depth * 8)
203
- self.xa4 = CrossAttention(start_depth * 8, xa_amt_depth // 16, context_dim)
204
-
205
- self.bot1 = DoubleConv(start_depth * 8, start_depth * 16)
206
- self.bot2 = DoubleConv(start_depth * 16, start_depth * 16)
207
- self.bot3 = DoubleConv(start_depth * 16, start_depth * 8)
208
-
209
- self.up1 = Up(start_depth * 16, start_depth * 4)
210
- self.xa5 = CrossAttention(start_depth * 4, xa_amt_depth // 8, context_dim)
211
-
212
- self.up2 = Up(start_depth * 8, start_depth * 2)
213
- self.xa6 = CrossAttention(start_depth * 2, xa_amt_depth // 4, context_dim)
214
-
215
- self.up3 = Up(start_depth * 4, start_depth)
216
- self.xa7 = CrossAttention(start_depth, xa_amt_depth // 2, context_dim)
217
-
218
- self.up4 = Up(start_depth * 2, start_depth)
219
- self.xa8 = CrossAttention(start_depth, xa_amt_depth, context_dim)
220
-
221
- self.outc = nn.Conv2d(start_depth, c_out, kernel_size=1)
222
-
223
- if num_classes is not None:
224
- self.label_emb = nn.Linear(num_classes, time_dim)#Embedding(num_classes, time_dim)
225
- self.num_classes = num_classes
226
- if context_dim is None:
227
- context_dim = num_classes
228
-
229
- self.context_dim = context_dim
230
-
231
- self.label_crossattn_emb = nn.Linear(num_classes, context_dim)
232
 
233
  def pos_encoding(self, t, channels):
234
  inv_freq = 1.0 / (
@@ -240,201 +206,34 @@ class UNet_conditional_large(nn.Module):
240
  pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
241
  return pos_enc
242
 
243
- def forward(self, x, t, y):
244
  t = t.unsqueeze(-1).type(torch.float)
245
  t = self.pos_encoding(t, self.time_dim)
246
 
247
- if y is not None:
248
-
249
- attn_y = y[:,:self.num_classes]
250
- attn_y = self.label_crossattn_emb(attn_y)
251
-
252
- # y = y[:,:self.num_classes]
253
-
254
- # y = self.label_emb(y)
255
-
256
-
257
- # t += y
258
-
259
  x1 = self.inc(x)
260
-
261
  x2 = self.down1(x1, t)
262
- x2 = self.xa1(x2, attn_y)
263
-
264
-
265
  x3 = self.down2(x2, t)
266
- x3 = self.xa2(x3, attn_y)
267
-
268
-
269
  x4 = self.down3(x3, t)
 
270
 
271
- x4 = self.xa3(x4, attn_y)
272
-
273
-
274
- x5 = self.down4(x4, t)
275
-
276
- x5 = self.xa4(x5, attn_y)
277
-
278
-
279
-
280
- x5 = self.bot1(x5)
281
- x5 = self.bot2(x5)
282
- x5 = self.bot3(x5)
283
-
284
-
285
-
286
- x = self.up1(x5, x4, t)
287
- x = self.xa5(x,attn_y)
288
-
289
-
290
- x = self.up2(x, x3, t)
291
- x = self.xa6(x,attn_y)
292
-
293
- x = self.up3(x, x2, t)
294
- x = self.xa7(x, attn_y)
295
-
296
-
297
- x = self.up4(x, x1, t)
298
- x = self.xa8(x, attn_y)
299
 
 
 
 
 
 
 
300
  output = self.outc(x)
301
  return output
302
 
303
- class UNet_conditional_efficient(nn.Module):
304
- def __init__(self, c_in=3, c_out=3, time_dim=1024, num_classes=1024, context_dim=None, device="mps"):
305
- super().__init__()
306
-
307
- if context_dim is None:
308
- context_dim = num_classes
309
- self.device = device
310
- self.time_dim = time_dim
311
-
312
-
313
- start_depth = 128
314
-
315
-
316
- xa_amt_depth = 64 # dont change
317
-
318
- self.inc = DoubleConv(c_in, start_depth * 2)
319
-
320
- self.downsample = nn.MaxPool2d(2)
321
-
322
-
323
- self.down2 = Down(start_depth * 2, start_depth * 4)
324
- self.xa2 = CrossAttention(start_depth * 4, xa_amt_depth // 4, context_dim)
325
-
326
- self.down3 = Down(start_depth * 4, start_depth * 8)
327
- self.xa3 = CrossAttention(start_depth * 8, xa_amt_depth // 8, context_dim)
328
-
329
- self.down4 = Down(start_depth * 8, start_depth * 8)
330
- self.xa4 = CrossAttention(start_depth * 8, xa_amt_depth // 16, context_dim)
331
-
332
- self.bot1 = DoubleConv(start_depth * 8, start_depth * 16)
333
- self.bot2 = DoubleConv(start_depth * 16, start_depth * 16)
334
- self.bot3 = DoubleConv(start_depth * 16, start_depth * 8)
335
-
336
- self.up1 = Up(start_depth * 16, start_depth * 4)
337
- self.xa5 = CrossAttention(start_depth * 4, xa_amt_depth // 8, context_dim)
338
-
339
- self.up2 = Up(start_depth * 8, start_depth * 2)
340
- self.xa6 = CrossAttention(start_depth * 2, xa_amt_depth // 4, context_dim)
341
 
342
- self.up3 = Up(start_depth * 4, start_depth)
343
- self.xa7 = CrossAttention(start_depth, xa_amt_depth // 2, context_dim)
344
-
345
- self.up4 = Up(start_depth * 2, start_depth)
346
- self.xa8 = CrossAttention(start_depth, xa_amt_depth, context_dim)
347
-
348
- self.upsample = nn.Upsample(scale_factor=2, mode="bilinear")
349
-
350
- self.outc = nn.Conv2d(start_depth, c_out, kernel_size=1)
351
-
352
- if num_classes is not None:
353
- self.label_emb = nn.Linear(num_classes, time_dim)#Embedding(num_classes, time_dim)
354
- self.num_classes = num_classes
355
- if context_dim is None:
356
- context_dim = num_classes
357
-
358
- self.context_dim = context_dim
359
-
360
- self.label_crossattn_emb = nn.Linear(num_classes, context_dim)
361
-
362
- def pos_encoding(self, t, channels):
363
- inv_freq = 1.0 / (
364
- 10000
365
- ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
366
- )
367
- pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
368
- pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
369
- pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
370
- return pos_enc
371
-
372
- def forward(self, x, t, y):
373
- t = t.unsqueeze(-1).type(torch.float)
374
- t = self.pos_encoding(t, self.time_dim)
375
-
376
- if y is not None:
377
-
378
- attn_y = y[:,:self.num_classes]
379
- attn_y = self.label_crossattn_emb(attn_y)
380
-
381
- # y = y[:,:self.num_classes]
382
-
383
- # y = self.label_emb(y)
384
-
385
-
386
- # t += y
387
-
388
- x1 = self.inc(x)
389
-
390
- x2 = self.downsample(x1)
391
-
392
-
393
-
394
-
395
-
396
-
397
- x3 = self.down2(x2, t)
398
- x3 = self.xa2(x3, attn_y)
399
-
400
-
401
- x4 = self.down3(x3, t)
402
-
403
- x4 = self.xa3(x4, attn_y)
404
-
405
-
406
- x5 = self.down4(x4, t)
407
-
408
- x5 = self.xa4(x5, attn_y)
409
-
410
-
411
-
412
- x5 = self.bot1(x5)
413
- x5 = self.bot2(x5)
414
- x5 = self.bot3(x5)
415
-
416
-
417
-
418
- x = self.up1(x5, x4, t)
419
- x = self.xa5(x,attn_y)
420
-
421
-
422
- x = self.up2(x, x3, t)
423
- x = self.xa6(x,attn_y)
424
-
425
- x = self.up3(x, x2, t)
426
- x = self.xa7(x, attn_y)
427
-
428
-
429
-
430
-
431
-
432
- x = self.upsample(x)
433
- output = self.outc(x)
434
- return output
435
-
436
- class UNet_conditional_start_depth(nn.Module):
437
- def __init__(self, c_in=3, c_out=3, time_dim=1024, num_classes=None, context_dim=None, device="mps"):
438
  super().__init__()
439
 
440
  if context_dim is None:
@@ -443,36 +242,31 @@ class UNet_conditional_start_depth(nn.Module):
443
  self.time_dim = time_dim
444
 
445
 
446
- start_depth = 128
447
- xa_amt_depth = 64
448
-
449
- self.inc = DoubleConv(c_in, start_depth)
450
-
451
- self.down1 = Down(start_depth, start_depth * 2)
452
- self.xa1 = CrossAttention(start_depth * 2, xa_amt_depth // 2, context_dim)
453
-
454
- self.down2 = Down(start_depth * 2, start_depth * 4)
455
- self.xa2 = CrossAttention(start_depth * 4, xa_amt_depth // 4, context_dim)
456
-
457
- self.down3 = Down(start_depth * 4, start_depth * 4)
458
- self.xa3 = CrossAttention(start_depth * 4, xa_amt_depth // 8, context_dim)
459
-
460
-
461
- self.bot1 = DoubleConv(start_depth * 4, start_depth * 8)
462
- self.bot2 = DoubleConv(start_depth * 8, start_depth * 8)
463
- self.bot3 = DoubleConv(start_depth * 8, start_depth * 4)
464
-
465
- self.up1 = Up(start_depth * 8, start_depth * 2)
466
- self.xa4 = CrossAttention(start_depth * 2, xa_amt_depth // 4, context_dim)
467
-
468
- self.up2 = Up(start_depth * 4, start_depth)
469
- self.xa5 = CrossAttention(start_depth, xa_amt_depth // 2, context_dim)
470
-
471
- self.up3 = Up(start_depth * 2, start_depth)
472
- self.xa6 = CrossAttention(start_depth, xa_amt_depth, context_dim)
473
-
474
- self.outc = nn.Conv2d(start_depth, c_out, kernel_size=1)
475
-
476
 
477
  if num_classes is not None:
478
  self.label_emb = nn.Linear(num_classes, time_dim)#Embedding(num_classes, time_dim)
@@ -547,19 +341,3 @@ class UNet_conditional_start_depth(nn.Module):
547
 
548
  #output = F.sigmoid(x)
549
  return output
550
-
551
-
552
- if __name__ == "__main__":
553
- net = UNet_conditional_start_depth(num_classes=1024).to("mps")
554
-
555
- def count_parameters(model):
556
- return torch.tensor([p.numel() for p in model.parameters() if p.requires_grad]).sum().item()
557
- print(f"Parameters: {count_parameters(net)}")
558
-
559
- minibatch = torch.randn((1,3,64,64)).to("mps")
560
-
561
- o = net(minibatch, torch.randint(low=1, high=1000, size=(1,)).to("mps"), torch.randn((1,1024)).to("mps"))
562
-
563
- print(o.size())
564
-
565
-
 
78
 
79
  # Reshape and permute x for multi-head attention
80
  batch_size, channels, height, width = x.size()
 
81
  x = x.view(-1, self.channels, self.size * self.size).swapaxes(1,2)
82
  x_ln = self.ln(x)
83
 
 
123
 
124
 
125
  class Down(nn.Module):
126
+ def __init__(self, in_channels, out_channels, emb_dim=256):
127
  super().__init__()
128
  self.maxpool_conv = nn.Sequential(
129
  nn.MaxPool2d(2),
 
146
 
147
 
148
  class Up(nn.Module):
149
+ def __init__(self, in_channels, out_channels, emb_dim=256):
150
  super().__init__()
151
 
152
  self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
 
171
  return x + emb
172
 
173
 
174
+ class Dome_UNet(nn.Module):
175
+ def __init__(self, c_in=3, c_out=3, time_dim=256, device="mps"):
 
176
  super().__init__()
 
 
 
177
  self.device = device
178
  self.time_dim = time_dim
179
+ self.inc = DoubleConv(c_in, 64)
180
+ self.down1 = Down(64, 128)
181
+ self.sa1 = SelfAttention(128, 32)
182
+ self.down2 = Down(128, 256)
183
+ self.sa2 = SelfAttention(256, 16)
184
+ self.down3 = Down(256, 256)
185
+ self.sa3 = SelfAttention(256, 8)
186
+
187
+ self.bot1 = DoubleConv(256, 512)
188
+ self.bot2 = DoubleConv(512, 512)
189
+ self.bot3 = DoubleConv(512, 256)
190
+
191
+ self.up1 = Up(512, 128)
192
+ self.sa4 = SelfAttention(128, 16)
193
+ self.up2 = Up(256, 64)
194
+ self.sa5 = SelfAttention(64, 32)
195
+ self.up3 = Up(128, 64)
196
+ self.sa6 = SelfAttention(64, 64)
197
+ self.outc = nn.Conv2d(64, c_out, kernel_size=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  def pos_encoding(self, t, channels):
200
  inv_freq = 1.0 / (
 
206
  pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
207
  return pos_enc
208
 
209
+ def forward(self, x, t):
210
  t = t.unsqueeze(-1).type(torch.float)
211
  t = self.pos_encoding(t, self.time_dim)
212
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  x1 = self.inc(x)
 
214
  x2 = self.down1(x1, t)
215
+ x2 = self.sa1(x2)
 
 
216
  x3 = self.down2(x2, t)
217
+ x3 = self.sa2(x3)
 
 
218
  x4 = self.down3(x3, t)
219
+ x4 = self.sa3(x4)
220
 
221
+ x4 = self.bot1(x4)
222
+ x4 = self.bot2(x4)
223
+ x4 = self.bot3(x4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ x = self.up1(x4, x3, t)
226
+ x = self.sa4(x)
227
+ x = self.up2(x, x2, t)
228
+ x = self.sa5(x)
229
+ x = self.up3(x, x1, t)
230
+ x = self.sa6(x)
231
  output = self.outc(x)
232
  return output
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ class UNet_conditional(nn.Module):
236
+ def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, context_dim=None, device="mps"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  super().__init__()
238
 
239
  if context_dim is None:
 
242
  self.time_dim = time_dim
243
 
244
 
245
+ self.inc = DoubleConv(c_in, 64)
246
+ self.down1 = Down(64, 128)
247
+ self.sa1 = SelfAttention(128, 32)
248
+ self.xa1 = CrossAttention(128, 32, context_dim)
249
+ self.down2 = Down(128, 256)
250
+ self.xa2 = CrossAttention(256, 16, context_dim)
251
+ self.sa2 = SelfAttention(256, 16)
252
+ self.down3 = Down(256, 256)
253
+ self.xa3 = CrossAttention(256, 8, context_dim)
254
+ self.sa3 = SelfAttention(256, 8)
255
+
256
+ self.bot1 = DoubleConv(256, 512)
257
+ self.bot2 = DoubleConv(512, 512)
258
+ self.bot3 = DoubleConv(512, 256)
259
+
260
+ self.up1 = Up(512, 128)
261
+ self.xa4 = CrossAttention(128, 16, context_dim)
262
+ self.sa4 = SelfAttention(128, 16)
263
+ self.up2 = Up(256, 64)
264
+ self.xa5 = CrossAttention(64, 32, context_dim)
265
+ self.sa5 = SelfAttention(64, 32)
266
+ self.up3 = Up(128, 64)
267
+ self.xa6 = CrossAttention(64, 64, context_dim)
268
+ self.sa6 = SelfAttention(64, 64)
269
+ self.outc = nn.Conv2d(64, c_out, kernel_size=1)
 
 
 
 
 
270
 
271
  if num_classes is not None:
272
  self.label_emb = nn.Linear(num_classes, time_dim)#Embedding(num_classes, time_dim)
 
341
 
342
  #output = F.sigmoid(x)
343
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
logger.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.tensorboard import SummaryWriter
3
+ import matplotlib.pyplot as plt
4
+
5
+ writer = None
6
+ def log_data(data, i):
7
+
8
+
9
+ for key in data.keys():
10
+ writer.add_scalar(key, data[key], i)
11
+
12
+ def log_img(img, name):
13
+ writer.add_image(name, img)
14
+
15
+
16
+ def save_grid_with_label(img_grid, label, out_file):
17
+ img_grid = img_grid.permute(1, 2, 0).numpy()
18
+
19
+ fig, ax = plt.subplots(figsize=(8, 8))
20
+ ax.imshow(img_grid)
21
+ ax.set_title(label, fontsize=20)
22
+ ax.axis('off')
23
+
24
+
25
+ plt.subplots_adjust(top=0.85)
26
+
27
+ plt.savefig(out_file, bbox_inches='tight', pad_inches=0.1)
28
+
29
+
30
+ plt.close(fig)
31
+ plt.close("all")
32
+
33
+
34
+
35
+
36
+ def init_logger(dir="runs"):
37
+
38
+ global writer
39
+ if not writer:
40
+ writer = SummaryWriter(dir)
runs/run_3_jxa/ckpt/latest.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cd39e8429ea0ace24bb40d4bd404baebb8aae471385987b898a966eb79dcc5f
3
+ size 103503678
wrapper.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from enum import Enum
4
+ from tqdm import trange
5
+
6
+
7
+
8
+
9
+
10
+ Schedule = Enum('Schedule', ['LINEAR', 'COSINE'])
11
+
12
+ class DiffusionManager(nn.Module):
13
+ def __init__(self, model: nn.Module, noise_steps=1000, start=0.0001, end=0.02, device="cpu", **kwargs ) -> None:
14
+ super().__init__(**kwargs)
15
+
16
+ self.model = model
17
+
18
+ self.noise_steps = noise_steps
19
+
20
+ self.start = start
21
+ self.end = end
22
+ self.device = device
23
+
24
+ self.schedule = None
25
+
26
+ self.set_schedule()
27
+
28
+ #model.set_parent(self)
29
+
30
+
31
+ def _get_schedule(self, schedule_type: Schedule = Schedule.LINEAR):
32
+ if schedule_type == Schedule.LINEAR:
33
+ return torch.linspace(self.start, self.end, self.noise_steps)
34
+ elif schedule_type == Schedule.COSINE:
35
+ # https://arxiv.org/pdf/2102.09672 page 4
36
+ #https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
37
+ #line 18
38
+ def get_alphahat_at(t):
39
+ def f(t):
40
+ s=self.start
41
+ return torch.cos((t/self.noise_steps + s)/(1+s) * torch.pi/2) ** 2
42
+
43
+ return f(t)/f(torch.zeros_like(t))
44
+
45
+ t = torch.Tensor(range(self.noise_steps))
46
+
47
+ t = 1-(get_alphahat_at(t + 1)/get_alphahat_at(t))
48
+
49
+ t = torch.minimum(t, torch.ones_like(t) * 0.999) #"In practice, we clip β_t to be no larger than 0.999 to prevent singularities at the end of the diffusion process n"
50
+
51
+ return t
52
+
53
+ def set_schedule(self, schedule: Schedule = Schedule.LINEAR):
54
+ self.schedule = self._get_schedule(schedule).to(self.device)
55
+
56
+ def get_schedule_at(self, step):
57
+ beta = self.schedule
58
+ alpha = 1 - beta
59
+ alpha_hat = torch.cumprod(alpha, dim=0)
60
+
61
+ return self._unsqueezify(beta.data[step]), self._unsqueezify(alpha.data[step]), self._unsqueezify(alpha_hat.data[step])
62
+
63
+ @staticmethod
64
+ def _unsqueezify(value):
65
+ return value.view(-1, 1, 1, 1)#.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
66
+
67
+ def noise_image(self, image, step):
68
+
69
+
70
+ image = image.to(self.device)
71
+
72
+ beta, alpha, alpha_hat = self.get_schedule_at(step)
73
+
74
+ epsilon = torch.randn_like(image)
75
+
76
+ # print(alpha_hat)
77
+
78
+ # print(alpha_hat.size())
79
+ # print(image.size())
80
+
81
+ noised_img = torch.sqrt(alpha_hat) * image + torch.sqrt(1 - alpha_hat) * epsilon
82
+
83
+ return noised_img, epsilon
84
+
85
+ def random_timesteps(self, amt=1):
86
+
87
+ return torch.randint(low=1, high=self.noise_steps, size=(amt,))
88
+
89
+
90
+
91
+
92
+ def sample(self, img_size, condition, amt=5, use_tqdm=True):
93
+
94
+ if tuple(condition.shape)[0] < amt:
95
+ condition = condition.repeat(amt, 1)
96
+
97
+ self.model.eval()
98
+
99
+ condition = condition.to(self.device)
100
+
101
+ my_trange = lambda x, y, z: trange(x,y, z, leave=False,dynamic_ncols=True)
102
+ fn = my_trange if use_tqdm else range
103
+ with torch.no_grad():
104
+
105
+ cur_img = torch.randn((amt, 3, img_size, img_size)).to(self.device)
106
+ for i in fn(self.noise_steps-1, 0, -1):
107
+
108
+ timestep = torch.ones(amt) * (i)
109
+
110
+ timestep = timestep.to(self.device)
111
+
112
+
113
+
114
+ predicted_noise = self.model(cur_img, timestep, condition)
115
+
116
+ beta, alpha, alpha_hat = self.get_schedule_at(i)
117
+
118
+ cur_img = (1/torch.sqrt(alpha))*(cur_img - (beta/torch.sqrt(1-alpha_hat))*predicted_noise)
119
+ if i > 1:
120
+ cur_img = cur_img + torch.sqrt(beta)*torch.randn_like(cur_img)
121
+
122
+
123
+ self.model.train()
124
+
125
+
126
+
127
+
128
+
129
+ return cur_img
130
+ def sample_multicond(self, img_size, condition, use_tqdm=True):
131
+ num_conditions = condition.shape[0]
132
+
133
+
134
+
135
+ amt = num_conditions
136
+
137
+ self.model.eval()
138
+
139
+ condition = condition.to(self.device)
140
+
141
+ my_trange = lambda x, y, z: trange(x, y, z, leave=False, dynamic_ncols=True)
142
+ fn = my_trange if use_tqdm else range
143
+
144
+ with torch.no_grad():
145
+
146
+ cur_img = torch.randn((amt, 3, img_size, img_size)).to(self.device)
147
+
148
+ for i in fn(self.noise_steps-1, 0, -1):
149
+ timestep = torch.ones(amt) * i
150
+ timestep = timestep.to(self.device)
151
+
152
+
153
+ predicted_noise = self.model(cur_img, timestep, condition)
154
+
155
+ beta, alpha, alpha_hat = self.get_schedule_at(i)
156
+
157
+ cur_img = (1 / torch.sqrt(alpha)) * (cur_img - (beta / torch.sqrt(1 - alpha_hat)) * predicted_noise)
158
+ if i > 1:
159
+ cur_img = cur_img + torch.sqrt(beta) * torch.randn_like(cur_img)
160
+
161
+ self.model.train()
162
+
163
+ # Return images sampled for each condition
164
+ return cur_img
165
+
166
+ def training_loop_iteration(self, optimizer, batch, label, criterion):
167
+
168
+ def print_(string):
169
+ for i in range(10):
170
+ print(string)
171
+ batch = batch.to(self.device)
172
+
173
+ #label = label.long() # uncomment for nn.Embedding
174
+ label = label.to(self.device)
175
+
176
+ timesteps = self.random_timesteps(batch.shape[0]).to(self.device)
177
+
178
+ noisy_batch, real_noise = self.noise_image(batch, timesteps)
179
+
180
+ if torch.isnan(noisy_batch).any() or torch.isnan(real_noise).any():
181
+ print_("NaNs detected in the noisy batch or real noise")
182
+
183
+
184
+ pred_noise = self.model(noisy_batch, timesteps, label)
185
+
186
+ if torch.isnan(pred_noise).any():
187
+ print_("NaNs detected in the predicted noise")
188
+
189
+ loss = criterion(real_noise, pred_noise)
190
+
191
+ if torch.isnan(loss).any():
192
+ print_("NaNs detected in the loss")
193
+
194
+ loss.backward()
195
+ optimizer.step()
196
+
197
+ return loss.item()
198
+