0718-2149
Browse files- diffusion.py +9 -9
- learn_multi_node.py +47 -0
diffusion.py
CHANGED
|
@@ -235,11 +235,11 @@ class TrainConfig:
|
|
| 235 |
# repeat = 2
|
| 236 |
|
| 237 |
# dim = 2
|
| 238 |
-
dim =
|
| 239 |
stride = (2,2) if dim == 2 else (2,2,4)
|
| 240 |
num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 241 |
-
batch_size = 1#2#50#20#2#100 # 10
|
| 242 |
-
n_epoch =
|
| 243 |
HII_DIM = 64
|
| 244 |
num_redshift = 512#128#64#512#256#256#64#512#128
|
| 245 |
channel = 1
|
|
@@ -499,7 +499,7 @@ class DDPM21CM:
|
|
| 499 |
'unet_state_dict': self.nn_model.module.state_dict(),
|
| 500 |
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 501 |
}
|
| 502 |
-
save_name = self.config.save_name+f"-N{self.config.num_image}-epoch{ep}-
|
| 503 |
torch.save(model_state, save_name)
|
| 504 |
print(f'device {torch.cuda.current_device()} saved model at ' + save_name)
|
| 505 |
# print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
|
|
@@ -586,7 +586,7 @@ def train(rank, world_size):
|
|
| 586 |
|
| 587 |
ddp_setup(rank, world_size)
|
| 588 |
|
| 589 |
-
num_train_image_list = [
|
| 590 |
for i, num_image in enumerate(num_train_image_list):
|
| 591 |
config.num_image = num_image
|
| 592 |
# config.world_size = world_size
|
|
@@ -677,9 +677,9 @@ if __name__ == "__main__":
|
|
| 677 |
world_size = torch.cuda.device_count()
|
| 678 |
print(f" sampling, world_size = {world_size} ".center(100,'-'))
|
| 679 |
# num_train_image_list = [1600,3200,6400,12800,25600]
|
| 680 |
-
num_train_image_list = [
|
| 681 |
-
num_new_img_per_gpu =
|
| 682 |
-
max_num_img_per_gpu =
|
| 683 |
|
| 684 |
params = torch.tensor([4.4, 131.341])
|
| 685 |
|
|
@@ -690,7 +690,7 @@ if __name__ == "__main__":
|
|
| 690 |
|
| 691 |
for num_image in num_train_image_list:
|
| 692 |
config.num_image = num_image
|
| 693 |
-
config.resume = f"./outputs/model_state-N{num_image}-
|
| 694 |
|
| 695 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 696 |
manager = mp.Manager()
|
|
|
|
| 235 |
# repeat = 2
|
| 236 |
|
| 237 |
# dim = 2
|
| 238 |
+
dim = 2
|
| 239 |
stride = (2,2) if dim == 2 else (2,2,4)
|
| 240 |
num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 241 |
+
batch_size = 10#1#2#50#20#2#100 # 10
|
| 242 |
+
n_epoch = 5#4# 10#50#20#20#2#5#25 # 120
|
| 243 |
HII_DIM = 64
|
| 244 |
num_redshift = 512#128#64#512#256#256#64#512#128
|
| 245 |
channel = 1
|
|
|
|
| 499 |
'unet_state_dict': self.nn_model.module.state_dict(),
|
| 500 |
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 501 |
}
|
| 502 |
+
save_name = self.config.save_name+f"-N{self.config.num_image}-epoch{ep}-device_count{torch.cuda.device_count()}"
|
| 503 |
torch.save(model_state, save_name)
|
| 504 |
print(f'device {torch.cuda.current_device()} saved model at ' + save_name)
|
| 505 |
# print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
|
|
|
|
| 586 |
|
| 587 |
ddp_setup(rank, world_size)
|
| 588 |
|
| 589 |
+
num_train_image_list = [400]#[3200]#[200]#[1600,3200,6400,12800,25600]
|
| 590 |
for i, num_image in enumerate(num_train_image_list):
|
| 591 |
config.num_image = num_image
|
| 592 |
# config.world_size = world_size
|
|
|
|
| 677 |
world_size = torch.cuda.device_count()
|
| 678 |
print(f" sampling, world_size = {world_size} ".center(100,'-'))
|
| 679 |
# num_train_image_list = [1600,3200,6400,12800,25600]
|
| 680 |
+
num_train_image_list = [400]
|
| 681 |
+
num_new_img_per_gpu = 40
|
| 682 |
+
max_num_img_per_gpu = 20
|
| 683 |
|
| 684 |
params = torch.tensor([4.4, 131.341])
|
| 685 |
|
|
|
|
| 690 |
|
| 691 |
for num_image in num_train_image_list:
|
| 692 |
config.num_image = num_image
|
| 693 |
+
config.resume = f"./outputs/model_state-N{num_image}-epoch{config.n_epoch-1}-device_count{torch.cuda.device_count()}"
|
| 694 |
|
| 695 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 696 |
manager = mp.Manager()
|
learn_multi_node.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
import torch.multiprocessing as mp
|
| 5 |
+
|
| 6 |
+
def setup(rank, world_size):
|
| 7 |
+
os.environ['MASTER_ADDR'] = 'localhost' # Replace with master node's IP
|
| 8 |
+
os.environ['MASTER_PORT'] = '12355'
|
| 9 |
+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
| 10 |
+
|
| 11 |
+
def cleanup():
|
| 12 |
+
dist.destroy_process_group()
|
| 13 |
+
|
| 14 |
+
class MyDiffusionModel(torch.nn.Module):
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.conv1 = torch.nn.Conv2d(3, 16, 3, 1)
|
| 18 |
+
# self.conv2 = torch.nn.Conv2d(16, 32, 3, 1)
|
| 19 |
+
self.fc1 = torch.nn.Linear(32 * 6 * 6, 128)
|
| 20 |
+
# self.fc2 = torch.nn.Linear(128, 10)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
x = torch.nn.functional.relu(self.conv1(x))
|
| 24 |
+
x = torch.nn.functional.max_pool2d(x, 2)
|
| 25 |
+
# x = torch.nn.functional.relu(self.conv2(x))
|
| 26 |
+
# x = torch.nn.functional.max_pool2d(x, 2)
|
| 27 |
+
x = torch.flatten(x, 1)
|
| 28 |
+
x = torch.nn.functional.relu(self.fc1(x))
|
| 29 |
+
x = self.fc2(x)
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
def main(rank, world_size):
|
| 33 |
+
setup(rank, world_size)
|
| 34 |
+
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
num_gpus = torch.cuda.device_count()
|
| 37 |
+
print(f"Rank {rank}, Number of GPUs available: {num_gpus}")
|
| 38 |
+
for i in range(num_gpus):
|
| 39 |
+
print(f"Rank {rank}, GPU {i}: {torch.cuda.get_device_name(i)}")
|
| 40 |
+
else:
|
| 41 |
+
print(f"Rank {rank}, No GPUs available")
|
| 42 |
+
|
| 43 |
+
cleanup()
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
world_size = 1 # Number of nodes
|
| 47 |
+
mp.spawn(main, args=(world_size,), nprocs=torch.cuda.device_count(), join=True)
|