Xsmos commited on
Commit
14c70a0
·
verified ·
1 Parent(s): e564941
context_unet.py CHANGED
@@ -20,6 +20,7 @@ import copy
20
  # from diffusers import DDPMScheduler
21
  # from diffusers.utils import make_image_grid
22
  import datetime
 
23
  # from pathlib import Path
24
  # from diffusers.optimization import get_cosine_schedule_with_warmup
25
  # from accelerate import notebook_launcher, Accelerator
@@ -132,10 +133,12 @@ class ResBlock(TimestepBlock):
132
  def __init__(
133
  self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_checkpoint=False, use_scale_shift_norm=False, up=False, down=False, dim=2, stride=(2,2),
134
  ):
 
135
  super().__init__()
136
  self.out_channels = out_channels or channels
137
  self.use_scale_shift_norm = use_scale_shift_norm
138
  self.stride = stride
 
139
 
140
  self.in_layers = nn.Sequential(
141
  # nn.BatchNorm2d(channels), # normalize to standard gaussian
@@ -177,8 +180,13 @@ class ResBlock(TimestepBlock):
177
  else:
178
  self.skip_connection = Conv[dim](channels, self.out_channels, 1)
179
 
180
-
181
  def forward(self, x, emb):
 
 
 
 
 
 
182
  if self.updown:
183
  in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
184
  h = in_rest(x)
@@ -239,6 +247,7 @@ class AttentionBlock(nn.Module):
239
  use_checkpoint=False,
240
  encoder_channels=None,
241
  ):
 
242
  super().__init__()
243
  self.channels = channels
244
  if num_head_channels == -1:
@@ -260,6 +269,12 @@ class AttentionBlock(nn.Module):
260
  self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
261
 
262
  def forward(self, x, encoder_out=None):
 
 
 
 
 
 
263
  b, c, *spatial = x.shape
264
  qkv = self.qkv(self.norm(x).view(b, c, -1))
265
  if encoder_out is not None:
@@ -533,11 +548,12 @@ class ContextUnet(nn.Module):
533
  #print("0,h.shape =", h.shape)
534
  for module in self.input_blocks:
535
  h = module(h, emb)
 
536
  hs.append(h)
537
  #print("module encoder, h.shape =", h.shape)
538
- # print("2,h.shape =", h.shape)
539
  h = self.middle_block(h, emb)
540
- #print("middle block, h.shape =", h.shape)
541
  #print("2, h.dtype =", h.dtype)
542
  for module in self.output_blocks:
543
  #print("for module in self.output_blocks, h.shape =", h.shape)
 
20
  # from diffusers import DDPMScheduler
21
  # from diffusers.utils import make_image_grid
22
  import datetime
23
+ import torch.utils.checkpoint as checkpoint
24
  # from pathlib import Path
25
  # from diffusers.optimization import get_cosine_schedule_with_warmup
26
  # from accelerate import notebook_launcher, Accelerator
 
133
  def __init__(
134
  self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_checkpoint=False, use_scale_shift_norm=False, up=False, down=False, dim=2, stride=(2,2),
135
  ):
136
+ #print(f"Resblock, use_checkpoint = {use_checkpoint}")
137
  super().__init__()
138
  self.out_channels = out_channels or channels
139
  self.use_scale_shift_norm = use_scale_shift_norm
140
  self.stride = stride
141
+ self.use_checkpoint = use_checkpoint
142
 
143
  self.in_layers = nn.Sequential(
144
  # nn.BatchNorm2d(channels), # normalize to standard gaussian
 
180
  else:
181
  self.skip_connection = Conv[dim](channels, self.out_channels, 1)
182
 
 
183
  def forward(self, x, emb):
184
+ if self.use_checkpoint:
185
+ return checkpoint.checkpoint(self._forward_impl, x, emb, use_reentrant=False)
186
+ else:
187
+ return self._forward_impl(x, emb)
188
+
189
+ def _forward_impl(self, x, emb):
190
  if self.updown:
191
  in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
192
  h = in_rest(x)
 
247
  use_checkpoint=False,
248
  encoder_channels=None,
249
  ):
250
+ #print(f"AttentionBlock, use_checkpoint = {use_checkpoint}")
251
  super().__init__()
252
  self.channels = channels
253
  if num_head_channels == -1:
 
269
  self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
270
 
271
  def forward(self, x, encoder_out=None):
272
+ if self.use_checkpoint:
273
+ return checkpoint.checkpoint(self._forward_impl, x, encoder_out, use_reentrant=False)
274
+ else:
275
+ return self._forward_impl(x, encoder_out)
276
+
277
+ def _forward_impl(self, x, encoder_out=None):
278
  b, c, *spatial = x.shape
279
  qkv = self.qkv(self.norm(x).view(b, c, -1))
280
  if encoder_out is not None:
 
548
  #print("0,h.shape =", h.shape)
549
  for module in self.input_blocks:
550
  h = module(h, emb)
551
+ #print(f"in for loop, h.shape = {h.shape}")
552
  hs.append(h)
553
  #print("module encoder, h.shape =", h.shape)
554
+ #print("before middle block, h.shape =", h.shape)
555
  h = self.middle_block(h, emb)
556
+ #print("after middle block, h.shape =", h.shape)
557
  #print("2, h.dtype =", h.dtype)
558
  for module in self.output_blocks:
559
  #print("for module in self.output_blocks, h.shape =", h.shape)
diffusion.py CHANGED
@@ -1,31 +1,3 @@
1
- # %% [markdown]
2
- # ## 改編ContextUnet及相關代碼,使其首先對二維的情況適用。並於diffusers.Unet2DModel作比較並加以優化。最後再改寫爲3維的情形。
3
- # - 經試用diffusers的Unet2DModel,發現loss從0.3降到0.2但仍然很高,説明存在非Unet2DModel的問題可以優化
4
- # - 改用diffusers的DDMPScheduler和DDPMPipeline后,loss降低至0.1以下,有時甚至可以低至0.004,可見我的代碼問題主要出在DDPM部分。DDPMScheduler部分比較簡短,似乎沒有問題,所以問題應該在DDPMPipeline裏某一部分代碼是我代碼欠缺的。
5
- # - 我在DDPMScheduler部分有一個typo,導致beta_t一直很小,修正后loss從0.2能降低至0.02, 維持在0.1以下
6
- # - 用diffusers的DDPMScheduler似乎效果要好一些,loss總是比我的DDPMScheduler要小一點。儅epoch為19時,前者的loss約0.02,後者loss約0.07。而且前者還支持3維圖像的加噪,不如直接用別人的輪子。但我想知道爲什麽我的loss會高一些。
7
- # - 我意識到別人的DDPMScheduler在sample函數中沒有兼容輸入參數,所以歸根結底還是需要我的DDPMscheduler。不過我可以先用別人的來debug我的ContextUnet.
8
- # - 我需要將我的ContextUnet擴展兼容不同維度的照片,畢竟我本身也需要和原文獻對比完了再拓展到三維的情形
9
- # - 我已將我的ContextUnet轉成了2維的模式,與diffusers.Unet2DModel的loss=0.037相比,我的Unet的loss=0.07。同時我的Unet生成的圖像看上去很奇怪,説明我的Unet也有問題。我需要將代碼退回原Unet,並檢查問題所在。
10
- # - 我將紅移方向的像素的數量限制在了64.以此比較兩個Unet的差別。經比較:\
11
- # Unet2DModel loss:0.03, 0.0655, 0.05, 0.02, 0.05\
12
- # ContextUnet loss: 0.1, 0.16, 0.1, 0.2186, 0.06
13
- # - 我把ContextUnet退回到了原作者的版本,結果loss=0.05,輸出的照片也不錯。我主要的改動是改回了他原用的normalization函數,其中還有個參數swish。有時間我可以研究一下具體是哪裏影響了訓練的結果。另外我發現了要想tensorboard的圖綫獨立美觀,需要把他們放在不同的文件夾下
14
- # - 經過驗證,GroupNorm比batchNorm效果要好
15
- # - 已擴展爲接受不同維度的情形
16
- # - 融合cond, guide_w, drop_out這些參數
17
- # - 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題
18
- # - 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形
19
- # - 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。
20
- # - 解決了,問題出在我忘了寫with torch.no_grad():
21
- # - 接下來就是生成800個lightcones,與此同時研究如何計算global signal以及power spectrum
22
- # - 儅訓練圖片的數量達到5000時,生成的圖片與檢測數據的相似程度很高
23
- # - it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.
24
- # - the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.
25
- # - In addtion, the performance of DDPM can looks better compared to computation-intensive simulations.
26
- # 1 GPU, batch_size = 10, num_image = 3200, 50s for each epoch
27
- # 4 GPU, batch_size = 10, num_image = 3200,
28
-
29
  # %%
30
  import logging
31
  #logging.getLogger("torch").setLevel(logging.ERROR)
@@ -66,7 +38,7 @@ from context_unet import ContextUnet
66
  from huggingface_hub import notebook_login
67
 
68
  import torch.multiprocessing as mp
69
- from torch.utils.data.distributed import DistributedSampler
70
  from torch.nn.parallel import DistributedDataParallel as DDP
71
  from torch.distributed import init_process_group, destroy_process_group
72
  import torch.distributed as dist
@@ -271,13 +243,14 @@ class TrainConfig:
271
 
272
  dim = 2
273
  #dim = 3#2
274
- stride = (2,2) if dim == 2 else (2,2,2)
275
  num_image = 32#0#0#640#320#6400#3000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
276
  batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
277
  n_epoch = 100#30#50#20#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
278
  HII_DIM = 64
279
- num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
280
- startat = 512-num_redshift
 
281
  channel = 1
282
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
283
 
@@ -324,9 +297,11 @@ class TrainConfig:
324
  gradient_accumulation_steps = 1
325
 
326
  pbar_update_step = 20
 
 
327
  # date = datetime.datetime.now().strftime("%m%d-%H%M")
328
  # run_name = f'{date}' # the unique name of each experiment
329
-
330
  # config = TrainConfig()
331
  # print("device =", config.device)
332
 
@@ -372,53 +347,40 @@ class TrainConfig:
372
  # if rank == 0 and all_gradients_consistent:
373
  # print("All model gradients are consistent across GPUs.")
374
  # return all_gradients_consistent
 
 
 
 
 
 
 
 
 
 
375
 
376
  class DDPM21CM:
377
  def __init__(self, config):
378
- # print(
379
- # "torch.cuda.is_available() =", torch.cuda.is_available(),
380
- # "torch.cuda.device_count() =", torch.cuda.device_count(),
381
- # "torch.cuda.is_initialized() =", torch.cuda.is_initialized(),
382
- # "torch.cuda.current_device() =", torch.cuda.current_device()
383
- # )
384
- # config = TrainConfig()
385
- # date = datetime.datetime.now().strftime("%m%d-%H%M")
386
  config.run_name = datetime.datetime.now().strftime("%d%H%M%S") # the unique name of each experiment
387
  self.config = config
388
- # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)
389
- # # self.shape_loaded = dataset.images.shape
390
- # # print("shape_loaded =", self.shape_loaded)
391
- # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
392
- # del dataset
393
- # print("self.ddpm = DDPMScheduler")
394
  self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, config=config,)#, dtype=config.dtype
395
 
396
- # print("self.nn_model = ContextUnet")
397
  # initialize the unet
398
- self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)#, dtype=config.dtype)
399
 
400
- # print("self.nn_model.train()")
401
- # nn_model = ContextUnet(n_param=1, image_size=28)
402
  self.nn_model.train()
403
- # print("self.ddpm.device =", self.ddpm.device)
404
  self.nn_model.to(self.ddpm.device)
405
- # print("before, nn_model.device =", self.ddpm.device)
406
  self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
407
- # print("after, nn_model.device =", self.ddpm.device)
408
- # number of parameters to be trained
409
 
 
410
  if config.resume and os.path.exists(config.resume):
411
  # resume_file = os.path.join(config.output_dir, f"{config.resume}")
412
  # self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
413
  # print(f"resumed nn_model from {config.resume}")
414
  self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
415
  #self.nn_model.module.to(config.dtype)
416
- print(f" {config.run_name} {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} resumed nn_model from {config.resume} with {sum(x.numel() for x in self.nn_model.parameters())} parameters ".center(120,'+'))
417
  else:
418
- print(f" {config.run_name} {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} initialized nn_model randomly with {sum(x.numel() for x in self.nn_model.parameters())} parameters ".center(120,'+'))
419
-
420
- # self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
421
- # print(f" Number of parameters for nn_model: {self.number_of_params} ".center(120,'-'))
422
 
423
  # whether to use ema
424
  if config.ema:
@@ -452,6 +414,7 @@ class DDPM21CM:
452
  dim=self.config.dim,
453
  ranges_dict=self.ranges_dict,
454
  num_workers=min(8,len(os.sched_getaffinity(0))//self.config.world_size),
 
455
  )
456
  # self.shape_loaded = dataset.images.shape
457
  # print("shape_loaded =", self.shape_loaded)
@@ -520,29 +483,11 @@ class DDPM21CM:
520
  else:
521
  print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} torch.distributed.is_initialized False!!!!!!!!!!!!!!!")
522
 
523
- #print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank}; nn_model.device = {self.nn_model.device}")
524
- #acc_prep_start = time()
525
- #self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \
526
- # self.accelerator.prepare(
527
- # self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler
528
- # )
529
- #self.nn_model = self.accelerator.prepare(self.nn_model)
530
- #self.optimizer = self.accelerator.prepare(self.optimizer)
531
- #self.dataloader = self.accelerator.prepare(self.dataloader)
532
- #self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler)
533
- #acc_prep_end = time()
534
- #print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} accelerate.prepare cost {acc_prep_end-acc_prep_start:.3f}s")
535
- # self.nn_model, self.optimizer, self.lr_scheduler = \
536
- # self.accelerator.prepare(
537
- # self.nn_model, self.optimizer, self.lr_scheduler
538
- # )
539
-
540
- # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
541
- # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
542
- # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.DistributedSampler =", self.dataloader.DistributedSampler)
543
- #train_start = time()
544
  global_step = 0
545
  for ep in range(self.config.n_epoch):
 
 
 
546
  self.ddpm.train()
547
  # self.dataloader.sampler.set_epoch(ep)
548
  pbar_train = tqdm(total=len(self.dataloader), file=sys.stderr)#, disable=self.config.global_rank!=0)#, mininterval=self.config.pbar_update_step)#, disable=True)#not self.accelerator.is_local_main_process)
@@ -550,20 +495,11 @@ class DDPM21CM:
550
  #train_end = time()
551
  #print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} ddpm.train costs {train_end-train_start:.3f}s")
552
  for i, (x, c) in enumerate(self.dataloader):
553
- #if i == 0:
554
- # train_end = time()
555
- # print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} ddpm.train costs {train_end-train_start:.3f}s")
556
-
557
  # print(f"cuda:{torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
558
  #with self.accelerator.accumulate(self.nn_model):
559
  x = x.to(self.config.device)#.to(self.config.dtype)
560
- # print("x = x.to(self.config.device), x.dtype =", x.dtype)
561
- # print("x = x.to(self.dtype), x.dtype =", x.dtype)
562
- # print(f"ddpm.add_noise(x), x.dtype = {x.dtype}")
563
- # print(f"ddpm.add_noise(x), xt.dtype = {xt.dtype}")
564
-
565
  # autocast forward propogation
566
- with autocast():
567
  xt, noise, ts = self.ddpm.add_noise(x)
568
 
569
  if self.config.guide_w == -1:
@@ -574,7 +510,9 @@ class DDPM21CM:
574
 
575
  loss = F.mse_loss(noise, noise_pred)
576
  loss = loss / self.config.gradient_accumulation_steps
577
-
 
 
578
  # scaler backward propogation
579
  self.scaler.scale(loss).backward()
580
  #loss.backward()
@@ -610,11 +548,13 @@ class DDPM21CM:
610
  global_step += 1
611
 
612
  if (i+1) % self.config.gradient_accumulation_steps != 0:
613
- print(f"(i+1)%self.config.gradient_accumulation_steps = {(i+1)%self.config.gradient_accumulation_steps}, i = {i}, scg = {self.config.gradient_accumulation_steps}".center(120,'-'))
614
- torch.nn.utils.clip_grad_norm_(self.nn_model.parameters(), max_norm=1.0)
615
- self.optimizer.step()
616
- self.lr_scheduler.step()
617
- self.optimizer.zero_grad()
 
 
618
 
619
 
620
  # if ep == config.n_epoch-1 or (ep+1)*config.save_period==1:
@@ -631,7 +571,7 @@ class DDPM21CM:
631
  del self.nn_model
632
  if self.config.ema:
633
  del self.ema_model
634
- torch.cuda.empty_cache()
635
 
636
  def save(self, ep):
637
  # save model
@@ -720,20 +660,21 @@ class DDPM21CM:
720
  # self.nn_model, self.optimizer, self.lr_scheduler
721
  # )
722
 
723
- self.nn_model.eval()
724
-
725
  # self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)
726
  # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f"{config.resume}"))['ema_unet_state_dict'])
727
  # print(f"resumed ema_model from {config.resume}")
728
 
 
729
  with torch.no_grad():
730
- x_last, x_entire = self.ddpm.sample(
731
- nn_model=self.nn_model,
732
- params=params_normalized.to(self.config.device),
733
- device=self.config.device,
734
- guide_w=self.config.guide_w
735
- )
736
-
 
 
737
  if save:
738
  # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
739
  savetime = datetime.datetime.now().strftime("%d%H%M%S")
@@ -809,6 +750,9 @@ if __name__ == "__main__":
809
  parser.add_argument("--num_image", type=int, required=False, default=32)
810
  parser.add_argument("--n_epoch", type=int, required=False, default=50)
811
  parser.add_argument("--batch_size", type=int, required=False, default=2)
 
 
 
812
 
813
  args = parser.parse_args()
814
 
@@ -823,10 +767,14 @@ if __name__ == "__main__":
823
  config.num_image = args.num_image
824
  config.n_epoch = args.n_epoch
825
  config.batch_size = args.batch_size
 
 
 
 
826
  ############################ training ################################
827
  if args.train:
828
  config.dataset_name = args.train
829
- print(f" training, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size} ".center(120,'#'))
830
  mp.spawn(
831
  train,
832
  args=(world_size, local_world_size, master_addr, master_port, config),
@@ -856,7 +804,7 @@ if __name__ == "__main__":
856
  ]
857
 
858
  for params in params_pairs:
859
- print(f"sampling for {params}, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size}".center(120,'-'))
860
  mp.spawn(
861
  generate_samples,
862
  args=(world_size, local_world_size, master_addr, master_port, config, num_new_img_per_gpu, max_num_img_per_gpu, torch.tensor(params)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # %%
2
  import logging
3
  #logging.getLogger("torch").setLevel(logging.ERROR)
 
38
  from huggingface_hub import notebook_login
39
 
40
  import torch.multiprocessing as mp
41
+ #from torch.utils.data.distributed import DistributedSampler
42
  from torch.nn.parallel import DistributedDataParallel as DDP
43
  from torch.distributed import init_process_group, destroy_process_group
44
  import torch.distributed as dist
 
243
 
244
  dim = 2
245
  #dim = 3#2
246
+ stride = (2,4) if dim == 2 else (2,2,4)
247
  num_image = 32#0#0#640#320#6400#3000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
248
  batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
249
  n_epoch = 100#30#50#20#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
250
  HII_DIM = 64
251
+ num_redshift = 1024#512#256#1024#64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
252
+ startat = 0#512-num_redshift
253
+
254
  channel = 1
255
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
256
 
 
297
  gradient_accumulation_steps = 1
298
 
299
  pbar_update_step = 20
300
+
301
+ channel_mult = (1,2,2,2,4)
302
  # date = datetime.datetime.now().strftime("%m%d-%H%M")
303
  # run_name = f'{date}' # the unique name of each experiment
304
+ str_len = 140
305
  # config = TrainConfig()
306
  # print("device =", config.device)
307
 
 
347
  # if rank == 0 and all_gradients_consistent:
348
  # print("All model gradients are consistent across GPUs.")
349
  # return all_gradients_consistent
350
+ def get_gpu_info(device):
351
+ total_memory = torch.cuda.get_device_properties(device).total_memory
352
+ reserved_memory = torch.cuda.memory_reserved(device)
353
+ allocated_memory = torch.cuda.memory_allocated(device)
354
+ free_memory = reserved_memory - allocated_memory
355
+ return {
356
+ 'total': int(total_memory / 1024**2),
357
+ 'used': int(allocated_memory / 1024**2),
358
+ 'free': int(free_memory / 1024**2),
359
+ }
360
 
361
  class DDPM21CM:
362
  def __init__(self, config):
 
 
 
 
 
 
 
 
363
  config.run_name = datetime.datetime.now().strftime("%d%H%M%S") # the unique name of each experiment
364
  self.config = config
 
 
 
 
 
 
365
  self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, config=config,)#, dtype=config.dtype
366
 
 
367
  # initialize the unet
368
+ self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, channel_mult=config.channel_mult, use_checkpoint=config.use_checkpoint)#, dtype=config.dtype)
369
 
 
 
370
  self.nn_model.train()
 
371
  self.nn_model.to(self.ddpm.device)
 
372
  self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
 
 
373
 
374
+ gpu_info = get_gpu_info(config.device)
375
  if config.resume and os.path.exists(config.resume):
376
  # resume_file = os.path.join(config.output_dir, f"{config.resume}")
377
  # self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
378
  # print(f"resumed nn_model from {config.resume}")
379
  self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
380
  #self.nn_model.module.to(config.dtype)
381
+ print(f"{config.run_name} cuda:{torch.cuda.current_device()}/{self.config.global_rank} resumed nn_model from {config.resume} with {sum(x.numel() for x in self.nn_model.parameters())} parameters, gpu:{gpu_info} MB".center(self.config.str_len,'+'))
382
  else:
383
+ print(f"{config.run_name} cuda:{torch.cuda.current_device()}/{self.config.global_rank} initialized nn_model randomly with {sum(x.numel() for x in self.nn_model.parameters())} parameters, gpu:{gpu_info} MB".center(self.config.str_len,'+'))
 
 
 
384
 
385
  # whether to use ema
386
  if config.ema:
 
414
  dim=self.config.dim,
415
  ranges_dict=self.ranges_dict,
416
  num_workers=min(8,len(os.sched_getaffinity(0))//self.config.world_size),
417
+ str_len = self.config.str_len,
418
  )
419
  # self.shape_loaded = dataset.images.shape
420
  # print("shape_loaded =", self.shape_loaded)
 
483
  else:
484
  print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} torch.distributed.is_initialized False!!!!!!!!!!!!!!!")
485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  global_step = 0
487
  for ep in range(self.config.n_epoch):
488
+ #torch.cuda.empty_cache()
489
+ #print(torch.cuda.memory_summary())#abbreviated=True))
490
+ #print(f"before for loop device{self.config.device} {get_gpu_info(self.config.device)}")
491
  self.ddpm.train()
492
  # self.dataloader.sampler.set_epoch(ep)
493
  pbar_train = tqdm(total=len(self.dataloader), file=sys.stderr)#, disable=self.config.global_rank!=0)#, mininterval=self.config.pbar_update_step)#, disable=True)#not self.accelerator.is_local_main_process)
 
495
  #train_end = time()
496
  #print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} ddpm.train costs {train_end-train_start:.3f}s")
497
  for i, (x, c) in enumerate(self.dataloader):
 
 
 
 
498
  # print(f"cuda:{torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
499
  #with self.accelerator.accumulate(self.nn_model):
500
  x = x.to(self.config.device)#.to(self.config.dtype)
 
 
 
 
 
501
  # autocast forward propogation
502
+ with autocast(enabled=self.config.autocast):
503
  xt, noise, ts = self.ddpm.add_noise(x)
504
 
505
  if self.config.guide_w == -1:
 
510
 
511
  loss = F.mse_loss(noise, noise_pred)
512
  loss = loss / self.config.gradient_accumulation_steps
513
+ #print(f"within autocast #{i}-device{self.config.device} {get_gpu_info(self.config.device)}")
514
+ #print(f"within autocast #{i}-device{self.config.device} t-r-a: {torch.cuda.get_device_properties(self.config.device).total_memory/1024**2}-{torch.cuda.memory_reserved(self.config.device)/1024**2}-{torch.cuda.memory_allocated(self.config.device)/1024**2}")
515
+
516
  # scaler backward propogation
517
  self.scaler.scale(loss).backward()
518
  #loss.backward()
 
548
  global_step += 1
549
 
550
  if (i+1) % self.config.gradient_accumulation_steps != 0:
551
+ print(f"(i+1)%self.config.gradient_accumulation_steps = {(i+1)%self.config.gradient_accumulation_steps}, i = {i}, scg = {self.config.gradient_accumulation_steps}".center(self.config.str_len,'-'))
552
+ #torch.nn.utils.clip_grad_norm_(self.nn_model.parameters(), max_norm=1.0)
553
+ #self.optimizer.step()
554
+ #self.lr_scheduler.step()
555
+ #self.optimizer.zero_grad()
556
+ #print(f"after autocast #{i}-device{self.config.device} {get_gpu_info(self.config.device)}")
557
+ #print(f"after autocast #{i}-device{self.config.device} t-r-a: {torch.cuda.get_device_properties(self.config.device).total_memory/1024**2}-{torch.cuda.memory_reserved(self.config.device)/1024**2}-{torch.cuda.memory_allocated(self.config.device)/1024**2}")
558
 
559
 
560
  # if ep == config.n_epoch-1 or (ep+1)*config.save_period==1:
 
571
  del self.nn_model
572
  if self.config.ema:
573
  del self.ema_model
574
+ #torch.cuda.empty_cache()
575
 
576
  def save(self, ep):
577
  # save model
 
660
  # self.nn_model, self.optimizer, self.lr_scheduler
661
  # )
662
 
 
 
663
  # self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)
664
  # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f"{config.resume}"))['ema_unet_state_dict'])
665
  # print(f"resumed ema_model from {config.resume}")
666
 
667
+ self.nn_model.eval()
668
  with torch.no_grad():
669
+ with autocast(enabled=self.config.autocast):
670
+ #with autocast():
671
+ x_last, x_entire = self.ddpm.sample(
672
+ nn_model=self.nn_model,
673
+ params=params_normalized.to(self.config.device),
674
+ device=self.config.device,
675
+ guide_w=self.config.guide_w
676
+ )
677
+ #print(f"x_last.dtype = {x_last.dtype}")
678
  if save:
679
  # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
680
  savetime = datetime.datetime.now().strftime("%d%H%M%S")
 
750
  parser.add_argument("--num_image", type=int, required=False, default=32)
751
  parser.add_argument("--n_epoch", type=int, required=False, default=50)
752
  parser.add_argument("--batch_size", type=int, required=False, default=2)
753
+ parser.add_argument("--channel_mult", type=float, nargs="+", required=False, default=(1,2,2,2,4))
754
+ parser.add_argument("--autocast", type=int, required=False, default=False)
755
+ parser.add_argument("--use_checkpoint", type=int, required=False, default=False)
756
 
757
  args = parser.parse_args()
758
 
 
767
  config.num_image = args.num_image
768
  config.n_epoch = args.n_epoch
769
  config.batch_size = args.batch_size
770
+ config.channel_mult = args.channel_mult
771
+ config.autocast = bool(args.autocast)
772
+ config.use_checkpoint = bool(args.use_checkpoint)
773
+
774
  ############################ training ################################
775
  if args.train:
776
  config.dataset_name = args.train
777
+ print(f" training, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size} ".center(config.str_len,'#'))
778
  mp.spawn(
779
  train,
780
  args=(world_size, local_world_size, master_addr, master_port, config),
 
804
  ]
805
 
806
  for params in params_pairs:
807
+ print(f"sampling for {params}, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size}".center(config.str_len,'-'))
808
  mp.spawn(
809
  generate_samples,
810
  args=(world_size, local_world_size, master_addr, master_port, config, num_new_img_per_gpu, max_num_img_per_gpu, torch.tensor(params)),
load_h5.py CHANGED
@@ -46,6 +46,7 @@ class Dataset4h5(Dataset):
46
  num_workers=1,#len(os.sched_getaffinity(0))//torch.cuda.device_count(),
47
  startat=0,
48
  # shuffle=False,
 
49
  ):
50
  super().__init__()
51
 
@@ -61,6 +62,7 @@ class Dataset4h5(Dataset):
61
  self.transform = transform
62
  self.num_workers = num_workers
63
  self.startat = startat
 
64
 
65
  self.load_h5()
66
  if rescale:
@@ -114,7 +116,7 @@ class Dataset4h5(Dataset):
114
  concurrent_init_start = time()
115
  with concurrent.futures.ProcessPoolExecutor(max_workers=self.num_workers) as executor:
116
  concurrent_init_end = time()
117
- print(f" {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}, concurrently loading by {self.num_workers}/{len(os.sched_getaffinity(0))} workers, initialized after {concurrent_init_end-concurrent_init_start:.3f}s ".center(120, '-'))
118
  futures = [None] * self.num_workers
119
  for i, idx in enumerate(np.array_split(self.idx, self.num_workers)):
120
  executor_start = time()
@@ -129,7 +131,7 @@ class Dataset4h5(Dataset):
129
  self.params[start_idx:start_idx+batch_size] = params
130
  start_idx += batch_size
131
  concurrent_end = time()
132
- print(f" {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}, {start_idx} images {self.images.shape} & params {self.params.shape} loaded after {concurrent_start-concurrent_init_start:.3f}/{concurrent_end-concurrent_start:.3f}s ".center(120, '-'))
133
 
134
  transform_start = time()
135
  if self.transform:
 
46
  num_workers=1,#len(os.sched_getaffinity(0))//torch.cuda.device_count(),
47
  startat=0,
48
  # shuffle=False,
49
+ str_len = 120,
50
  ):
51
  super().__init__()
52
 
 
62
  self.transform = transform
63
  self.num_workers = num_workers
64
  self.startat = startat
65
+ self.str_len = str_len
66
 
67
  self.load_h5()
68
  if rescale:
 
116
  concurrent_init_start = time()
117
  with concurrent.futures.ProcessPoolExecutor(max_workers=self.num_workers) as executor:
118
  concurrent_init_end = time()
119
+ print(f" {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}, concurrently loading by {self.num_workers}/{len(os.sched_getaffinity(0))} workers, initialized after {concurrent_init_end-concurrent_init_start:.3f}s ".center(self.str_len, '-'))
120
  futures = [None] * self.num_workers
121
  for i, idx in enumerate(np.array_split(self.idx, self.num_workers)):
122
  executor_start = time()
 
131
  self.params[start_idx:start_idx+batch_size] = params
132
  start_idx += batch_size
133
  concurrent_end = time()
134
+ print(f" {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}, {start_idx} images {self.images.shape} & params {self.params.shape} loaded after {concurrent_start-concurrent_init_start:.3f}/{concurrent_end-concurrent_start:.3f}s ".center(self.str_len, '-'))
135
 
136
  transform_start = time()
137
  if self.transform:
perlmutter_diffusion.sbatch CHANGED
@@ -1,13 +1,13 @@
1
  #!/bin/bash
2
  #SBATCH -A m4717
3
  #SBATCH -J diffusion
4
- #SBATCH -C gpu
5
  #SBATCH -q shared #regular
6
  #SBATCH -N1
7
  #SBATCH --gpus-per-node=1
8
- #SBATCH -t 0:59:00
9
  #SBATCH --ntasks-per-node=1
10
- #SBATCH -oReport-%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL
12
  #SBATCH --gpu-bind=none
13
 
@@ -33,15 +33,19 @@ export MASTER_PORT=$MASTER_PORT
33
  #export NCCL_DEBUG=INFO
34
  #export NCCL_DEBUG_SUBSYS=ALL
35
  cat $0
 
36
 
37
  srun python diffusion.py \
38
- --train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
39
- --num_image 3200 \
40
- --batch_size 32 \
41
- --n_epoch 100 \
42
  --gradient_accumulation_steps 1 \
43
- --num_new_img_per_gpu 800 \
44
- --max_num_img_per_gpu 80 \
45
- #--resume outputs/model-N3200-device_count1-node1-epoch99-16103542 \
 
 
 
 
46
 
47
  date
 
1
  #!/bin/bash
2
  #SBATCH -A m4717
3
  #SBATCH -J diffusion
4
+ #SBATCH -C gpu&hbm80g
5
  #SBATCH -q shared #regular
6
  #SBATCH -N1
7
  #SBATCH --gpus-per-node=1
8
+ #SBATCH -t 3:00:00
9
  #SBATCH --ntasks-per-node=1
10
+ #SBATCH -o%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL
12
  #SBATCH --gpu-bind=none
13
 
 
33
  #export NCCL_DEBUG=INFO
34
  #export NCCL_DEBUG_SUBSYS=ALL
35
  cat $0
36
+ #nvidia-smi
37
 
38
  srun python diffusion.py \
39
+ --num_image 6400 \
40
+ --batch_size 64 \
41
+ --n_epoch 50 \
 
42
  --gradient_accumulation_steps 1 \
43
+ --num_new_img_per_gpu 200 \
44
+ --max_num_img_per_gpu 20 \
45
+ --channel_mult 1 1 2 2 4 \
46
+ --autocast 1 \
47
+ --use_checkpoint 1 \
48
+ --train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
49
+ #--resume outputs/model-N3200-device_count1-node1-epoch99-17160118 \
50
 
51
  date
quantify_results.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
tensorboard.ipynb CHANGED
@@ -23,13 +23,13 @@
23
  "data": {
24
  "text/html": [
25
  "\n",
26
- " <iframe id=\"tensorboard-frame-262245829087dd6a\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
27
  " </iframe>\n",
28
  " <script>\n",
29
  " (function() {\n",
30
- " const frame = document.getElementById(\"tensorboard-frame-262245829087dd6a\");\n",
31
  " const url = new URL(\"/\", window.location);\n",
32
- " const port = 45355;\n",
33
  " if (port) {\n",
34
  " url.port = port;\n",
35
  " }\n",
@@ -59,7 +59,7 @@
59
  {
60
  "data": {
61
  "text/html": [
62
- "<a href=\"https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/45355/\">https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/45355/</a>"
63
  ],
64
  "text/plain": [
65
  "<IPython.core.display.HTML object>"
@@ -72,6 +72,14 @@
72
  "source": [
73
  "nersc_tensorboard_helper.tb_address()"
74
  ]
 
 
 
 
 
 
 
 
75
  }
76
  ],
77
  "metadata": {
 
23
  "data": {
24
  "text/html": [
25
  "\n",
26
+ " <iframe id=\"tensorboard-frame-13f025ce79187ae\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
27
  " </iframe>\n",
28
  " <script>\n",
29
  " (function() {\n",
30
+ " const frame = document.getElementById(\"tensorboard-frame-13f025ce79187ae\");\n",
31
  " const url = new URL(\"/\", window.location);\n",
32
+ " const port = 41355;\n",
33
  " if (port) {\n",
34
  " url.port = port;\n",
35
  " }\n",
 
59
  {
60
  "data": {
61
  "text/html": [
62
+ "<a href=\"https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/41355/\">https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/41355/</a>"
63
  ],
64
  "text/plain": [
65
  "<IPython.core.display.HTML object>"
 
72
  "source": [
73
  "nersc_tensorboard_helper.tb_address()"
74
  ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "8ca783fe-501c-4e12-b769-f037b4671ef0",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": []
83
  }
84
  ],
85
  "metadata": {