Xsmos commited on
Commit
ee36fac
·
verified ·
1 Parent(s): 89e85d4
Files changed (3) hide show
  1. diffusion.py +25 -15
  2. load_h5.py +35 -11
  3. quantify_results.ipynb +0 -0
diffusion.py CHANGED
@@ -241,9 +241,9 @@ class TrainConfig:
241
  stride = (2,2) if dim == 2 else (2,2,2)
242
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
243
  batch_size = 50#1#2#50#20#2#100 # 10
244
- n_epoch = 100#30#120#5#4# 10#50#20#20#2#5#25 # 120
245
  HII_DIM = 64
246
- num_redshift = 64#512#128#64#512#256#256#64#512#128
247
  channel = 1
248
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
249
 
@@ -366,18 +366,28 @@ class DDPM21CM:
366
  self.ranges_dict = config.ranges_dict
367
 
368
  def load(self):
369
- dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)
 
 
 
 
 
 
 
 
 
 
370
  # self.shape_loaded = dataset.images.shape
371
  # print("shape_loaded =", self.shape_loaded)
372
  # print(f"load, current_device() = {torch.cuda.current_device()}")
373
  self.dataloader = DataLoader(
374
  dataset=dataset,
375
  batch_size=self.config.batch_size,
376
- shuffle=False,
377
- num_workers=1,#len(os.sched_getaffinity(0)),
378
  pin_memory=True,
379
  persistent_workers=True,
380
- sampler=DistributedSampler(dataset),
381
  )
382
 
383
  del dataset
@@ -414,14 +424,14 @@ class DDPM21CM:
414
 
415
 
416
  # print("!!!!!!!!!!!!!!!!, before prepare, self.dataloader.sampler =", self.dataloader.sampler)
417
- # self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \
418
- # self.accelerator.prepare(
419
- # self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler
420
- # )
421
- self.nn_model, self.optimizer, self.lr_scheduler = \
422
  self.accelerator.prepare(
423
- self.nn_model, self.optimizer, self.lr_scheduler
424
  )
 
 
 
 
425
 
426
  # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
427
  # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
@@ -430,7 +440,7 @@ class DDPM21CM:
430
  global_step = 0
431
  for ep in range(self.config.n_epoch):
432
  self.ddpm.train()
433
- self.dataloader.sampler.set_epoch(ep)
434
 
435
  pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
436
  pbar_train.set_description(f"device {torch.cuda.current_device()}, Epoch {ep}")
@@ -527,7 +537,7 @@ class DDPM21CM:
527
  # n_sample = params.shape[0]
528
  # file = self.config.resume
529
 
530
- print(f"device {torch.cuda.current_device()}, sample, params = {params}")
531
  if params is None:
532
  params = torch.tensor([4.4, 131.341])
533
  # params_backup = params.numpy().copy()
@@ -583,7 +593,7 @@ class DDPM21CM:
583
  return x_last
584
  # %%
585
 
586
- num_train_image_list = [8000]
587
 
588
  def train(rank, world_size):
589
  config = TrainConfig()
 
241
  stride = (2,2) if dim == 2 else (2,2,2)
242
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
243
  batch_size = 50#1#2#50#20#2#100 # 10
244
+ n_epoch = 50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
245
  HII_DIM = 64
246
+ num_redshift = 64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
247
  channel = 1
248
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
249
 
 
366
  self.ranges_dict = config.ranges_dict
367
 
368
  def load(self):
369
+ # rank = torch.cuda.current_device()
370
+ dataset = Dataset4h5(
371
+ self.config.dataset_name,
372
+ num_image=self.config.num_image,
373
+ idx = 'random',
374
+ HII_DIM=self.config.HII_DIM,
375
+ num_redshift=self.config.num_redshift,
376
+ drop_prob=self.config.drop_prob,
377
+ dim=self.config.dim,
378
+ ranges_dict=self.ranges_dict
379
+ )
380
  # self.shape_loaded = dataset.images.shape
381
  # print("shape_loaded =", self.shape_loaded)
382
  # print(f"load, current_device() = {torch.cuda.current_device()}")
383
  self.dataloader = DataLoader(
384
  dataset=dataset,
385
  batch_size=self.config.batch_size,
386
+ shuffle=True,#False,
387
+ num_workers=len(os.sched_getaffinity(0)),
388
  pin_memory=True,
389
  persistent_workers=True,
390
+ # sampler=DistributedSampler(dataset),
391
  )
392
 
393
  del dataset
 
424
 
425
 
426
  # print("!!!!!!!!!!!!!!!!, before prepare, self.dataloader.sampler =", self.dataloader.sampler)
427
+ self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \
 
 
 
 
428
  self.accelerator.prepare(
429
+ self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler
430
  )
431
+ # self.nn_model, self.optimizer, self.lr_scheduler = \
432
+ # self.accelerator.prepare(
433
+ # self.nn_model, self.optimizer, self.lr_scheduler
434
+ # )
435
 
436
  # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
437
  # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
 
440
  global_step = 0
441
  for ep in range(self.config.n_epoch):
442
  self.ddpm.train()
443
+ # self.dataloader.sampler.set_epoch(ep)
444
 
445
  pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
446
  pbar_train.set_description(f"device {torch.cuda.current_device()}, Epoch {ep}")
 
537
  # n_sample = params.shape[0]
538
  # file = self.config.resume
539
 
540
+ # print(f"device {torch.cuda.current_device()}, sample, params = {params}")
541
  if params is None:
542
  params = torch.tensor([4.4, 131.341])
543
  # params_backup = params.numpy().copy()
 
593
  return x_last
594
  # %%
595
 
596
+ num_train_image_list = [200]#[8000]
597
 
598
  def train(rank, world_size):
599
  config = TrainConfig()
load_h5.py CHANGED
@@ -26,14 +26,28 @@ import datetime
26
  # from huggingface_hub import create_repo, upload_folder
27
 
28
  class Dataset4h5(Dataset):
29
- def __init__(self, dir_name, num_image=10, field='brightness_temp', shuffle=False, idx=None, num_redshift=512, HII_DIM=64, rescale=True, drop_prob = 0, dim=2, transform=True, ranges_dict=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  super().__init__()
31
 
32
  self.dir_name = dir_name
33
  self.num_image = num_image
34
- self.field = field
35
- self.shuffle = shuffle
36
  self.idx = idx
 
 
37
  self.num_redshift = num_redshift
38
  self.HII_DIM = HII_DIM
39
  self.drop_prob = drop_prob
@@ -81,14 +95,24 @@ class Dataset4h5(Dataset):
81
  self.params_keys = list(f['params']['keys'])
82
  print(f"params keys = {self.params_keys}")
83
 
84
- if self.idx is None:
85
- if self.shuffle:
86
- self.idx = np.sort(random.sample(range(max_num_image), self.num_image))
87
- print(f"loading {self.num_image} images randomly")
88
- # print(self.idx)
89
- else:
90
- self.idx = range(self.num_image)
91
- print(f"loading {len(self.idx)} images with idx = {self.idx}")
 
 
 
 
 
 
 
 
 
 
92
  else:
93
  print(f"loading {len(self.idx)} images with idx = {self.idx}")
94
 
 
26
  # from huggingface_hub import create_repo, upload_folder
27
 
28
  class Dataset4h5(Dataset):
29
+ def __init__(
30
+ self,
31
+ dir_name,
32
+ num_image=10,
33
+ field='brightness_temp',
34
+ idx=None,
35
+ num_redshift=512,
36
+ HII_DIM=64,
37
+ rescale=True,
38
+ drop_prob = 0,
39
+ dim=2,
40
+ transform=True,
41
+ ranges_dict=None,
42
+ # shuffle=False,
43
+ ):
44
  super().__init__()
45
 
46
  self.dir_name = dir_name
47
  self.num_image = num_image
 
 
48
  self.idx = idx
49
+ self.field = field
50
+ # self.shuffle = shuffle
51
  self.num_redshift = num_redshift
52
  self.HII_DIM = HII_DIM
53
  self.drop_prob = drop_prob
 
95
  self.params_keys = list(f['params']['keys'])
96
  print(f"params keys = {self.params_keys}")
97
 
98
+ # if self.idx is None:
99
+ # if self.shuffle:
100
+ # self.idx = np.sort(random.sample(range(max_num_image), self.num_image))
101
+ # print(f"loading {self.num_image} images randomly")
102
+ # # print(self.idx)
103
+ # else:
104
+ # self.idx = range(self.num_image)
105
+ # print(f"loading {len(self.idx)} images with idx = {self.idx}")
106
+ if self.idx == "random":
107
+ self.idx = np.sort(random.sample(range(max_num_image), self.num_image))
108
+ print(f"loading {self.num_image} images randomly with idx = {self.idx}")
109
+ # print(self.idx)
110
+ elif self.idx == "range":
111
+ rank = torch.cuda.current_device()
112
+ self.idx = range(
113
+ rank*self.num_image, (rank+1)*self.num_image
114
+ )
115
+ print(f"loading {len(self.idx)} images with idx = {self.idx}")
116
  else:
117
  print(f"loading {len(self.idx)} images with idx = {self.idx}")
118
 
quantify_results.ipynb CHANGED
The diff for this file is too large to render. See raw diff