0726-1429
Browse files- diffusion.py +5 -5
- load_h5.py +1 -1
- phoenix_diffusion.sbatch +1 -0
diffusion.py
CHANGED
|
@@ -651,12 +651,12 @@ class DDPM21CM:
|
|
| 651 |
|
| 652 |
#num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
|
| 653 |
def train(rank, world_size, local_world_size, master_addr, master_port):
|
| 654 |
-
print("before ddp_setup")
|
| 655 |
ddp_setup(rank, world_size, master_addr, master_port)
|
| 656 |
-
print("after ddp_setup")
|
| 657 |
local_rank = rank % local_world_size
|
| 658 |
torch.cuda.set_device(local_rank)
|
| 659 |
-
print("after set device")
|
| 660 |
|
| 661 |
config = TrainConfig()
|
| 662 |
config.device = f"cuda:{local_rank}"
|
|
@@ -668,11 +668,11 @@ def train(rank, world_size, local_world_size, master_addr, master_port):
|
|
| 668 |
# config.world_size = world_size
|
| 669 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 670 |
# print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
|
| 671 |
-
print("before dppm21cm")
|
| 672 |
ddpm21cm = DDPM21CM(config)
|
| 673 |
# print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
|
| 674 |
# print(f"run_name = {ddpm21cm.config.run_name}")
|
| 675 |
-
print(f"run_name {ddpm21cm.config.run_name}, global_rank {rank}, local_rank {local_rank}, current_device {torch.cuda.current_device()}, local_world_size {local_world_size}, world_size {world_size}")
|
| 676 |
ddpm21cm.train()
|
| 677 |
destroy_process_group()
|
| 678 |
# %%
|
|
|
|
| 651 |
|
| 652 |
#num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
|
| 653 |
def train(rank, world_size, local_world_size, master_addr, master_port):
|
| 654 |
+
#print("before ddp_setup")
|
| 655 |
ddp_setup(rank, world_size, master_addr, master_port)
|
| 656 |
+
#print("after ddp_setup")
|
| 657 |
local_rank = rank % local_world_size
|
| 658 |
torch.cuda.set_device(local_rank)
|
| 659 |
+
#print("after set device")
|
| 660 |
|
| 661 |
config = TrainConfig()
|
| 662 |
config.device = f"cuda:{local_rank}"
|
|
|
|
| 668 |
# config.world_size = world_size
|
| 669 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 670 |
# print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
|
| 671 |
+
#print("before dppm21cm")
|
| 672 |
ddpm21cm = DDPM21CM(config)
|
| 673 |
# print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
|
| 674 |
# print(f"run_name = {ddpm21cm.config.run_name}")
|
| 675 |
+
#print(f"run_name {ddpm21cm.config.run_name}, global_rank {rank}, local_rank {local_rank}, current_device {torch.cuda.current_device()}, local_world_size {local_world_size}, world_size {world_size}")
|
| 676 |
ddpm21cm.train()
|
| 677 |
destroy_process_group()
|
| 678 |
# %%
|
load_h5.py
CHANGED
|
@@ -168,7 +168,7 @@ class Dataset4h5(Dataset):
|
|
| 168 |
param_start = time()
|
| 169 |
params = f['params']['values'][idx]
|
| 170 |
param_end = time()
|
| 171 |
-
print(f"
|
| 172 |
|
| 173 |
return images, params
|
| 174 |
|
|
|
|
| 168 |
param_start = time()
|
| 169 |
params = f['params']['values'][idx]
|
| 170 |
param_end = time()
|
| 171 |
+
print(f"{socket.gethostbyname(socket.gethostname())}, cuda:{torch.cuda.current_device()}, CPU-pid {cpu_num}-{pid}: images {images.shape} & params {params.shape} loaded after {images_end-images_start:.3f}s & {param_end-param_start:.3f}s")
|
| 172 |
|
| 173 |
return images, params
|
| 174 |
|
phoenix_diffusion.sbatch
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
#SBATCH -A gts-jw254-coda20
|
| 4 |
#SBATCH -qembers
|
| 5 |
#SBATCH -N2 --gpus-per-node=RTX_6000:3 # -C A100-80GB # Number of nodes and cores per node required
|
|
|
|
| 6 |
#SBATCH --mem-per-gpu=32G # Memory per core
|
| 7 |
#SBATCH -t 10:00 # Duration of the job (Ex: 15 mins)
|
| 8 |
#SBATCH -oReport-%j # Combined output and error messages file
|
|
|
|
| 3 |
#SBATCH -A gts-jw254-coda20
|
| 4 |
#SBATCH -qembers
|
| 5 |
#SBATCH -N2 --gpus-per-node=RTX_6000:3 # -C A100-80GB # Number of nodes and cores per node required
|
| 6 |
+
#SBATCH --ntasks-per-gpu=1
|
| 7 |
#SBATCH --mem-per-gpu=32G # Memory per core
|
| 8 |
#SBATCH -t 10:00 # Duration of the job (Ex: 15 mins)
|
| 9 |
#SBATCH -oReport-%j # Combined output and error messages file
|