Xsmos commited on
Commit
aa037e1
·
verified ·
1 Parent(s): 959d161

0815-184556

Browse files
context_unet.py CHANGED
@@ -533,7 +533,7 @@ class ContextUnet(nn.Module):
533
  text_outputs = self.token_embedding(y.to(self.dtype))
534
  emb = emb + text_outputs.to(emb)
535
 
536
- print("forward, h = x.type(self.dtype), self.dtype =", self.dtype)
537
  h = x.type(self.dtype)
538
  #print("0,h.shape =", h.shape)
539
  for module in self.input_blocks:
@@ -543,7 +543,7 @@ class ContextUnet(nn.Module):
543
  # print("2,h.shape =", h.shape)
544
  h = self.middle_block(h, emb)
545
  #print("middle block, h.shape =", h.shape)
546
- print("2, h.dtype =", h.dtype)
547
  for module in self.output_blocks:
548
  #print("for module in self.output_blocks, h.shape =", h.shape)
549
  # print("len(hs) =", len(hs), ", hs[-1].shape =", hs[-1].shape)
@@ -551,9 +551,9 @@ class ContextUnet(nn.Module):
551
  h = module(h, emb)
552
  # print("module decoder, h.shape =", h.shape)
553
 
554
- print("h = h.type(x.dtype), x.dtype =", x.dtype, h.dtype)
555
  h = h.type(x.dtype)
556
  h = self.out(h)
557
- print("self.out(h)", "h.dtype =", h.dtype)
558
 
559
  return h
 
533
  text_outputs = self.token_embedding(y.to(self.dtype))
534
  emb = emb + text_outputs.to(emb)
535
 
536
+ #print("forward, h = x.type(self.dtype), self.dtype =", self.dtype)
537
  h = x.type(self.dtype)
538
  #print("0,h.shape =", h.shape)
539
  for module in self.input_blocks:
 
543
  # print("2,h.shape =", h.shape)
544
  h = self.middle_block(h, emb)
545
  #print("middle block, h.shape =", h.shape)
546
+ #print("2, h.dtype =", h.dtype)
547
  for module in self.output_blocks:
548
  #print("for module in self.output_blocks, h.shape =", h.shape)
549
  # print("len(hs) =", len(hs), ", hs[-1].shape =", hs[-1].shape)
 
551
  h = module(h, emb)
552
  # print("module decoder, h.shape =", h.shape)
553
 
554
+ #print("h = h.type(x.dtype), x.dtype =", x.dtype, h.dtype)
555
  h = h.type(x.dtype)
556
  h = self.out(h)
557
+ #print("self.out(h)", "h.dtype =", h.dtype)
558
 
559
  return h
diffusion.py CHANGED
@@ -33,8 +33,9 @@ import warnings
33
  #warnings.filterwarnings("ignore", message=r"^Detected kernel version")
34
 
35
  from dataclasses import dataclass
36
- import h5py
37
  import torch
 
38
  import torch.nn as nn
39
  from torch.utils.data import DataLoader, Dataset
40
  # from datasets import Dataset
@@ -208,9 +209,9 @@ class DDPMScheduler(nn.Module):
208
  # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
209
 
210
  # print("x_i.shape =", x_i.shape)
211
- print(f"before, x_i.dtype = {x_i.dtype}, beta_t.dtype = {self.beta_t.dtype}, eps.dtype = {eps.dtype}, alpha_t.dtype = {self.alpha_t.dtype}, z.dtype = {z.dtype}")
212
  x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
213
- print(f"after, x_i.dtype = {x_i.dtype}, beta_t.dtype = {self.beta_t.dtype}, eps.dtype = {eps.dtype}, alpha_t.dtype = {self.alpha_t.dtype}, z.dtype = {z.dtype}")
214
 
215
  pbar_sample.update(1)
216
 
@@ -268,8 +269,8 @@ class TrainConfig:
268
  # dim = 2
269
  dim = 3#2
270
  stride = (2,4) if dim == 2 else (2,2,2)
271
- num_image = 30#00#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
272
- batch_size = 5#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
273
  n_epoch = 50#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
274
  HII_DIM = 64
275
  num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
@@ -522,19 +523,19 @@ class DDPM21CM:
522
  # print("x = x.to(self.config.device), x.dtype =", x.dtype)
523
  x = x.to(self.config.dtype)
524
  # print("x = x.to(self.dtype), x.dtype =", x.dtype)
525
- print(f"ddpm.add_noise(x), x.dtype = {x.dtype}")
526
  xt, noise, ts = self.ddpm.add_noise(x)
527
- print(f"ddpm.add_noise(x), xt.dtype = {xt.dtype}")
528
  if self.config.guide_w == -1:
529
  noise_pred = self.nn_model(xt, ts).to(x.dtype)
530
  else:
531
  c = c.to(self.config.device)
532
  noise_pred = self.nn_model(xt, ts, c).to(x.dtype)
533
 
534
- print("noise_pred = self.nn_model(xt, ts, c), noise_pred.dtype =", noise_pred.dtype, noise.dtype)
535
 
536
  loss = F.mse_loss(noise, noise_pred)
537
- print(f"loss.dtype =", loss.dtype)
538
  self.accelerator.backward(loss)
539
  self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)
540
  self.optimizer.step()
@@ -742,6 +743,7 @@ if __name__ == "__main__":
742
  parser.add_argument("--resume", type=str, required=False, help="filename of the model to resume", default=False)
743
  parser.add_argument("--num_new_img_per_gpu", type=int, required=False, default=4)
744
  parser.add_argument("--max_num_img_per_gpu", type=int, required=False, default=2)
 
745
 
746
  args = parser.parse_args()
747
 
@@ -766,8 +768,9 @@ if __name__ == "__main__":
766
  max_num_img_per_gpu = args.max_num_img_per_gpu#40#2#20
767
  config = TrainConfig()
768
  #config.world_size = world_size
769
- # config.dtype = torch.float32
770
  config.resume = args.resume
 
771
  # config.resume = f"./outputs/model_state-N30-device_count3-epoch4-172.27.149.181"
772
  # config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
773
  # config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
 
33
  #warnings.filterwarnings("ignore", message=r"^Detected kernel version")
34
 
35
  from dataclasses import dataclass
36
+ #import h5py
37
  import torch
38
+ #print(f"starting, torch.__path__ = {torch.__path__}, torch.cuda.device_count() = {torch.cuda.device_count()}, torch.cuda.is_available() = {torch.cuda.is_available()}")
39
  import torch.nn as nn
40
  from torch.utils.data import DataLoader, Dataset
41
  # from datasets import Dataset
 
209
  # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
210
 
211
  # print("x_i.shape =", x_i.shape)
212
+ #print(f"before, x_i.dtype = {x_i.dtype}, beta_t.dtype = {self.beta_t.dtype}, eps.dtype = {eps.dtype}, alpha_t.dtype = {self.alpha_t.dtype}, z.dtype = {z.dtype}")
213
  x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
214
+ #print(f"after, x_i.dtype = {x_i.dtype}, beta_t.dtype = {self.beta_t.dtype}, eps.dtype = {eps.dtype}, alpha_t.dtype = {self.alpha_t.dtype}, z.dtype = {z.dtype}")
215
 
216
  pbar_sample.update(1)
217
 
 
269
  # dim = 2
270
  dim = 3#2
271
  stride = (2,4) if dim == 2 else (2,2,2)
272
+ num_image = 300#0#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
273
+ batch_size = 1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
274
  n_epoch = 50#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
275
  HII_DIM = 64
276
  num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
 
523
  # print("x = x.to(self.config.device), x.dtype =", x.dtype)
524
  x = x.to(self.config.dtype)
525
  # print("x = x.to(self.dtype), x.dtype =", x.dtype)
526
+ # print(f"ddpm.add_noise(x), x.dtype = {x.dtype}")
527
  xt, noise, ts = self.ddpm.add_noise(x)
528
+ # print(f"ddpm.add_noise(x), xt.dtype = {xt.dtype}")
529
  if self.config.guide_w == -1:
530
  noise_pred = self.nn_model(xt, ts).to(x.dtype)
531
  else:
532
  c = c.to(self.config.device)
533
  noise_pred = self.nn_model(xt, ts, c).to(x.dtype)
534
 
535
+ # print("noise_pred = self.nn_model(xt, ts, c), noise_pred.dtype =", noise_pred.dtype, noise.dtype)
536
 
537
  loss = F.mse_loss(noise, noise_pred)
538
+ #print(f"loss.dtype =", loss.dtype)
539
  self.accelerator.backward(loss)
540
  self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)
541
  self.optimizer.step()
 
743
  parser.add_argument("--resume", type=str, required=False, help="filename of the model to resume", default=False)
744
  parser.add_argument("--num_new_img_per_gpu", type=int, required=False, default=4)
745
  parser.add_argument("--max_num_img_per_gpu", type=int, required=False, default=2)
746
+ parser.add_argument("--gradient_accumulation_steps", type=int, required=False, default=1)
747
 
748
  args = parser.parse_args()
749
 
 
768
  max_num_img_per_gpu = args.max_num_img_per_gpu#40#2#20
769
  config = TrainConfig()
770
  #config.world_size = world_size
771
+ config.dtype = torch.float32
772
  config.resume = args.resume
773
+ config.gradient_accumulation_steps = args.gradient_accumulation_steps
774
  # config.resume = f"./outputs/model_state-N30-device_count3-epoch4-172.27.149.181"
775
  # config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
776
  # config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
frontera_generate_dataset.sbatch CHANGED
@@ -25,9 +25,9 @@
25
 
26
  #SBATCH -J datasets # Job name
27
  #SBATCH -o Report-%j # Name of stdout output file
28
- #SBATCH -p normal # Queue (partition) name
29
- #SBATCH -N 12 # 50 # Total # of nodes
30
- #SBATCH -t 2-00:00:00 # Run time (hh:mm:ss)
31
  #SBATCH --mail-type=all # Send email at begin and end of job
32
  #SBATCH --mail-user=xiabin@gatech.edu
33
  #SBATCH --ntasks-per-node=1
@@ -44,7 +44,7 @@ conda env list
44
 
45
  srun python generate_dataset.py \
46
  --save_direc $SCRATCH \
47
- --num_images 25600 \
48
  --BOX_LEN 128 \
49
  --HII_DIM 64 \
50
  --NON_CUBIC_FACTOR 16 \
 
25
 
26
  #SBATCH -J datasets # Job name
27
  #SBATCH -o Report-%j # Name of stdout output file
28
+ #SBATCH -p small # Queue (partition) name
29
+ #SBATCH -N 2 # 50 # Total # of nodes
30
+ #SBATCH -t 09:00:00 # Run time (hh:mm:ss)
31
  #SBATCH --mail-type=all # Send email at begin and end of job
32
  #SBATCH --mail-user=xiabin@gatech.edu
33
  #SBATCH --ntasks-per-node=1
 
44
 
45
  srun python generate_dataset.py \
46
  --save_direc $SCRATCH \
47
+ --num_images 800\
48
  --BOX_LEN 128 \
49
  --HII_DIM 64 \
50
  --NON_CUBIC_FACTOR 16 \
generate_dataset.py CHANGED
@@ -443,8 +443,8 @@ if __name__ == '__main__':
443
  args = parser.parse_args()
444
 
445
  params_ranges = dict(
446
- ION_Tvir_MIN = [4,6],#4.8,#5.477,#4.699,#5.6,#4.4, #[4,6],
447
- HII_EFF_FACTOR = [10,250],#131.341,#200,#30,#19.037,#131.341, #[10, 250],
448
  )
449
 
450
  kwargs = dict(
 
443
  args = parser.parse_args()
444
 
445
  params_ranges = dict(
446
+ ION_Tvir_MIN = 4.8,#5.477,#4.699,#5.6,#4.4, #[4,6],
447
+ HII_EFF_FACTOR = 131.341,#200,#30,#19.037,#131.341, #[10, 250],
448
  )
449
 
450
  kwargs = dict(
phoenix_diffusion.sbatch CHANGED
@@ -2,10 +2,10 @@
2
  #SBATCH -J diffusion # Job name
3
  #SBATCH -A gts-jw254-coda20
4
  #SBATCH -qembers
5
- #SBATCH -N1 --gpus-per-node=V100:1 -C V100-32GB # Number of nodes and cores per node required
6
  #SBATCH --ntasks-per-node=1
7
  #SBATCH --mem-per-gpu=16G # Memory per core
8
- #SBATCH -t 00:30:00 # Duration of the job (Ex: 15 mins)
9
  #SBATCH -oReport-%j # Combined output and error messages file
10
  #SBATCH --error=error-%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
@@ -17,9 +17,10 @@ pwd
17
  date
18
  module load anaconda3/2022.05 # Load module dependencies
19
  module load pytorch
20
- module list
21
  conda activate diffusers
22
  conda env list
 
23
  cat $0
24
 
25
  MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
@@ -33,6 +34,6 @@ srun python diffusion.py \
33
  --resume outputs/model_state-N3000-device_count1-node2-epoch49-172.27.149.67 \
34
  --num_new_img_per_gpu 50 \
35
  --max_num_img_per_gpu 5 \
36
-
37
  ######################################################################################
38
 
 
2
  #SBATCH -J diffusion # Job name
3
  #SBATCH -A gts-jw254-coda20
4
  #SBATCH -qembers
5
+ #SBATCH -N1 --gpus-per-node=V100:1 -C V100-16GB # Number of nodes and cores per node required
6
  #SBATCH --ntasks-per-node=1
7
  #SBATCH --mem-per-gpu=16G # Memory per core
8
+ #SBATCH -t 01:00:00 # Duration of the job (Ex: 15 mins)
9
  #SBATCH -oReport-%j # Combined output and error messages file
10
  #SBATCH --error=error-%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
 
17
  date
18
  module load anaconda3/2022.05 # Load module dependencies
19
  module load pytorch
20
+
21
  conda activate diffusers
22
  conda env list
23
+ module list
24
  cat $0
25
 
26
  MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
 
34
  --resume outputs/model_state-N3000-device_count1-node2-epoch49-172.27.149.67 \
35
  --num_new_img_per_gpu 50 \
36
  --max_num_img_per_gpu 5 \
37
+ --gradient_accumulation_steps 60 \
38
  ######################################################################################
39
 
quantify_results.ipynb CHANGED
@@ -203,31 +203,82 @@
203
  },
204
  {
205
  "cell_type": "code",
206
- "execution_count": null,
207
  "metadata": {},
208
  "outputs": [],
209
- "source": []
 
 
 
210
  },
211
  {
212
  "cell_type": "code",
213
- "execution_count": null,
214
  "metadata": {},
215
  "outputs": [],
216
- "source": []
 
 
217
  },
218
  {
219
  "cell_type": "code",
220
- "execution_count": null,
221
  "metadata": {},
222
- "outputs": [],
223
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  },
225
  {
226
  "cell_type": "code",
227
- "execution_count": null,
228
  "metadata": {},
229
- "outputs": [],
230
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  },
232
  {
233
  "cell_type": "code",
@@ -2183,7 +2234,7 @@
2183
  "name": "python",
2184
  "nbconvert_exporter": "python",
2185
  "pygments_lexer": "ipython3",
2186
- "version": "3.9.19"
2187
  }
2188
  },
2189
  "nbformat": 4,
 
203
  },
204
  {
205
  "cell_type": "code",
206
+ "execution_count": 1,
207
  "metadata": {},
208
  "outputs": [],
209
+ "source": [
210
+ "!module load pytorch\n",
211
+ "# !module list"
212
+ ]
213
  },
214
  {
215
  "cell_type": "code",
216
+ "execution_count": 2,
217
  "metadata": {},
218
  "outputs": [],
219
+ "source": [
220
+ "import torch"
221
+ ]
222
  },
223
  {
224
  "cell_type": "code",
225
+ "execution_count": 3,
226
  "metadata": {},
227
+ "outputs": [
228
+ {
229
+ "data": {
230
+ "text/plain": [
231
+ "['/storage/home/hcoda1/3/bxia34/.conda/envs/rh9_diffusers/lib/python3.12/site-packages/torch']"
232
+ ]
233
+ },
234
+ "execution_count": 3,
235
+ "metadata": {},
236
+ "output_type": "execute_result"
237
+ }
238
+ ],
239
+ "source": [
240
+ "torch.__path__"
241
+ ]
242
  },
243
  {
244
  "cell_type": "code",
245
+ "execution_count": 4,
246
  "metadata": {},
247
+ "outputs": [
248
+ {
249
+ "data": {
250
+ "text/plain": [
251
+ "False"
252
+ ]
253
+ },
254
+ "execution_count": 4,
255
+ "metadata": {},
256
+ "output_type": "execute_result"
257
+ }
258
+ ],
259
+ "source": [
260
+ "torch.cuda.is_available()"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": 5,
266
+ "metadata": {},
267
+ "outputs": [
268
+ {
269
+ "data": {
270
+ "text/plain": [
271
+ "0"
272
+ ]
273
+ },
274
+ "execution_count": 5,
275
+ "metadata": {},
276
+ "output_type": "execute_result"
277
+ }
278
+ ],
279
+ "source": [
280
+ "torch.cuda.device_count()"
281
+ ]
282
  },
283
  {
284
  "cell_type": "code",
 
2234
  "name": "python",
2235
  "nbconvert_exporter": "python",
2236
  "pygments_lexer": "ipython3",
2237
+ "version": "3.12.5"
2238
  }
2239
  },
2240
  "nbformat": 4,