0725-1750
Browse files- diffusion.py +61 -50
diffusion.py
CHANGED
|
@@ -65,17 +65,18 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|
| 65 |
from torch.distributed import init_process_group, destroy_process_group
|
| 66 |
import torch.distributed as dist
|
| 67 |
|
|
|
|
|
|
|
| 68 |
# %%
|
| 69 |
-
def ddp_setup(rank: int, world_size: int):
|
| 70 |
"""
|
| 71 |
Args:
|
| 72 |
rank: Unique identifier of each process
|
| 73 |
world_size: Total number of processes
|
| 74 |
"""
|
| 75 |
-
os.environ["MASTER_ADDR"] =
|
| 76 |
-
os.environ["MASTER_PORT"] =
|
| 77 |
# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ddp_setup, rank =", rank)
|
| 78 |
-
torch.cuda.set_device(rank)
|
| 79 |
init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 80 |
|
| 81 |
# %%
|
|
@@ -240,7 +241,7 @@ class TrainConfig:
|
|
| 240 |
# dim = 2
|
| 241 |
dim = 2
|
| 242 |
stride = (2,4) if dim == 2 else (2,2,2)
|
| 243 |
-
num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 244 |
batch_size = 10#50#20#50#1#2#50#20#2#100 # 10
|
| 245 |
n_epoch = 50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
|
| 246 |
HII_DIM = 64
|
|
@@ -642,41 +643,30 @@ class DDPM21CM:
|
|
| 642 |
return x_last
|
| 643 |
# %%
|
| 644 |
|
| 645 |
-
num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
|
| 647 |
-
|
| 648 |
|
| 649 |
-
# print("before ddp_setup")
|
| 650 |
-
ddp_setup(rank, world_size)
|
| 651 |
-
# print("after ddp_setup")
|
| 652 |
-
# print("TrainConfig()")
|
| 653 |
config = TrainConfig()
|
| 654 |
-
config.device = f"cuda:{
|
| 655 |
-
# print("torch.cuda.current_device(), config.device =", torch.cuda.current_device(), config.device)
|
| 656 |
config.world_size = world_size
|
| 657 |
|
| 658 |
#[3200]#[200]#[1600,3200,6400,12800,25600]
|
| 659 |
-
for i, num_image in enumerate(num_train_image_list):
|
| 660 |
-
config.num_image = num_image
|
| 661 |
# config.world_size = world_size
|
| 662 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 663 |
# print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
if __name__ == "__main__":# and False:
|
| 671 |
-
world_size = torch.cuda.device_count()
|
| 672 |
-
print(f" training, world_size = {world_size} ".center(120,'-'))
|
| 673 |
-
# torch.multiprocessing.set_start_method("spawn")
|
| 674 |
-
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
| 675 |
-
|
| 676 |
-
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
| 677 |
-
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|
| 678 |
-
|
| 679 |
-
|
| 680 |
# %%
|
| 681 |
|
| 682 |
# def generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params):
|
|
@@ -705,8 +695,8 @@ if __name__ == "__main__":# and False:
|
|
| 705 |
# # else:
|
| 706 |
# # return None
|
| 707 |
|
| 708 |
-
def generate_samples(rank, world_size, config, num_new_img_per_gpu, max_num_img_per_gpu,
|
| 709 |
-
ddp_setup(rank, world_size)
|
| 710 |
ddpm21cm = DDPM21CM(config)
|
| 711 |
|
| 712 |
# generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params)
|
|
@@ -729,28 +719,43 @@ def generate_samples(rank, world_size, config, num_new_img_per_gpu, max_num_img_
|
|
| 729 |
|
| 730 |
|
| 731 |
if __name__ == "__main__":
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
num_new_img_per_gpu = 200
|
| 737 |
-
max_num_img_per_gpu = 20
|
| 738 |
|
| 739 |
-
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 745 |
|
| 746 |
-
|
| 747 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
|
| 749 |
# config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
|
| 750 |
|
| 751 |
-
#
|
| 752 |
-
|
| 753 |
-
return_dict = manager.dict()
|
| 754 |
|
| 755 |
params_pairs = [
|
| 756 |
(4.4, 131.341),
|
|
@@ -759,9 +764,15 @@ if __name__ == "__main__":
|
|
| 759 |
(5.477, 200),
|
| 760 |
(4.8, 131.341),
|
| 761 |
]
|
|
|
|
| 762 |
for params in params_pairs:
|
| 763 |
print(f" sampling for {params}, world_size = {world_size} ".center(120,'-'))
|
| 764 |
-
mp.spawn(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
|
| 766 |
# print("---"*30)
|
| 767 |
# print(f"cuda:{torch.cuda.current_device()}, keys = {return_dict.keys()}")
|
|
|
|
| 65 |
from torch.distributed import init_process_group, destroy_process_group
|
| 66 |
import torch.distributed as dist
|
| 67 |
|
| 68 |
+
import argparse
|
| 69 |
+
|
| 70 |
# %%
|
| 71 |
+
def ddp_setup(rank: int, world_size: int, master_addr, master_port):
|
| 72 |
"""
|
| 73 |
Args:
|
| 74 |
rank: Unique identifier of each process
|
| 75 |
world_size: Total number of processes
|
| 76 |
"""
|
| 77 |
+
os.environ["MASTER_ADDR"] = master_addr
|
| 78 |
+
os.environ["MASTER_PORT"] = master_port
|
| 79 |
# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ddp_setup, rank =", rank)
|
|
|
|
| 80 |
init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 81 |
|
| 82 |
# %%
|
|
|
|
| 241 |
# dim = 2
|
| 242 |
dim = 2
|
| 243 |
stride = (2,4) if dim == 2 else (2,2,2)
|
| 244 |
+
num_image = 60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 245 |
batch_size = 10#50#20#50#1#2#50#20#2#100 # 10
|
| 246 |
n_epoch = 50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
|
| 247 |
HII_DIM = 64
|
|
|
|
| 643 |
return x_last
|
| 644 |
# %%
|
| 645 |
|
| 646 |
+
#num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
|
| 647 |
+
def train(rank, world_size, local_world_size, master_addr, master_port):
|
| 648 |
+
ddp_setup(rank, world_size, master_addr, master_port)
|
| 649 |
+
|
| 650 |
+
local_rank = rank % local_world_size
|
| 651 |
+
torch.cuda.set_device(local_rank)
|
| 652 |
|
| 653 |
+
print(f"Global rank {rank}, local rank {local_rank}, current_device {torch.cuda.current_device()}")
|
| 654 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
config = TrainConfig()
|
| 656 |
+
config.device = f"cuda:{local_rank}"
|
|
|
|
| 657 |
config.world_size = world_size
|
| 658 |
|
| 659 |
#[3200]#[200]#[1600,3200,6400,12800,25600]
|
| 660 |
+
#for i, num_image in enumerate(num_train_image_list):
|
| 661 |
+
#config.num_image = num_image
|
| 662 |
# config.world_size = world_size
|
| 663 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 664 |
# print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
|
| 665 |
+
ddpm21cm = DDPM21CM(config)
|
| 666 |
+
# print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
|
| 667 |
+
print(f"run_name = {ddpm21cm.config.run_name}")
|
| 668 |
+
ddpm21cm.train()
|
| 669 |
+
destroy_process_group()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 670 |
# %%
|
| 671 |
|
| 672 |
# def generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params):
|
|
|
|
| 695 |
# # else:
|
| 696 |
# # return None
|
| 697 |
|
| 698 |
+
def generate_samples(rank, world_size, local_world_size, master_addr, master_port, config, num_new_img_per_gpu, max_num_img_per_gpu, params):
|
| 699 |
+
ddp_setup(rank, world_size, master_addr, master_port)
|
| 700 |
ddpm21cm = DDPM21CM(config)
|
| 701 |
|
| 702 |
# generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params)
|
|
|
|
| 719 |
|
| 720 |
|
| 721 |
if __name__ == "__main__":
|
| 722 |
+
parser = argparse.ArgumentParser()
|
| 723 |
+
parser.add_argument("--train", type=int, required=False, help="whether to train the model", default=1)
|
| 724 |
+
parser.add_argument("--sample", type=int, required=False, help="whether to sample", default=0)
|
| 725 |
+
args = parser.parse_args()
|
|
|
|
|
|
|
| 726 |
|
| 727 |
+
master_addr = os.environ["SLURM_NODELIST"].split(",")[0]
|
| 728 |
+
master_port = "12355"
|
| 729 |
+
world_size = int(os.environ["SLURM_NTASKS"])
|
| 730 |
+
local_world_size = torch.cuda.device_count()
|
| 731 |
|
| 732 |
+
############################ training ################################
|
| 733 |
+
world_size = torch.cuda.device_count()
|
| 734 |
+
if args.train:
|
| 735 |
+
print(f" training, world_size = {world_size} ".center(120,'-'))
|
| 736 |
+
mp.spawn(
|
| 737 |
+
train,
|
| 738 |
+
args=(world_size, local_world_size, master_addr, master_port),
|
| 739 |
+
nprocs=local_world_size,
|
| 740 |
+
join=True
|
| 741 |
+
)
|
| 742 |
|
| 743 |
+
|
| 744 |
+
############################ sampling ################################
|
| 745 |
+
if args.sample:
|
| 746 |
+
num_new_img_per_gpu = 200
|
| 747 |
+
max_num_img_per_gpu = 20
|
| 748 |
+
config = TrainConfig()
|
| 749 |
+
config.world_size = world_size
|
| 750 |
+
# print("config.world_size = world_size")
|
| 751 |
+
|
| 752 |
+
#for num_image in num_train_image_list:
|
| 753 |
+
#config.num_image = num_image# // world_size
|
| 754 |
config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
|
| 755 |
# config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
|
| 756 |
|
| 757 |
+
# manager = mp.Manager()
|
| 758 |
+
# return_dict = manager.dict()
|
|
|
|
| 759 |
|
| 760 |
params_pairs = [
|
| 761 |
(4.4, 131.341),
|
|
|
|
| 764 |
(5.477, 200),
|
| 765 |
(4.8, 131.341),
|
| 766 |
]
|
| 767 |
+
|
| 768 |
for params in params_pairs:
|
| 769 |
print(f" sampling for {params}, world_size = {world_size} ".center(120,'-'))
|
| 770 |
+
mp.spawn(
|
| 771 |
+
generate_samples,
|
| 772 |
+
args=(world_size, local_world_size, master_addr, master_port, config, num_new_img_per_gpu, max_num_img_per_gpu, torch.tensor(params)),
|
| 773 |
+
nprocs=local_world_size,
|
| 774 |
+
join=True
|
| 775 |
+
)
|
| 776 |
|
| 777 |
# print("---"*30)
|
| 778 |
# print(f"cuda:{torch.cuda.current_device()}, keys = {return_dict.keys()}")
|