Xsmos commited on
Commit
b8fcc70
·
verified ·
1 Parent(s): 84eaa74

0810-170449

Browse files
context_unet.py CHANGED
@@ -32,12 +32,15 @@ class GroupNorm32(nn.GroupNorm):
32
  self.swish = swish
33
 
34
  def forward(self, x):
35
- # print("GroupNorm32, x.dtype =", x.dtype)
36
- y = super().forward(x.float()).to(x.dtype)
 
 
37
  if self.swish == 1.0:
38
  y = F.silu(y)
39
  elif self.swish:
40
  y = y * F.sigmoid(y * float(self.swish))
 
41
  return y
42
 
43
  def normalization(channels, swish=0.0):
@@ -284,7 +287,7 @@ def timestep_embedding(timesteps, dim, max_period=10000):
284
  :param max_period: controls the minimum frequency of the embeddings.
285
  :return: an [N x dim] Tensor of positional embeddings.
286
  """
287
- #print (timesteps.shape)
288
  half = dim // 2
289
  freqs = torch.exp(
290
  -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
@@ -294,6 +297,7 @@ def timestep_embedding(timesteps, dim, max_period=10000):
294
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
295
  if dim % 2:
296
  embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
 
297
  return embedding
298
 
299
  class ContextUnet(nn.Module):
@@ -522,32 +526,34 @@ class ContextUnet(nn.Module):
522
  def forward(self, x, timesteps, y=None):
523
  hs = []
524
  # print("device of timesteps, self.model_channels:", timesteps.device, self.model_channels)
525
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
 
526
  if y != None:
527
- text_outputs = self.token_embedding(y.float())
 
528
  emb = emb + text_outputs.to(emb)
529
 
530
- # print("forward, h = x.type(self.dtype), self.dtype =", self.dtype)
531
  h = x.type(self.dtype)
532
- # print("0,h.shape =", h.shape)
533
  for module in self.input_blocks:
534
  h = module(h, emb)
535
  hs.append(h)
536
- # print("module encoder, h.shape =", h.shape)
537
  # print("2,h.shape =", h.shape)
538
  h = self.middle_block(h, emb)
539
- # print("middle block, h.shape =", h.shape)
540
  # print("2,h.shape =", h.shape)
541
  for module in self.output_blocks:
542
- # print("for module in self.output_blocks, h.shape =", h.shape)
543
  # print("len(hs) =", len(hs), ", hs[-1].shape =", hs[-1].shape)
544
  h = torch.cat([h, hs.pop()], dim=1)
545
  h = module(h, emb)
546
  # print("module decoder, h.shape =", h.shape)
547
 
548
- # print("h = h.type(x.dtype), x.dtype =", x.dtype)
549
  h = h.type(x.dtype)
550
  h = self.out(h)
551
- # print("self.out(h)", "h.shape =", h.shape)
552
 
553
  return h
 
32
  self.swish = swish
33
 
34
  def forward(self, x):
35
+ #print(f"GroupNorm32, x.dtype = {x.dtype}, x.float().dtype = {x.float().dtype}, swish = {self.swish}")
36
+ #y = super().forward(x.float()).to(x.dtype)
37
+ y = super().forward(x)
38
+ #print(f"swish == {self.swish}, {y.dtype}")
39
  if self.swish == 1.0:
40
  y = F.silu(y)
41
  elif self.swish:
42
  y = y * F.sigmoid(y * float(self.swish))
43
+ #print(f"swish == {self.swish}, {y.dtype}")
44
  return y
45
 
46
  def normalization(channels, swish=0.0):
 
287
  :param max_period: controls the minimum frequency of the embeddings.
288
  :return: an [N x dim] Tensor of positional embeddings.
289
  """
290
+ #print(f"timestep_embedding is running")
291
  half = dim // 2
292
  freqs = torch.exp(
293
  -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
 
297
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
298
  if dim % 2:
299
  embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
300
+ #print(f"timestep_embedding is ending")
301
  return embedding
302
 
303
  class ContextUnet(nn.Module):
 
526
  def forward(self, x, timesteps, y=None):
527
  hs = []
528
  # print("device of timesteps, self.model_channels:", timesteps.device, self.model_channels)
529
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels).to(self.dtype))
530
+ #print(f"forward after emb")
531
  if y != None:
532
+ #text_outputs = self.token_embedding(y.float())
533
+ text_outputs = self.token_embedding(y.to(self.dtype))
534
  emb = emb + text_outputs.to(emb)
535
 
536
+ #print("forward, h = x.type(self.dtype), self.dtype =", self.dtype)
537
  h = x.type(self.dtype)
538
+ #print("0,h.shape =", h.shape)
539
  for module in self.input_blocks:
540
  h = module(h, emb)
541
  hs.append(h)
542
+ #print("module encoder, h.shape =", h.shape)
543
  # print("2,h.shape =", h.shape)
544
  h = self.middle_block(h, emb)
545
+ #print("middle block, h.shape =", h.shape)
546
  # print("2,h.shape =", h.shape)
547
  for module in self.output_blocks:
548
+ #print("for module in self.output_blocks, h.shape =", h.shape)
549
  # print("len(hs) =", len(hs), ", hs[-1].shape =", hs[-1].shape)
550
  h = torch.cat([h, hs.pop()], dim=1)
551
  h = module(h, emb)
552
  # print("module decoder, h.shape =", h.shape)
553
 
554
+ #print("h = h.type(x.dtype), x.dtype =", x.dtype)
555
  h = h.type(x.dtype)
556
  h = self.out(h)
557
+ #print("self.out(h)", "h.shape =", h.shape)
558
 
559
  return h
diffusion.py CHANGED
@@ -115,8 +115,9 @@ def ddp_setup(rank: int, world_size: int, master_addr, master_port):
115
 
116
  # %%
117
  class DDPMScheduler(nn.Module):
118
- def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu', dtype=torch.float32, config=None):
119
  super().__init__()
 
120
 
121
  beta_1, beta_T = betas
122
  assert 0 < beta_1 <= beta_T <= 1, "ensure 0 < beta_1 <= beta_T <= 1"
@@ -124,6 +125,7 @@ class DDPMScheduler(nn.Module):
124
  self.num_timesteps = num_timesteps
125
  self.img_shape = img_shape
126
  self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1
 
127
  self.beta_t = self.beta_t.to(self.device)
128
 
129
  # self.drop_prob = drop_prob
@@ -132,7 +134,6 @@ class DDPMScheduler(nn.Module):
132
  # self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
133
  self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
134
  # self.use_fp16 = use_fp16
135
- self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32
136
  self.config = config
137
 
138
  def add_noise(self, clean_images):
@@ -157,15 +158,18 @@ class DDPMScheduler(nn.Module):
157
  def sample(self, nn_model, params, device, guide_w = 0):
158
  n_sample = len(params) #params.shape[0]
159
  # print("params.shape[0], len(params)", params.shape[0], len(params))
160
- x_i = torch.randn(n_sample, *self.img_shape).to(device)
 
 
161
  # print("x_i.shape =", x_i.shape)
162
  # print("x_i.shape =", x_i.shape)
163
  if guide_w != -1:
164
  c_i = params
165
- uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)
166
  # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)
167
  # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)
168
- c_i = torch.cat((c_i, uncond_tokens), 0)
 
169
 
170
  x_i_entire = [] # keep track of generated steps in case want to plot something
171
  # print("self.num_timesteps =", self.num_timesteps)
@@ -177,8 +181,10 @@ class DDPMScheduler(nn.Module):
177
  # print(f'sampling timestep {i:4d}',end='\r')
178
  t_is = torch.tensor([i]).to(device)
179
  t_is = t_is.repeat(n_sample)
 
180
 
181
- z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else 0
 
182
 
183
  if guide_w == -1:
184
  # eps = nn_model(x_i, t_is, return_dict=False)[0]
@@ -186,22 +192,26 @@ class DDPMScheduler(nn.Module):
186
  # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
187
  else:
188
  # double batch
189
- x_i = x_i.repeat(2, *torch.ones(len(self.img_shape), dtype=int).tolist())
190
- t_is = t_is.repeat(2)
 
191
 
192
  # split predictions and compute weighting
193
  # print("nn_model input shape", x_i.shape, t_is.shape, c_i.shape)
 
194
  eps = nn_model(x_i, t_is, c_i)
195
- eps1 = eps[:n_sample]
196
- eps2 = eps[n_sample:]
197
- eps = eps1 + guide_w*(eps1 - eps2)
198
  # eps = (1+guide_w)*eps1 - guide_w*eps2
199
- x_i = x_i[:n_sample]
200
  # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
201
 
202
  # print("x_i.shape =", x_i.shape)
 
203
  x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
204
-
 
205
  pbar_sample.update(1)
206
 
207
  # store only part of the intermediate steps
@@ -257,12 +267,12 @@ class TrainConfig:
257
 
258
  # dim = 2
259
  dim = 3#2
260
- stride = (2,4) if dim == 2 else (2,2,4)
261
- num_image = 480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
262
- batch_size = 1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
263
  n_epoch = 50#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
264
  HII_DIM = 64
265
- num_redshift = 512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
266
  channel = 1
267
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
268
 
@@ -396,6 +406,7 @@ class DDPM21CM:
396
  # self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
397
  # print(f"resumed nn_model from {config.resume}")
398
  self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
 
399
  print(f"{config.run_name} {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} resumed nn_model from {config.resume} with {sum(x.numel() for x in self.nn_model.parameters())} parameters".center(120,'-'))
400
  else:
401
  print(f"{config.run_name} {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} initialized nn_model randomly with {sum(x.numel() for x in self.nn_model.parameters())} parameters".center(120,'-'))
@@ -511,8 +522,9 @@ class DDPM21CM:
511
  # print("x = x.to(self.config.device), x.dtype =", x.dtype)
512
  # x = x.to(self.config.dtype)
513
  # print("x = x.to(self.dtype), x.dtype =", x.dtype)
 
514
  xt, noise, ts = self.ddpm.add_noise(x)
515
-
516
  if self.config.guide_w == -1:
517
  noise_pred = self.nn_model(xt, ts)
518
  else:
 
115
 
116
  # %%
117
  class DDPMScheduler(nn.Module):
118
+ def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu', dtype=torch.float16, config=None):
119
  super().__init__()
120
+ self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32
121
 
122
  beta_1, beta_T = betas
123
  assert 0 < beta_1 <= beta_T <= 1, "ensure 0 < beta_1 <= beta_T <= 1"
 
125
  self.num_timesteps = num_timesteps
126
  self.img_shape = img_shape
127
  self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1
128
+ self.beta_t = self.beta_t.to(self.dtype)
129
  self.beta_t = self.beta_t.to(self.device)
130
 
131
  # self.drop_prob = drop_prob
 
134
  # self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
135
  self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
136
  # self.use_fp16 = use_fp16
 
137
  self.config = config
138
 
139
  def add_noise(self, clean_images):
 
158
  def sample(self, nn_model, params, device, guide_w = 0):
159
  n_sample = len(params) #params.shape[0]
160
  # print("params.shape[0], len(params)", params.shape[0], len(params))
161
+ x_i = torch.randn(n_sample, *self.img_shape).to(self.dtype)
162
+ x_i = x_i.to(device)
163
+ #print(f"#1 x_i.device = {x_i.device}")
164
  # print("x_i.shape =", x_i.shape)
165
  # print("x_i.shape =", x_i.shape)
166
  if guide_w != -1:
167
  c_i = params
168
+ #uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)
169
  # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)
170
  # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)
171
+ #c_i = torch.cat((c_i, uncond_tokens), 0)
172
+ c_i = c_i.to(self.dtype)
173
 
174
  x_i_entire = [] # keep track of generated steps in case want to plot something
175
  # print("self.num_timesteps =", self.num_timesteps)
 
181
  # print(f'sampling timestep {i:4d}',end='\r')
182
  t_is = torch.tensor([i]).to(device)
183
  t_is = t_is.repeat(n_sample)
184
+ t_is = t_is.to(self.dtype)
185
 
186
+ z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else torch.tensor(0.)
187
+ z = z.to(self.dtype)
188
 
189
  if guide_w == -1:
190
  # eps = nn_model(x_i, t_is, return_dict=False)[0]
 
192
  # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
193
  else:
194
  # double batch
195
+ #print(f"#2 x_i.device = {x_i.device}")
196
+ #x_i = x_i.repeat(2, *torch.ones(len(self.img_shape), dtype=int).tolist())
197
+ #t_is = t_is.repeat(2)
198
 
199
  # split predictions and compute weighting
200
  # print("nn_model input shape", x_i.shape, t_is.shape, c_i.shape)
201
+ #print(f"sample, i = {i}, x_i.dtype = {x_i.dtype}, c_i.dtype = {c_i.dtype}")
202
  eps = nn_model(x_i, t_is, c_i)
203
+ #eps1 = eps[:n_sample]
204
+ #eps2 = eps[n_sample:]
205
+ #eps = eps1 + guide_w*(eps1 - eps2)
206
  # eps = (1+guide_w)*eps1 - guide_w*eps2
207
+ #x_i = x_i[:n_sample]
208
  # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
209
 
210
  # print("x_i.shape =", x_i.shape)
211
+ #print(f"before, x_i.dtype = {x_i.dtype}, beta_t.dtype = {self.beta_t.dtype}, eps.dtype = {eps.dtype}, alpha_t.dtype = {self.alpha_t.dtype}, z.dtype = {z.dtype}")
212
  x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
213
+ #print(f"after, x_i.dtype = {x_i.dtype}, beta_t.dtype = {self.beta_t.dtype}, eps.dtype = {eps.dtype}, alpha_t.dtype = {self.alpha_t.dtype}, z.dtype = {z.dtype}")
214
+
215
  pbar_sample.update(1)
216
 
217
  # store only part of the intermediate steps
 
267
 
268
  # dim = 2
269
  dim = 3#2
270
+ stride = (2,4) if dim == 2 else (2,2,2)
271
+ num_image = 30#00#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
272
+ batch_size = 5#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
273
  n_epoch = 50#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
274
  HII_DIM = 64
275
+ num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
276
  channel = 1
277
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
278
 
 
406
  # self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
407
  # print(f"resumed nn_model from {config.resume}")
408
  self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
409
+ self.nn_model.module.to(config.dtype)
410
  print(f"{config.run_name} {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} resumed nn_model from {config.resume} with {sum(x.numel() for x in self.nn_model.parameters())} parameters".center(120,'-'))
411
  else:
412
  print(f"{config.run_name} {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} initialized nn_model randomly with {sum(x.numel() for x in self.nn_model.parameters())} parameters".center(120,'-'))
 
522
  # print("x = x.to(self.config.device), x.dtype =", x.dtype)
523
  # x = x.to(self.config.dtype)
524
  # print("x = x.to(self.dtype), x.dtype =", x.dtype)
525
+ #print(f"ddpm.add_noise(x), x.dtype = {x.dtype}")
526
  xt, noise, ts = self.ddpm.add_noise(x)
527
+ #print(f"ddpm.add_noise(x), xt.dtype = {xt.dtype}")
528
  if self.config.guide_w == -1:
529
  noise_pred = self.nn_model(xt, ts)
530
  else:
frontera_generate_dataset.sbatch ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #----------------------------------------------------
3
+ # Sample Slurm job script
4
+ # for TACC Frontera CLX nodes
5
+ #
6
+ # *** MPI Job in Normal Queue ***
7
+ #
8
+ # Last revised: 20 May 2019
9
+ #
10
+ # Notes:
11
+ #
12
+ # -- Launch this script by executing
13
+ # "sbatch clx.mpi.slurm" on a Frontera login node.
14
+ #
15
+ # -- Use ibrun to launch MPI codes on TACC systems.
16
+ # Do NOT use mpirun or mpiexec.
17
+ #
18
+ # -- Max recommended MPI ranks per CLX node: 56
19
+ # (start small, increase gradually).
20
+ #
21
+ # -- If you're running out of memory, try running
22
+ # fewer tasks per node to give each task more memory.
23
+ #
24
+ #----------------------------------------------------
25
+
26
+ #SBATCH -J datasets # Job name
27
+ #SBATCH -o Report-%j # Name of stdout output file
28
+ #SBATCH -p normal # Queue (partition) name
29
+ #SBATCH -N 12 # 50 # Total # of nodes
30
+ #SBATCH -t 2-00:00:00 # Run time (hh:mm:ss)
31
+ #SBATCH --mail-type=all # Send email at begin and end of job
32
+ #SBATCH --mail-user=xiabin@gatech.edu
33
+ #SBATCH --ntasks-per-node=1
34
+
35
+ # Any other commands must follow all #SBATCH directives...
36
+ ############# #SBATCH -c 56 # Total # of mpi tasks
37
+
38
+ #----------------------------------------------------
39
+ cat $0
40
+ date
41
+ pwd
42
+ module list
43
+ conda env list
44
+
45
+ srun python generate_dataset.py \
46
+ --save_direc $SCRATCH \
47
+ --num_images 25600 \
48
+ --BOX_LEN 128 \
49
+ --HII_DIM 64 \
50
+ --NON_CUBIC_FACTOR 16 \
51
+ --cpus_per_node 38 \
52
+ #----------------------------------------------------
53
+
generate_dataset.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
generate_dataset.py CHANGED
@@ -19,6 +19,7 @@ import fcntl
19
  import time
20
  from time import sleep
21
  from pathlib import Path
 
22
 
23
  # Parallize
24
  try:
@@ -120,6 +121,7 @@ class Generator():
120
  BOX_LEN = 150,
121
  HII_DIM = 60,
122
  USE_INTERPOLATION_TABLES = True,
 
123
 
124
  # cosmo_params of py21cmfast.run_coeval():
125
  SIGMA_8 = 0.810,
@@ -201,6 +203,7 @@ class Generator():
201
  user_params = kwargs_params_cpu,
202
  cosmo_params = p21c.CosmoParams(kwargs_params_cpu),
203
  astro_params = p21c.AstroParams(kwargs_params_cpu),
 
204
  random_seed = random_seed,
205
  write = kwargs_params_cpu['write'],
206
  )
@@ -210,11 +213,13 @@ class Generator():
210
  elif self.kwargs['p21c_run'] == 'lightcone':
211
  lightcone_cpu = p21c.run_lightcone(
212
  redshift = kwargs_params_cpu['redshift'][0],
213
- max_redshift = kwargs_params_cpu['redshift'][-1],
 
214
  lightcone_quantities = kwargs_params_cpu['fields'],
215
  user_params = kwargs_params_cpu,
216
  cosmo_params = p21c.CosmoParams(kwargs_params_cpu),
217
  astro_params = p21c.AstroParams(kwargs_params_cpu),
 
218
  random_seed = random_seed,
219
  write = kwargs_params_cpu['write'],
220
  )
@@ -362,14 +367,16 @@ class Generator():
362
  # break
363
  except IOError or BlockingIOError:
364
  if try_time > 30:
365
- print(f"{rank}-{multiprocessing.current_process().pid}, try_time = {try_time:.2f} sec")
366
- sleep(10)
367
  else:
368
  sleep(0.1)
369
 
370
  # Save as hdf5
371
  def save(self, images_node, params_seeds):
372
- max_num_images = None # self.kwargs['num_images']
 
 
373
  with h5py.File(self.kwargs['save_direc_name'], 'a') as f:
374
  if 'kwargs' not in f.keys():
375
  keys = list(self.kwargs)
@@ -436,23 +443,25 @@ if __name__ == '__main__':
436
  args = parser.parse_args()
437
 
438
  params_ranges = dict(
439
- ION_Tvir_MIN = 4.4, #[4,6],
440
- HII_EFF_FACTOR = 131.341, #[10, 250],
441
  )
442
 
443
  kwargs = dict(
444
  num_images=args.num_images,#2400,#30000,
445
  fields = ['brightness_temp', 'density', 'xH_box'],
446
- BOX_LEN=args.BOX_LEN,#128,
447
- HII_DIM=args.HII_DIM,
448
- verbose=3, redshift=[7.51, 11.93],
 
449
  NON_CUBIC_FACTOR = args.NON_CUBIC_FACTOR,
450
  write = True,
451
  cpus_per_node = args.cpus_per_node,#10,#112,#20,
452
  cache_rmdir = False,
453
  )
454
 
455
- save_name = f"LEN{kwargs['BOX_LEN']}-DIM{kwargs['HII_DIM']}-CUB{kwargs['NON_CUBIC_FACTOR']}-{params_ranges['ION_Tvir_MIN']}-{params_ranges['HII_EFF_FACTOR']}.h5"
 
456
  kwargs['save_direc_name'] = os.path.join(args.save_direc, save_name)
457
 
458
  generator = Generator(params_ranges, **kwargs)
 
19
  import time
20
  from time import sleep
21
  from pathlib import Path
22
+ import datetime
23
 
24
  # Parallize
25
  try:
 
121
  BOX_LEN = 150,
122
  HII_DIM = 60,
123
  USE_INTERPOLATION_TABLES = True,
124
+ USE_TS_FLUCT = True,
125
 
126
  # cosmo_params of py21cmfast.run_coeval():
127
  SIGMA_8 = 0.810,
 
203
  user_params = kwargs_params_cpu,
204
  cosmo_params = p21c.CosmoParams(kwargs_params_cpu),
205
  astro_params = p21c.AstroParams(kwargs_params_cpu),
206
+ flag_options = p21c.FlagOptions(kwargs_params_cpu),
207
  random_seed = random_seed,
208
  write = kwargs_params_cpu['write'],
209
  )
 
213
  elif self.kwargs['p21c_run'] == 'lightcone':
214
  lightcone_cpu = p21c.run_lightcone(
215
  redshift = kwargs_params_cpu['redshift'][0],
216
+ #max_redshift = kwargs_params_cpu['redshift'][-1],
217
+ z_heat_max = kwargs_params_cpu['redshift'][-1],
218
  lightcone_quantities = kwargs_params_cpu['fields'],
219
  user_params = kwargs_params_cpu,
220
  cosmo_params = p21c.CosmoParams(kwargs_params_cpu),
221
  astro_params = p21c.AstroParams(kwargs_params_cpu),
222
+ flag_options = p21c.FlagOptions(kwargs_params_cpu),
223
  random_seed = random_seed,
224
  write = kwargs_params_cpu['write'],
225
  )
 
367
  # break
368
  except IOError or BlockingIOError:
369
  if try_time > 30:
370
+ print(f"cpu {multiprocessing.current_process().pid}-{rank}, try_time = {try_time:.2f} sec")
371
+ sleep(5)
372
  else:
373
  sleep(0.1)
374
 
375
  # Save as hdf5
376
  def save(self, images_node, params_seeds):
377
+ #max_num_images = None # self.kwargs['num_images']
378
+ max_num_images = self.kwargs['num_images']
379
+ #print(f"max_num_images = {max_num_images}")
380
  with h5py.File(self.kwargs['save_direc_name'], 'a') as f:
381
  if 'kwargs' not in f.keys():
382
  keys = list(self.kwargs)
 
443
  args = parser.parse_args()
444
 
445
  params_ranges = dict(
446
+ ION_Tvir_MIN = [4,6],#4.8,#5.477,#4.699,#5.6,#4.4, #[4,6],
447
+ HII_EFF_FACTOR = [10,250],#131.341,#200,#30,#19.037,#131.341, #[10, 250],
448
  )
449
 
450
  kwargs = dict(
451
  num_images=args.num_images,#2400,#30000,
452
  fields = ['brightness_temp', 'density', 'xH_box'],
453
+ BOX_LEN = args.BOX_LEN,#128,
454
+ HII_DIM = args.HII_DIM,
455
+ verbose = 3,
456
+ redshift = [7.51, 21.02],#11.93],
457
  NON_CUBIC_FACTOR = args.NON_CUBIC_FACTOR,
458
  write = True,
459
  cpus_per_node = args.cpus_per_node,#10,#112,#20,
460
  cache_rmdir = False,
461
  )
462
 
463
+ now = datetime.datetime.now().strftime("%m%d-%H%M%S")
464
+ save_name = f"LEN{kwargs['BOX_LEN']}-DIM{kwargs['HII_DIM']}-CUB{kwargs['NON_CUBIC_FACTOR']}-Tvir{params_ranges['ION_Tvir_MIN']}-zeta{params_ranges['HII_EFF_FACTOR']}-{now}.h5"
465
  kwargs['save_direc_name'] = os.path.join(args.save_direc, save_name)
466
 
467
  generator = Generator(params_ranges, **kwargs)
phoenix_diffusion.sbatch CHANGED
@@ -2,10 +2,10 @@
2
  #SBATCH -J diffusion # Job name
3
  #SBATCH -A gts-jw254-coda20
4
  #SBATCH -qembers
5
- #SBATCH -N4 --gpus-per-node=V100:2 -C V100-32GB # Number of nodes and cores per node required
6
  #SBATCH --ntasks-per-node=1
7
  #SBATCH --mem-per-gpu=16G # Memory per core
8
- #SBATCH -t 08:00:00 # Duration of the job (Ex: 15 mins)
9
  #SBATCH -oReport-%j # Combined output and error messages file
10
  #SBATCH --error=error-%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
@@ -30,9 +30,9 @@ export MASTER_PORT=$MASTER_PORT
30
 
31
  srun python diffusion.py \
32
  --train 1 \
33
- --resume outputs/model_state-N3000-device_count1-node8-epoch49-172.27.149.181 \
34
  --num_new_img_per_gpu 50 \
35
- --max_num_img_per_gpu 10 \
36
 
37
  ######################################################################################
38
 
 
2
  #SBATCH -J diffusion # Job name
3
  #SBATCH -A gts-jw254-coda20
4
  #SBATCH -qembers
5
+ #SBATCH -N1 --gpus-per-node=V100:1 -C V100-32GB # Number of nodes and cores per node required
6
  #SBATCH --ntasks-per-node=1
7
  #SBATCH --mem-per-gpu=16G # Memory per core
8
+ #SBATCH -t 00:10:00 # Duration of the job (Ex: 15 mins)
9
  #SBATCH -oReport-%j # Combined output and error messages file
10
  #SBATCH --error=error-%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
 
30
 
31
  srun python diffusion.py \
32
  --train 1 \
33
+ --resume outputs/model_state-N480-device_count1-node4-epoch49-172.27.149.66 \
34
  --num_new_img_per_gpu 50 \
35
+ --max_num_img_per_gpu 2 \
36
 
37
  ######################################################################################
38
 
quantify_results.ipynb CHANGED
The diff for this file is too large to render. See raw diff