0728-1339
Browse files- diffusion.py +6 -6
- load_h5.py +2 -2
diffusion.py
CHANGED
|
@@ -81,10 +81,10 @@ def ddp_setup(rank: int, world_size: int, local_world_size, master_addr, master_
|
|
| 81 |
world_size: Total number of processes
|
| 82 |
"""
|
| 83 |
|
| 84 |
-
print("inside ddp_setup")
|
| 85 |
os.environ["MASTER_ADDR"] = master_addr
|
| 86 |
os.environ["MASTER_PORT"] = master_port
|
| 87 |
-
print("ddp_setup, rank =", rank)
|
| 88 |
init_process_group(
|
| 89 |
backend="nccl",
|
| 90 |
init_method=f"tcp://{master_addr}:{master_port}",
|
|
@@ -574,7 +574,7 @@ class DDPM21CM:
|
|
| 574 |
'unet_state_dict': self.nn_model.module.state_dict(),
|
| 575 |
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 576 |
}
|
| 577 |
-
save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-epoch{ep}"
|
| 578 |
torch.save(model_state, save_name)
|
| 579 |
print(f'cuda:{torch.cuda.current_device()} saved model at ' + save_name)
|
| 580 |
# print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
|
|
@@ -658,12 +658,12 @@ class DDPM21CM:
|
|
| 658 |
|
| 659 |
#num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
|
| 660 |
def train(rank, world_size, local_world_size, master_addr, master_port):
|
| 661 |
-
print("before ddp_setup")
|
| 662 |
ddp_setup(rank, world_size, local_world_size, master_addr, master_port)
|
| 663 |
-
print("after ddp_setup")
|
| 664 |
local_rank = rank % local_world_size
|
| 665 |
torch.cuda.set_device(local_rank)
|
| 666 |
-
print("after set device")
|
| 667 |
print(f"rank = {rank}, local_rank = {local_rank}, world_size = {world_size}, local_world_size = {local_world_size}")
|
| 668 |
|
| 669 |
config = TrainConfig()
|
|
|
|
| 81 |
world_size: Total number of processes
|
| 82 |
"""
|
| 83 |
|
| 84 |
+
#print("inside ddp_setup")
|
| 85 |
os.environ["MASTER_ADDR"] = master_addr
|
| 86 |
os.environ["MASTER_PORT"] = master_port
|
| 87 |
+
#print("ddp_setup, rank =", rank)
|
| 88 |
init_process_group(
|
| 89 |
backend="nccl",
|
| 90 |
init_method=f"tcp://{master_addr}:{master_port}",
|
|
|
|
| 574 |
'unet_state_dict': self.nn_model.module.state_dict(),
|
| 575 |
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 576 |
}
|
| 577 |
+
save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-epoch{ep}-{socket.gethostbyname(socket.gethostname())}"
|
| 578 |
torch.save(model_state, save_name)
|
| 579 |
print(f'cuda:{torch.cuda.current_device()} saved model at ' + save_name)
|
| 580 |
# print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
|
|
|
|
| 658 |
|
| 659 |
#num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
|
| 660 |
def train(rank, world_size, local_world_size, master_addr, master_port):
|
| 661 |
+
#print("before ddp_setup")
|
| 662 |
ddp_setup(rank, world_size, local_world_size, master_addr, master_port)
|
| 663 |
+
#print("after ddp_setup")
|
| 664 |
local_rank = rank % local_world_size
|
| 665 |
torch.cuda.set_device(local_rank)
|
| 666 |
+
#print("after set device")
|
| 667 |
print(f"rank = {rank}, local_rank = {local_rank}, world_size = {world_size}, local_world_size = {local_world_size}")
|
| 668 |
|
| 669 |
config = TrainConfig()
|
load_h5.py
CHANGED
|
@@ -96,10 +96,10 @@ class Dataset4h5(Dataset):
|
|
| 96 |
print(f"dataset content: {f.keys()}")
|
| 97 |
max_num_image = len(f['brightness_temp'])#.shape[0]
|
| 98 |
field_shape = f['brightness_temp'].shape[1:]
|
| 99 |
-
print(f"{max_num_image} images of shape {field_shape} can be loaded")
|
| 100 |
#print(f"field.shape = {field_shape}")
|
| 101 |
self.params_keys = list(f['params']['keys'])
|
| 102 |
-
print(f"
|
|
|
|
| 103 |
|
| 104 |
# if self.idx is None:
|
| 105 |
# if self.shuffle:
|
|
|
|
| 96 |
print(f"dataset content: {f.keys()}")
|
| 97 |
max_num_image = len(f['brightness_temp'])#.shape[0]
|
| 98 |
field_shape = f['brightness_temp'].shape[1:]
|
|
|
|
| 99 |
#print(f"field.shape = {field_shape}")
|
| 100 |
self.params_keys = list(f['params']['keys'])
|
| 101 |
+
print(f"{max_num_image} images of shape {field_shape} can be loaded with different params.keys {self.params_keys}")
|
| 102 |
+
#print(f"params keys = {self.params_keys}")
|
| 103 |
|
| 104 |
# if self.idx is None:
|
| 105 |
# if self.shuffle:
|