0810-170449
Browse files- context_unet.py +18 -12
- diffusion.py +30 -18
- frontera_generate_dataset.sbatch +53 -0
- generate_dataset.ipynb +0 -0
- generate_dataset.py +19 -10
- phoenix_diffusion.sbatch +4 -4
- quantify_results.ipynb +0 -0
context_unet.py
CHANGED
|
@@ -32,12 +32,15 @@ class GroupNorm32(nn.GroupNorm):
|
|
| 32 |
self.swish = swish
|
| 33 |
|
| 34 |
def forward(self, x):
|
| 35 |
-
#
|
| 36 |
-
y = super().forward(x.float()).to(x.dtype)
|
|
|
|
|
|
|
| 37 |
if self.swish == 1.0:
|
| 38 |
y = F.silu(y)
|
| 39 |
elif self.swish:
|
| 40 |
y = y * F.sigmoid(y * float(self.swish))
|
|
|
|
| 41 |
return y
|
| 42 |
|
| 43 |
def normalization(channels, swish=0.0):
|
|
@@ -284,7 +287,7 @@ def timestep_embedding(timesteps, dim, max_period=10000):
|
|
| 284 |
:param max_period: controls the minimum frequency of the embeddings.
|
| 285 |
:return: an [N x dim] Tensor of positional embeddings.
|
| 286 |
"""
|
| 287 |
-
#print
|
| 288 |
half = dim // 2
|
| 289 |
freqs = torch.exp(
|
| 290 |
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
|
@@ -294,6 +297,7 @@ def timestep_embedding(timesteps, dim, max_period=10000):
|
|
| 294 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 295 |
if dim % 2:
|
| 296 |
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
|
|
| 297 |
return embedding
|
| 298 |
|
| 299 |
class ContextUnet(nn.Module):
|
|
@@ -522,32 +526,34 @@ class ContextUnet(nn.Module):
|
|
| 522 |
def forward(self, x, timesteps, y=None):
|
| 523 |
hs = []
|
| 524 |
# print("device of timesteps, self.model_channels:", timesteps.device, self.model_channels)
|
| 525 |
-
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
|
|
|
| 526 |
if y != None:
|
| 527 |
-
text_outputs = self.token_embedding(y.float())
|
|
|
|
| 528 |
emb = emb + text_outputs.to(emb)
|
| 529 |
|
| 530 |
-
#
|
| 531 |
h = x.type(self.dtype)
|
| 532 |
-
#
|
| 533 |
for module in self.input_blocks:
|
| 534 |
h = module(h, emb)
|
| 535 |
hs.append(h)
|
| 536 |
-
#
|
| 537 |
# print("2,h.shape =", h.shape)
|
| 538 |
h = self.middle_block(h, emb)
|
| 539 |
-
#
|
| 540 |
# print("2,h.shape =", h.shape)
|
| 541 |
for module in self.output_blocks:
|
| 542 |
-
#
|
| 543 |
# print("len(hs) =", len(hs), ", hs[-1].shape =", hs[-1].shape)
|
| 544 |
h = torch.cat([h, hs.pop()], dim=1)
|
| 545 |
h = module(h, emb)
|
| 546 |
# print("module decoder, h.shape =", h.shape)
|
| 547 |
|
| 548 |
-
#
|
| 549 |
h = h.type(x.dtype)
|
| 550 |
h = self.out(h)
|
| 551 |
-
#
|
| 552 |
|
| 553 |
return h
|
|
|
|
| 32 |
self.swish = swish
|
| 33 |
|
| 34 |
def forward(self, x):
|
| 35 |
+
#print(f"GroupNorm32, x.dtype = {x.dtype}, x.float().dtype = {x.float().dtype}, swish = {self.swish}")
|
| 36 |
+
#y = super().forward(x.float()).to(x.dtype)
|
| 37 |
+
y = super().forward(x)
|
| 38 |
+
#print(f"swish == {self.swish}, {y.dtype}")
|
| 39 |
if self.swish == 1.0:
|
| 40 |
y = F.silu(y)
|
| 41 |
elif self.swish:
|
| 42 |
y = y * F.sigmoid(y * float(self.swish))
|
| 43 |
+
#print(f"swish == {self.swish}, {y.dtype}")
|
| 44 |
return y
|
| 45 |
|
| 46 |
def normalization(channels, swish=0.0):
|
|
|
|
| 287 |
:param max_period: controls the minimum frequency of the embeddings.
|
| 288 |
:return: an [N x dim] Tensor of positional embeddings.
|
| 289 |
"""
|
| 290 |
+
#print(f"timestep_embedding is running")
|
| 291 |
half = dim // 2
|
| 292 |
freqs = torch.exp(
|
| 293 |
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
|
|
|
| 297 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 298 |
if dim % 2:
|
| 299 |
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 300 |
+
#print(f"timestep_embedding is ending")
|
| 301 |
return embedding
|
| 302 |
|
| 303 |
class ContextUnet(nn.Module):
|
|
|
|
| 526 |
def forward(self, x, timesteps, y=None):
|
| 527 |
hs = []
|
| 528 |
# print("device of timesteps, self.model_channels:", timesteps.device, self.model_channels)
|
| 529 |
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels).to(self.dtype))
|
| 530 |
+
#print(f"forward after emb")
|
| 531 |
if y != None:
|
| 532 |
+
#text_outputs = self.token_embedding(y.float())
|
| 533 |
+
text_outputs = self.token_embedding(y.to(self.dtype))
|
| 534 |
emb = emb + text_outputs.to(emb)
|
| 535 |
|
| 536 |
+
#print("forward, h = x.type(self.dtype), self.dtype =", self.dtype)
|
| 537 |
h = x.type(self.dtype)
|
| 538 |
+
#print("0,h.shape =", h.shape)
|
| 539 |
for module in self.input_blocks:
|
| 540 |
h = module(h, emb)
|
| 541 |
hs.append(h)
|
| 542 |
+
#print("module encoder, h.shape =", h.shape)
|
| 543 |
# print("2,h.shape =", h.shape)
|
| 544 |
h = self.middle_block(h, emb)
|
| 545 |
+
#print("middle block, h.shape =", h.shape)
|
| 546 |
# print("2,h.shape =", h.shape)
|
| 547 |
for module in self.output_blocks:
|
| 548 |
+
#print("for module in self.output_blocks, h.shape =", h.shape)
|
| 549 |
# print("len(hs) =", len(hs), ", hs[-1].shape =", hs[-1].shape)
|
| 550 |
h = torch.cat([h, hs.pop()], dim=1)
|
| 551 |
h = module(h, emb)
|
| 552 |
# print("module decoder, h.shape =", h.shape)
|
| 553 |
|
| 554 |
+
#print("h = h.type(x.dtype), x.dtype =", x.dtype)
|
| 555 |
h = h.type(x.dtype)
|
| 556 |
h = self.out(h)
|
| 557 |
+
#print("self.out(h)", "h.shape =", h.shape)
|
| 558 |
|
| 559 |
return h
|
diffusion.py
CHANGED
|
@@ -115,8 +115,9 @@ def ddp_setup(rank: int, world_size: int, master_addr, master_port):
|
|
| 115 |
|
| 116 |
# %%
|
| 117 |
class DDPMScheduler(nn.Module):
|
| 118 |
-
def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu', dtype=torch.
|
| 119 |
super().__init__()
|
|
|
|
| 120 |
|
| 121 |
beta_1, beta_T = betas
|
| 122 |
assert 0 < beta_1 <= beta_T <= 1, "ensure 0 < beta_1 <= beta_T <= 1"
|
|
@@ -124,6 +125,7 @@ class DDPMScheduler(nn.Module):
|
|
| 124 |
self.num_timesteps = num_timesteps
|
| 125 |
self.img_shape = img_shape
|
| 126 |
self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1
|
|
|
|
| 127 |
self.beta_t = self.beta_t.to(self.device)
|
| 128 |
|
| 129 |
# self.drop_prob = drop_prob
|
|
@@ -132,7 +134,6 @@ class DDPMScheduler(nn.Module):
|
|
| 132 |
# self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
|
| 133 |
self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
|
| 134 |
# self.use_fp16 = use_fp16
|
| 135 |
-
self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32
|
| 136 |
self.config = config
|
| 137 |
|
| 138 |
def add_noise(self, clean_images):
|
|
@@ -157,15 +158,18 @@ class DDPMScheduler(nn.Module):
|
|
| 157 |
def sample(self, nn_model, params, device, guide_w = 0):
|
| 158 |
n_sample = len(params) #params.shape[0]
|
| 159 |
# print("params.shape[0], len(params)", params.shape[0], len(params))
|
| 160 |
-
x_i = torch.randn(n_sample, *self.img_shape).to(
|
|
|
|
|
|
|
| 161 |
# print("x_i.shape =", x_i.shape)
|
| 162 |
# print("x_i.shape =", x_i.shape)
|
| 163 |
if guide_w != -1:
|
| 164 |
c_i = params
|
| 165 |
-
uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)
|
| 166 |
# uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)
|
| 167 |
# uncond_tokens = uncond_tokens.repeat(int(n_sample),1)
|
| 168 |
-
c_i = torch.cat((c_i, uncond_tokens), 0)
|
|
|
|
| 169 |
|
| 170 |
x_i_entire = [] # keep track of generated steps in case want to plot something
|
| 171 |
# print("self.num_timesteps =", self.num_timesteps)
|
|
@@ -177,8 +181,10 @@ class DDPMScheduler(nn.Module):
|
|
| 177 |
# print(f'sampling timestep {i:4d}',end='\r')
|
| 178 |
t_is = torch.tensor([i]).to(device)
|
| 179 |
t_is = t_is.repeat(n_sample)
|
|
|
|
| 180 |
|
| 181 |
-
z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else 0
|
|
|
|
| 182 |
|
| 183 |
if guide_w == -1:
|
| 184 |
# eps = nn_model(x_i, t_is, return_dict=False)[0]
|
|
@@ -186,22 +192,26 @@ class DDPMScheduler(nn.Module):
|
|
| 186 |
# x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
|
| 187 |
else:
|
| 188 |
# double batch
|
| 189 |
-
|
| 190 |
-
|
|
|
|
| 191 |
|
| 192 |
# split predictions and compute weighting
|
| 193 |
# print("nn_model input shape", x_i.shape, t_is.shape, c_i.shape)
|
|
|
|
| 194 |
eps = nn_model(x_i, t_is, c_i)
|
| 195 |
-
eps1 = eps[:n_sample]
|
| 196 |
-
eps2 = eps[n_sample:]
|
| 197 |
-
eps = eps1 + guide_w*(eps1 - eps2)
|
| 198 |
# eps = (1+guide_w)*eps1 - guide_w*eps2
|
| 199 |
-
x_i = x_i[:n_sample]
|
| 200 |
# x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
|
| 201 |
|
| 202 |
# print("x_i.shape =", x_i.shape)
|
|
|
|
| 203 |
x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
|
| 204 |
-
|
|
|
|
| 205 |
pbar_sample.update(1)
|
| 206 |
|
| 207 |
# store only part of the intermediate steps
|
|
@@ -257,12 +267,12 @@ class TrainConfig:
|
|
| 257 |
|
| 258 |
# dim = 2
|
| 259 |
dim = 3#2
|
| 260 |
-
stride = (2,4) if dim == 2 else (2,2,
|
| 261 |
-
num_image = 480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 262 |
-
batch_size = 1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
|
| 263 |
n_epoch = 50#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
|
| 264 |
HII_DIM = 64
|
| 265 |
-
num_redshift = 512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
|
| 266 |
channel = 1
|
| 267 |
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
|
| 268 |
|
|
@@ -396,6 +406,7 @@ class DDPM21CM:
|
|
| 396 |
# self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
|
| 397 |
# print(f"resumed nn_model from {config.resume}")
|
| 398 |
self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
|
|
|
|
| 399 |
print(f"{config.run_name} {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} resumed nn_model from {config.resume} with {sum(x.numel() for x in self.nn_model.parameters())} parameters".center(120,'-'))
|
| 400 |
else:
|
| 401 |
print(f"{config.run_name} {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} initialized nn_model randomly with {sum(x.numel() for x in self.nn_model.parameters())} parameters".center(120,'-'))
|
|
@@ -511,8 +522,9 @@ class DDPM21CM:
|
|
| 511 |
# print("x = x.to(self.config.device), x.dtype =", x.dtype)
|
| 512 |
# x = x.to(self.config.dtype)
|
| 513 |
# print("x = x.to(self.dtype), x.dtype =", x.dtype)
|
|
|
|
| 514 |
xt, noise, ts = self.ddpm.add_noise(x)
|
| 515 |
-
|
| 516 |
if self.config.guide_w == -1:
|
| 517 |
noise_pred = self.nn_model(xt, ts)
|
| 518 |
else:
|
|
|
|
| 115 |
|
| 116 |
# %%
|
| 117 |
class DDPMScheduler(nn.Module):
|
| 118 |
+
def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu', dtype=torch.float16, config=None):
|
| 119 |
super().__init__()
|
| 120 |
+
self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32
|
| 121 |
|
| 122 |
beta_1, beta_T = betas
|
| 123 |
assert 0 < beta_1 <= beta_T <= 1, "ensure 0 < beta_1 <= beta_T <= 1"
|
|
|
|
| 125 |
self.num_timesteps = num_timesteps
|
| 126 |
self.img_shape = img_shape
|
| 127 |
self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1
|
| 128 |
+
self.beta_t = self.beta_t.to(self.dtype)
|
| 129 |
self.beta_t = self.beta_t.to(self.device)
|
| 130 |
|
| 131 |
# self.drop_prob = drop_prob
|
|
|
|
| 134 |
# self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
|
| 135 |
self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
|
| 136 |
# self.use_fp16 = use_fp16
|
|
|
|
| 137 |
self.config = config
|
| 138 |
|
| 139 |
def add_noise(self, clean_images):
|
|
|
|
| 158 |
def sample(self, nn_model, params, device, guide_w = 0):
|
| 159 |
n_sample = len(params) #params.shape[0]
|
| 160 |
# print("params.shape[0], len(params)", params.shape[0], len(params))
|
| 161 |
+
x_i = torch.randn(n_sample, *self.img_shape).to(self.dtype)
|
| 162 |
+
x_i = x_i.to(device)
|
| 163 |
+
#print(f"#1 x_i.device = {x_i.device}")
|
| 164 |
# print("x_i.shape =", x_i.shape)
|
| 165 |
# print("x_i.shape =", x_i.shape)
|
| 166 |
if guide_w != -1:
|
| 167 |
c_i = params
|
| 168 |
+
#uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)
|
| 169 |
# uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)
|
| 170 |
# uncond_tokens = uncond_tokens.repeat(int(n_sample),1)
|
| 171 |
+
#c_i = torch.cat((c_i, uncond_tokens), 0)
|
| 172 |
+
c_i = c_i.to(self.dtype)
|
| 173 |
|
| 174 |
x_i_entire = [] # keep track of generated steps in case want to plot something
|
| 175 |
# print("self.num_timesteps =", self.num_timesteps)
|
|
|
|
| 181 |
# print(f'sampling timestep {i:4d}',end='\r')
|
| 182 |
t_is = torch.tensor([i]).to(device)
|
| 183 |
t_is = t_is.repeat(n_sample)
|
| 184 |
+
t_is = t_is.to(self.dtype)
|
| 185 |
|
| 186 |
+
z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else torch.tensor(0.)
|
| 187 |
+
z = z.to(self.dtype)
|
| 188 |
|
| 189 |
if guide_w == -1:
|
| 190 |
# eps = nn_model(x_i, t_is, return_dict=False)[0]
|
|
|
|
| 192 |
# x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
|
| 193 |
else:
|
| 194 |
# double batch
|
| 195 |
+
#print(f"#2 x_i.device = {x_i.device}")
|
| 196 |
+
#x_i = x_i.repeat(2, *torch.ones(len(self.img_shape), dtype=int).tolist())
|
| 197 |
+
#t_is = t_is.repeat(2)
|
| 198 |
|
| 199 |
# split predictions and compute weighting
|
| 200 |
# print("nn_model input shape", x_i.shape, t_is.shape, c_i.shape)
|
| 201 |
+
#print(f"sample, i = {i}, x_i.dtype = {x_i.dtype}, c_i.dtype = {c_i.dtype}")
|
| 202 |
eps = nn_model(x_i, t_is, c_i)
|
| 203 |
+
#eps1 = eps[:n_sample]
|
| 204 |
+
#eps2 = eps[n_sample:]
|
| 205 |
+
#eps = eps1 + guide_w*(eps1 - eps2)
|
| 206 |
# eps = (1+guide_w)*eps1 - guide_w*eps2
|
| 207 |
+
#x_i = x_i[:n_sample]
|
| 208 |
# x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
|
| 209 |
|
| 210 |
# print("x_i.shape =", x_i.shape)
|
| 211 |
+
#print(f"before, x_i.dtype = {x_i.dtype}, beta_t.dtype = {self.beta_t.dtype}, eps.dtype = {eps.dtype}, alpha_t.dtype = {self.alpha_t.dtype}, z.dtype = {z.dtype}")
|
| 212 |
x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
|
| 213 |
+
#print(f"after, x_i.dtype = {x_i.dtype}, beta_t.dtype = {self.beta_t.dtype}, eps.dtype = {eps.dtype}, alpha_t.dtype = {self.alpha_t.dtype}, z.dtype = {z.dtype}")
|
| 214 |
+
|
| 215 |
pbar_sample.update(1)
|
| 216 |
|
| 217 |
# store only part of the intermediate steps
|
|
|
|
| 267 |
|
| 268 |
# dim = 2
|
| 269 |
dim = 3#2
|
| 270 |
+
stride = (2,4) if dim == 2 else (2,2,2)
|
| 271 |
+
num_image = 30#00#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 272 |
+
batch_size = 5#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
|
| 273 |
n_epoch = 50#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
|
| 274 |
HII_DIM = 64
|
| 275 |
+
num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
|
| 276 |
channel = 1
|
| 277 |
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
|
| 278 |
|
|
|
|
| 406 |
# self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
|
| 407 |
# print(f"resumed nn_model from {config.resume}")
|
| 408 |
self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
|
| 409 |
+
self.nn_model.module.to(config.dtype)
|
| 410 |
print(f"{config.run_name} {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} resumed nn_model from {config.resume} with {sum(x.numel() for x in self.nn_model.parameters())} parameters".center(120,'-'))
|
| 411 |
else:
|
| 412 |
print(f"{config.run_name} {socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} initialized nn_model randomly with {sum(x.numel() for x in self.nn_model.parameters())} parameters".center(120,'-'))
|
|
|
|
| 522 |
# print("x = x.to(self.config.device), x.dtype =", x.dtype)
|
| 523 |
# x = x.to(self.config.dtype)
|
| 524 |
# print("x = x.to(self.dtype), x.dtype =", x.dtype)
|
| 525 |
+
#print(f"ddpm.add_noise(x), x.dtype = {x.dtype}")
|
| 526 |
xt, noise, ts = self.ddpm.add_noise(x)
|
| 527 |
+
#print(f"ddpm.add_noise(x), xt.dtype = {xt.dtype}")
|
| 528 |
if self.config.guide_w == -1:
|
| 529 |
noise_pred = self.nn_model(xt, ts)
|
| 530 |
else:
|
frontera_generate_dataset.sbatch
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#----------------------------------------------------
|
| 3 |
+
# Sample Slurm job script
|
| 4 |
+
# for TACC Frontera CLX nodes
|
| 5 |
+
#
|
| 6 |
+
# *** MPI Job in Normal Queue ***
|
| 7 |
+
#
|
| 8 |
+
# Last revised: 20 May 2019
|
| 9 |
+
#
|
| 10 |
+
# Notes:
|
| 11 |
+
#
|
| 12 |
+
# -- Launch this script by executing
|
| 13 |
+
# "sbatch clx.mpi.slurm" on a Frontera login node.
|
| 14 |
+
#
|
| 15 |
+
# -- Use ibrun to launch MPI codes on TACC systems.
|
| 16 |
+
# Do NOT use mpirun or mpiexec.
|
| 17 |
+
#
|
| 18 |
+
# -- Max recommended MPI ranks per CLX node: 56
|
| 19 |
+
# (start small, increase gradually).
|
| 20 |
+
#
|
| 21 |
+
# -- If you're running out of memory, try running
|
| 22 |
+
# fewer tasks per node to give each task more memory.
|
| 23 |
+
#
|
| 24 |
+
#----------------------------------------------------
|
| 25 |
+
|
| 26 |
+
#SBATCH -J datasets # Job name
|
| 27 |
+
#SBATCH -o Report-%j # Name of stdout output file
|
| 28 |
+
#SBATCH -p normal # Queue (partition) name
|
| 29 |
+
#SBATCH -N 12 # 50 # Total # of nodes
|
| 30 |
+
#SBATCH -t 2-00:00:00 # Run time (hh:mm:ss)
|
| 31 |
+
#SBATCH --mail-type=all # Send email at begin and end of job
|
| 32 |
+
#SBATCH --mail-user=xiabin@gatech.edu
|
| 33 |
+
#SBATCH --ntasks-per-node=1
|
| 34 |
+
|
| 35 |
+
# Any other commands must follow all #SBATCH directives...
|
| 36 |
+
############# #SBATCH -c 56 # Total # of mpi tasks
|
| 37 |
+
|
| 38 |
+
#----------------------------------------------------
|
| 39 |
+
cat $0
|
| 40 |
+
date
|
| 41 |
+
pwd
|
| 42 |
+
module list
|
| 43 |
+
conda env list
|
| 44 |
+
|
| 45 |
+
srun python generate_dataset.py \
|
| 46 |
+
--save_direc $SCRATCH \
|
| 47 |
+
--num_images 25600 \
|
| 48 |
+
--BOX_LEN 128 \
|
| 49 |
+
--HII_DIM 64 \
|
| 50 |
+
--NON_CUBIC_FACTOR 16 \
|
| 51 |
+
--cpus_per_node 38 \
|
| 52 |
+
#----------------------------------------------------
|
| 53 |
+
|
generate_dataset.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generate_dataset.py
CHANGED
|
@@ -19,6 +19,7 @@ import fcntl
|
|
| 19 |
import time
|
| 20 |
from time import sleep
|
| 21 |
from pathlib import Path
|
|
|
|
| 22 |
|
| 23 |
# Parallize
|
| 24 |
try:
|
|
@@ -120,6 +121,7 @@ class Generator():
|
|
| 120 |
BOX_LEN = 150,
|
| 121 |
HII_DIM = 60,
|
| 122 |
USE_INTERPOLATION_TABLES = True,
|
|
|
|
| 123 |
|
| 124 |
# cosmo_params of py21cmfast.run_coeval():
|
| 125 |
SIGMA_8 = 0.810,
|
|
@@ -201,6 +203,7 @@ class Generator():
|
|
| 201 |
user_params = kwargs_params_cpu,
|
| 202 |
cosmo_params = p21c.CosmoParams(kwargs_params_cpu),
|
| 203 |
astro_params = p21c.AstroParams(kwargs_params_cpu),
|
|
|
|
| 204 |
random_seed = random_seed,
|
| 205 |
write = kwargs_params_cpu['write'],
|
| 206 |
)
|
|
@@ -210,11 +213,13 @@ class Generator():
|
|
| 210 |
elif self.kwargs['p21c_run'] == 'lightcone':
|
| 211 |
lightcone_cpu = p21c.run_lightcone(
|
| 212 |
redshift = kwargs_params_cpu['redshift'][0],
|
| 213 |
-
max_redshift = kwargs_params_cpu['redshift'][-1],
|
|
|
|
| 214 |
lightcone_quantities = kwargs_params_cpu['fields'],
|
| 215 |
user_params = kwargs_params_cpu,
|
| 216 |
cosmo_params = p21c.CosmoParams(kwargs_params_cpu),
|
| 217 |
astro_params = p21c.AstroParams(kwargs_params_cpu),
|
|
|
|
| 218 |
random_seed = random_seed,
|
| 219 |
write = kwargs_params_cpu['write'],
|
| 220 |
)
|
|
@@ -362,14 +367,16 @@ class Generator():
|
|
| 362 |
# break
|
| 363 |
except IOError or BlockingIOError:
|
| 364 |
if try_time > 30:
|
| 365 |
-
print(f"{
|
| 366 |
-
sleep(
|
| 367 |
else:
|
| 368 |
sleep(0.1)
|
| 369 |
|
| 370 |
# Save as hdf5
|
| 371 |
def save(self, images_node, params_seeds):
|
| 372 |
-
max_num_images = None # self.kwargs['num_images']
|
|
|
|
|
|
|
| 373 |
with h5py.File(self.kwargs['save_direc_name'], 'a') as f:
|
| 374 |
if 'kwargs' not in f.keys():
|
| 375 |
keys = list(self.kwargs)
|
|
@@ -436,23 +443,25 @@ if __name__ == '__main__':
|
|
| 436 |
args = parser.parse_args()
|
| 437 |
|
| 438 |
params_ranges = dict(
|
| 439 |
-
ION_Tvir_MIN = 4.4, #[4,6],
|
| 440 |
-
HII_EFF_FACTOR = 131.341, #[10, 250],
|
| 441 |
)
|
| 442 |
|
| 443 |
kwargs = dict(
|
| 444 |
num_images=args.num_images,#2400,#30000,
|
| 445 |
fields = ['brightness_temp', 'density', 'xH_box'],
|
| 446 |
-
BOX_LEN=args.BOX_LEN,#128,
|
| 447 |
-
HII_DIM=args.HII_DIM,
|
| 448 |
-
verbose=3,
|
|
|
|
| 449 |
NON_CUBIC_FACTOR = args.NON_CUBIC_FACTOR,
|
| 450 |
write = True,
|
| 451 |
cpus_per_node = args.cpus_per_node,#10,#112,#20,
|
| 452 |
cache_rmdir = False,
|
| 453 |
)
|
| 454 |
|
| 455 |
-
|
|
|
|
| 456 |
kwargs['save_direc_name'] = os.path.join(args.save_direc, save_name)
|
| 457 |
|
| 458 |
generator = Generator(params_ranges, **kwargs)
|
|
|
|
| 19 |
import time
|
| 20 |
from time import sleep
|
| 21 |
from pathlib import Path
|
| 22 |
+
import datetime
|
| 23 |
|
| 24 |
# Parallize
|
| 25 |
try:
|
|
|
|
| 121 |
BOX_LEN = 150,
|
| 122 |
HII_DIM = 60,
|
| 123 |
USE_INTERPOLATION_TABLES = True,
|
| 124 |
+
USE_TS_FLUCT = True,
|
| 125 |
|
| 126 |
# cosmo_params of py21cmfast.run_coeval():
|
| 127 |
SIGMA_8 = 0.810,
|
|
|
|
| 203 |
user_params = kwargs_params_cpu,
|
| 204 |
cosmo_params = p21c.CosmoParams(kwargs_params_cpu),
|
| 205 |
astro_params = p21c.AstroParams(kwargs_params_cpu),
|
| 206 |
+
flag_options = p21c.FlagOptions(kwargs_params_cpu),
|
| 207 |
random_seed = random_seed,
|
| 208 |
write = kwargs_params_cpu['write'],
|
| 209 |
)
|
|
|
|
| 213 |
elif self.kwargs['p21c_run'] == 'lightcone':
|
| 214 |
lightcone_cpu = p21c.run_lightcone(
|
| 215 |
redshift = kwargs_params_cpu['redshift'][0],
|
| 216 |
+
#max_redshift = kwargs_params_cpu['redshift'][-1],
|
| 217 |
+
z_heat_max = kwargs_params_cpu['redshift'][-1],
|
| 218 |
lightcone_quantities = kwargs_params_cpu['fields'],
|
| 219 |
user_params = kwargs_params_cpu,
|
| 220 |
cosmo_params = p21c.CosmoParams(kwargs_params_cpu),
|
| 221 |
astro_params = p21c.AstroParams(kwargs_params_cpu),
|
| 222 |
+
flag_options = p21c.FlagOptions(kwargs_params_cpu),
|
| 223 |
random_seed = random_seed,
|
| 224 |
write = kwargs_params_cpu['write'],
|
| 225 |
)
|
|
|
|
| 367 |
# break
|
| 368 |
except IOError or BlockingIOError:
|
| 369 |
if try_time > 30:
|
| 370 |
+
print(f"cpu {multiprocessing.current_process().pid}-{rank}, try_time = {try_time:.2f} sec")
|
| 371 |
+
sleep(5)
|
| 372 |
else:
|
| 373 |
sleep(0.1)
|
| 374 |
|
| 375 |
# Save as hdf5
|
| 376 |
def save(self, images_node, params_seeds):
|
| 377 |
+
#max_num_images = None # self.kwargs['num_images']
|
| 378 |
+
max_num_images = self.kwargs['num_images']
|
| 379 |
+
#print(f"max_num_images = {max_num_images}")
|
| 380 |
with h5py.File(self.kwargs['save_direc_name'], 'a') as f:
|
| 381 |
if 'kwargs' not in f.keys():
|
| 382 |
keys = list(self.kwargs)
|
|
|
|
| 443 |
args = parser.parse_args()
|
| 444 |
|
| 445 |
params_ranges = dict(
|
| 446 |
+
ION_Tvir_MIN = [4,6],#4.8,#5.477,#4.699,#5.6,#4.4, #[4,6],
|
| 447 |
+
HII_EFF_FACTOR = [10,250],#131.341,#200,#30,#19.037,#131.341, #[10, 250],
|
| 448 |
)
|
| 449 |
|
| 450 |
kwargs = dict(
|
| 451 |
num_images=args.num_images,#2400,#30000,
|
| 452 |
fields = ['brightness_temp', 'density', 'xH_box'],
|
| 453 |
+
BOX_LEN = args.BOX_LEN,#128,
|
| 454 |
+
HII_DIM = args.HII_DIM,
|
| 455 |
+
verbose = 3,
|
| 456 |
+
redshift = [7.51, 21.02],#11.93],
|
| 457 |
NON_CUBIC_FACTOR = args.NON_CUBIC_FACTOR,
|
| 458 |
write = True,
|
| 459 |
cpus_per_node = args.cpus_per_node,#10,#112,#20,
|
| 460 |
cache_rmdir = False,
|
| 461 |
)
|
| 462 |
|
| 463 |
+
now = datetime.datetime.now().strftime("%m%d-%H%M%S")
|
| 464 |
+
save_name = f"LEN{kwargs['BOX_LEN']}-DIM{kwargs['HII_DIM']}-CUB{kwargs['NON_CUBIC_FACTOR']}-Tvir{params_ranges['ION_Tvir_MIN']}-zeta{params_ranges['HII_EFF_FACTOR']}-{now}.h5"
|
| 465 |
kwargs['save_direc_name'] = os.path.join(args.save_direc, save_name)
|
| 466 |
|
| 467 |
generator = Generator(params_ranges, **kwargs)
|
phoenix_diffusion.sbatch
CHANGED
|
@@ -2,10 +2,10 @@
|
|
| 2 |
#SBATCH -J diffusion # Job name
|
| 3 |
#SBATCH -A gts-jw254-coda20
|
| 4 |
#SBATCH -qembers
|
| 5 |
-
#SBATCH -
|
| 6 |
#SBATCH --ntasks-per-node=1
|
| 7 |
#SBATCH --mem-per-gpu=16G # Memory per core
|
| 8 |
-
#SBATCH -t
|
| 9 |
#SBATCH -oReport-%j # Combined output and error messages file
|
| 10 |
#SBATCH --error=error-%j
|
| 11 |
#SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
|
|
@@ -30,9 +30,9 @@ export MASTER_PORT=$MASTER_PORT
|
|
| 30 |
|
| 31 |
srun python diffusion.py \
|
| 32 |
--train 1 \
|
| 33 |
-
--resume outputs/model_state-
|
| 34 |
--num_new_img_per_gpu 50 \
|
| 35 |
-
--max_num_img_per_gpu
|
| 36 |
|
| 37 |
######################################################################################
|
| 38 |
|
|
|
|
| 2 |
#SBATCH -J diffusion # Job name
|
| 3 |
#SBATCH -A gts-jw254-coda20
|
| 4 |
#SBATCH -qembers
|
| 5 |
+
#SBATCH -N1 --gpus-per-node=V100:1 -C V100-32GB # Number of nodes and cores per node required
|
| 6 |
#SBATCH --ntasks-per-node=1
|
| 7 |
#SBATCH --mem-per-gpu=16G # Memory per core
|
| 8 |
+
#SBATCH -t 00:10:00 # Duration of the job (Ex: 15 mins)
|
| 9 |
#SBATCH -oReport-%j # Combined output and error messages file
|
| 10 |
#SBATCH --error=error-%j
|
| 11 |
#SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
|
|
|
|
| 30 |
|
| 31 |
srun python diffusion.py \
|
| 32 |
--train 1 \
|
| 33 |
+
--resume outputs/model_state-N480-device_count1-node4-epoch49-172.27.149.66 \
|
| 34 |
--num_new_img_per_gpu 50 \
|
| 35 |
+
--max_num_img_per_gpu 2 \
|
| 36 |
|
| 37 |
######################################################################################
|
| 38 |
|
quantify_results.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|