Xsmos commited on
Commit
9f265ee
·
verified ·
1 Parent(s): 34e5bce
Files changed (2) hide show
  1. diffusion.py +9 -9
  2. learn_multi_node.py +47 -0
diffusion.py CHANGED
@@ -235,11 +235,11 @@ class TrainConfig:
235
  # repeat = 2
236
 
237
  # dim = 2
238
- dim = 3
239
  stride = (2,2) if dim == 2 else (2,2,4)
240
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
241
- batch_size = 1#2#50#20#2#100 # 10
242
- n_epoch = 2#4# 10#50#20#20#2#5#25 # 120
243
  HII_DIM = 64
244
  num_redshift = 512#128#64#512#256#256#64#512#128
245
  channel = 1
@@ -499,7 +499,7 @@ class DDPM21CM:
499
  'unet_state_dict': self.nn_model.module.state_dict(),
500
  # 'ema_unet_state_dict': self.ema_model.state_dict(),
501
  }
502
- save_name = self.config.save_name+f"-N{self.config.num_image}-epoch{ep}-device{torch.cuda.current_device()}"
503
  torch.save(model_state, save_name)
504
  print(f'device {torch.cuda.current_device()} saved model at ' + save_name)
505
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
@@ -586,7 +586,7 @@ def train(rank, world_size):
586
 
587
  ddp_setup(rank, world_size)
588
 
589
- num_train_image_list = [3200]#[3200]#[200]#[1600,3200,6400,12800,25600]
590
  for i, num_image in enumerate(num_train_image_list):
591
  config.num_image = num_image
592
  # config.world_size = world_size
@@ -677,9 +677,9 @@ if __name__ == "__main__":
677
  world_size = torch.cuda.device_count()
678
  print(f" sampling, world_size = {world_size} ".center(100,'-'))
679
  # num_train_image_list = [1600,3200,6400,12800,25600]
680
- num_train_image_list = [3200]
681
- num_new_img_per_gpu = 9
682
- max_num_img_per_gpu = 1
683
 
684
  params = torch.tensor([4.4, 131.341])
685
 
@@ -690,7 +690,7 @@ if __name__ == "__main__":
690
 
691
  for num_image in num_train_image_list:
692
  config.num_image = num_image
693
- config.resume = f"./outputs/model_state-N{num_image}-epoch6-device0"
694
 
695
  # print("ddpm21cm = DDPM21CM(config)")
696
  manager = mp.Manager()
 
235
  # repeat = 2
236
 
237
  # dim = 2
238
+ dim = 2
239
  stride = (2,2) if dim == 2 else (2,2,4)
240
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
241
+ batch_size = 10#1#2#50#20#2#100 # 10
242
+ n_epoch = 5#4# 10#50#20#20#2#5#25 # 120
243
  HII_DIM = 64
244
  num_redshift = 512#128#64#512#256#256#64#512#128
245
  channel = 1
 
499
  'unet_state_dict': self.nn_model.module.state_dict(),
500
  # 'ema_unet_state_dict': self.ema_model.state_dict(),
501
  }
502
+ save_name = self.config.save_name+f"-N{self.config.num_image}-epoch{ep}-device_count{torch.cuda.device_count()}"
503
  torch.save(model_state, save_name)
504
  print(f'device {torch.cuda.current_device()} saved model at ' + save_name)
505
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
 
586
 
587
  ddp_setup(rank, world_size)
588
 
589
+ num_train_image_list = [400]#[3200]#[200]#[1600,3200,6400,12800,25600]
590
  for i, num_image in enumerate(num_train_image_list):
591
  config.num_image = num_image
592
  # config.world_size = world_size
 
677
  world_size = torch.cuda.device_count()
678
  print(f" sampling, world_size = {world_size} ".center(100,'-'))
679
  # num_train_image_list = [1600,3200,6400,12800,25600]
680
+ num_train_image_list = [400]
681
+ num_new_img_per_gpu = 40
682
+ max_num_img_per_gpu = 20
683
 
684
  params = torch.tensor([4.4, 131.341])
685
 
 
690
 
691
  for num_image in num_train_image_list:
692
  config.num_image = num_image
693
+ config.resume = f"./outputs/model_state-N{num_image}-epoch{config.n_epoch-1}-device_count{torch.cuda.device_count()}"
694
 
695
  # print("ddpm21cm = DDPM21CM(config)")
696
  manager = mp.Manager()
learn_multi_node.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.distributed as dist
4
+ import torch.multiprocessing as mp
5
+
6
+ def setup(rank, world_size):
7
+ os.environ['MASTER_ADDR'] = 'localhost' # Replace with master node's IP
8
+ os.environ['MASTER_PORT'] = '12355'
9
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
10
+
11
+ def cleanup():
12
+ dist.destroy_process_group()
13
+
14
+ class MyDiffusionModel(torch.nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.conv1 = torch.nn.Conv2d(3, 16, 3, 1)
18
+ # self.conv2 = torch.nn.Conv2d(16, 32, 3, 1)
19
+ self.fc1 = torch.nn.Linear(32 * 6 * 6, 128)
20
+ # self.fc2 = torch.nn.Linear(128, 10)
21
+
22
+ def forward(self, x):
23
+ x = torch.nn.functional.relu(self.conv1(x))
24
+ x = torch.nn.functional.max_pool2d(x, 2)
25
+ # x = torch.nn.functional.relu(self.conv2(x))
26
+ # x = torch.nn.functional.max_pool2d(x, 2)
27
+ x = torch.flatten(x, 1)
28
+ x = torch.nn.functional.relu(self.fc1(x))
29
+ x = self.fc2(x)
30
+ return x
31
+
32
+ def main(rank, world_size):
33
+ setup(rank, world_size)
34
+
35
+ if torch.cuda.is_available():
36
+ num_gpus = torch.cuda.device_count()
37
+ print(f"Rank {rank}, Number of GPUs available: {num_gpus}")
38
+ for i in range(num_gpus):
39
+ print(f"Rank {rank}, GPU {i}: {torch.cuda.get_device_name(i)}")
40
+ else:
41
+ print(f"Rank {rank}, No GPUs available")
42
+
43
+ cleanup()
44
+
45
+ if __name__ == "__main__":
46
+ world_size = 1 # Number of nodes
47
+ mp.spawn(main, args=(world_size,), nprocs=torch.cuda.device_count(), join=True)