Xsmos commited on
Commit
84a4faa
·
verified ·
1 Parent(s): 713c506
Files changed (3) hide show
  1. context_unet.py +543 -0
  2. diffusion.ipynb +623 -711
  3. load_h5.py +101 -12
context_unet.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from dataclasses import dataclass
2
+ # import h5py
3
+ import torch
4
+ import torch.nn as nn
5
+ # from torch.utils.data import DataLoader, Dataset
6
+ # from datasets import Dataset
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import random
10
+ from abc import ABC, abstractmethod
11
+ import torch.nn.functional as F
12
+ import math
13
+ # from PIL import Image
14
+ import os
15
+ # from torch.utils.tensorboard import SummaryWriter
16
+ import copy
17
+ # from tqdm.auto import tqdm
18
+ # from torchvision import transforms
19
+ # from diffusers import UNet2DModel#, UNet3DConditionModel
20
+ # from diffusers import DDPMScheduler
21
+ # from diffusers.utils import make_image_grid
22
+ import datetime
23
+ # from pathlib import Path
24
+ # from diffusers.optimization import get_cosine_schedule_with_warmup
25
+ # from accelerate import notebook_launcher, Accelerator
26
+ # from huggingface_hub import create_repo, upload_folder
27
+ # from load_h5 import Dataset4h5
28
+
29
+ class GroupNorm32(nn.GroupNorm):
30
+ def __init__(self, num_groups, num_channels, swish, eps=1e-5):
31
+ super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
32
+ self.swish = swish
33
+
34
+ def forward(self, x):
35
+ y = super().forward(x.float()).to(x.dtype)
36
+ if self.swish == 1.0:
37
+ y = F.silu(y)
38
+ elif self.swish:
39
+ y = y * F.sigmoid(y * float(self.swish))
40
+ return y
41
+
42
+ def normalization(channels, swish=0.0):
43
+ """
44
+ Make a standard normalization layer, with an optional swish activation.
45
+
46
+ :param channels: number of input channels.
47
+ :return: an nn.Module for normalization.
48
+ """
49
+ #print (channels)
50
+ return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
51
+
52
+ Conv = {
53
+ 1: nn.Conv1d,
54
+ 2: nn.Conv2d,
55
+ 3: nn.Conv3d,
56
+ }
57
+
58
+ AvgPool = {
59
+ 1: nn.AvgPool1d,
60
+ 2: nn.AvgPool2d,
61
+ 3: nn.AvgPool3d
62
+ }
63
+
64
+ class Downsample(nn.Module):
65
+ def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2)):
66
+ super().__init__()
67
+ self.channels = channels
68
+ self.out_channels = out_channels or channels
69
+ # stride = config.stride
70
+ if use_conv:
71
+ # print("conv")
72
+ self.op = Conv[dim](channels, self.out_channels, 3, stride=stride, padding=1)
73
+ else:
74
+ # print("pool")
75
+ assert channels == self.out_channels
76
+ self.op = AvgPool[dim](kernel_size=stride, stride=stride)
77
+
78
+ def forward(self, x):
79
+ assert x.shape[1] == self.channels
80
+ return self.op(x)
81
+
82
+ class Upsample(nn.Module):
83
+ def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2)):
84
+ super().__init__()
85
+ self.channels = channels
86
+ self.out_channels = out_channels
87
+ self.use_conv = use_conv
88
+ self.stride = stride
89
+ if self.use_conv:
90
+ self.conv = Conv[dim](self.channels, self.out_channels, 3, padding=1)
91
+
92
+ def forward(self, x):
93
+ assert x.shape[1] == self.channels
94
+ # stride = config.stride
95
+ # print(torch.tensor(x.shape[2:]))
96
+ # print(torch.tensor(stride))
97
+ shape = torch.tensor(x.shape[2:]) * torch.tensor(self.stride)
98
+ shape = tuple(shape.detach().numpy())
99
+ # print(shape)
100
+ x = F.interpolate(x, shape, mode='nearest')
101
+ if self.use_conv:
102
+ x = self.conv(x)
103
+ return x
104
+
105
+ def zero_module(module):
106
+ """
107
+ clean gradient of parameters of the module
108
+ """
109
+ for p in module.parameters():
110
+ p.detach().zero_()
111
+ return module
112
+
113
+ class TimestepBlock(ABC, nn.Module):
114
+ @abstractmethod
115
+ def forward(self, x, emb):
116
+ """
117
+ test
118
+ """
119
+
120
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
121
+ def forward(self, x, emb, encoder_out=None):
122
+ for layer in self:
123
+ if isinstance(layer, TimestepBlock):
124
+ x = layer(x, emb)
125
+ elif isinstance(layer, AttentionBlock):
126
+ x = layer(x, encoder_out)
127
+ else:
128
+ x = layer(x)
129
+ return x
130
+
131
+ class ResBlock(TimestepBlock):
132
+ def __init__(
133
+ self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_checkpoint=False, use_scale_shift_norm=False, up=False, down=False, dim=2, stride=(2,2),
134
+ ):
135
+ super().__init__()
136
+ self.out_channels = out_channels or channels
137
+ self.use_scale_shift_norm = use_scale_shift_norm
138
+ self.stride = stride
139
+
140
+ self.in_layers = nn.Sequential(
141
+ # nn.BatchNorm2d(channels), # normalize to standard gaussian
142
+ normalization(channels, swish=1.0),
143
+ nn.Identity(),
144
+ Conv[dim](channels, self.out_channels, 3, padding=1),
145
+ )
146
+
147
+ self.updown = up or down
148
+ if up:
149
+ self.h_updown = Upsample(channels, False, dim=dim, stride=stride)
150
+ self.x_updown = Upsample(channels, False, dim=dim, stride=stride)
151
+ elif down:
152
+ self.h_updown = Downsample(channels, False, dim=dim, stride=stride)
153
+ self.x_updown = Downsample(channels, False, dim=dim, stride=stride)
154
+ else:
155
+ self.h_updown = self.x_updown = nn.Identity()
156
+
157
+ self.emb_layers = nn.Sequential(
158
+ nn.SiLU(),
159
+ nn.Linear(
160
+ emb_channels,
161
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
162
+ ),
163
+ )
164
+
165
+ self.out_layers = nn.Sequential(
166
+ # nn.BatchNorm2d(self.out_channels),
167
+ normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
168
+ nn.SiLU() if use_scale_shift_norm else nn.Identity(),
169
+ nn.Dropout(p=dropout),
170
+ zero_module(Conv[dim](self.out_channels, self.out_channels, 3, padding=1)),
171
+ )
172
+
173
+ if self.out_channels == channels:
174
+ self.skip_connection = nn.Identity()
175
+ elif use_conv:
176
+ self.skip_connection = Conv[dim](channels, self.out_channels, 3, padding=1)
177
+ else:
178
+ self.skip_connection = Conv[dim](channels, self.out_channels, 1)
179
+
180
+
181
+ def forward(self, x, emb):
182
+ if self.updown:
183
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
184
+ h = in_rest(x)
185
+ h = self.h_updown(h)
186
+ x = self.x_updown(x)
187
+ h = in_conv(h)
188
+ else:
189
+ h = self.in_layers(x)
190
+ emb_out = self.emb_layers(emb).type(h.dtype)
191
+
192
+ while len(emb_out.shape) < len(h.shape):
193
+ emb_out = emb_out[..., None]
194
+
195
+ if self.use_scale_shift_norm:
196
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
197
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
198
+ h = out_norm(h) * (1+scale) + shift
199
+ h = out_rest(h)
200
+ else:
201
+ h += emb_out
202
+ h = self.out_layers(h)
203
+ # print("ResBlock, torch.unique(h).shape =", torch.unique(h).shape)
204
+ return self.skip_connection(x) + h
205
+
206
+ class QKVAttention(nn.Module):
207
+ def __init__(self, n_heads):
208
+ super().__init__()
209
+ self.n_heads = n_heads
210
+ # print("QKVAttention, self.n_heads =", self.n_heads)
211
+
212
+ def forward(self, qkv, encoder_kv=None):
213
+ bs, width, length = qkv.shape
214
+ assert width % (3*self.n_heads) == 0
215
+ ch = width // (3*self.n_heads)
216
+
217
+ # print("QKVAttention", bs, self.n_heads, ch, length)
218
+ q, k, v = qkv.reshape(bs*self.n_heads, ch*3, length).split(ch, dim=1)
219
+ if encoder_kv is not None:
220
+ assert encoder_kv.shape[1] == self.n_heads * ch * 2
221
+ ek, ev = encoder_kv.reshape(bs*self.n_heads, ch*2, -1).split(ch, dim=1)
222
+ k = torch.cat([ek,k], dim=-1)
223
+ v = torch.cat([ev,v], dim=-1)
224
+
225
+ scale = 1 / math.sqrt(math.sqrt(ch))
226
+ weight = torch.einsum("bct,bcs->bts", q*scale, k*scale)
227
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
228
+
229
+ a = torch.einsum("bts,bcs->bct", weight, v)
230
+ return a.reshape(bs, -1, length)
231
+
232
+ class AttentionBlock(nn.Module):
233
+ def __init__(
234
+ self,
235
+ channels,
236
+ num_heads=1,
237
+ num_head_channels=-1,
238
+ use_checkpoint=False,
239
+ encoder_channels=None,
240
+ ):
241
+ super().__init__()
242
+ self.channels = channels
243
+ if num_head_channels == -1:
244
+ self.num_heads = num_heads
245
+ else:
246
+ assert channels % num_head_channels == 0,\
247
+ f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
248
+ self.num_heads = channels // num_head_channels
249
+
250
+ self.use_checkpoint = use_checkpoint
251
+ # self.norm = nn.BatchNorm2d(channels)
252
+ self.norm = normalization(channels, swish=0.0)
253
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
254
+
255
+ self.attention = QKVAttention(self.num_heads)
256
+
257
+ if encoder_channels is not None:
258
+ self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
259
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
260
+
261
+ def forward(self, x, encoder_out=None):
262
+ b, c, *spatial = x.shape
263
+ qkv = self.qkv(self.norm(x).view(b, c, -1))
264
+ if encoder_out is not None:
265
+ encoder_out = self.encoder_kv(encoder_out)
266
+ h = self.attention(qkv, encoder_out)
267
+ else:
268
+ h = self.attention(qkv)
269
+ # print("AttentionBlock, before proj_out, torch.unique(h).shape =", torch.unique(h).shape)
270
+ h = self.proj_out(h)
271
+ # print("AttentionBlock, after proj_out, torch.unique(h).shape =", torch.unique(h).shape)
272
+ return x + h.reshape(b, c, *spatial)
273
+
274
+ def timestep_embedding(timesteps, dim, max_period=10000):
275
+ """
276
+ Create sinusoidal timestep embeddings.
277
+
278
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
279
+ These may be fractional.
280
+ :param dim: the dimension of the output.
281
+ :param max_period: controls the minimum frequency of the embeddings.
282
+ :return: an [N x dim] Tensor of positional embeddings.
283
+ """
284
+ #print (timesteps.shape)
285
+ half = dim // 2
286
+ freqs = torch.exp(
287
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
288
+ ).to(device=timesteps.device)
289
+ #print (timesteps[:, None].float().shape,freqs[None].shape)
290
+ args = timesteps[:, None].float() * freqs[None]
291
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
292
+ if dim % 2:
293
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
294
+ return embedding
295
+
296
+ class ContextUnet(nn.Module):
297
+ def __init__(
298
+ self,
299
+ n_param=2,
300
+ image_size=64,
301
+ in_channels=1,
302
+ model_channels=128,
303
+ out_channels = 1,
304
+ channel_mult = None,
305
+ num_res_blocks = 2,
306
+ dropout = 0,
307
+ use_checkpoint = False,
308
+ use_scale_shift_norm = False,
309
+ attention_resolutions = (16, 8),
310
+ num_heads = 4,
311
+ num_head_channels = -1,
312
+ num_heads_upsample = -1,
313
+ resblock_updown = False,
314
+ conv_resample = True,
315
+ encoder_channels = None,
316
+ dim = 2,
317
+ stride = (2,2)
318
+ ):
319
+ super().__init__()
320
+
321
+ if channel_mult == None:
322
+ if image_size == 512:
323
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
324
+ elif image_size == 256:
325
+ channel_mult = (1, 1, 2, 2, 4, 4)
326
+ elif image_size == 128:
327
+ channel_mult = (1, 1, 2, 3, 4)
328
+ elif image_size == 64:
329
+ channel_mult = (1, 1, 2, 2, 4, 4)#(1, 2, 3, 4)
330
+ elif image_size == 28:
331
+ channel_mult = (1, 2)#(1, 2, 3, 4)
332
+ else:
333
+ raise ValueError(f"unsupported image size: {image_size}")
334
+ # else:
335
+ # channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
336
+
337
+ attention_ds = []
338
+ for res in attention_resolutions:
339
+ attention_ds.append(image_size // int(res))
340
+
341
+ # print("before, ContextUnet, num_heads_upsample =", num_heads_upsample, "num_heads =", num_heads)
342
+ if num_heads_upsample == -1:
343
+ num_heads_upsample = num_heads
344
+ # print("after, ContextUnet, num_heads_upsample =", num_heads_upsample, "num_heads =", num_heads)
345
+
346
+ # self.n_param = n_param
347
+ self.model_channels = model_channels
348
+ self.dtype = torch.float32
349
+
350
+ self.token_embedding = nn.Linear(n_param, model_channels * 4)
351
+
352
+ time_embed_dim = model_channels * 4
353
+ self.time_embed = nn.Sequential(
354
+ nn.Linear(model_channels, time_embed_dim),
355
+ nn.SiLU(),
356
+ nn.Linear(time_embed_dim, time_embed_dim),
357
+ )
358
+
359
+ ch = input_ch = int(channel_mult[0] * model_channels)
360
+
361
+ ###################### input_blocks ######################
362
+ self.input_blocks = nn.ModuleList(
363
+ [TimestepEmbedSequential(Conv[dim](in_channels, ch, 3, padding=1))]
364
+ )
365
+ self._feature_size = ch
366
+ input_block_chans = [ch]
367
+ ds = 1
368
+
369
+ for level, mult in enumerate(channel_mult):
370
+ for _ in range(num_res_blocks):
371
+ layers = [
372
+ ResBlock(
373
+ ch,
374
+ time_embed_dim,
375
+ dropout,
376
+ out_channels = int(mult * model_channels),
377
+ use_checkpoint = use_checkpoint,
378
+ use_scale_shift_norm = use_scale_shift_norm,
379
+ dim = dim,
380
+ stride = stride,
381
+ )
382
+ ]
383
+ ch = int(mult * model_channels)
384
+ if ds in attention_ds:
385
+ layers.append(
386
+ AttentionBlock(
387
+ ch,
388
+ use_checkpoint=use_checkpoint,
389
+ num_heads = num_heads,
390
+ num_head_channels = num_head_channels,
391
+ encoder_channels = encoder_channels,
392
+ )
393
+ )
394
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
395
+ self._feature_size += ch
396
+ input_block_chans.append(ch)
397
+
398
+ if level != len(channel_mult) - 1:
399
+ out_ch = ch
400
+ self.input_blocks.append(
401
+ TimestepEmbedSequential(
402
+ ResBlock(
403
+ ch,
404
+ time_embed_dim,
405
+ dropout,
406
+ out_channels=out_ch,
407
+ # dims=dims,
408
+ use_checkpoint=use_checkpoint,
409
+ use_scale_shift_norm=use_scale_shift_norm,
410
+ down=True,
411
+ dim = dim,
412
+ stride = stride,
413
+ )
414
+ if resblock_updown
415
+ else Downsample(ch, conv_resample, out_channels=out_ch, dim=dim, stride=stride)
416
+ )
417
+ )
418
+ ch = out_ch
419
+ input_block_chans.append(ch)
420
+ ds *= 2
421
+ self._feature_size += ch
422
+
423
+
424
+ ###################### middle_blocks ######################
425
+ self.middle_block = TimestepEmbedSequential(
426
+ ResBlock(
427
+ ch,
428
+ time_embed_dim,
429
+ dropout,
430
+ use_checkpoint=use_checkpoint,
431
+ use_scale_shift_norm=use_scale_shift_norm,
432
+ dim = dim,
433
+ stride = stride,
434
+ ),
435
+ AttentionBlock(
436
+ ch,
437
+ use_checkpoint=use_checkpoint,
438
+ num_heads=num_heads,
439
+ num_head_channels=num_head_channels,
440
+ encoder_channels=encoder_channels,
441
+ ),
442
+ ResBlock(
443
+ ch,
444
+ time_embed_dim,
445
+ dropout,
446
+ use_checkpoint=use_checkpoint,
447
+ use_scale_shift_norm=use_scale_shift_norm,
448
+ dim = dim,
449
+ stride = stride,
450
+ ),
451
+ )
452
+ self._feature_size += ch
453
+
454
+
455
+ ###################### output_blocks ######################
456
+ self.output_blocks = nn.ModuleList([])
457
+ for level, mult in list(enumerate(channel_mult))[::-1]:
458
+ for i in range(num_res_blocks + 1):
459
+ ich = input_block_chans.pop()
460
+ layers = [
461
+ ResBlock(
462
+ ch + ich,
463
+ time_embed_dim,
464
+ dropout,
465
+ out_channels=int(model_channels * mult),
466
+ # dims=dims,
467
+ use_checkpoint=use_checkpoint,
468
+ use_scale_shift_norm=use_scale_shift_norm,
469
+ dim = dim,
470
+ stride = stride,
471
+ )
472
+ ]
473
+ ch = int(model_channels * mult)
474
+ if ds in attention_ds:
475
+ # print("ds in attention_resolutions, num_heads=", num_heads_upsample)
476
+ layers.append(
477
+ AttentionBlock(
478
+ ch,
479
+ use_checkpoint=use_checkpoint,
480
+ num_heads=num_heads_upsample,
481
+ num_head_channels=num_head_channels,
482
+ encoder_channels=encoder_channels,
483
+ )
484
+ )
485
+ if level and i == num_res_blocks:
486
+ out_ch = ch
487
+ layers.append(
488
+ ResBlock(
489
+ ch,
490
+ time_embed_dim,
491
+ dropout,
492
+ out_channels=out_ch,
493
+ # dims=dims,
494
+ use_checkpoint=use_checkpoint,
495
+ use_scale_shift_norm=use_scale_shift_norm,
496
+ up=True,
497
+ dim = dim,
498
+ stride = stride,
499
+ )
500
+ if resblock_updown
501
+ else Upsample(ch, conv_resample, out_channels=out_ch, dim=dim, stride=stride)
502
+ )
503
+ ds //= 2
504
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
505
+ self._feature_size += ch
506
+
507
+ self.out = nn.Sequential(
508
+ # nn.BatchNorm2d(ch),
509
+ normalization(ch, swish=1.0),
510
+ nn.Identity(),
511
+ zero_module(Conv[dim](input_ch, out_channels, 3, padding=1)),
512
+ )
513
+ # self.use_fp16 = use_fp16
514
+
515
+ def forward(self, x, timesteps, y=None):
516
+ hs = []
517
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
518
+ if y != None:
519
+ text_outputs = self.token_embedding(y.float())
520
+ emb = emb + text_outputs.to(emb)
521
+
522
+ h = x.type(self.dtype)
523
+ # print("0,h.shape =", h.shape)
524
+ for module in self.input_blocks:
525
+ h = module(h, emb)
526
+ hs.append(h)
527
+ # print("module encoder, h.shape =", h.shape)
528
+ # print("2,h.shape =", h.shape)
529
+ h = self.middle_block(h, emb)
530
+ # print("middle block, h.shape =", h.shape)
531
+ # print("2,h.shape =", h.shape)
532
+ for module in self.output_blocks:
533
+ # print("for module in self.output_blocks, h.shape =", h.shape)
534
+ # print("len(hs) =", len(hs), ", hs[-1].shape =", hs[-1].shape)
535
+ h = torch.cat([h, hs.pop()], dim=1)
536
+ h = module(h, emb)
537
+ # print("module decoder, h.shape =", h.shape)
538
+
539
+ h = h.type(x.dtype)
540
+ h = self.out(h)
541
+ # print("self.out(h)", "h.shape =", h.shape)
542
+
543
+ return h
diffusion.ipynb CHANGED
@@ -33,7 +33,7 @@
33
  {
34
  "data": {
35
  "application/vnd.jupyter.widget-view+json": {
36
- "model_id": "4f2bbf6f5e904828bc65afc7ad97df36",
37
  "version_major": 2,
38
  "version_minor": 0
39
  },
@@ -81,7 +81,10 @@
81
  "from pathlib import Path\n",
82
  "from diffusers.optimization import get_cosine_schedule_with_warmup\n",
83
  "from accelerate import notebook_launcher, Accelerator\n",
84
- "from huggingface_hub import create_repo, upload_folder"
 
 
 
85
  ]
86
  },
87
  {
@@ -99,95 +102,95 @@
99
  "metadata": {},
100
  "outputs": [],
101
  "source": [
102
- "class Dataset4h5(Dataset):\n",
103
- " def __init__(self, dir_name, num_image=10, field='brightness_temp', shuffle=True, idx=None, num_redshift=32, HII_DIM=32, rescale=True, drop_prob = 0, dim=2):\n",
104
- " super().__init__()\n",
105
  " \n",
106
- " self.dir_name = dir_name\n",
107
- " self.num_image = num_image\n",
108
- " self.field = field\n",
109
- " self.shuffle = shuffle\n",
110
- " self.idx = idx\n",
111
- " self.num_redshift = num_redshift\n",
112
- " self.HII_DIM = HII_DIM\n",
113
- " self.drop_prob = drop_prob\n",
114
- " self.dim = dim\n",
115
- "\n",
116
- " self.load_h5()\n",
117
- " if rescale:\n",
118
- " self.images = self.rescale(self.images, to=[-1,1])\n",
119
- " self.params = self.rescale(self.params, to=[0,1])\n",
120
- "\n",
121
- " self.len = len(self.params)\n",
122
- " self.images = torch.from_numpy(self.images)\n",
123
- " print(f\"images rescaled to [{self.images.min()}, {self.images.max()}]\")\n",
124
- "\n",
125
- " cond_filter = torch.bernoulli(torch.ones(len(self.params),1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy()\n",
126
- " self.params = torch.from_numpy(self.params*cond_filter)\n",
127
- " print(f\"params rescaled to [{self.params.min()}, {self.params.max()}]\")\n",
128
- "\n",
129
- " def load_h5(self):\n",
130
- " with h5py.File(self.dir_name, 'r') as f:\n",
131
- " print(f\"dataset content: {f.keys()}\")\n",
132
- " max_num_image = len(f['brightness_temp'])#.shape[0]\n",
133
- " print(f\"{max_num_image} images can be loaded\")\n",
134
- " field_shape = f['brightness_temp'].shape[1:]\n",
135
- " print(f\"field.shape = {field_shape}\")\n",
136
- " self.params_keys = list(f['params']['keys'])\n",
137
- " print(f\"params keys = {self.params_keys}\")\n",
138
- "\n",
139
- " if self.idx is None:\n",
140
- " if self.shuffle:\n",
141
- " self.idx = np.sort(random.sample(range(max_num_image), self.num_image))\n",
142
- " print(f\"loading {self.num_image} images randomly\")\n",
143
- " # print(self.idx)\n",
144
- " else:\n",
145
- " self.idx = range(self.num_image)\n",
146
- " print(f\"loading {len(self.idx)} images with idx = {self.idx}\")\n",
147
- " else:\n",
148
- " print(f\"loading {len(self.idx)} images with idx = {self.idx}\")\n",
149
- "\n",
150
- " if self.dim == 2:\n",
151
- " self.images = f[self.field][self.idx,0,:self.HII_DIM,-self.num_redshift:][:,None]\n",
152
- " elif self.dim == 3:\n",
153
- " self.images = f[self.field][self.idx,:self.HII_DIM,:self.HII_DIM,-self.num_redshift:][:,None]\n",
154
- " print(f\"images loaded:\", self.images.shape)\n",
155
- "\n",
156
- " self.params = f['params']['values'][self.idx]\n",
157
- " print(\"params loaded:\", self.params.shape)\n",
158
  " \n",
159
- " # plt.imshow(self.images[0,0,0])\n",
160
- " # plt.show()\n",
161
- "\n",
162
- " def rescale(self, value, to: list):\n",
163
- " # print(np.ndim(value))\n",
164
- " if np.ndim(value)==2:\n",
165
- " # print(f\"rescale params of shape {value.shape}\")\n",
166
- " ranges = \\\n",
167
- " {\n",
168
- " 0: [4, 6], # ION_Tvir_MIN\n",
169
- " 1: [10, 250], # HII_EFF_FACTOR\n",
170
- " }\n",
171
- " # elif np.ndim(value)==5: \n",
172
- " else: \n",
173
- " # value = np.array(value)\n",
174
- " # print(f\"rescale images of shape {np.shape(value)}\")\n",
175
- " ranges = \\\n",
176
- " {\n",
177
- " 0: [0, 80], # brightness_temp\n",
178
- " }\n",
179
- " # print(f\"value.min = {value.min()}, value.max = {value.max()}\")\n",
180
- " for i in range(np.shape(value)[1]):\n",
181
- " value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0])\n",
182
- " # print(f\"value.min = {value.min()}, value.max = {value.max()}\")\n",
183
- " value = value * (to[1]-to[0]) + to[0]\n",
184
- " return value \n",
185
- "\n",
186
- " def __getitem__(self, index):\n",
187
- " return self.images[index], self.params[index]\n",
188
- "\n",
189
- " def __len__(self):\n",
190
- " return self.len"
191
  ]
192
  },
193
  {
@@ -346,596 +349,526 @@
346
  "metadata": {},
347
  "outputs": [],
348
  "source": [
349
- "class GroupNorm32(nn.GroupNorm):\n",
350
- " def __init__(self, num_groups, num_channels, swish, eps=1e-5):\n",
351
- " super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)\n",
352
- " self.swish = swish\n",
353
- "\n",
354
- " def forward(self, x):\n",
355
- " y = super().forward(x.float()).to(x.dtype)\n",
356
- " if self.swish == 1.0:\n",
357
- " y = F.silu(y)\n",
358
- " elif self.swish:\n",
359
- " y = y * F.sigmoid(y * float(self.swish))\n",
360
- " return y\n",
361
- "\n",
362
- "def normalization(channels, swish=0.0):\n",
363
- " \"\"\"\n",
364
- " Make a standard normalization layer, with an optional swish activation.\n",
365
- "\n",
366
- " :param channels: number of input channels.\n",
367
- " :return: an nn.Module for normalization.\n",
368
- " \"\"\"\n",
369
- " #print (channels)\n",
370
- " return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)\n",
371
- "\n",
372
- "Conv = {\n",
373
- " 1: nn.Conv1d,\n",
374
- " 2: nn.Conv2d,\n",
375
- " 3: nn.Conv3d,\n",
376
- "}\n",
377
- "\n",
378
- "AvgPool = {\n",
379
- " 1: nn.AvgPool1d,\n",
380
- " 2: nn.AvgPool2d,\n",
381
- " 3: nn.AvgPool3d\n",
382
- "}"
383
- ]
384
- },
385
- {
386
- "cell_type": "code",
387
- "execution_count": 9,
388
- "metadata": {},
389
- "outputs": [],
390
- "source": [
391
- "class Downsample(nn.Module):\n",
392
- " def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2)):\n",
393
- " super().__init__()\n",
394
- " self.channels = channels\n",
395
- " self.out_channels = out_channels or channels\n",
396
- " # stride = config.stride\n",
397
- " if use_conv:\n",
398
- " # print(\"conv\")\n",
399
- " self.op = Conv[dim](channels, self.out_channels, 3, stride=stride, padding=1)\n",
400
- " else:\n",
401
- " # print(\"pool\")\n",
402
- " assert channels == self.out_channels\n",
403
- " self.op = AvgPool[dim](kernel_size=stride, stride=stride)\n",
404
- "\n",
405
- " def forward(self, x):\n",
406
- " assert x.shape[1] == self.channels\n",
407
- " return self.op(x)"
408
- ]
409
- },
410
- {
411
- "cell_type": "code",
412
- "execution_count": 10,
413
- "metadata": {},
414
- "outputs": [],
415
- "source": [
416
- "class Upsample(nn.Module):\n",
417
- " def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2)):\n",
418
- " super().__init__()\n",
419
- " self.channels = channels\n",
420
- " self.out_channels = out_channels\n",
421
- " self.use_conv = use_conv\n",
422
- " self.stride = stride\n",
423
- " if self.use_conv:\n",
424
- " self.conv = Conv[dim](self.channels, self.out_channels, 3, padding=1)\n",
425
- "\n",
426
- " def forward(self, x):\n",
427
- " assert x.shape[1] == self.channels\n",
428
- " # stride = config.stride\n",
429
- " # print(torch.tensor(x.shape[2:]))\n",
430
- " # print(torch.tensor(stride))\n",
431
- " shape = torch.tensor(x.shape[2:]) * torch.tensor(self.stride)\n",
432
- " shape = tuple(shape.detach().numpy())\n",
433
- " # print(shape)\n",
434
- " x = F.interpolate(x, shape, mode='nearest')\n",
435
- " if self.use_conv:\n",
436
- " x = self.conv(x)\n",
437
- " return x"
438
- ]
439
- },
440
- {
441
- "cell_type": "code",
442
- "execution_count": 11,
443
- "metadata": {},
444
- "outputs": [],
445
- "source": [
446
- "def zero_module(module):\n",
447
- " \"\"\"\n",
448
- " clean gradient of parameters of the module\n",
449
- " \"\"\"\n",
450
- " for p in module.parameters():\n",
451
- " p.detach().zero_()\n",
452
- " return module"
453
- ]
454
- },
455
- {
456
- "cell_type": "code",
457
- "execution_count": 12,
458
- "metadata": {},
459
- "outputs": [],
460
- "source": [
461
- "class TimestepBlock(ABC, nn.Module):\n",
462
- " @abstractmethod\n",
463
- " def forward(self, x, emb):\n",
464
- " \"\"\"\n",
465
- " test\n",
466
- " \"\"\""
467
- ]
468
- },
469
- {
470
- "cell_type": "code",
471
- "execution_count": 13,
472
- "metadata": {},
473
- "outputs": [],
474
- "source": [
475
- "class TimestepEmbedSequential(nn.Sequential, TimestepBlock):\n",
476
- " def forward(self, x, emb, encoder_out=None):\n",
477
- " for layer in self:\n",
478
- " if isinstance(layer, TimestepBlock):\n",
479
- " x = layer(x, emb)\n",
480
- " elif isinstance(layer, AttentionBlock):\n",
481
- " x = layer(x, encoder_out)\n",
482
- " else:\n",
483
- " x = layer(x)\n",
484
- " return x"
485
- ]
486
- },
487
- {
488
- "cell_type": "code",
489
- "execution_count": 14,
490
- "metadata": {},
491
- "outputs": [],
492
- "source": [
493
- "class ResBlock(TimestepBlock):\n",
494
- " def __init__(\n",
495
- " self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_checkpoint=False, use_scale_shift_norm=False, up=False, down=False, dim=2, stride=(2,2),\n",
496
- " ):\n",
497
- " super().__init__()\n",
498
- " self.out_channels = out_channels or channels\n",
499
- " self.use_scale_shift_norm = use_scale_shift_norm\n",
500
- " self.stride = stride\n",
501
- "\n",
502
- " self.in_layers = nn.Sequential(\n",
503
- " # nn.BatchNorm2d(channels), # normalize to standard gaussian\n",
504
- " normalization(channels, swish=1.0),\n",
505
- " nn.Identity(),\n",
506
- " Conv[dim](channels, self.out_channels, 3, padding=1),\n",
507
- " )\n",
508
- "\n",
509
- " self.updown = up or down\n",
510
- " if up:\n",
511
- " self.h_updown = Upsample(channels, False, dim=dim, stride=stride)\n",
512
- " self.x_updown = Upsample(channels, False, dim=dim, stride=stride)\n",
513
- " elif down:\n",
514
- " self.h_updown = Downsample(channels, False, dim=dim, stride=stride)\n",
515
- " self.x_updown = Downsample(channels, False, dim=dim, stride=stride)\n",
516
- " else:\n",
517
- " self.h_updown = self.x_updown = nn.Identity()\n",
518
- "\n",
519
- " self.emb_layers = nn.Sequential(\n",
520
- " nn.SiLU(),\n",
521
- " nn.Linear(\n",
522
- " emb_channels,\n",
523
- " 2 * self.out_channels if use_scale_shift_norm else self.out_channels,\n",
524
- " ),\n",
525
- " )\n",
526
- "\n",
527
- " self.out_layers = nn.Sequential(\n",
528
- " # nn.BatchNorm2d(self.out_channels),\n",
529
- " normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),\n",
530
- " nn.SiLU() if use_scale_shift_norm else nn.Identity(),\n",
531
- " nn.Dropout(p=dropout),\n",
532
- " zero_module(Conv[dim](self.out_channels, self.out_channels, 3, padding=1)),\n",
533
- " )\n",
534
  "\n",
535
- " if self.out_channels == channels:\n",
536
- " self.skip_connection = nn.Identity()\n",
537
- " elif use_conv:\n",
538
- " self.skip_connection = Conv[dim](channels, self.out_channels, 3, padding=1)\n",
539
- " else:\n",
540
- " self.skip_connection = Conv[dim](channels, self.out_channels, 1)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  " \n",
542
  "\n",
543
- " def forward(self, x, emb):\n",
544
- " if self.updown:\n",
545
- " in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]\n",
546
- " h = in_rest(x)\n",
547
- " h = self.h_updown(h)\n",
548
- " x = self.x_updown(x)\n",
549
- " h = in_conv(h)\n",
550
- " else:\n",
551
- " h = self.in_layers(x)\n",
552
- " emb_out = self.emb_layers(emb).type(h.dtype)\n",
553
  "\n",
554
- " while len(emb_out.shape) < len(h.shape):\n",
555
- " emb_out = emb_out[..., None]\n",
556
  "\n",
557
- " if self.use_scale_shift_norm:\n",
558
- " out_norm, out_rest = self.out_layers[0], self.out_layers[1:]\n",
559
- " scale, shift = torch.chunk(emb_out, 2, dim=1)\n",
560
- " h = out_norm(h) * (1+scale) + shift\n",
561
- " h = out_rest(h)\n",
562
- " else:\n",
563
- " h += emb_out\n",
564
- " h = self.out_layers(h)\n",
565
- " # print(\"ResBlock, torch.unique(h).shape =\", torch.unique(h).shape)\n",
566
- " return self.skip_connection(x) + h"
567
- ]
568
- },
569
- {
570
- "cell_type": "code",
571
- "execution_count": 15,
572
- "metadata": {},
573
- "outputs": [],
574
- "source": [
575
- "class QKVAttention(nn.Module):\n",
576
- " def __init__(self, n_heads):\n",
577
- " super().__init__()\n",
578
- " self.n_heads = n_heads\n",
579
- " # print(\"QKVAttention, self.n_heads =\", self.n_heads)\n",
580
  " \n",
581
- " def forward(self, qkv, encoder_kv=None):\n",
582
- " bs, width, length = qkv.shape\n",
583
- " assert width % (3*self.n_heads) == 0\n",
584
- " ch = width // (3*self.n_heads)\n",
585
- "\n",
586
- " # print(\"QKVAttention\", bs, self.n_heads, ch, length)\n",
587
- " q, k, v = qkv.reshape(bs*self.n_heads, ch*3, length).split(ch, dim=1)\n",
588
- " if encoder_kv is not None:\n",
589
- " assert encoder_kv.shape[1] == self.n_heads * ch * 2\n",
590
- " ek, ev = encoder_kv.reshape(bs*self.n_heads, ch*2, -1).split(ch, dim=1)\n",
591
- " k = torch.cat([ek,k], dim=-1)\n",
592
- " v = torch.cat([ev,v], dim=-1)\n",
593
- "\n",
594
- " scale = 1 / math.sqrt(math.sqrt(ch))\n",
595
- " weight = torch.einsum(\"bct,bcs->bts\", q*scale, k*scale)\n",
596
- " weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n",
597
- "\n",
598
- " a = torch.einsum(\"bts,bcs->bct\", weight, v)\n",
599
- " return a.reshape(bs, -1, length)"
600
- ]
601
- },
602
- {
603
- "cell_type": "code",
604
- "execution_count": 16,
605
- "metadata": {},
606
- "outputs": [],
607
- "source": [
608
- "class AttentionBlock(nn.Module):\n",
609
- " def __init__(\n",
610
- " self,\n",
611
- " channels,\n",
612
- " num_heads=1,\n",
613
- " num_head_channels=-1,\n",
614
- " use_checkpoint=False,\n",
615
- " encoder_channels=None,\n",
616
- " ):\n",
617
- " super().__init__()\n",
618
- " self.channels = channels\n",
619
- " if num_head_channels == -1:\n",
620
- " self.num_heads = num_heads\n",
621
- " else:\n",
622
- " assert channels % num_head_channels == 0,\\\n",
623
- " f\"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}\"\n",
624
- " self.num_heads = channels // num_head_channels\n",
625
- "\n",
626
- " self.use_checkpoint = use_checkpoint\n",
627
- " # self.norm = nn.BatchNorm2d(channels)\n",
628
- " self.norm = normalization(channels, swish=0.0)\n",
629
- " self.qkv = nn.Conv1d(channels, channels * 3, 1)\n",
630
  " \n",
631
- " self.attention = QKVAttention(self.num_heads)\n",
632
- "\n",
633
- " if encoder_channels is not None:\n",
634
- " self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)\n",
635
- " self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))\n",
636
- "\n",
637
- " def forward(self, x, encoder_out=None):\n",
638
- " b, c, *spatial = x.shape\n",
639
- " qkv = self.qkv(self.norm(x).view(b, c, -1))\n",
640
- " if encoder_out is not None:\n",
641
- " encoder_out = self.encoder_kv(encoder_out)\n",
642
- " h = self.attention(qkv, encoder_out)\n",
643
- " else:\n",
644
- " h = self.attention(qkv)\n",
645
- " # print(\"AttentionBlock, before proj_out, torch.unique(h).shape =\", torch.unique(h).shape)\n",
646
- " h = self.proj_out(h)\n",
647
- " # print(\"AttentionBlock, after proj_out, torch.unique(h).shape =\", torch.unique(h).shape)\n",
648
- " return x + h.reshape(b, c, *spatial)"
649
- ]
650
- },
651
- {
652
- "cell_type": "code",
653
- "execution_count": 17,
654
- "metadata": {},
655
- "outputs": [],
656
- "source": [
657
- "def timestep_embedding(timesteps, dim, max_period=10000):\n",
658
- " \"\"\"\n",
659
- " Create sinusoidal timestep embeddings.\n",
660
- "\n",
661
- " :param timesteps: a 1-D Tensor of N indices, one per batch element.\n",
662
- " These may be fractional.\n",
663
- " :param dim: the dimension of the output.\n",
664
- " :param max_period: controls the minimum frequency of the embeddings.\n",
665
- " :return: an [N x dim] Tensor of positional embeddings.\n",
666
- " \"\"\"\n",
667
- " #print (timesteps.shape)\n",
668
- " half = dim // 2\n",
669
- " freqs = torch.exp(\n",
670
- " -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n",
671
- " ).to(device=timesteps.device)\n",
672
- " #print (timesteps[:, None].float().shape,freqs[None].shape)\n",
673
- " args = timesteps[:, None].float() * freqs[None]\n",
674
- " embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n",
675
- " if dim % 2:\n",
676
- " embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n",
677
- " return embedding"
678
- ]
679
- },
680
- {
681
- "cell_type": "code",
682
- "execution_count": 18,
683
- "metadata": {},
684
- "outputs": [],
685
- "source": [
686
- "class ContextUnet(nn.Module):\n",
687
- " def __init__(\n",
688
- " self,\n",
689
- " n_param=2,\n",
690
- " image_size=64,\n",
691
- " in_channels=1,\n",
692
- " model_channels=128,\n",
693
- " out_channels = 1,\n",
694
- " channel_mult = None,\n",
695
- " num_res_blocks = 2,\n",
696
- " dropout = 0,\n",
697
- " use_checkpoint = False,\n",
698
- " use_scale_shift_norm = False,\n",
699
- " attention_resolutions = (16, 8),\n",
700
- " num_heads = 4,\n",
701
- " num_head_channels = -1,\n",
702
- " num_heads_upsample = -1,\n",
703
- " resblock_updown = False,\n",
704
- " conv_resample = True,\n",
705
- " encoder_channels = None,\n",
706
- " dim = 2,\n",
707
- " stride = (2,2)\n",
708
- " ):\n",
709
- " super().__init__()\n",
710
- "\n",
711
- " if channel_mult == None:\n",
712
- " if image_size == 512:\n",
713
- " channel_mult = (0.5, 1, 1, 2, 2, 4, 4)\n",
714
- " elif image_size == 256:\n",
715
- " channel_mult = (1, 1, 2, 2, 4, 4)\n",
716
- " elif image_size == 128:\n",
717
- " channel_mult = (1, 1, 2, 3, 4)\n",
718
- " elif image_size == 64:\n",
719
- " channel_mult = (1, 1, 2, 2, 4, 4)#(1, 2, 3, 4)\n",
720
- " elif image_size == 28:\n",
721
- " channel_mult = (1, 2)#(1, 2, 3, 4)\n",
722
- " else:\n",
723
- " raise ValueError(f\"unsupported image size: {image_size}\")\n",
724
- " # else:\n",
725
- " # channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(\",\"))\n",
726
  " \n",
727
- " attention_ds = []\n",
728
- " for res in attention_resolutions:\n",
729
- " attention_ds.append(image_size // int(res))\n",
730
- "\n",
731
- " # print(\"before, ContextUnet, num_heads_upsample =\", num_heads_upsample, \"num_heads =\", num_heads)\n",
732
- " if num_heads_upsample == -1:\n",
733
- " num_heads_upsample = num_heads\n",
734
- " # print(\"after, ContextUnet, num_heads_upsample =\", num_heads_upsample, \"num_heads =\", num_heads)\n",
735
- "\n",
736
- " # self.n_param = n_param\n",
737
- " self.model_channels = model_channels\n",
738
- " self.dtype = torch.float32\n",
739
- "\n",
740
- " self.token_embedding = nn.Linear(n_param, model_channels * 4)\n",
741
- "\n",
742
- " time_embed_dim = model_channels * 4\n",
743
- " self.time_embed = nn.Sequential(\n",
744
- " nn.Linear(model_channels, time_embed_dim),\n",
745
- " nn.SiLU(),\n",
746
- " nn.Linear(time_embed_dim, time_embed_dim),\n",
747
- " )\n",
748
- "\n",
749
- " ch = input_ch = int(channel_mult[0] * model_channels)\n",
750
- "\n",
751
- " ###################### input_blocks ######################\n",
752
- " self.input_blocks = nn.ModuleList(\n",
753
- " [TimestepEmbedSequential(Conv[dim](in_channels, ch, 3, padding=1))]\n",
754
- " )\n",
755
- " self._feature_size = ch\n",
756
- " input_block_chans = [ch]\n",
757
- " ds = 1\n",
758
- "\n",
759
- " for level, mult in enumerate(channel_mult):\n",
760
- " for _ in range(num_res_blocks):\n",
761
- " layers = [\n",
762
- " ResBlock(\n",
763
- " ch,\n",
764
- " time_embed_dim,\n",
765
- " dropout,\n",
766
- " out_channels = int(mult * model_channels),\n",
767
- " use_checkpoint = use_checkpoint,\n",
768
- " use_scale_shift_norm = use_scale_shift_norm,\n",
769
- " dim = dim,\n",
770
- " stride = stride,\n",
771
- " )\n",
772
- " ]\n",
773
- " ch = int(mult * model_channels)\n",
774
- " if ds in attention_ds:\n",
775
- " layers.append(\n",
776
- " AttentionBlock(\n",
777
- " ch,\n",
778
- " use_checkpoint=use_checkpoint,\n",
779
- " num_heads = num_heads,\n",
780
- " num_head_channels = num_head_channels,\n",
781
- " encoder_channels = encoder_channels,\n",
782
- " )\n",
783
- " )\n",
784
- " self.input_blocks.append(TimestepEmbedSequential(*layers))\n",
785
- " self._feature_size += ch\n",
786
- " input_block_chans.append(ch)\n",
787
- "\n",
788
- " if level != len(channel_mult) - 1:\n",
789
- " out_ch = ch\n",
790
- " self.input_blocks.append(\n",
791
- " TimestepEmbedSequential(\n",
792
- " ResBlock(\n",
793
- " ch,\n",
794
- " time_embed_dim,\n",
795
- " dropout,\n",
796
- " out_channels=out_ch,\n",
797
- " # dims=dims,\n",
798
- " use_checkpoint=use_checkpoint,\n",
799
- " use_scale_shift_norm=use_scale_shift_norm,\n",
800
- " down=True,\n",
801
- " dim = dim,\n",
802
- " stride = stride,\n",
803
- " )\n",
804
- " if resblock_updown\n",
805
- " else Downsample(ch, conv_resample, out_channels=out_ch, dim=dim, stride=stride)\n",
806
- " )\n",
807
- " )\n",
808
- " ch = out_ch\n",
809
- " input_block_chans.append(ch)\n",
810
- " ds *= 2\n",
811
- " self._feature_size += ch\n",
812
- "\n",
813
- "\n",
814
- " ###################### middle_blocks ######################\n",
815
- " self.middle_block = TimestepEmbedSequential(\n",
816
- " ResBlock(\n",
817
- " ch,\n",
818
- " time_embed_dim,\n",
819
- " dropout,\n",
820
- " use_checkpoint=use_checkpoint,\n",
821
- " use_scale_shift_norm=use_scale_shift_norm,\n",
822
- " dim = dim,\n",
823
- " stride = stride,\n",
824
- " ),\n",
825
- " AttentionBlock(\n",
826
- " ch,\n",
827
- " use_checkpoint=use_checkpoint,\n",
828
- " num_heads=num_heads,\n",
829
- " num_head_channels=num_head_channels,\n",
830
- " encoder_channels=encoder_channels,\n",
831
- " ),\n",
832
- " ResBlock(\n",
833
- " ch,\n",
834
- " time_embed_dim,\n",
835
- " dropout,\n",
836
- " use_checkpoint=use_checkpoint,\n",
837
- " use_scale_shift_norm=use_scale_shift_norm,\n",
838
- " dim = dim,\n",
839
- " stride = stride,\n",
840
- " ),\n",
841
- " )\n",
842
- " self._feature_size += ch\n",
843
- "\n",
844
- "\n",
845
- " ###################### output_blocks ######################\n",
846
- " self.output_blocks = nn.ModuleList([])\n",
847
- " for level, mult in list(enumerate(channel_mult))[::-1]:\n",
848
- " for i in range(num_res_blocks + 1):\n",
849
- " ich = input_block_chans.pop()\n",
850
- " layers = [\n",
851
- " ResBlock(\n",
852
- " ch + ich,\n",
853
- " time_embed_dim,\n",
854
- " dropout,\n",
855
- " out_channels=int(model_channels * mult),\n",
856
- " # dims=dims,\n",
857
- " use_checkpoint=use_checkpoint,\n",
858
- " use_scale_shift_norm=use_scale_shift_norm,\n",
859
- " dim = dim,\n",
860
- " stride = stride,\n",
861
- " )\n",
862
- " ]\n",
863
- " ch = int(model_channels * mult)\n",
864
- " if ds in attention_ds:\n",
865
- " # print(\"ds in attention_resolutions, num_heads=\", num_heads_upsample)\n",
866
- " layers.append(\n",
867
- " AttentionBlock(\n",
868
- " ch,\n",
869
- " use_checkpoint=use_checkpoint,\n",
870
- " num_heads=num_heads_upsample,\n",
871
- " num_head_channels=num_head_channels,\n",
872
- " encoder_channels=encoder_channels,\n",
873
- " )\n",
874
- " )\n",
875
- " if level and i == num_res_blocks:\n",
876
- " out_ch = ch\n",
877
- " layers.append(\n",
878
- " ResBlock(\n",
879
- " ch,\n",
880
- " time_embed_dim,\n",
881
- " dropout,\n",
882
- " out_channels=out_ch,\n",
883
- " # dims=dims,\n",
884
- " use_checkpoint=use_checkpoint,\n",
885
- " use_scale_shift_norm=use_scale_shift_norm,\n",
886
- " up=True,\n",
887
- " dim = dim,\n",
888
- " stride = stride,\n",
889
- " )\n",
890
- " if resblock_updown\n",
891
- " else Upsample(ch, conv_resample, out_channels=out_ch, dim=dim, stride=stride)\n",
892
- " )\n",
893
- " ds //= 2\n",
894
- " self.output_blocks.append(TimestepEmbedSequential(*layers))\n",
895
- " self._feature_size += ch\n",
896
- "\n",
897
- " self.out = nn.Sequential(\n",
898
- " # nn.BatchNorm2d(ch),\n",
899
- " normalization(ch, swish=1.0),\n",
900
- " nn.Identity(),\n",
901
- " zero_module(Conv[dim](input_ch, out_channels, 3, padding=1)),\n",
902
- " )\n",
903
- " # self.use_fp16 = use_fp16\n",
904
- "\n",
905
- " def forward(self, x, timesteps, y=None):\n",
906
- " hs = []\n",
907
- " emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))\n",
908
- " if y != None:\n",
909
- " text_outputs = self.token_embedding(y.float())\n",
910
- " emb = emb + text_outputs.to(emb)\n",
911
- "\n",
912
- " h = x.type(self.dtype)\n",
913
- " # print(\"0,h.shape =\", h.shape)\n",
914
- " for module in self.input_blocks:\n",
915
- " h = module(h, emb)\n",
916
- " hs.append(h)\n",
917
- " # print(\"module encoder, h.shape =\", h.shape)\n",
918
- " # print(\"2,h.shape =\", h.shape)\n",
919
- " h = self.middle_block(h, emb)\n",
920
- " # print(\"middle block, h.shape =\", h.shape)\n",
921
- " # print(\"2,h.shape =\", h.shape)\n",
922
- " for module in self.output_blocks:\n",
923
- " # print(\"for module in self.output_blocks, h.shape =\", h.shape)\n",
924
- " # print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\n",
925
- " h = torch.cat([h, hs.pop()], dim=1)\n",
926
- " h = module(h, emb)\n",
927
- " # print(\"module decoder, h.shape =\", h.shape)\n",
928
- "\n",
929
- " h = h.type(x.dtype)\n",
930
- " h = self.out(h)\n",
931
- " # print(\"self.out(h)\", \"h.shape =\", h.shape)\n",
932
- "\n",
933
- " return h "
934
  ]
935
  },
936
  {
937
  "cell_type": "code",
938
- "execution_count": 19,
939
  "metadata": {},
940
  "outputs": [],
941
  "source": [
@@ -960,12 +893,13 @@
960
  " self.step += 1\n",
961
  "\n",
962
  " def reset_parameters(self, ema_model, model):\n",
963
- " ema_model.load_state_dict(model.state_dict())"
 
964
  ]
965
  },
966
  {
967
  "cell_type": "code",
968
- "execution_count": 20,
969
  "metadata": {},
970
  "outputs": [],
971
  "source": [
@@ -1031,7 +965,7 @@
1031
  },
1032
  {
1033
  "cell_type": "code",
1034
- "execution_count": 21,
1035
  "metadata": {},
1036
  "outputs": [],
1037
  "source": [
@@ -1041,7 +975,7 @@
1041
  },
1042
  {
1043
  "cell_type": "code",
1044
- "execution_count": 22,
1045
  "metadata": {},
1046
  "outputs": [],
1047
  "source": [
@@ -1050,7 +984,7 @@
1050
  },
1051
  {
1052
  "cell_type": "code",
1053
- "execution_count": 23,
1054
  "metadata": {},
1055
  "outputs": [],
1056
  "source": [
@@ -1074,7 +1008,7 @@
1074
  },
1075
  {
1076
  "cell_type": "code",
1077
- "execution_count": 24,
1078
  "metadata": {},
1079
  "outputs": [],
1080
  "source": [
@@ -1272,7 +1206,7 @@
1272
  },
1273
  {
1274
  "cell_type": "code",
1275
- "execution_count": 25,
1276
  "metadata": {},
1277
  "outputs": [
1278
  {
@@ -1482,7 +1416,7 @@
1482
  },
1483
  {
1484
  "cell_type": "code",
1485
- "execution_count": 26,
1486
  "metadata": {},
1487
  "outputs": [
1488
  {
@@ -1509,14 +1443,14 @@
1509
  "output_type": "stream",
1510
  "text": [
1511
  "params loaded: (200, 2)\n",
1512
- "images rescaled to [-1.0, 1.064338207244873]\n",
1513
- "params rescaled to [0.0, 0.9988593502151616]\n"
1514
  ]
1515
  },
1516
  {
1517
  "data": {
1518
  "application/vnd.jupyter.widget-view+json": {
1519
- "model_id": "2e0b629831714bc2b32e25d44a72f4b3",
1520
  "version_major": 2,
1521
  "version_minor": 0
1522
  },
@@ -1530,7 +1464,7 @@
1530
  {
1531
  "data": {
1532
  "application/vnd.jupyter.widget-view+json": {
1533
- "model_id": "c634a180ede04f3cb09ab74daf0401c6",
1534
  "version_major": 2,
1535
  "version_minor": 0
1536
  },
@@ -1544,7 +1478,7 @@
1544
  {
1545
  "data": {
1546
  "application/vnd.jupyter.widget-view+json": {
1547
- "model_id": "6f3a0791c42b4d7e958f2a9d57f64de8",
1548
  "version_major": 2,
1549
  "version_minor": 0
1550
  },
@@ -1558,7 +1492,7 @@
1558
  {
1559
  "data": {
1560
  "application/vnd.jupyter.widget-view+json": {
1561
- "model_id": "9dce2de3e8a14aee83e2b182dc06608f",
1562
  "version_major": 2,
1563
  "version_minor": 0
1564
  },
@@ -1572,7 +1506,7 @@
1572
  {
1573
  "data": {
1574
  "application/vnd.jupyter.widget-view+json": {
1575
- "model_id": "d4596bdc71cc4d4cb780442b97849883",
1576
  "version_major": 2,
1577
  "version_minor": 0
1578
  },
@@ -1586,7 +1520,7 @@
1586
  {
1587
  "data": {
1588
  "application/vnd.jupyter.widget-view+json": {
1589
- "model_id": "6e68847216504241b81ebcb71c48f687",
1590
  "version_major": 2,
1591
  "version_minor": 0
1592
  },
@@ -1600,7 +1534,7 @@
1600
  {
1601
  "data": {
1602
  "application/vnd.jupyter.widget-view+json": {
1603
- "model_id": "830c25eb902a47e7997dcdb40099c5a4",
1604
  "version_major": 2,
1605
  "version_minor": 0
1606
  },
@@ -1614,7 +1548,7 @@
1614
  {
1615
  "data": {
1616
  "application/vnd.jupyter.widget-view+json": {
1617
- "model_id": "87fdac7b595c4d0ea7258ee8bb35de17",
1618
  "version_major": 2,
1619
  "version_minor": 0
1620
  },
@@ -1628,7 +1562,7 @@
1628
  {
1629
  "data": {
1630
  "application/vnd.jupyter.widget-view+json": {
1631
- "model_id": "b9f6be95f4bd403d85f6df34756e7b8d",
1632
  "version_major": 2,
1633
  "version_minor": 0
1634
  },
@@ -1642,7 +1576,7 @@
1642
  {
1643
  "data": {
1644
  "application/vnd.jupyter.widget-view+json": {
1645
- "model_id": "28ec5d881b37440ba5f4c863fc552c17",
1646
  "version_major": 2,
1647
  "version_minor": 0
1648
  },
@@ -1660,7 +1594,7 @@
1660
  },
1661
  {
1662
  "cell_type": "code",
1663
- "execution_count": null,
1664
  "metadata": {},
1665
  "outputs": [
1666
  {
@@ -1678,7 +1612,7 @@
1678
  {
1679
  "data": {
1680
  "application/vnd.jupyter.widget-view+json": {
1681
- "model_id": "58944c3b1e4f42bb8771f776c35a90a7",
1682
  "version_major": 2,
1683
  "version_minor": 0
1684
  },
@@ -1688,28 +1622,6 @@
1688
  },
1689
  "metadata": {},
1690
  "output_type": "display_data"
1691
- },
1692
- {
1693
- "ename": "RuntimeError",
1694
- "evalue": "CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.64 GiB total capacity; 21.65 GiB already allocated; 432.50 MiB free; 22.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
1695
- "output_type": "error",
1696
- "traceback": [
1697
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1698
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1699
- "Cell \u001b[0;32mIn[26], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1700
- "Cell \u001b[0;32mIn[25], line 177\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 171\u001b[0m nn_model\u001b[39m.\u001b[39meval()\n\u001b[1;32m 173\u001b[0m \u001b[39m# self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\u001b[39;00m\n\u001b[1;32m 174\u001b[0m \u001b[39m# self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\u001b[39;00m\n\u001b[1;32m 175\u001b[0m \u001b[39m# print(f\"resumed ema_model from {config.resume}\")\u001b[39;00m\n\u001b[0;32m--> 177\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 178\u001b[0m nn_model\u001b[39m=\u001b[39;49mnn_model, \n\u001b[1;32m 179\u001b[0m params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), \n\u001b[1;32m 180\u001b[0m device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, \n\u001b[1;32m 181\u001b[0m guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w\n\u001b[1;32m 182\u001b[0m )\n\u001b[1;32m 184\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 186\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1701
- "Cell \u001b[0;32mIn[7], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 71\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39m# print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
1702
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1703
- "Cell \u001b[0;32mIn[18], line 241\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_blocks:\n\u001b[1;32m 238\u001b[0m \u001b[39m# print(\"for module in self.output_blocks, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 239\u001b[0m \u001b[39m# print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\u001b[39;00m\n\u001b[1;32m 240\u001b[0m h \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([h, hs\u001b[39m.\u001b[39mpop()], dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m--> 241\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 242\u001b[0m \u001b[39m# print(\"module decoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 244\u001b[0m h \u001b[39m=\u001b[39m h\u001b[39m.\u001b[39mtype(x\u001b[39m.\u001b[39mdtype)\n",
1704
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1705
- "Cell \u001b[0;32mIn[13], line 7\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 5\u001b[0m x \u001b[39m=\u001b[39m layer(x, emb)\n\u001b[1;32m 6\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, AttentionBlock):\n\u001b[0;32m----> 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n\u001b[1;32m 8\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 9\u001b[0m x \u001b[39m=\u001b[39m layer(x)\n",
1706
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1707
- "Cell \u001b[0;32mIn[16], line 37\u001b[0m, in \u001b[0;36mAttentionBlock.forward\u001b[0;34m(self, x, encoder_out)\u001b[0m\n\u001b[1;32m 35\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mattention(qkv, encoder_out)\n\u001b[1;32m 36\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 37\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mattention(qkv)\n\u001b[1;32m 38\u001b[0m \u001b[39m# print(\"AttentionBlock, before proj_out, torch.unique(h).shape =\", torch.unique(h).shape)\u001b[39;00m\n\u001b[1;32m 39\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mproj_out(h)\n",
1708
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1709
- "Cell \u001b[0;32mIn[15], line 21\u001b[0m, in \u001b[0;36mQKVAttention.forward\u001b[0;34m(self, qkv, encoder_kv)\u001b[0m\n\u001b[1;32m 18\u001b[0m v \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([ev,v], dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m 20\u001b[0m scale \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m \u001b[39m/\u001b[39m math\u001b[39m.\u001b[39msqrt(math\u001b[39m.\u001b[39msqrt(ch))\n\u001b[0;32m---> 21\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49meinsum(\u001b[39m\"\u001b[39;49m\u001b[39mbct,bcs->bts\u001b[39;49m\u001b[39m\"\u001b[39;49m, q\u001b[39m*\u001b[39;49mscale, k\u001b[39m*\u001b[39;49mscale)\n\u001b[1;32m 22\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39msoftmax(weight\u001b[39m.\u001b[39mfloat(), dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\u001b[39m.\u001b[39mtype(weight\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 24\u001b[0m a \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39meinsum(\u001b[39m\"\u001b[39m\u001b[39mbts,bcs->bct\u001b[39m\u001b[39m\"\u001b[39m, weight, v)\n",
1710
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/functional.py:360\u001b[0m, in \u001b[0;36meinsum\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[39m# recurse incase operands contains value that has torch function\u001b[39;00m\n\u001b[1;32m 357\u001b[0m \u001b[39m# in the original implementation this line is omitted\u001b[39;00m\n\u001b[1;32m 358\u001b[0m \u001b[39mreturn\u001b[39;00m einsum(equation, \u001b[39m*\u001b[39m_operands)\n\u001b[0;32m--> 360\u001b[0m \u001b[39mreturn\u001b[39;00m _VF\u001b[39m.\u001b[39;49meinsum(equation, operands)\n",
1711
- "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.64 GiB total capacity; 21.65 GiB already allocated; 432.50 MiB free; 22.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
1712
- ]
1713
  }
1714
  ],
1715
  "source": [
 
33
  {
34
  "data": {
35
  "application/vnd.jupyter.widget-view+json": {
36
+ "model_id": "0e2d634b9f734693a5e1eace447bd2e1",
37
  "version_major": 2,
38
  "version_minor": 0
39
  },
 
81
  "from pathlib import Path\n",
82
  "from diffusers.optimization import get_cosine_schedule_with_warmup\n",
83
  "from accelerate import notebook_launcher, Accelerator\n",
84
+ "from huggingface_hub import create_repo, upload_folder\n",
85
+ "\n",
86
+ "from load_h5 import Dataset4h5\n",
87
+ "from context_unet import ContextUnet"
88
  ]
89
  },
90
  {
 
102
  "metadata": {},
103
  "outputs": [],
104
  "source": [
105
+ "# class Dataset4h5(Dataset):\n",
106
+ "# def __init__(self, dir_name, num_image=10, field='brightness_temp', shuffle=True, idx=None, num_redshift=32, HII_DIM=32, rescale=True, drop_prob = 0, dim=2):\n",
107
+ "# super().__init__()\n",
108
  " \n",
109
+ "# self.dir_name = dir_name\n",
110
+ "# self.num_image = num_image\n",
111
+ "# self.field = field\n",
112
+ "# self.shuffle = shuffle\n",
113
+ "# self.idx = idx\n",
114
+ "# self.num_redshift = num_redshift\n",
115
+ "# self.HII_DIM = HII_DIM\n",
116
+ "# self.drop_prob = drop_prob\n",
117
+ "# self.dim = dim\n",
118
+ "\n",
119
+ "# self.load_h5()\n",
120
+ "# if rescale:\n",
121
+ "# self.images = self.rescale(self.images, to=[-1,1])\n",
122
+ "# self.params = self.rescale(self.params, to=[0,1])\n",
123
+ "\n",
124
+ "# self.len = len(self.params)\n",
125
+ "# self.images = torch.from_numpy(self.images)\n",
126
+ "# print(f\"images rescaled to [{self.images.min()}, {self.images.max()}]\")\n",
127
+ "\n",
128
+ "# cond_filter = torch.bernoulli(torch.ones(len(self.params),1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy()\n",
129
+ "# self.params = torch.from_numpy(self.params*cond_filter)\n",
130
+ "# print(f\"params rescaled to [{self.params.min()}, {self.params.max()}]\")\n",
131
+ "\n",
132
+ "# def load_h5(self):\n",
133
+ "# with h5py.File(self.dir_name, 'r') as f:\n",
134
+ "# print(f\"dataset content: {f.keys()}\")\n",
135
+ "# max_num_image = len(f['brightness_temp'])#.shape[0]\n",
136
+ "# print(f\"{max_num_image} images can be loaded\")\n",
137
+ "# field_shape = f['brightness_temp'].shape[1:]\n",
138
+ "# print(f\"field.shape = {field_shape}\")\n",
139
+ "# self.params_keys = list(f['params']['keys'])\n",
140
+ "# print(f\"params keys = {self.params_keys}\")\n",
141
+ "\n",
142
+ "# if self.idx is None:\n",
143
+ "# if self.shuffle:\n",
144
+ "# self.idx = np.sort(random.sample(range(max_num_image), self.num_image))\n",
145
+ "# print(f\"loading {self.num_image} images randomly\")\n",
146
+ "# # print(self.idx)\n",
147
+ "# else:\n",
148
+ "# self.idx = range(self.num_image)\n",
149
+ "# print(f\"loading {len(self.idx)} images with idx = {self.idx}\")\n",
150
+ "# else:\n",
151
+ "# print(f\"loading {len(self.idx)} images with idx = {self.idx}\")\n",
152
+ "\n",
153
+ "# if self.dim == 2:\n",
154
+ "# self.images = f[self.field][self.idx,0,:self.HII_DIM,-self.num_redshift:][:,None]\n",
155
+ "# elif self.dim == 3:\n",
156
+ "# self.images = f[self.field][self.idx,:self.HII_DIM,:self.HII_DIM,-self.num_redshift:][:,None]\n",
157
+ "# print(f\"images loaded:\", self.images.shape)\n",
158
+ "\n",
159
+ "# self.params = f['params']['values'][self.idx]\n",
160
+ "# print(\"params loaded:\", self.params.shape)\n",
161
  " \n",
162
+ "# # plt.imshow(self.images[0,0,0])\n",
163
+ "# # plt.show()\n",
164
+ "\n",
165
+ "# def rescale(self, value, to: list):\n",
166
+ "# # print(np.ndim(value))\n",
167
+ "# if np.ndim(value)==2:\n",
168
+ "# # print(f\"rescale params of shape {value.shape}\")\n",
169
+ "# ranges = \\\n",
170
+ "# {\n",
171
+ "# 0: [4, 6], # ION_Tvir_MIN\n",
172
+ "# 1: [10, 250], # HII_EFF_FACTOR\n",
173
+ "# }\n",
174
+ "# # elif np.ndim(value)==5: \n",
175
+ "# else: \n",
176
+ "# # value = np.array(value)\n",
177
+ "# # print(f\"rescale images of shape {np.shape(value)}\")\n",
178
+ "# ranges = \\\n",
179
+ "# {\n",
180
+ "# 0: [0, 80], # brightness_temp\n",
181
+ "# }\n",
182
+ "# # print(f\"value.min = {value.min()}, value.max = {value.max()}\")\n",
183
+ "# for i in range(np.shape(value)[1]):\n",
184
+ "# value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0])\n",
185
+ "# # print(f\"value.min = {value.min()}, value.max = {value.max()}\")\n",
186
+ "# value = value * (to[1]-to[0]) + to[0]\n",
187
+ "# return value \n",
188
+ "\n",
189
+ "# def __getitem__(self, index):\n",
190
+ "# return self.images[index], self.params[index]\n",
191
+ "\n",
192
+ "# def __len__(self):\n",
193
+ "# return self.len"
194
  ]
195
  },
196
  {
 
349
  "metadata": {},
350
  "outputs": [],
351
  "source": [
352
+ "# class GroupNorm32(nn.GroupNorm):\n",
353
+ "# def __init__(self, num_groups, num_channels, swish, eps=1e-5):\n",
354
+ "# super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)\n",
355
+ "# self.swish = swish\n",
356
+ "\n",
357
+ "# def forward(self, x):\n",
358
+ "# y = super().forward(x.float()).to(x.dtype)\n",
359
+ "# if self.swish == 1.0:\n",
360
+ "# y = F.silu(y)\n",
361
+ "# elif self.swish:\n",
362
+ "# y = y * F.sigmoid(y * float(self.swish))\n",
363
+ "# return y\n",
364
+ "\n",
365
+ "# def normalization(channels, swish=0.0):\n",
366
+ "# \"\"\"\n",
367
+ "# Make a standard normalization layer, with an optional swish activation.\n",
368
+ "\n",
369
+ "# :param channels: number of input channels.\n",
370
+ "# :return: an nn.Module for normalization.\n",
371
+ "# \"\"\"\n",
372
+ "# #print (channels)\n",
373
+ "# return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)\n",
374
+ "\n",
375
+ "# Conv = {\n",
376
+ "# 1: nn.Conv1d,\n",
377
+ "# 2: nn.Conv2d,\n",
378
+ "# 3: nn.Conv3d,\n",
379
+ "# }\n",
380
+ "\n",
381
+ "# AvgPool = {\n",
382
+ "# 1: nn.AvgPool1d,\n",
383
+ "# 2: nn.AvgPool2d,\n",
384
+ "# 3: nn.AvgPool3d\n",
385
+ "# }\n",
386
+ "\n",
387
+ "# class Downsample(nn.Module):\n",
388
+ "# def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2)):\n",
389
+ "# super().__init__()\n",
390
+ "# self.channels = channels\n",
391
+ "# self.out_channels = out_channels or channels\n",
392
+ "# # stride = config.stride\n",
393
+ "# if use_conv:\n",
394
+ "# # print(\"conv\")\n",
395
+ "# self.op = Conv[dim](channels, self.out_channels, 3, stride=stride, padding=1)\n",
396
+ "# else:\n",
397
+ "# # print(\"pool\")\n",
398
+ "# assert channels == self.out_channels\n",
399
+ "# self.op = AvgPool[dim](kernel_size=stride, stride=stride)\n",
400
+ "\n",
401
+ "# def forward(self, x):\n",
402
+ "# assert x.shape[1] == self.channels\n",
403
+ "# return self.op(x)\n",
404
+ "\n",
405
+ "# class Upsample(nn.Module):\n",
406
+ "# def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2)):\n",
407
+ "# super().__init__()\n",
408
+ "# self.channels = channels\n",
409
+ "# self.out_channels = out_channels\n",
410
+ "# self.use_conv = use_conv\n",
411
+ "# self.stride = stride\n",
412
+ "# if self.use_conv:\n",
413
+ "# self.conv = Conv[dim](self.channels, self.out_channels, 3, padding=1)\n",
414
+ "\n",
415
+ "# def forward(self, x):\n",
416
+ "# assert x.shape[1] == self.channels\n",
417
+ "# # stride = config.stride\n",
418
+ "# # print(torch.tensor(x.shape[2:]))\n",
419
+ "# # print(torch.tensor(stride))\n",
420
+ "# shape = torch.tensor(x.shape[2:]) * torch.tensor(self.stride)\n",
421
+ "# shape = tuple(shape.detach().numpy())\n",
422
+ "# # print(shape)\n",
423
+ "# x = F.interpolate(x, shape, mode='nearest')\n",
424
+ "# if self.use_conv:\n",
425
+ "# x = self.conv(x)\n",
426
+ "# return x\n",
427
+ "\n",
428
+ "# def zero_module(module):\n",
429
+ "# \"\"\"\n",
430
+ "# clean gradient of parameters of the module\n",
431
+ "# \"\"\"\n",
432
+ "# for p in module.parameters():\n",
433
+ "# p.detach().zero_()\n",
434
+ "# return module\n",
435
+ "\n",
436
+ "# class TimestepBlock(ABC, nn.Module):\n",
437
+ "# @abstractmethod\n",
438
+ "# def forward(self, x, emb):\n",
439
+ "# \"\"\"\n",
440
+ "# test\n",
441
+ "# \"\"\"\n",
442
+ "\n",
443
+ "# class TimestepEmbedSequential(nn.Sequential, TimestepBlock):\n",
444
+ "# def forward(self, x, emb, encoder_out=None):\n",
445
+ "# for layer in self:\n",
446
+ "# if isinstance(layer, TimestepBlock):\n",
447
+ "# x = layer(x, emb)\n",
448
+ "# elif isinstance(layer, AttentionBlock):\n",
449
+ "# x = layer(x, encoder_out)\n",
450
+ "# else:\n",
451
+ "# x = layer(x)\n",
452
+ "# return x\n",
453
+ "\n",
454
+ "# class ResBlock(TimestepBlock):\n",
455
+ "# def __init__(\n",
456
+ "# self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_checkpoint=False, use_scale_shift_norm=False, up=False, down=False, dim=2, stride=(2,2),\n",
457
+ "# ):\n",
458
+ "# super().__init__()\n",
459
+ "# self.out_channels = out_channels or channels\n",
460
+ "# self.use_scale_shift_norm = use_scale_shift_norm\n",
461
+ "# self.stride = stride\n",
462
+ "\n",
463
+ "# self.in_layers = nn.Sequential(\n",
464
+ "# # nn.BatchNorm2d(channels), # normalize to standard gaussian\n",
465
+ "# normalization(channels, swish=1.0),\n",
466
+ "# nn.Identity(),\n",
467
+ "# Conv[dim](channels, self.out_channels, 3, padding=1),\n",
468
+ "# )\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  "\n",
470
+ "# self.updown = up or down\n",
471
+ "# if up:\n",
472
+ "# self.h_updown = Upsample(channels, False, dim=dim, stride=stride)\n",
473
+ "# self.x_updown = Upsample(channels, False, dim=dim, stride=stride)\n",
474
+ "# elif down:\n",
475
+ "# self.h_updown = Downsample(channels, False, dim=dim, stride=stride)\n",
476
+ "# self.x_updown = Downsample(channels, False, dim=dim, stride=stride)\n",
477
+ "# else:\n",
478
+ "# self.h_updown = self.x_updown = nn.Identity()\n",
479
+ "\n",
480
+ "# self.emb_layers = nn.Sequential(\n",
481
+ "# nn.SiLU(),\n",
482
+ "# nn.Linear(\n",
483
+ "# emb_channels,\n",
484
+ "# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,\n",
485
+ "# ),\n",
486
+ "# )\n",
487
+ "\n",
488
+ "# self.out_layers = nn.Sequential(\n",
489
+ "# # nn.BatchNorm2d(self.out_channels),\n",
490
+ "# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),\n",
491
+ "# nn.SiLU() if use_scale_shift_norm else nn.Identity(),\n",
492
+ "# nn.Dropout(p=dropout),\n",
493
+ "# zero_module(Conv[dim](self.out_channels, self.out_channels, 3, padding=1)),\n",
494
+ "# )\n",
495
+ "\n",
496
+ "# if self.out_channels == channels:\n",
497
+ "# self.skip_connection = nn.Identity()\n",
498
+ "# elif use_conv:\n",
499
+ "# self.skip_connection = Conv[dim](channels, self.out_channels, 3, padding=1)\n",
500
+ "# else:\n",
501
+ "# self.skip_connection = Conv[dim](channels, self.out_channels, 1)\n",
502
  " \n",
503
  "\n",
504
+ "# def forward(self, x, emb):\n",
505
+ "# if self.updown:\n",
506
+ "# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]\n",
507
+ "# h = in_rest(x)\n",
508
+ "# h = self.h_updown(h)\n",
509
+ "# x = self.x_updown(x)\n",
510
+ "# h = in_conv(h)\n",
511
+ "# else:\n",
512
+ "# h = self.in_layers(x)\n",
513
+ "# emb_out = self.emb_layers(emb).type(h.dtype)\n",
514
  "\n",
515
+ "# while len(emb_out.shape) < len(h.shape):\n",
516
+ "# emb_out = emb_out[..., None]\n",
517
  "\n",
518
+ "# if self.use_scale_shift_norm:\n",
519
+ "# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]\n",
520
+ "# scale, shift = torch.chunk(emb_out, 2, dim=1)\n",
521
+ "# h = out_norm(h) * (1+scale) + shift\n",
522
+ "# h = out_rest(h)\n",
523
+ "# else:\n",
524
+ "# h += emb_out\n",
525
+ "# h = self.out_layers(h)\n",
526
+ "# # print(\"ResBlock, torch.unique(h).shape =\", torch.unique(h).shape)\n",
527
+ "# return self.skip_connection(x) + h\n",
528
+ "\n",
529
+ "# class QKVAttention(nn.Module):\n",
530
+ "# def __init__(self, n_heads):\n",
531
+ "# super().__init__()\n",
532
+ "# self.n_heads = n_heads\n",
533
+ "# # print(\"QKVAttention, self.n_heads =\", self.n_heads)\n",
 
 
 
 
 
 
 
534
  " \n",
535
+ "# def forward(self, qkv, encoder_kv=None):\n",
536
+ "# bs, width, length = qkv.shape\n",
537
+ "# assert width % (3*self.n_heads) == 0\n",
538
+ "# ch = width // (3*self.n_heads)\n",
539
+ "\n",
540
+ "# # print(\"QKVAttention\", bs, self.n_heads, ch, length)\n",
541
+ "# q, k, v = qkv.reshape(bs*self.n_heads, ch*3, length).split(ch, dim=1)\n",
542
+ "# if encoder_kv is not None:\n",
543
+ "# assert encoder_kv.shape[1] == self.n_heads * ch * 2\n",
544
+ "# ek, ev = encoder_kv.reshape(bs*self.n_heads, ch*2, -1).split(ch, dim=1)\n",
545
+ "# k = torch.cat([ek,k], dim=-1)\n",
546
+ "# v = torch.cat([ev,v], dim=-1)\n",
547
+ "\n",
548
+ "# scale = 1 / math.sqrt(math.sqrt(ch))\n",
549
+ "# weight = torch.einsum(\"bct,bcs->bts\", q*scale, k*scale)\n",
550
+ "# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n",
551
+ "\n",
552
+ "# a = torch.einsum(\"bts,bcs->bct\", weight, v)\n",
553
+ "# return a.reshape(bs, -1, length)\n",
554
+ "\n",
555
+ "# class AttentionBlock(nn.Module):\n",
556
+ "# def __init__(\n",
557
+ "# self,\n",
558
+ "# channels,\n",
559
+ "# num_heads=1,\n",
560
+ "# num_head_channels=-1,\n",
561
+ "# use_checkpoint=False,\n",
562
+ "# encoder_channels=None,\n",
563
+ "# ):\n",
564
+ "# super().__init__()\n",
565
+ "# self.channels = channels\n",
566
+ "# if num_head_channels == -1:\n",
567
+ "# self.num_heads = num_heads\n",
568
+ "# else:\n",
569
+ "# assert channels % num_head_channels == 0,\\\n",
570
+ "# f\"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}\"\n",
571
+ "# self.num_heads = channels // num_head_channels\n",
572
+ "\n",
573
+ "# self.use_checkpoint = use_checkpoint\n",
574
+ "# # self.norm = nn.BatchNorm2d(channels)\n",
575
+ "# self.norm = normalization(channels, swish=0.0)\n",
576
+ "# self.qkv = nn.Conv1d(channels, channels * 3, 1)\n",
 
 
 
 
 
 
 
577
  " \n",
578
+ "# self.attention = QKVAttention(self.num_heads)\n",
579
+ "\n",
580
+ "# if encoder_channels is not None:\n",
581
+ "# self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)\n",
582
+ "# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))\n",
583
+ "\n",
584
+ "# def forward(self, x, encoder_out=None):\n",
585
+ "# b, c, *spatial = x.shape\n",
586
+ "# qkv = self.qkv(self.norm(x).view(b, c, -1))\n",
587
+ "# if encoder_out is not None:\n",
588
+ "# encoder_out = self.encoder_kv(encoder_out)\n",
589
+ "# h = self.attention(qkv, encoder_out)\n",
590
+ "# else:\n",
591
+ "# h = self.attention(qkv)\n",
592
+ "# # print(\"AttentionBlock, before proj_out, torch.unique(h).shape =\", torch.unique(h).shape)\n",
593
+ "# h = self.proj_out(h)\n",
594
+ "# # print(\"AttentionBlock, after proj_out, torch.unique(h).shape =\", torch.unique(h).shape)\n",
595
+ "# return x + h.reshape(b, c, *spatial)\n",
596
+ "\n",
597
+ "# def timestep_embedding(timesteps, dim, max_period=10000):\n",
598
+ "# \"\"\"\n",
599
+ "# Create sinusoidal timestep embeddings.\n",
600
+ "\n",
601
+ "# :param timesteps: a 1-D Tensor of N indices, one per batch element.\n",
602
+ "# These may be fractional.\n",
603
+ "# :param dim: the dimension of the output.\n",
604
+ "# :param max_period: controls the minimum frequency of the embeddings.\n",
605
+ "# :return: an [N x dim] Tensor of positional embeddings.\n",
606
+ "# \"\"\"\n",
607
+ "# #print (timesteps.shape)\n",
608
+ "# half = dim // 2\n",
609
+ "# freqs = torch.exp(\n",
610
+ "# -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n",
611
+ "# ).to(device=timesteps.device)\n",
612
+ "# #print (timesteps[:, None].float().shape,freqs[None].shape)\n",
613
+ "# args = timesteps[:, None].float() * freqs[None]\n",
614
+ "# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n",
615
+ "# if dim % 2:\n",
616
+ "# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n",
617
+ "# return embedding\n",
618
+ "\n",
619
+ "# class ContextUnet(nn.Module):\n",
620
+ "# def __init__(\n",
621
+ "# self,\n",
622
+ "# n_param=2,\n",
623
+ "# image_size=64,\n",
624
+ "# in_channels=1,\n",
625
+ "# model_channels=128,\n",
626
+ "# out_channels = 1,\n",
627
+ "# channel_mult = None,\n",
628
+ "# num_res_blocks = 2,\n",
629
+ "# dropout = 0,\n",
630
+ "# use_checkpoint = False,\n",
631
+ "# use_scale_shift_norm = False,\n",
632
+ "# attention_resolutions = (16, 8),\n",
633
+ "# num_heads = 4,\n",
634
+ "# num_head_channels = -1,\n",
635
+ "# num_heads_upsample = -1,\n",
636
+ "# resblock_updown = False,\n",
637
+ "# conv_resample = True,\n",
638
+ "# encoder_channels = None,\n",
639
+ "# dim = 2,\n",
640
+ "# stride = (2,2)\n",
641
+ "# ):\n",
642
+ "# super().__init__()\n",
643
+ "\n",
644
+ "# if channel_mult == None:\n",
645
+ "# if image_size == 512:\n",
646
+ "# channel_mult = (0.5, 1, 1, 2, 2, 4, 4)\n",
647
+ "# elif image_size == 256:\n",
648
+ "# channel_mult = (1, 1, 2, 2, 4, 4)\n",
649
+ "# elif image_size == 128:\n",
650
+ "# channel_mult = (1, 1, 2, 3, 4)\n",
651
+ "# elif image_size == 64:\n",
652
+ "# channel_mult = (1, 1, 2, 2, 4, 4)#(1, 2, 3, 4)\n",
653
+ "# elif image_size == 28:\n",
654
+ "# channel_mult = (1, 2)#(1, 2, 3, 4)\n",
655
+ "# else:\n",
656
+ "# raise ValueError(f\"unsupported image size: {image_size}\")\n",
657
+ "# # else:\n",
658
+ "# # channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(\",\"))\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  " \n",
660
+ "# attention_ds = []\n",
661
+ "# for res in attention_resolutions:\n",
662
+ "# attention_ds.append(image_size // int(res))\n",
663
+ "\n",
664
+ "# # print(\"before, ContextUnet, num_heads_upsample =\", num_heads_upsample, \"num_heads =\", num_heads)\n",
665
+ "# if num_heads_upsample == -1:\n",
666
+ "# num_heads_upsample = num_heads\n",
667
+ "# # print(\"after, ContextUnet, num_heads_upsample =\", num_heads_upsample, \"num_heads =\", num_heads)\n",
668
+ "\n",
669
+ "# # self.n_param = n_param\n",
670
+ "# self.model_channels = model_channels\n",
671
+ "# self.dtype = torch.float32\n",
672
+ "\n",
673
+ "# self.token_embedding = nn.Linear(n_param, model_channels * 4)\n",
674
+ "\n",
675
+ "# time_embed_dim = model_channels * 4\n",
676
+ "# self.time_embed = nn.Sequential(\n",
677
+ "# nn.Linear(model_channels, time_embed_dim),\n",
678
+ "# nn.SiLU(),\n",
679
+ "# nn.Linear(time_embed_dim, time_embed_dim),\n",
680
+ "# )\n",
681
+ "\n",
682
+ "# ch = input_ch = int(channel_mult[0] * model_channels)\n",
683
+ "\n",
684
+ "# ###################### input_blocks ######################\n",
685
+ "# self.input_blocks = nn.ModuleList(\n",
686
+ "# [TimestepEmbedSequential(Conv[dim](in_channels, ch, 3, padding=1))]\n",
687
+ "# )\n",
688
+ "# self._feature_size = ch\n",
689
+ "# input_block_chans = [ch]\n",
690
+ "# ds = 1\n",
691
+ "\n",
692
+ "# for level, mult in enumerate(channel_mult):\n",
693
+ "# for _ in range(num_res_blocks):\n",
694
+ "# layers = [\n",
695
+ "# ResBlock(\n",
696
+ "# ch,\n",
697
+ "# time_embed_dim,\n",
698
+ "# dropout,\n",
699
+ "# out_channels = int(mult * model_channels),\n",
700
+ "# use_checkpoint = use_checkpoint,\n",
701
+ "# use_scale_shift_norm = use_scale_shift_norm,\n",
702
+ "# dim = dim,\n",
703
+ "# stride = stride,\n",
704
+ "# )\n",
705
+ "# ]\n",
706
+ "# ch = int(mult * model_channels)\n",
707
+ "# if ds in attention_ds:\n",
708
+ "# layers.append(\n",
709
+ "# AttentionBlock(\n",
710
+ "# ch,\n",
711
+ "# use_checkpoint=use_checkpoint,\n",
712
+ "# num_heads = num_heads,\n",
713
+ "# num_head_channels = num_head_channels,\n",
714
+ "# encoder_channels = encoder_channels,\n",
715
+ "# )\n",
716
+ "# )\n",
717
+ "# self.input_blocks.append(TimestepEmbedSequential(*layers))\n",
718
+ "# self._feature_size += ch\n",
719
+ "# input_block_chans.append(ch)\n",
720
+ "\n",
721
+ "# if level != len(channel_mult) - 1:\n",
722
+ "# out_ch = ch\n",
723
+ "# self.input_blocks.append(\n",
724
+ "# TimestepEmbedSequential(\n",
725
+ "# ResBlock(\n",
726
+ "# ch,\n",
727
+ "# time_embed_dim,\n",
728
+ "# dropout,\n",
729
+ "# out_channels=out_ch,\n",
730
+ "# # dims=dims,\n",
731
+ "# use_checkpoint=use_checkpoint,\n",
732
+ "# use_scale_shift_norm=use_scale_shift_norm,\n",
733
+ "# down=True,\n",
734
+ "# dim = dim,\n",
735
+ "# stride = stride,\n",
736
+ "# )\n",
737
+ "# if resblock_updown\n",
738
+ "# else Downsample(ch, conv_resample, out_channels=out_ch, dim=dim, stride=stride)\n",
739
+ "# )\n",
740
+ "# )\n",
741
+ "# ch = out_ch\n",
742
+ "# input_block_chans.append(ch)\n",
743
+ "# ds *= 2\n",
744
+ "# self._feature_size += ch\n",
745
+ "\n",
746
+ "\n",
747
+ "# ###################### middle_blocks ######################\n",
748
+ "# self.middle_block = TimestepEmbedSequential(\n",
749
+ "# ResBlock(\n",
750
+ "# ch,\n",
751
+ "# time_embed_dim,\n",
752
+ "# dropout,\n",
753
+ "# use_checkpoint=use_checkpoint,\n",
754
+ "# use_scale_shift_norm=use_scale_shift_norm,\n",
755
+ "# dim = dim,\n",
756
+ "# stride = stride,\n",
757
+ "# ),\n",
758
+ "# AttentionBlock(\n",
759
+ "# ch,\n",
760
+ "# use_checkpoint=use_checkpoint,\n",
761
+ "# num_heads=num_heads,\n",
762
+ "# num_head_channels=num_head_channels,\n",
763
+ "# encoder_channels=encoder_channels,\n",
764
+ "# ),\n",
765
+ "# ResBlock(\n",
766
+ "# ch,\n",
767
+ "# time_embed_dim,\n",
768
+ "# dropout,\n",
769
+ "# use_checkpoint=use_checkpoint,\n",
770
+ "# use_scale_shift_norm=use_scale_shift_norm,\n",
771
+ "# dim = dim,\n",
772
+ "# stride = stride,\n",
773
+ "# ),\n",
774
+ "# )\n",
775
+ "# self._feature_size += ch\n",
776
+ "\n",
777
+ "\n",
778
+ "# ###################### output_blocks ######################\n",
779
+ "# self.output_blocks = nn.ModuleList([])\n",
780
+ "# for level, mult in list(enumerate(channel_mult))[::-1]:\n",
781
+ "# for i in range(num_res_blocks + 1):\n",
782
+ "# ich = input_block_chans.pop()\n",
783
+ "# layers = [\n",
784
+ "# ResBlock(\n",
785
+ "# ch + ich,\n",
786
+ "# time_embed_dim,\n",
787
+ "# dropout,\n",
788
+ "# out_channels=int(model_channels * mult),\n",
789
+ "# # dims=dims,\n",
790
+ "# use_checkpoint=use_checkpoint,\n",
791
+ "# use_scale_shift_norm=use_scale_shift_norm,\n",
792
+ "# dim = dim,\n",
793
+ "# stride = stride,\n",
794
+ "# )\n",
795
+ "# ]\n",
796
+ "# ch = int(model_channels * mult)\n",
797
+ "# if ds in attention_ds:\n",
798
+ "# # print(\"ds in attention_resolutions, num_heads=\", num_heads_upsample)\n",
799
+ "# layers.append(\n",
800
+ "# AttentionBlock(\n",
801
+ "# ch,\n",
802
+ "# use_checkpoint=use_checkpoint,\n",
803
+ "# num_heads=num_heads_upsample,\n",
804
+ "# num_head_channels=num_head_channels,\n",
805
+ "# encoder_channels=encoder_channels,\n",
806
+ "# )\n",
807
+ "# )\n",
808
+ "# if level and i == num_res_blocks:\n",
809
+ "# out_ch = ch\n",
810
+ "# layers.append(\n",
811
+ "# ResBlock(\n",
812
+ "# ch,\n",
813
+ "# time_embed_dim,\n",
814
+ "# dropout,\n",
815
+ "# out_channels=out_ch,\n",
816
+ "# # dims=dims,\n",
817
+ "# use_checkpoint=use_checkpoint,\n",
818
+ "# use_scale_shift_norm=use_scale_shift_norm,\n",
819
+ "# up=True,\n",
820
+ "# dim = dim,\n",
821
+ "# stride = stride,\n",
822
+ "# )\n",
823
+ "# if resblock_updown\n",
824
+ "# else Upsample(ch, conv_resample, out_channels=out_ch, dim=dim, stride=stride)\n",
825
+ "# )\n",
826
+ "# ds //= 2\n",
827
+ "# self.output_blocks.append(TimestepEmbedSequential(*layers))\n",
828
+ "# self._feature_size += ch\n",
829
+ "\n",
830
+ "# self.out = nn.Sequential(\n",
831
+ "# # nn.BatchNorm2d(ch),\n",
832
+ "# normalization(ch, swish=1.0),\n",
833
+ "# nn.Identity(),\n",
834
+ "# zero_module(Conv[dim](input_ch, out_channels, 3, padding=1)),\n",
835
+ "# )\n",
836
+ "# # self.use_fp16 = use_fp16\n",
837
+ "\n",
838
+ "# def forward(self, x, timesteps, y=None):\n",
839
+ "# hs = []\n",
840
+ "# emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))\n",
841
+ "# if y != None:\n",
842
+ "# text_outputs = self.token_embedding(y.float())\n",
843
+ "# emb = emb + text_outputs.to(emb)\n",
844
+ "\n",
845
+ "# h = x.type(self.dtype)\n",
846
+ "# # print(\"0,h.shape =\", h.shape)\n",
847
+ "# for module in self.input_blocks:\n",
848
+ "# h = module(h, emb)\n",
849
+ "# hs.append(h)\n",
850
+ "# # print(\"module encoder, h.shape =\", h.shape)\n",
851
+ "# # print(\"2,h.shape =\", h.shape)\n",
852
+ "# h = self.middle_block(h, emb)\n",
853
+ "# # print(\"middle block, h.shape =\", h.shape)\n",
854
+ "# # print(\"2,h.shape =\", h.shape)\n",
855
+ "# for module in self.output_blocks:\n",
856
+ "# # print(\"for module in self.output_blocks, h.shape =\", h.shape)\n",
857
+ "# # print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\n",
858
+ "# h = torch.cat([h, hs.pop()], dim=1)\n",
859
+ "# h = module(h, emb)\n",
860
+ "# # print(\"module decoder, h.shape =\", h.shape)\n",
861
+ "\n",
862
+ "# h = h.type(x.dtype)\n",
863
+ "# h = self.out(h)\n",
864
+ "# # print(\"self.out(h)\", \"h.shape =\", h.shape)\n",
865
+ "\n",
866
+ "# return h "
867
  ]
868
  },
869
  {
870
  "cell_type": "code",
871
+ "execution_count": 9,
872
  "metadata": {},
873
  "outputs": [],
874
  "source": [
 
893
  " self.step += 1\n",
894
  "\n",
895
  " def reset_parameters(self, ema_model, model):\n",
896
+ " ema_model.load_state_dict(model.state_dict())\n",
897
+ " "
898
  ]
899
  },
900
  {
901
  "cell_type": "code",
902
+ "execution_count": 10,
903
  "metadata": {},
904
  "outputs": [],
905
  "source": [
 
965
  },
966
  {
967
  "cell_type": "code",
968
+ "execution_count": 11,
969
  "metadata": {},
970
  "outputs": [],
971
  "source": [
 
975
  },
976
  {
977
  "cell_type": "code",
978
+ "execution_count": 12,
979
  "metadata": {},
980
  "outputs": [],
981
  "source": [
 
984
  },
985
  {
986
  "cell_type": "code",
987
+ "execution_count": 13,
988
  "metadata": {},
989
  "outputs": [],
990
  "source": [
 
1008
  },
1009
  {
1010
  "cell_type": "code",
1011
+ "execution_count": 14,
1012
  "metadata": {},
1013
  "outputs": [],
1014
  "source": [
 
1206
  },
1207
  {
1208
  "cell_type": "code",
1209
+ "execution_count": 15,
1210
  "metadata": {},
1211
  "outputs": [
1212
  {
 
1416
  },
1417
  {
1418
  "cell_type": "code",
1419
+ "execution_count": 16,
1420
  "metadata": {},
1421
  "outputs": [
1422
  {
 
1443
  "output_type": "stream",
1444
  "text": [
1445
  "params loaded: (200, 2)\n",
1446
+ "images rescaled to [-1.0, 1.056351900100708]\n",
1447
+ "params rescaled to [0.0, 0.999164249684298]\n"
1448
  ]
1449
  },
1450
  {
1451
  "data": {
1452
  "application/vnd.jupyter.widget-view+json": {
1453
+ "model_id": "fec693362692472581efafa594095278",
1454
  "version_major": 2,
1455
  "version_minor": 0
1456
  },
 
1464
  {
1465
  "data": {
1466
  "application/vnd.jupyter.widget-view+json": {
1467
+ "model_id": "929a642531414269ae5516eb9d9a9ba2",
1468
  "version_major": 2,
1469
  "version_minor": 0
1470
  },
 
1478
  {
1479
  "data": {
1480
  "application/vnd.jupyter.widget-view+json": {
1481
+ "model_id": "2fb5460387ad4a3798499bbae31d301e",
1482
  "version_major": 2,
1483
  "version_minor": 0
1484
  },
 
1492
  {
1493
  "data": {
1494
  "application/vnd.jupyter.widget-view+json": {
1495
+ "model_id": "f7213d3285cd46ad9f2604f88b45725b",
1496
  "version_major": 2,
1497
  "version_minor": 0
1498
  },
 
1506
  {
1507
  "data": {
1508
  "application/vnd.jupyter.widget-view+json": {
1509
+ "model_id": "5ec52d75f5b54fe7a8d912baf75686c6",
1510
  "version_major": 2,
1511
  "version_minor": 0
1512
  },
 
1520
  {
1521
  "data": {
1522
  "application/vnd.jupyter.widget-view+json": {
1523
+ "model_id": "c68ccbc52fdb4c0fbeec1932bd8f74d5",
1524
  "version_major": 2,
1525
  "version_minor": 0
1526
  },
 
1534
  {
1535
  "data": {
1536
  "application/vnd.jupyter.widget-view+json": {
1537
+ "model_id": "5dc2869a9e694a0388336d2ec71818f5",
1538
  "version_major": 2,
1539
  "version_minor": 0
1540
  },
 
1548
  {
1549
  "data": {
1550
  "application/vnd.jupyter.widget-view+json": {
1551
+ "model_id": "9ed1309b7afb46d59b568e212ee2ac0a",
1552
  "version_major": 2,
1553
  "version_minor": 0
1554
  },
 
1562
  {
1563
  "data": {
1564
  "application/vnd.jupyter.widget-view+json": {
1565
+ "model_id": "f3ee8347673c47759bc4b419e363f39a",
1566
  "version_major": 2,
1567
  "version_minor": 0
1568
  },
 
1576
  {
1577
  "data": {
1578
  "application/vnd.jupyter.widget-view+json": {
1579
+ "model_id": "cb4824a035494e97a647d0c185645318",
1580
  "version_major": 2,
1581
  "version_minor": 0
1582
  },
 
1594
  },
1595
  {
1596
  "cell_type": "code",
1597
+ "execution_count": 27,
1598
  "metadata": {},
1599
  "outputs": [
1600
  {
 
1612
  {
1613
  "data": {
1614
  "application/vnd.jupyter.widget-view+json": {
1615
+ "model_id": "402d3818dd8a45cdaf774a7a1c19a4f4",
1616
  "version_major": 2,
1617
  "version_minor": 0
1618
  },
 
1622
  },
1623
  "metadata": {},
1624
  "output_type": "display_data"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1625
  }
1626
  ],
1627
  "source": [
load_h5.py CHANGED
@@ -1,27 +1,116 @@
1
  from dataclasses import dataclass
2
  import h5py
3
  import torch
4
- import torch.nn as nn
5
  from torch.utils.data import DataLoader, Dataset
6
  # from datasets import Dataset
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import random
10
- from abc import ABC, abstractmethod
11
- import torch.nn.functional as F
12
  import math
13
- from PIL import Image
14
  import os
15
- from torch.utils.tensorboard import SummaryWriter
16
- import copy
17
- from tqdm.auto import tqdm
18
  # from torchvision import transforms
19
  # from diffusers import UNet2DModel#, UNet3DConditionModel
20
  # from diffusers import DDPMScheduler
21
- from diffusers.utils import make_image_grid
22
  import datetime
23
- from pathlib import Path
24
- from diffusers.optimization import get_cosine_schedule_with_warmup
25
- from accelerate import notebook_launcher, Accelerator
26
- from huggingface_hub import create_repo, upload_folder
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
  import h5py
3
  import torch
4
+ # import torch.nn as nn
5
  from torch.utils.data import DataLoader, Dataset
6
  # from datasets import Dataset
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import random
10
+ # from abc import ABC, abstractmethod
11
+ # import torch.nn.functional as F
12
  import math
13
+ # from PIL import Image
14
  import os
15
+ # from torch.utils.tensorboard import SummaryWriter
16
+ # import copy
17
+ # from tqdm.auto import tqdm
18
  # from torchvision import transforms
19
  # from diffusers import UNet2DModel#, UNet3DConditionModel
20
  # from diffusers import DDPMScheduler
21
+ # from diffusers.utils import make_image_grid
22
  import datetime
23
+ # from pathlib import Path
24
+ # from diffusers.optimization import get_cosine_schedule_with_warmup
25
+ # from accelerate import notebook_launcher, Accelerator
26
+ # from huggingface_hub import create_repo, upload_folder
27
 
28
+ class Dataset4h5(Dataset):
29
+ def __init__(self, dir_name, num_image=10, field='brightness_temp', shuffle=True, idx=None, num_redshift=32, HII_DIM=32, rescale=True, drop_prob = 0, dim=2):
30
+ super().__init__()
31
+
32
+ self.dir_name = dir_name
33
+ self.num_image = num_image
34
+ self.field = field
35
+ self.shuffle = shuffle
36
+ self.idx = idx
37
+ self.num_redshift = num_redshift
38
+ self.HII_DIM = HII_DIM
39
+ self.drop_prob = drop_prob
40
+ self.dim = dim
41
+
42
+ self.load_h5()
43
+ if rescale:
44
+ self.images = self.rescale(self.images, to=[-1,1])
45
+ self.params = self.rescale(self.params, to=[0,1])
46
+
47
+ self.len = len(self.params)
48
+ self.images = torch.from_numpy(self.images)
49
+ print(f"images rescaled to [{self.images.min()}, {self.images.max()}]")
50
+
51
+ cond_filter = torch.bernoulli(torch.ones(len(self.params),1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy()
52
+ self.params = torch.from_numpy(self.params*cond_filter)
53
+ print(f"params rescaled to [{self.params.min()}, {self.params.max()}]")
54
+
55
+ def load_h5(self):
56
+ with h5py.File(self.dir_name, 'r') as f:
57
+ print(f"dataset content: {f.keys()}")
58
+ max_num_image = len(f['brightness_temp'])#.shape[0]
59
+ print(f"{max_num_image} images can be loaded")
60
+ field_shape = f['brightness_temp'].shape[1:]
61
+ print(f"field.shape = {field_shape}")
62
+ self.params_keys = list(f['params']['keys'])
63
+ print(f"params keys = {self.params_keys}")
64
+
65
+ if self.idx is None:
66
+ if self.shuffle:
67
+ self.idx = np.sort(random.sample(range(max_num_image), self.num_image))
68
+ print(f"loading {self.num_image} images randomly")
69
+ # print(self.idx)
70
+ else:
71
+ self.idx = range(self.num_image)
72
+ print(f"loading {len(self.idx)} images with idx = {self.idx}")
73
+ else:
74
+ print(f"loading {len(self.idx)} images with idx = {self.idx}")
75
+
76
+ if self.dim == 2:
77
+ self.images = f[self.field][self.idx,0,:self.HII_DIM,-self.num_redshift:][:,None]
78
+ elif self.dim == 3:
79
+ self.images = f[self.field][self.idx,:self.HII_DIM,:self.HII_DIM,-self.num_redshift:][:,None]
80
+ print(f"images loaded:", self.images.shape)
81
+
82
+ self.params = f['params']['values'][self.idx]
83
+ print("params loaded:", self.params.shape)
84
+
85
+ # plt.imshow(self.images[0,0,0])
86
+ # plt.show()
87
+
88
+ def rescale(self, value, to: list):
89
+ # print(np.ndim(value))
90
+ if np.ndim(value)==2:
91
+ # print(f"rescale params of shape {value.shape}")
92
+ ranges = \
93
+ {
94
+ 0: [4, 6], # ION_Tvir_MIN
95
+ 1: [10, 250], # HII_EFF_FACTOR
96
+ }
97
+ # elif np.ndim(value)==5:
98
+ else:
99
+ # value = np.array(value)
100
+ # print(f"rescale images of shape {np.shape(value)}")
101
+ ranges = \
102
+ {
103
+ 0: [0, 80], # brightness_temp
104
+ }
105
+ # print(f"value.min = {value.min()}, value.max = {value.max()}")
106
+ for i in range(np.shape(value)[1]):
107
+ value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0])
108
+ # print(f"value.min = {value.min()}, value.max = {value.max()}")
109
+ value = value * (to[1]-to[0]) + to[0]
110
+ return value
111
+
112
+ def __getitem__(self, index):
113
+ return self.images[index], self.params[index]
114
+
115
+ def __len__(self):
116
+ return self.len