Xsmos commited on
Commit
4c7f352
·
verified ·
1 Parent(s): f8f5f14
Files changed (3) hide show
  1. context_unet.py +1 -1
  2. diffusion.py +8 -8
  3. quantify_results.ipynb +0 -0
context_unet.py CHANGED
@@ -330,7 +330,7 @@ class ContextUnet(nn.Module):
330
  elif image_size == 128:
331
  channel_mult = (1, 1, 2, 3, 4)
332
  elif image_size == 64:
333
- channel_mult = (1, 2, 2, 4, 8)#(1, 2, 3, 4)#(1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
334
  elif image_size == 32:
335
  channel_mult = (1, 2, 2, 4)
336
  elif image_size == 28:
 
330
  elif image_size == 128:
331
  channel_mult = (1, 1, 2, 3, 4)
332
  elif image_size == 64:
333
+ channel_mult = (1, 2, 4, 8, 16)#(1, 2, 3, 4)#(1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
334
  elif image_size == 32:
335
  channel_mult = (1, 2, 2, 4)
336
  elif image_size == 28:
diffusion.py CHANGED
@@ -231,7 +231,7 @@ class TrainConfig:
231
  push_to_hub = True
232
  hub_model_id = "Xsmos/ml21cm"
233
  hub_private_repo = False
234
- dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8-4.4-131.341.h5"
235
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
236
  world_size = torch.cuda.device_count()
237
  # repeat = 2
@@ -240,8 +240,8 @@ class TrainConfig:
240
  dim = 2
241
  stride = (2,2) if dim == 2 else (2,2,2)
242
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
243
- batch_size = 50#1#2#50#20#2#100 # 10
244
- n_epoch = 120#5#4# 10#50#20#20#2#5#25 # 120
245
  HII_DIM = 64
246
  num_redshift = 64#512#128#64#512#256#256#64#512#128
247
  channel = 1
@@ -583,7 +583,7 @@ class DDPM21CM:
583
  return x_last
584
  # %%
585
 
586
- num_train_image_list = [800]
587
 
588
  def train(rank, world_size):
589
  config = TrainConfig()
@@ -689,10 +689,10 @@ if __name__ == "__main__":
689
 
690
  params_pairs = [
691
  (4.4, 131.341),
692
- # (5.6, 19.037),
693
- # (4.699, 30),
694
- # (5.477, 200),
695
- # (4.8, 131.341),
696
  ]
697
  for params in params_pairs:
698
  print(f" sampling for {params}, world_size = {world_size} ".center(100,'-'))
 
231
  push_to_hub = True
232
  hub_model_id = "Xsmos/ml21cm"
233
  hub_private_repo = False
234
+ dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
235
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
236
  world_size = torch.cuda.device_count()
237
  # repeat = 2
 
240
  dim = 2
241
  stride = (2,2) if dim == 2 else (2,2,2)
242
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
243
+ batch_size = 20#1#2#50#20#2#100 # 10
244
+ n_epoch = 30#120#5#4# 10#50#20#20#2#5#25 # 120
245
  HII_DIM = 64
246
  num_redshift = 64#512#128#64#512#256#256#64#512#128
247
  channel = 1
 
583
  return x_last
584
  # %%
585
 
586
+ num_train_image_list = [5000]
587
 
588
  def train(rank, world_size):
589
  config = TrainConfig()
 
689
 
690
  params_pairs = [
691
  (4.4, 131.341),
692
+ (5.6, 19.037),
693
+ (4.699, 30),
694
+ (5.477, 200),
695
+ (4.8, 131.341),
696
  ]
697
  for params in params_pairs:
698
  print(f" sampling for {params}, world_size = {world_size} ".center(100,'-'))
quantify_results.ipynb CHANGED
The diff for this file is too large to render. See raw diff