JBlitzar commited on
Commit
fc9acd0
·
1 Parent(s): 17789ea
Files changed (4) hide show
  1. factories.py +565 -0
  2. infer.py +43 -0
  3. pipeline.py +364 -0
  4. runner.py +80 -0
factories.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class EMA:
8
+ def __init__(self, beta):
9
+ super().__init__()
10
+ self.beta = beta
11
+ self.step = 0
12
+
13
+ def update_model_average(self, ma_model, current_model):
14
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
15
+ old_weight, up_weight = ma_params.data, current_params.data
16
+ ma_params.data = self.update_average(old_weight, up_weight)
17
+
18
+ def update_average(self, old, new):
19
+ if old is None:
20
+ return new
21
+ return old * self.beta + (1 - self.beta) * new
22
+
23
+ def step_ema(self, ema_model, model, step_start_ema=2000):
24
+ if self.step < step_start_ema:
25
+ self.reset_parameters(ema_model, model)
26
+ self.step += 1
27
+ return
28
+ self.update_model_average(ema_model, model)
29
+ self.step += 1
30
+
31
+ def reset_parameters(self, ema_model, model):
32
+ ema_model.load_state_dict(model.state_dict())
33
+
34
+
35
+ class SelfAttention(nn.Module):
36
+ def __init__(self, channels, size):
37
+ super(SelfAttention, self).__init__()
38
+ self.channels = channels
39
+ self.size = size
40
+ self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
41
+ self.ln = nn.LayerNorm([channels])
42
+ self.ff_self = nn.Sequential(
43
+ nn.LayerNorm([channels]),
44
+ nn.Linear(channels, channels),
45
+ nn.GELU(),
46
+ nn.Linear(channels, channels),
47
+ )
48
+
49
+ def forward(self, x):
50
+ x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
51
+ x_ln = self.ln(x)
52
+ attention_value, _ = self.mha(x_ln, x_ln, x_ln)
53
+ attention_value = attention_value + x
54
+ attention_value = self.ff_self(attention_value) + attention_value
55
+ return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
56
+
57
+
58
+ class CrossAttention(nn.Module):
59
+ def __init__(self, channels, size, context_dim):
60
+ super(CrossAttention, self).__init__()
61
+ self.channels = channels
62
+ self.size = size
63
+ self.context_dim = context_dim
64
+ self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
65
+ self.ln = nn.LayerNorm(channels)
66
+ self.context_ln = nn.LayerNorm(channels)
67
+ self.ff_self = nn.Sequential(
68
+ nn.LayerNorm(channels),
69
+ nn.Linear(channels, channels),
70
+ nn.GELU(),
71
+ nn.Linear(channels, channels),
72
+ )
73
+
74
+
75
+ self.context_proj = nn.Linear(context_dim, channels)
76
+
77
+ def forward(self, x, context):
78
+
79
+ # Reshape and permute x for multi-head attention
80
+ batch_size, channels, height, width = x.size()
81
+
82
+ x = x.view(-1, self.channels, self.size * self.size).swapaxes(1,2)
83
+ x_ln = self.ln(x)
84
+
85
+ # Expand context to match the sequence length of x
86
+ context = self.context_proj(context)
87
+
88
+ context = context.unsqueeze(1).expand(-1, x_ln.size(1), -1)
89
+
90
+ context_ln = self.context_ln(context)
91
+
92
+
93
+
94
+
95
+
96
+ # Apply cross-attention
97
+ attention_value, _ = self.mha(x_ln, context_ln, context_ln)
98
+ attention_value = attention_value + x
99
+ attention_value = self.ff_self(attention_value) + attention_value
100
+
101
+ # Reshape and permute back to the original format
102
+ return attention_value.permute(0, 2, 1).view(batch_size, channels, height, width)
103
+
104
+
105
+ class DoubleConv(nn.Module):
106
+ def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
107
+ super().__init__()
108
+ self.residual = residual
109
+ if not mid_channels:
110
+ mid_channels = out_channels
111
+ self.double_conv = nn.Sequential(
112
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
113
+ nn.GroupNorm(1, mid_channels),
114
+ nn.GELU(),
115
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
116
+ nn.GroupNorm(1, out_channels),
117
+ )
118
+
119
+ def forward(self, x):
120
+ if self.residual:
121
+ return F.gelu(x + self.double_conv(x))
122
+ else:
123
+ return self.double_conv(x)
124
+
125
+
126
+ class Down(nn.Module):
127
+ def __init__(self, in_channels, out_channels, emb_dim=1024):
128
+ super().__init__()
129
+ self.maxpool_conv = nn.Sequential(
130
+ nn.MaxPool2d(2),
131
+ DoubleConv(in_channels, in_channels, residual=True),
132
+ DoubleConv(in_channels, out_channels),
133
+ )
134
+
135
+ self.emb_layer = nn.Sequential(
136
+ nn.SiLU(),
137
+ nn.Linear(
138
+ emb_dim,
139
+ out_channels
140
+ ),
141
+ )
142
+
143
+ def forward(self, x, t):
144
+ x = self.maxpool_conv(x)
145
+ emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
146
+ return x + emb
147
+
148
+
149
+ class Up(nn.Module):
150
+ def __init__(self, in_channels, out_channels, emb_dim=1024):
151
+ super().__init__()
152
+
153
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
154
+ self.conv = nn.Sequential(
155
+ DoubleConv(in_channels, in_channels, residual=True),
156
+ DoubleConv(in_channels, out_channels, in_channels // 2),
157
+ )
158
+
159
+ self.emb_layer = nn.Sequential(
160
+ nn.SiLU(),
161
+ nn.Linear(
162
+ emb_dim,
163
+ out_channels
164
+ ),
165
+ )
166
+
167
+ def forward(self, x, skip_x, t):
168
+ x = self.up(x)
169
+ x = torch.cat([skip_x, x], dim=1)
170
+ x = self.conv(x)
171
+ emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
172
+ return x + emb
173
+
174
+
175
+
176
+ class UNet_conditional_large(nn.Module):
177
+ def __init__(self, c_in=3, c_out=3, time_dim=1024, num_classes=1024, context_dim=None, device="mps"):
178
+ super().__init__()
179
+
180
+ if context_dim is None:
181
+ context_dim = num_classes
182
+ self.device = device
183
+ self.time_dim = time_dim
184
+
185
+
186
+ start_depth = 128
187
+
188
+
189
+ xa_amt_depth = 64 # dont change
190
+
191
+ self.inc = DoubleConv(c_in, start_depth)
192
+ self.down1 = Down(start_depth, start_depth * 2)
193
+
194
+ self.xa1 = CrossAttention(start_depth * 2, xa_amt_depth // 2, context_dim)
195
+
196
+ self.down2 = Down(start_depth * 2, start_depth * 4)
197
+ self.xa2 = CrossAttention(start_depth * 4, xa_amt_depth // 4, context_dim)
198
+
199
+ self.down3 = Down(start_depth * 4, start_depth * 8)
200
+ self.xa3 = CrossAttention(start_depth * 8, xa_amt_depth // 8, context_dim)
201
+
202
+ self.down4 = Down(start_depth * 8, start_depth * 8)
203
+ self.xa4 = CrossAttention(start_depth * 8, xa_amt_depth // 16, context_dim)
204
+
205
+ self.bot1 = DoubleConv(start_depth * 8, start_depth * 16)
206
+ self.bot2 = DoubleConv(start_depth * 16, start_depth * 16)
207
+ self.bot3 = DoubleConv(start_depth * 16, start_depth * 8)
208
+
209
+ self.up1 = Up(start_depth * 16, start_depth * 4)
210
+ self.xa5 = CrossAttention(start_depth * 4, xa_amt_depth // 8, context_dim)
211
+
212
+ self.up2 = Up(start_depth * 8, start_depth * 2)
213
+ self.xa6 = CrossAttention(start_depth * 2, xa_amt_depth // 4, context_dim)
214
+
215
+ self.up3 = Up(start_depth * 4, start_depth)
216
+ self.xa7 = CrossAttention(start_depth, xa_amt_depth // 2, context_dim)
217
+
218
+ self.up4 = Up(start_depth * 2, start_depth)
219
+ self.xa8 = CrossAttention(start_depth, xa_amt_depth, context_dim)
220
+
221
+ self.outc = nn.Conv2d(start_depth, c_out, kernel_size=1)
222
+
223
+ if num_classes is not None:
224
+ self.label_emb = nn.Linear(num_classes, time_dim)#Embedding(num_classes, time_dim)
225
+ self.num_classes = num_classes
226
+ if context_dim is None:
227
+ context_dim = num_classes
228
+
229
+ self.context_dim = context_dim
230
+
231
+ self.label_crossattn_emb = nn.Linear(num_classes, context_dim)
232
+
233
+ def pos_encoding(self, t, channels):
234
+ inv_freq = 1.0 / (
235
+ 10000
236
+ ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
237
+ )
238
+ pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
239
+ pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
240
+ pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
241
+ return pos_enc
242
+
243
+ def forward(self, x, t, y):
244
+ t = t.unsqueeze(-1).type(torch.float)
245
+ t = self.pos_encoding(t, self.time_dim)
246
+
247
+ if y is not None:
248
+
249
+ attn_y = y[:,:self.num_classes]
250
+ attn_y = self.label_crossattn_emb(attn_y)
251
+
252
+ # y = y[:,:self.num_classes]
253
+
254
+ # y = self.label_emb(y)
255
+
256
+
257
+ # t += y
258
+
259
+ x1 = self.inc(x)
260
+
261
+ x2 = self.down1(x1, t)
262
+ x2 = self.xa1(x2, attn_y)
263
+
264
+
265
+ x3 = self.down2(x2, t)
266
+ x3 = self.xa2(x3, attn_y)
267
+
268
+
269
+ x4 = self.down3(x3, t)
270
+
271
+ x4 = self.xa3(x4, attn_y)
272
+
273
+
274
+ x5 = self.down4(x4, t)
275
+
276
+ x5 = self.xa4(x5, attn_y)
277
+
278
+
279
+
280
+ x5 = self.bot1(x5)
281
+ x5 = self.bot2(x5)
282
+ x5 = self.bot3(x5)
283
+
284
+
285
+
286
+ x = self.up1(x5, x4, t)
287
+ x = self.xa5(x,attn_y)
288
+
289
+
290
+ x = self.up2(x, x3, t)
291
+ x = self.xa6(x,attn_y)
292
+
293
+ x = self.up3(x, x2, t)
294
+ x = self.xa7(x, attn_y)
295
+
296
+
297
+ x = self.up4(x, x1, t)
298
+ x = self.xa8(x, attn_y)
299
+
300
+ output = self.outc(x)
301
+ return output
302
+
303
+ class UNet_conditional_efficient(nn.Module):
304
+ def __init__(self, c_in=3, c_out=3, time_dim=1024, num_classes=1024, context_dim=None, device="mps"):
305
+ super().__init__()
306
+
307
+ if context_dim is None:
308
+ context_dim = num_classes
309
+ self.device = device
310
+ self.time_dim = time_dim
311
+
312
+
313
+ start_depth = 128
314
+
315
+
316
+ xa_amt_depth = 64 # dont change
317
+
318
+ self.inc = DoubleConv(c_in, start_depth * 2)
319
+
320
+ self.downsample = nn.MaxPool2d(2)
321
+
322
+
323
+ self.down2 = Down(start_depth * 2, start_depth * 4)
324
+ self.xa2 = CrossAttention(start_depth * 4, xa_amt_depth // 4, context_dim)
325
+
326
+ self.down3 = Down(start_depth * 4, start_depth * 8)
327
+ self.xa3 = CrossAttention(start_depth * 8, xa_amt_depth // 8, context_dim)
328
+
329
+ self.down4 = Down(start_depth * 8, start_depth * 8)
330
+ self.xa4 = CrossAttention(start_depth * 8, xa_amt_depth // 16, context_dim)
331
+
332
+ self.bot1 = DoubleConv(start_depth * 8, start_depth * 16)
333
+ self.bot2 = DoubleConv(start_depth * 16, start_depth * 16)
334
+ self.bot3 = DoubleConv(start_depth * 16, start_depth * 8)
335
+
336
+ self.up1 = Up(start_depth * 16, start_depth * 4)
337
+ self.xa5 = CrossAttention(start_depth * 4, xa_amt_depth // 8, context_dim)
338
+
339
+ self.up2 = Up(start_depth * 8, start_depth * 2)
340
+ self.xa6 = CrossAttention(start_depth * 2, xa_amt_depth // 4, context_dim)
341
+
342
+ self.up3 = Up(start_depth * 4, start_depth)
343
+ self.xa7 = CrossAttention(start_depth, xa_amt_depth // 2, context_dim)
344
+
345
+ self.up4 = Up(start_depth * 2, start_depth)
346
+ self.xa8 = CrossAttention(start_depth, xa_amt_depth, context_dim)
347
+
348
+ self.upsample = nn.Upsample(scale_factor=2, mode="bilinear")
349
+
350
+ self.outc = nn.Conv2d(start_depth, c_out, kernel_size=1)
351
+
352
+ if num_classes is not None:
353
+ self.label_emb = nn.Linear(num_classes, time_dim)#Embedding(num_classes, time_dim)
354
+ self.num_classes = num_classes
355
+ if context_dim is None:
356
+ context_dim = num_classes
357
+
358
+ self.context_dim = context_dim
359
+
360
+ self.label_crossattn_emb = nn.Linear(num_classes, context_dim)
361
+
362
+ def pos_encoding(self, t, channels):
363
+ inv_freq = 1.0 / (
364
+ 10000
365
+ ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
366
+ )
367
+ pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
368
+ pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
369
+ pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
370
+ return pos_enc
371
+
372
+ def forward(self, x, t, y):
373
+ t = t.unsqueeze(-1).type(torch.float)
374
+ t = self.pos_encoding(t, self.time_dim)
375
+
376
+ if y is not None:
377
+
378
+ attn_y = y[:,:self.num_classes]
379
+ attn_y = self.label_crossattn_emb(attn_y)
380
+
381
+ # y = y[:,:self.num_classes]
382
+
383
+ # y = self.label_emb(y)
384
+
385
+
386
+ # t += y
387
+
388
+ x1 = self.inc(x)
389
+
390
+ x2 = self.downsample(x1)
391
+
392
+
393
+
394
+
395
+
396
+
397
+ x3 = self.down2(x2, t)
398
+ x3 = self.xa2(x3, attn_y)
399
+
400
+
401
+ x4 = self.down3(x3, t)
402
+
403
+ x4 = self.xa3(x4, attn_y)
404
+
405
+
406
+ x5 = self.down4(x4, t)
407
+
408
+ x5 = self.xa4(x5, attn_y)
409
+
410
+
411
+
412
+ x5 = self.bot1(x5)
413
+ x5 = self.bot2(x5)
414
+ x5 = self.bot3(x5)
415
+
416
+
417
+
418
+ x = self.up1(x5, x4, t)
419
+ x = self.xa5(x,attn_y)
420
+
421
+
422
+ x = self.up2(x, x3, t)
423
+ x = self.xa6(x,attn_y)
424
+
425
+ x = self.up3(x, x2, t)
426
+ x = self.xa7(x, attn_y)
427
+
428
+
429
+
430
+
431
+
432
+ x = self.upsample(x)
433
+ output = self.outc(x)
434
+ return output
435
+
436
+ class UNet_conditional_start_depth(nn.Module):
437
+ def __init__(self, c_in=3, c_out=3, time_dim=1024, num_classes=None, context_dim=None, device="mps"):
438
+ super().__init__()
439
+
440
+ if context_dim is None:
441
+ context_dim = num_classes
442
+ self.device = device
443
+ self.time_dim = time_dim
444
+
445
+
446
+ start_depth = 128
447
+ xa_amt_depth = 64
448
+
449
+ self.inc = DoubleConv(c_in, start_depth)
450
+
451
+ self.down1 = Down(start_depth, start_depth * 2)
452
+ self.xa1 = CrossAttention(start_depth * 2, xa_amt_depth // 2, context_dim)
453
+
454
+ self.down2 = Down(start_depth * 2, start_depth * 4)
455
+ self.xa2 = CrossAttention(start_depth * 4, xa_amt_depth // 4, context_dim)
456
+
457
+ self.down3 = Down(start_depth * 4, start_depth * 4)
458
+ self.xa3 = CrossAttention(start_depth * 4, xa_amt_depth // 8, context_dim)
459
+
460
+
461
+ self.bot1 = DoubleConv(start_depth * 4, start_depth * 8)
462
+ self.bot2 = DoubleConv(start_depth * 8, start_depth * 8)
463
+ self.bot3 = DoubleConv(start_depth * 8, start_depth * 4)
464
+
465
+ self.up1 = Up(start_depth * 8, start_depth * 2)
466
+ self.xa4 = CrossAttention(start_depth * 2, xa_amt_depth // 4, context_dim)
467
+
468
+ self.up2 = Up(start_depth * 4, start_depth)
469
+ self.xa5 = CrossAttention(start_depth, xa_amt_depth // 2, context_dim)
470
+
471
+ self.up3 = Up(start_depth * 2, start_depth)
472
+ self.xa6 = CrossAttention(start_depth, xa_amt_depth, context_dim)
473
+
474
+ self.outc = nn.Conv2d(start_depth, c_out, kernel_size=1)
475
+
476
+
477
+ if num_classes is not None:
478
+ self.label_emb = nn.Linear(num_classes, time_dim)#Embedding(num_classes, time_dim)
479
+ self.num_classes = num_classes
480
+ if context_dim is None:
481
+ context_dim = num_classes
482
+
483
+ self.context_dim = context_dim
484
+
485
+ self.label_crossattn_emb = nn.Linear(num_classes, context_dim)
486
+
487
+ def pos_encoding(self, t, channels):
488
+ inv_freq = 1.0 / (
489
+ 10000
490
+ ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
491
+ )
492
+ pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
493
+ pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
494
+ pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
495
+ return pos_enc
496
+
497
+ def forward(self, x, t, y):
498
+ t = t.unsqueeze(-1).type(torch.float)
499
+ t = self.pos_encoding(t, self.time_dim)
500
+
501
+ if y is not None:
502
+
503
+ attn_y = y[:,:self.num_classes]
504
+ attn_y = self.label_crossattn_emb(attn_y)
505
+
506
+ # y = y[:,:self.num_classes]
507
+
508
+ # y = self.label_emb(y)
509
+
510
+
511
+ # t += y
512
+
513
+ x1 = self.inc(x)
514
+
515
+ x2 = self.down1(x1, t)
516
+ x2 = self.xa1(x2, attn_y)
517
+ #x2 = self.sa1(x2)
518
+
519
+ x3 = self.down2(x2, t)
520
+ x3 = self.xa2(x3, attn_y)
521
+ #x3 = self.sa2(x3)
522
+
523
+ x4 = self.down3(x3, t)
524
+ x4 = self.xa3(x4, attn_y)
525
+ #x4 = self.sa3(x4)
526
+
527
+
528
+ x4 = self.bot1(x4)
529
+ x4 = self.bot2(x4)
530
+ x4 = self.bot3(x4)
531
+
532
+
533
+ x = self.up1(x4, x3, t)
534
+ x = self.xa4(x,attn_y)
535
+ #x = self.sa4(x)
536
+
537
+ x = self.up2(x, x2, t)
538
+ x = self.xa5(x, attn_y)
539
+ #x = self.sa5(x)
540
+
541
+ x = self.up3(x, x1, t)
542
+ x = self.xa6(x, attn_y)
543
+ #x = self.sa6(x)
544
+ output = self.outc(x)
545
+
546
+
547
+
548
+ #output = F.sigmoid(x)
549
+ return output
550
+
551
+
552
+ if __name__ == "__main__":
553
+ net = UNet_conditional_start_depth(num_classes=1024).to("mps")
554
+
555
+ def count_parameters(model):
556
+ return torch.tensor([p.numel() for p in model.parameters() if p.requires_grad]).sum().item()
557
+ print(f"Parameters: {count_parameters(net)}")
558
+
559
+ minibatch = torch.randn((1,3,64,64)).to("mps")
560
+
561
+ o = net(minibatch, torch.randint(low=1, high=1000, size=(1,)).to("mps"), torch.randn((1,1024)).to("mps"))
562
+
563
+ print(o.size())
564
+
565
+
infer.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from factories import UNet_conditional
2
+ from wrapper import DiffusionManager, Schedule
3
+ import os
4
+ import re
5
+ import torch
6
+ from bert_vectorize import vectorize_text_with_bert
7
+ import time
8
+ import torchvision
9
+ from logger import save_grid_with_label
10
+
11
+
12
+
13
+ EXPERIMENT_DIRECTORY = "runs/run_3_jxa"
14
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
15
+
16
+ try:
17
+ os.mkdir(os.path.join(EXPERIMENT_DIRECTORY, "inferred"))
18
+ except:
19
+ print("Skipping making directory, directory already exists")
20
+
21
+ net = UNet_conditional(num_classes=768)
22
+ net.to(device)
23
+ net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt")))
24
+
25
+
26
+
27
+ wrapper = DiffusionManager(net, device=device, noise_steps=1000)
28
+ wrapper.set_schedule(Schedule.LINEAR)
29
+
30
+
31
+ def generate_sample_save_images(prompt, amt=1):
32
+
33
+ path = os.path.join(EXPERIMENT_DIRECTORY, "inferred", re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_")+str(int(time.time()))+".png")
34
+
35
+ vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
36
+
37
+ generated = wrapper.sample(64, vprompt, amt=amt).detach().cpu()
38
+
39
+
40
+ save_grid_with_label(torchvision.utils.make_grid(generated),prompt, path)
41
+
42
+ if __name__ == "__main__":
43
+ generate_sample_save_images(input("Prompt? "), 8)
pipeline.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline.py
2
+ import torch
3
+ from transformers import Pipeline
4
+
5
+
6
+
7
+ class TextToImagePipeline(Pipeline):
8
+ def __init__(self, model, tokenizer):
9
+ super().__init__(model=model, tokenizer=tokenizer)
10
+
11
+ def __call__(self, inputs):
12
+ text_inputs = self.tokenizer(inputs, return_tensors="pt")
13
+
14
+
15
+ with torch.no_grad():
16
+ image = self.model(text_inputs['input_ids'])
17
+
18
+
19
+ image = image.cpu().numpy()
20
+
21
+ return image
22
+
23
+
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+
30
+ class EMA:
31
+ def __init__(self, beta):
32
+ super().__init__()
33
+ self.beta = beta
34
+ self.step = 0
35
+
36
+ def update_model_average(self, ma_model, current_model):
37
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
38
+ old_weight, up_weight = ma_params.data, current_params.data
39
+ ma_params.data = self.update_average(old_weight, up_weight)
40
+
41
+ def update_average(self, old, new):
42
+ if old is None:
43
+ return new
44
+ return old * self.beta + (1 - self.beta) * new
45
+
46
+ def step_ema(self, ema_model, model, step_start_ema=2000):
47
+ if self.step < step_start_ema:
48
+ self.reset_parameters(ema_model, model)
49
+ self.step += 1
50
+ return
51
+ self.update_model_average(ema_model, model)
52
+ self.step += 1
53
+
54
+ def reset_parameters(self, ema_model, model):
55
+ ema_model.load_state_dict(model.state_dict())
56
+
57
+
58
+ class SelfAttention(nn.Module):
59
+ def __init__(self, channels, size):
60
+ super(SelfAttention, self).__init__()
61
+ self.channels = channels
62
+ self.size = size
63
+ self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
64
+ self.ln = nn.LayerNorm([channels])
65
+ self.ff_self = nn.Sequential(
66
+ nn.LayerNorm([channels]),
67
+ nn.Linear(channels, channels),
68
+ nn.GELU(),
69
+ nn.Linear(channels, channels),
70
+ )
71
+
72
+ def forward(self, x):
73
+ x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
74
+ x_ln = self.ln(x)
75
+ attention_value, _ = self.mha(x_ln, x_ln, x_ln)
76
+ attention_value = attention_value + x
77
+ attention_value = self.ff_self(attention_value) + attention_value
78
+ return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
79
+
80
+
81
+ class CrossAttention(nn.Module):
82
+ def __init__(self, channels, size, context_dim):
83
+ super(CrossAttention, self).__init__()
84
+ self.channels = channels
85
+ self.size = size
86
+ self.context_dim = context_dim
87
+ self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
88
+ self.ln = nn.LayerNorm(channels)
89
+ self.context_ln = nn.LayerNorm(channels)
90
+ self.ff_self = nn.Sequential(
91
+ nn.LayerNorm(channels),
92
+ nn.Linear(channels, channels),
93
+ nn.GELU(),
94
+ nn.Linear(channels, channels),
95
+ )
96
+
97
+
98
+ self.context_proj = nn.Linear(context_dim, channels)
99
+
100
+ def forward(self, x, context):
101
+
102
+ # Reshape and permute x for multi-head attention
103
+ batch_size, channels, height, width = x.size()
104
+ x = x.view(-1, self.channels, self.size * self.size).swapaxes(1,2)
105
+ x_ln = self.ln(x)
106
+
107
+ # Expand context to match the sequence length of x
108
+ context = self.context_proj(context)
109
+
110
+ context = context.unsqueeze(1).expand(-1, x_ln.size(1), -1)
111
+
112
+ context_ln = self.context_ln(context)
113
+
114
+
115
+
116
+
117
+
118
+ # Apply cross-attention
119
+ attention_value, _ = self.mha(x_ln, context_ln, context_ln)
120
+ attention_value = attention_value + x
121
+ attention_value = self.ff_self(attention_value) + attention_value
122
+
123
+ # Reshape and permute back to the original format
124
+ return attention_value.permute(0, 2, 1).view(batch_size, channels, height, width)
125
+
126
+
127
+ class DoubleConv(nn.Module):
128
+ def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
129
+ super().__init__()
130
+ self.residual = residual
131
+ if not mid_channels:
132
+ mid_channels = out_channels
133
+ self.double_conv = nn.Sequential(
134
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
135
+ nn.GroupNorm(1, mid_channels),
136
+ nn.GELU(),
137
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
138
+ nn.GroupNorm(1, out_channels),
139
+ )
140
+
141
+ def forward(self, x):
142
+ if self.residual:
143
+ return F.gelu(x + self.double_conv(x))
144
+ else:
145
+ return self.double_conv(x)
146
+
147
+
148
+ class Down(nn.Module):
149
+ def __init__(self, in_channels, out_channels, emb_dim=256):
150
+ super().__init__()
151
+ self.maxpool_conv = nn.Sequential(
152
+ nn.MaxPool2d(2),
153
+ DoubleConv(in_channels, in_channels, residual=True),
154
+ DoubleConv(in_channels, out_channels),
155
+ )
156
+
157
+ self.emb_layer = nn.Sequential(
158
+ nn.SiLU(),
159
+ nn.Linear(
160
+ emb_dim,
161
+ out_channels
162
+ ),
163
+ )
164
+
165
+ def forward(self, x, t):
166
+ x = self.maxpool_conv(x)
167
+ emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
168
+ return x + emb
169
+
170
+
171
+ class Up(nn.Module):
172
+ def __init__(self, in_channels, out_channels, emb_dim=256):
173
+ super().__init__()
174
+
175
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
176
+ self.conv = nn.Sequential(
177
+ DoubleConv(in_channels, in_channels, residual=True),
178
+ DoubleConv(in_channels, out_channels, in_channels // 2),
179
+ )
180
+
181
+ self.emb_layer = nn.Sequential(
182
+ nn.SiLU(),
183
+ nn.Linear(
184
+ emb_dim,
185
+ out_channels
186
+ ),
187
+ )
188
+
189
+ def forward(self, x, skip_x, t):
190
+ x = self.up(x)
191
+ x = torch.cat([skip_x, x], dim=1)
192
+ x = self.conv(x)
193
+ emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
194
+ return x + emb
195
+
196
+
197
+ class Dome_UNet(nn.Module):
198
+ def __init__(self, c_in=3, c_out=3, time_dim=256, device="mps"):
199
+ super().__init__()
200
+ self.device = device
201
+ self.time_dim = time_dim
202
+ self.inc = DoubleConv(c_in, 64)
203
+ self.down1 = Down(64, 128)
204
+ self.sa1 = SelfAttention(128, 32)
205
+ self.down2 = Down(128, 256)
206
+ self.sa2 = SelfAttention(256, 16)
207
+ self.down3 = Down(256, 256)
208
+ self.sa3 = SelfAttention(256, 8)
209
+
210
+ self.bot1 = DoubleConv(256, 512)
211
+ self.bot2 = DoubleConv(512, 512)
212
+ self.bot3 = DoubleConv(512, 256)
213
+
214
+ self.up1 = Up(512, 128)
215
+ self.sa4 = SelfAttention(128, 16)
216
+ self.up2 = Up(256, 64)
217
+ self.sa5 = SelfAttention(64, 32)
218
+ self.up3 = Up(128, 64)
219
+ self.sa6 = SelfAttention(64, 64)
220
+ self.outc = nn.Conv2d(64, c_out, kernel_size=1)
221
+
222
+ def pos_encoding(self, t, channels):
223
+ inv_freq = 1.0 / (
224
+ 10000
225
+ ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
226
+ )
227
+ pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
228
+ pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
229
+ pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
230
+ return pos_enc
231
+
232
+ def forward(self, x, t):
233
+ t = t.unsqueeze(-1).type(torch.float)
234
+ t = self.pos_encoding(t, self.time_dim)
235
+
236
+ x1 = self.inc(x)
237
+ x2 = self.down1(x1, t)
238
+ x2 = self.sa1(x2)
239
+ x3 = self.down2(x2, t)
240
+ x3 = self.sa2(x3)
241
+ x4 = self.down3(x3, t)
242
+ x4 = self.sa3(x4)
243
+
244
+ x4 = self.bot1(x4)
245
+ x4 = self.bot2(x4)
246
+ x4 = self.bot3(x4)
247
+
248
+ x = self.up1(x4, x3, t)
249
+ x = self.sa4(x)
250
+ x = self.up2(x, x2, t)
251
+ x = self.sa5(x)
252
+ x = self.up3(x, x1, t)
253
+ x = self.sa6(x)
254
+ output = self.outc(x)
255
+ return output
256
+
257
+
258
+ class UNet_conditional(nn.Module):
259
+ def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, context_dim=None, device="mps"):
260
+ super().__init__()
261
+
262
+ if context_dim is None:
263
+ context_dim = num_classes
264
+ self.device = device
265
+ self.time_dim = time_dim
266
+
267
+
268
+ self.inc = DoubleConv(c_in, 64)
269
+ self.down1 = Down(64, 128)
270
+ self.sa1 = SelfAttention(128, 32)
271
+ self.xa1 = CrossAttention(128, 32, context_dim)
272
+ self.down2 = Down(128, 256)
273
+ self.xa2 = CrossAttention(256, 16, context_dim)
274
+ self.sa2 = SelfAttention(256, 16)
275
+ self.down3 = Down(256, 256)
276
+ self.xa3 = CrossAttention(256, 8, context_dim)
277
+ self.sa3 = SelfAttention(256, 8)
278
+
279
+ self.bot1 = DoubleConv(256, 512)
280
+ self.bot2 = DoubleConv(512, 512)
281
+ self.bot3 = DoubleConv(512, 256)
282
+
283
+ self.up1 = Up(512, 128)
284
+ self.xa4 = CrossAttention(128, 16, context_dim)
285
+ self.sa4 = SelfAttention(128, 16)
286
+ self.up2 = Up(256, 64)
287
+ self.xa5 = CrossAttention(64, 32, context_dim)
288
+ self.sa5 = SelfAttention(64, 32)
289
+ self.up3 = Up(128, 64)
290
+ self.xa6 = CrossAttention(64, 64, context_dim)
291
+ self.sa6 = SelfAttention(64, 64)
292
+ self.outc = nn.Conv2d(64, c_out, kernel_size=1)
293
+
294
+ if num_classes is not None:
295
+ self.label_emb = nn.Linear(num_classes, time_dim)#Embedding(num_classes, time_dim)
296
+ self.num_classes = num_classes
297
+ if context_dim is None:
298
+ context_dim = num_classes
299
+
300
+ self.context_dim = context_dim
301
+
302
+ self.label_crossattn_emb = nn.Linear(num_classes, context_dim)
303
+
304
+ def pos_encoding(self, t, channels):
305
+ inv_freq = 1.0 / (
306
+ 10000
307
+ ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
308
+ )
309
+ pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
310
+ pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
311
+ pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
312
+ return pos_enc
313
+
314
+ def forward(self, x, t, y):
315
+ t = t.unsqueeze(-1).type(torch.float)
316
+ t = self.pos_encoding(t, self.time_dim)
317
+
318
+ if y is not None:
319
+
320
+ attn_y = y[:,:self.num_classes]
321
+ attn_y = self.label_crossattn_emb(attn_y)
322
+
323
+ # y = y[:,:self.num_classes]
324
+
325
+ # y = self.label_emb(y)
326
+
327
+
328
+ # t += y
329
+
330
+ x1 = self.inc(x)
331
+
332
+ x2 = self.down1(x1, t)
333
+ x2 = self.xa1(x2, attn_y)
334
+
335
+
336
+ x3 = self.down2(x2, t)
337
+ x3 = self.xa2(x3, attn_y)
338
+
339
+
340
+ x4 = self.down3(x3, t)
341
+ x4 = self.xa3(x4, attn_y)
342
+
343
+
344
+
345
+ x4 = self.bot1(x4)
346
+
347
+
348
+ x = self.up1(x4, x3, t)
349
+ x = self.xa4(x,attn_y)
350
+
351
+ x = self.up2(x, x2, t)
352
+ x = self.xa5(x, attn_y)
353
+
354
+
355
+ x = self.up3(x, x1, t)
356
+ x = self.xa6(x, attn_y)
357
+ x = self.sa6(x)
358
+
359
+ output = self.outc(x)
360
+
361
+
362
+ #output = F.sigmoid(x)
363
+ return output
364
+
runner.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from factories import UNet_conditional
2
+ from wrapper import DiffusionManager, Schedule
3
+ import os
4
+ import re
5
+ import torch
6
+ from bert_vectorize import vectorize_text_with_bert, cleanup
7
+ import time
8
+ import torchvision
9
+ from logger import save_grid_with_label
10
+ from clip_score import select_top_n_images
11
+ from torchinfo import summary
12
+
13
+
14
+
15
+ EXPERIMENT_DIRECTORY = "runs/run_3_jxa_resumed"
16
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
17
+
18
+ try:
19
+ os.mkdir(os.path.join(EXPERIMENT_DIRECTORY, "inferred"))
20
+ except:
21
+ print("Skipping making directory, directory already exists")
22
+
23
+ net = UNet_conditional(num_classes=768)
24
+ net.to(device)
25
+ net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt")))
26
+
27
+
28
+ def count_parameters(model):
29
+ return torch.tensor([p.numel() for p in model.parameters() if p.requires_grad]).sum().item()
30
+ print(f"Parameters: {count_parameters(net)}")
31
+
32
+
33
+
34
+ wrapper = DiffusionManager(net, device=device, noise_steps=1000)
35
+ wrapper.set_schedule(Schedule.LINEAR)
36
+
37
+
38
+ def infer(prompt, amt=1, topn=8):
39
+
40
+ path = os.path.join(EXPERIMENT_DIRECTORY, "inferred", re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_")+str(int(time.time()))+".png")
41
+
42
+ vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
43
+
44
+ generated = wrapper.sample(64, vprompt, amt=amt).detach().cpu()
45
+
46
+ generated, _ = select_top_n_images(generated, prompt, n=topn)
47
+
48
+ save_grid_with_label(torchvision.utils.make_grid(generated),prompt + f"({topn} best of {amt})", path)
49
+
50
+
51
+ def run_jobs():
52
+ n=8
53
+ bestof=32
54
+ print(f"using best {bestof} of {n}")
55
+ processed_tasks = set()
56
+ def read_jobs():
57
+ try:
58
+ with open("inference_jobs.txt", 'r') as file:
59
+ tasks = file.readlines()
60
+ return [task.strip() for task in tasks]
61
+ except FileNotFoundError:
62
+ return []
63
+
64
+ tasks = read_jobs()
65
+ new_tasks = [task for task in tasks if task not in processed_tasks]
66
+ while new_tasks:
67
+
68
+
69
+ if new_tasks:
70
+ for task in new_tasks:
71
+ infer(task, n,bestof)
72
+ processed_tasks.add(task)
73
+ tasks = read_jobs()
74
+ new_tasks = [task for task in tasks if task not in processed_tasks]
75
+
76
+ cleanup()
77
+
78
+ if __name__ == "__main__":
79
+ #infer(input("Prompt? "), 8)
80
+ run_jobs()