17184615
Browse files- context_unet.py +1 -1
- diffusion.py +18 -14
- phoenix_diffusion.sbatch +2 -3
context_unet.py
CHANGED
|
@@ -334,7 +334,7 @@ class ContextUnet(nn.Module):
|
|
| 334 |
elif image_size == 128:
|
| 335 |
channel_mult = (1, 1, 2, 3, 4)
|
| 336 |
elif image_size == 64:
|
| 337 |
-
channel_mult = (1,2,2,
|
| 338 |
elif image_size == 32:
|
| 339 |
channel_mult = (1, 2, 2, 4)
|
| 340 |
elif image_size == 28:
|
|
|
|
| 334 |
elif image_size == 128:
|
| 335 |
channel_mult = (1, 1, 2, 3, 4)
|
| 336 |
elif image_size == 64:
|
| 337 |
+
channel_mult = (1,2,2,2,4)#(1,1,2,2,4)#(1,1,1,2,2)#(0.5,1,1,2,2)#(1,1,2)#(1,2)#(1,1,2,2)#(1,1,2,2,4)#(2,2,4,4,4)#(1, 2, 4)#(2,4,4,4,8)#(1, 2, 2, 4, 4)#(1, 2, 2, 4, 8)#(1, 1, 2, 2, 4, 4)#(1, 2, 4, 8, 16)#(1, 2, 3, 4)#(1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
|
| 338 |
elif image_size == 32:
|
| 339 |
channel_mult = (1, 2, 2, 4)
|
| 340 |
elif image_size == 28:
|
diffusion.py
CHANGED
|
@@ -271,7 +271,7 @@ class TrainConfig:
|
|
| 271 |
stride = (2,4) if dim == 2 else (2,2,2)
|
| 272 |
num_image = 2000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 273 |
batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
|
| 274 |
-
n_epoch =
|
| 275 |
HII_DIM = 64
|
| 276 |
num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
|
| 277 |
channel = 1
|
|
@@ -304,7 +304,7 @@ class TrainConfig:
|
|
| 304 |
lrate = 1e-4
|
| 305 |
lr_warmup_steps = 0#5#00
|
| 306 |
output_dir = "./outputs/"
|
| 307 |
-
save_name = os.path.join(output_dir, '
|
| 308 |
# save_period = 1 #10 # the period of saving model
|
| 309 |
# cond = True # if training using the conditional information
|
| 310 |
# lr_decay = False #True# if using the learning rate decay
|
|
@@ -378,7 +378,7 @@ class DDPM21CM:
|
|
| 378 |
# )
|
| 379 |
# config = TrainConfig()
|
| 380 |
# date = datetime.datetime.now().strftime("%m%d-%H%M")
|
| 381 |
-
config.run_name = datetime.datetime.now().strftime("%
|
| 382 |
self.config = config
|
| 383 |
# dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)
|
| 384 |
# # self.shape_loaded = dataset.images.shape
|
|
@@ -537,7 +537,7 @@ class DDPM21CM:
|
|
| 537 |
loss = F.mse_loss(noise, noise_pred)
|
| 538 |
#print(f"loss.dtype =", loss.dtype)
|
| 539 |
self.accelerator.backward(loss)
|
| 540 |
-
self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)
|
| 541 |
self.optimizer.step()
|
| 542 |
self.lr_scheduler.step()
|
| 543 |
self.optimizer.zero_grad()
|
|
@@ -593,7 +593,7 @@ class DDPM21CM:
|
|
| 593 |
'unet_state_dict': self.nn_model.module.state_dict(),
|
| 594 |
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 595 |
}
|
| 596 |
-
save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-node{int(os.environ['SLURM_NNODES'])}-epoch{ep}-{
|
| 597 |
torch.save(model_state, save_name)
|
| 598 |
print(f'cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved model at ' + save_name)
|
| 599 |
# print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
|
|
@@ -676,13 +676,15 @@ class DDPM21CM:
|
|
| 676 |
|
| 677 |
if save:
|
| 678 |
# np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
|
| 679 |
-
savetime = datetime.datetime.now().strftime("%
|
| 680 |
-
savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{self.config.global_rank}-{os.path.basename(self.config.resume)}-{savetime}{'ema' if ema else ''}.npy")
|
|
|
|
|
|
|
| 681 |
np.save(savename, x_last)
|
| 682 |
print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved images of shape {x_last.shape} to {savename}")
|
| 683 |
|
| 684 |
if entire:
|
| 685 |
-
savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{self.config.global_rank}-{os.path.basename(self.config.resume)}-{savetime}{'ema' if ema else ''}_entire.npy")
|
| 686 |
np.save(savename, x_entire)
|
| 687 |
print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved images of shape {x_entire.shape} to {savename}")
|
| 688 |
# else:
|
|
@@ -690,13 +692,13 @@ class DDPM21CM:
|
|
| 690 |
# %%
|
| 691 |
|
| 692 |
#num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
|
| 693 |
-
def train(rank, world_size, local_world_size, master_addr, master_port):
|
| 694 |
global_rank = rank + local_world_size * int(os.environ["SLURM_NODEID"])
|
| 695 |
ddp_setup(global_rank, world_size, master_addr, master_port)
|
| 696 |
torch.cuda.set_device(rank)
|
| 697 |
#print(f"rank = {rank}, global_rank = {global_rank}, world_size = {world_size}, local_world_size = {local_world_size}")
|
| 698 |
|
| 699 |
-
config = TrainConfig()
|
| 700 |
config.device = f"cuda:{rank}"
|
| 701 |
config.world_size = local_world_size
|
| 702 |
config.global_rank = global_rank
|
|
@@ -753,12 +755,14 @@ if __name__ == "__main__":
|
|
| 753 |
total_nodes = int(os.environ["SLURM_NNODES"])
|
| 754 |
world_size = local_world_size * total_nodes #6#int(os.environ["SLURM_NTASKS"])
|
| 755 |
|
|
|
|
|
|
|
| 756 |
############################ training ################################
|
| 757 |
if args.train == 1:
|
| 758 |
print(f" training, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size} ".center(120,'-'))
|
| 759 |
mp.spawn(
|
| 760 |
train,
|
| 761 |
-
args=(world_size, local_world_size, master_addr, master_port),
|
| 762 |
nprocs=local_world_size,
|
| 763 |
join=True,
|
| 764 |
)
|
|
@@ -766,11 +770,11 @@ if __name__ == "__main__":
|
|
| 766 |
if args.train == 0:
|
| 767 |
num_new_img_per_gpu = args.num_new_img_per_gpu#200#4#200
|
| 768 |
max_num_img_per_gpu = args.max_num_img_per_gpu#40#2#20
|
| 769 |
-
config = TrainConfig()
|
| 770 |
#config.world_size = world_size
|
| 771 |
-
config.dtype = torch.float32
|
| 772 |
config.resume = args.resume
|
| 773 |
-
config.gradient_accumulation_steps = args.gradient_accumulation_steps
|
| 774 |
# config.resume = f"./outputs/model_state-N30-device_count3-epoch4-172.27.149.181"
|
| 775 |
# config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
|
| 776 |
# config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
|
|
|
|
| 271 |
stride = (2,4) if dim == 2 else (2,2,2)
|
| 272 |
num_image = 2000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 273 |
batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
|
| 274 |
+
n_epoch = 20#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
|
| 275 |
HII_DIM = 64
|
| 276 |
num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
|
| 277 |
channel = 1
|
|
|
|
| 304 |
lrate = 1e-4
|
| 305 |
lr_warmup_steps = 0#5#00
|
| 306 |
output_dir = "./outputs/"
|
| 307 |
+
save_name = os.path.join(output_dir, 'model')
|
| 308 |
# save_period = 1 #10 # the period of saving model
|
| 309 |
# cond = True # if training using the conditional information
|
| 310 |
# lr_decay = False #True# if using the learning rate decay
|
|
|
|
| 378 |
# )
|
| 379 |
# config = TrainConfig()
|
| 380 |
# date = datetime.datetime.now().strftime("%m%d-%H%M")
|
| 381 |
+
config.run_name = datetime.datetime.now().strftime("%d%H%M%S") # the unique name of each experiment
|
| 382 |
self.config = config
|
| 383 |
# dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)
|
| 384 |
# # self.shape_loaded = dataset.images.shape
|
|
|
|
| 537 |
loss = F.mse_loss(noise, noise_pred)
|
| 538 |
#print(f"loss.dtype =", loss.dtype)
|
| 539 |
self.accelerator.backward(loss)
|
| 540 |
+
#self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)
|
| 541 |
self.optimizer.step()
|
| 542 |
self.lr_scheduler.step()
|
| 543 |
self.optimizer.zero_grad()
|
|
|
|
| 593 |
'unet_state_dict': self.nn_model.module.state_dict(),
|
| 594 |
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 595 |
}
|
| 596 |
+
save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-node{int(os.environ['SLURM_NNODES'])}-epoch{ep}-{self.config.run_name}"
|
| 597 |
torch.save(model_state, save_name)
|
| 598 |
print(f'cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved model at ' + save_name)
|
| 599 |
# print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
|
|
|
|
| 676 |
|
| 677 |
if save:
|
| 678 |
# np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
|
| 679 |
+
savetime = datetime.datetime.now().strftime("%d%H%M%S")
|
| 680 |
+
savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]:.3f}-zeta{params_backup[1]:.3f}-N{self.config.num_image}-device{self.config.global_rank}-{os.path.basename(self.config.resume)}-{savetime}{'ema' if ema else ''}.npy")
|
| 681 |
+
if not os.path.exists(self.config.output_dir):
|
| 682 |
+
os.makedirs(self.config.output_dir)
|
| 683 |
np.save(savename, x_last)
|
| 684 |
print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved images of shape {x_last.shape} to {savename}")
|
| 685 |
|
| 686 |
if entire:
|
| 687 |
+
savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]:.3f}-zeta{params_backup[1]:.3f}-N{self.config.num_image}-device{self.config.global_rank}-{os.path.basename(self.config.resume)}-{savetime}{'ema' if ema else ''}_entire.npy")
|
| 688 |
np.save(savename, x_entire)
|
| 689 |
print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved images of shape {x_entire.shape} to {savename}")
|
| 690 |
# else:
|
|
|
|
| 692 |
# %%
|
| 693 |
|
| 694 |
#num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
|
| 695 |
+
def train(rank, world_size, local_world_size, master_addr, master_port, config):
|
| 696 |
global_rank = rank + local_world_size * int(os.environ["SLURM_NODEID"])
|
| 697 |
ddp_setup(global_rank, world_size, master_addr, master_port)
|
| 698 |
torch.cuda.set_device(rank)
|
| 699 |
#print(f"rank = {rank}, global_rank = {global_rank}, world_size = {world_size}, local_world_size = {local_world_size}")
|
| 700 |
|
| 701 |
+
#config = TrainConfig()
|
| 702 |
config.device = f"cuda:{rank}"
|
| 703 |
config.world_size = local_world_size
|
| 704 |
config.global_rank = global_rank
|
|
|
|
| 755 |
total_nodes = int(os.environ["SLURM_NNODES"])
|
| 756 |
world_size = local_world_size * total_nodes #6#int(os.environ["SLURM_NTASKS"])
|
| 757 |
|
| 758 |
+
config = TrainConfig()
|
| 759 |
+
config.gradient_accumulation_steps = args.gradient_accumulation_steps
|
| 760 |
############################ training ################################
|
| 761 |
if args.train == 1:
|
| 762 |
print(f" training, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size} ".center(120,'-'))
|
| 763 |
mp.spawn(
|
| 764 |
train,
|
| 765 |
+
args=(world_size, local_world_size, master_addr, master_port, config),
|
| 766 |
nprocs=local_world_size,
|
| 767 |
join=True,
|
| 768 |
)
|
|
|
|
| 770 |
if args.train == 0:
|
| 771 |
num_new_img_per_gpu = args.num_new_img_per_gpu#200#4#200
|
| 772 |
max_num_img_per_gpu = args.max_num_img_per_gpu#40#2#20
|
| 773 |
+
#config = TrainConfig()
|
| 774 |
#config.world_size = world_size
|
| 775 |
+
#config.dtype = torch.float32
|
| 776 |
config.resume = args.resume
|
| 777 |
+
#config.gradient_accumulation_steps = args.gradient_accumulation_steps
|
| 778 |
# config.resume = f"./outputs/model_state-N30-device_count3-epoch4-172.27.149.181"
|
| 779 |
# config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
|
| 780 |
# config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
|
phoenix_diffusion.sbatch
CHANGED
|
@@ -7,7 +7,6 @@
|
|
| 7 |
#SBATCH --mem-per-gpu=16G # Memory per core
|
| 8 |
#SBATCH -t 08:00:00 # Duration of the job (Ex: 15 mins)
|
| 9 |
#SBATCH -oReport-%j # Combined output and error messages file
|
| 10 |
-
#SBATCH --error=error-%j
|
| 11 |
#SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
|
| 12 |
|
| 13 |
#module load gcc/10.3.0-o57x6h
|
|
@@ -31,9 +30,9 @@ export MASTER_PORT=$MASTER_PORT
|
|
| 31 |
|
| 32 |
srun python diffusion.py \
|
| 33 |
--train 1 \
|
| 34 |
-
--resume outputs/model_state-
|
| 35 |
--num_new_img_per_gpu 50 \
|
| 36 |
--max_num_img_per_gpu 2 \
|
| 37 |
-
--gradient_accumulation_steps
|
| 38 |
######################################################################################
|
| 39 |
|
|
|
|
| 7 |
#SBATCH --mem-per-gpu=16G # Memory per core
|
| 8 |
#SBATCH -t 08:00:00 # Duration of the job (Ex: 15 mins)
|
| 9 |
#SBATCH -oReport-%j # Combined output and error messages file
|
|
|
|
| 10 |
#SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
|
| 11 |
|
| 12 |
#module load gcc/10.3.0-o57x6h
|
|
|
|
| 30 |
|
| 31 |
srun python diffusion.py \
|
| 32 |
--train 1 \
|
| 33 |
+
--resume outputs/model_state-N2000-device_count1-node8-epoch19-172.27.144.173 \
|
| 34 |
--num_new_img_per_gpu 50 \
|
| 35 |
--max_num_img_per_gpu 2 \
|
| 36 |
+
--gradient_accumulation_steps 40 \
|
| 37 |
######################################################################################
|
| 38 |
|