20051216
Browse files- context_unet.py +1 -1
- diffusion.py +15 -6
- load_h5.py +1 -1
- perlmutter_diffusion.sbatch +13 -19
- quantify_results.ipynb +2 -2
context_unet.py
CHANGED
|
@@ -179,7 +179,7 @@ class ResBlock(TimestepBlock):
|
|
| 179 |
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
| 180 |
),
|
| 181 |
)
|
| 182 |
-
|
| 183 |
self.out_layers = nn.Sequential(
|
| 184 |
# nn.BatchNorm2d(self.out_channels),
|
| 185 |
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
|
|
|
| 179 |
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
| 180 |
),
|
| 181 |
)
|
| 182 |
+
print(f"resnet: dropout = {dropout}")
|
| 183 |
self.out_layers = nn.Sequential(
|
| 184 |
# nn.BatchNorm2d(self.out_channels),
|
| 185 |
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
diffusion.py
CHANGED
|
@@ -51,6 +51,8 @@ from time import time
|
|
| 51 |
from torch.cuda.amp import autocast, GradScaler
|
| 52 |
from random import getrandbits
|
| 53 |
|
|
|
|
|
|
|
| 54 |
# %%
|
| 55 |
def ddp_setup(rank: int, world_size: int, master_addr, master_port):
|
| 56 |
"""
|
|
@@ -268,7 +270,8 @@ class TrainConfig:
|
|
| 268 |
# n_sample = 24 # 64, the number of samples in sampling process
|
| 269 |
n_param = 2
|
| 270 |
guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance
|
| 271 |
-
|
|
|
|
| 272 |
ema=False # whether to use ema
|
| 273 |
ema_rate=0.995
|
| 274 |
|
|
@@ -365,7 +368,7 @@ class DDPM21CM:
|
|
| 365 |
self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, config=config,)#, dtype=config.dtype
|
| 366 |
|
| 367 |
# initialize the unet
|
| 368 |
-
self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, channel_mult=config.channel_mult, use_checkpoint=config.use_checkpoint)#, dtype=config.dtype)
|
| 369 |
|
| 370 |
self.nn_model.train()
|
| 371 |
self.nn_model.to(self.ddpm.device)
|
|
@@ -386,7 +389,7 @@ class DDPM21CM:
|
|
| 386 |
if config.ema:
|
| 387 |
self.ema = EMA(config.ema_rate)
|
| 388 |
if config.resume and os.path.exists(config.resume):
|
| 389 |
-
self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)#, dtype=config.dtype
|
| 390 |
self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])
|
| 391 |
print(f"resumed ema_model from {config.resume}")
|
| 392 |
else:
|
|
@@ -409,7 +412,7 @@ class DDPM21CM:
|
|
| 409 |
HII_DIM=self.config.HII_DIM,
|
| 410 |
num_redshift=self.config.num_redshift,
|
| 411 |
startat=self.config.startat,
|
| 412 |
-
drop_prob=self.config.drop_prob,
|
| 413 |
dim=self.config.dim,
|
| 414 |
ranges_dict=self.ranges_dict,
|
| 415 |
num_workers=min(1,len(os.sched_getaffinity(0))//self.config.world_size),
|
|
@@ -505,6 +508,10 @@ class DDPM21CM:
|
|
| 505 |
c = c.to(self.config.device)
|
| 506 |
noise_pred = self.nn_model(xt, ts, c)#.to(x.dtype)
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
loss = F.mse_loss(noise, noise_pred)
|
| 509 |
loss = loss / self.config.gradient_accumulation_steps
|
| 510 |
|
|
@@ -610,7 +617,7 @@ class DDPM21CM:
|
|
| 610 |
params_backup = params.numpy().copy()
|
| 611 |
params_normalized = self.rescale(params, self.ranges_dict['params'], to=[0,1])
|
| 612 |
|
| 613 |
-
print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} sampling {num_new_img_per_gpu} images with normalized params = {params_normalized}")
|
| 614 |
params_normalized = params_normalized.repeat(num_new_img_per_gpu,1)
|
| 615 |
assert params_normalized.dim() == 2, "params_normalized must be a 2D torch.tensor"
|
| 616 |
# print("params =", params)
|
|
@@ -705,6 +712,7 @@ if __name__ == "__main__":
|
|
| 705 |
parser.add_argument("--channel_mult", type=float, nargs="+", required=False, default=(1,2,2,2,4))
|
| 706 |
parser.add_argument("--autocast", type=int, required=False, default=False)
|
| 707 |
parser.add_argument("--use_checkpoint", type=int, required=False, default=False)
|
|
|
|
| 708 |
|
| 709 |
args = parser.parse_args()
|
| 710 |
|
|
@@ -722,6 +730,7 @@ if __name__ == "__main__":
|
|
| 722 |
config.channel_mult = args.channel_mult
|
| 723 |
config.autocast = bool(args.autocast)
|
| 724 |
config.use_checkpoint = bool(args.use_checkpoint)
|
|
|
|
| 725 |
|
| 726 |
############################ training ################################
|
| 727 |
if args.train:
|
|
@@ -756,7 +765,7 @@ if __name__ == "__main__":
|
|
| 756 |
]
|
| 757 |
|
| 758 |
for params in params_pairs:
|
| 759 |
-
print(f"sampling
|
| 760 |
mp.spawn(
|
| 761 |
generate_samples,
|
| 762 |
args=(world_size, local_world_size, master_addr, master_port, config, num_new_img_per_gpu, max_num_img_per_gpu, torch.tensor(params)),
|
|
|
|
| 51 |
from torch.cuda.amp import autocast, GradScaler
|
| 52 |
from random import getrandbits
|
| 53 |
|
| 54 |
+
import subprocess
|
| 55 |
+
|
| 56 |
# %%
|
| 57 |
def ddp_setup(rank: int, world_size: int, master_addr, master_port):
|
| 58 |
"""
|
|
|
|
| 270 |
# n_sample = 24 # 64, the number of samples in sampling process
|
| 271 |
n_param = 2
|
| 272 |
guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance
|
| 273 |
+
dropout = 0
|
| 274 |
+
#drop_prob = 0.1 #0.28 # only takes effect when guide_w != -1
|
| 275 |
ema=False # whether to use ema
|
| 276 |
ema_rate=0.995
|
| 277 |
|
|
|
|
| 368 |
self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, config=config,)#, dtype=config.dtype
|
| 369 |
|
| 370 |
# initialize the unet
|
| 371 |
+
self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, channel_mult=config.channel_mult, use_checkpoint=config.use_checkpoint, dropout=config.dropout)#, dtype=config.dtype)
|
| 372 |
|
| 373 |
self.nn_model.train()
|
| 374 |
self.nn_model.to(self.ddpm.device)
|
|
|
|
| 389 |
if config.ema:
|
| 390 |
self.ema = EMA(config.ema_rate)
|
| 391 |
if config.resume and os.path.exists(config.resume):
|
| 392 |
+
self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device, dropout=config.dropout)#, dtype=config.dtype
|
| 393 |
self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])
|
| 394 |
print(f"resumed ema_model from {config.resume}")
|
| 395 |
else:
|
|
|
|
| 412 |
HII_DIM=self.config.HII_DIM,
|
| 413 |
num_redshift=self.config.num_redshift,
|
| 414 |
startat=self.config.startat,
|
| 415 |
+
#drop_prob=self.config.drop_prob,
|
| 416 |
dim=self.config.dim,
|
| 417 |
ranges_dict=self.ranges_dict,
|
| 418 |
num_workers=min(1,len(os.sched_getaffinity(0))//self.config.world_size),
|
|
|
|
| 508 |
c = c.to(self.config.device)
|
| 509 |
noise_pred = self.nn_model(xt, ts, c)#.to(x.dtype)
|
| 510 |
|
| 511 |
+
#if ep == 0 and i == 0 and self.config.global_rank == 0:
|
| 512 |
+
# result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
| 513 |
+
# print(result.stdout, flush=True)
|
| 514 |
+
|
| 515 |
loss = F.mse_loss(noise, noise_pred)
|
| 516 |
loss = loss / self.config.gradient_accumulation_steps
|
| 517 |
|
|
|
|
| 617 |
params_backup = params.numpy().copy()
|
| 618 |
params_normalized = self.rescale(params, self.ranges_dict['params'], to=[0,1])
|
| 619 |
|
| 620 |
+
print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} sampling {num_new_img_per_gpu} images with normalized params = {params_normalized}, {datetime.now().strftime('%d-%H:%M:%S.%f')}")
|
| 621 |
params_normalized = params_normalized.repeat(num_new_img_per_gpu,1)
|
| 622 |
assert params_normalized.dim() == 2, "params_normalized must be a 2D torch.tensor"
|
| 623 |
# print("params =", params)
|
|
|
|
| 712 |
parser.add_argument("--channel_mult", type=float, nargs="+", required=False, default=(1,2,2,2,4))
|
| 713 |
parser.add_argument("--autocast", type=int, required=False, default=False)
|
| 714 |
parser.add_argument("--use_checkpoint", type=int, required=False, default=False)
|
| 715 |
+
parser.add_argument("--dropout", type=float, required=False, default=0)
|
| 716 |
|
| 717 |
args = parser.parse_args()
|
| 718 |
|
|
|
|
| 730 |
config.channel_mult = args.channel_mult
|
| 731 |
config.autocast = bool(args.autocast)
|
| 732 |
config.use_checkpoint = bool(args.use_checkpoint)
|
| 733 |
+
config.dropout = args.dropout
|
| 734 |
|
| 735 |
############################ training ################################
|
| 736 |
if args.train:
|
|
|
|
| 765 |
]
|
| 766 |
|
| 767 |
for params in params_pairs:
|
| 768 |
+
print(f"sampling, {params}, ip = {socket.gethostbyname(socket.gethostname())}, local_world_size = {local_world_size}, world_size = {world_size}, {datetime.now().strftime('%d-%H:%M:%S.%f')}".center(config.str_len,'#'))
|
| 769 |
mp.spawn(
|
| 770 |
generate_samples,
|
| 771 |
args=(world_size, local_world_size, master_addr, master_port, config, num_new_img_per_gpu, max_num_img_per_gpu, torch.tensor(params)),
|
load_h5.py
CHANGED
|
@@ -91,7 +91,7 @@ class Dataset4h5(Dataset):
|
|
| 91 |
field_shape = f['brightness_temp'].shape[1:]
|
| 92 |
#print(f"field.shape = {field_shape}")
|
| 93 |
self.params_keys = list(f['params']['keys'])
|
| 94 |
-
print(f"{max_num_image} images of shape {field_shape} can be loaded with
|
| 95 |
#print(f"params keys = {self.params_keys}")
|
| 96 |
|
| 97 |
if self.idx == "random":
|
|
|
|
| 91 |
field_shape = f['brightness_temp'].shape[1:]
|
| 92 |
#print(f"field.shape = {field_shape}")
|
| 93 |
self.params_keys = list(f['params']['keys'])
|
| 94 |
+
print(f"{max_num_image} {f['brightness_temp'].dtype} images of shape {field_shape} can be loaded with params.keys {self.params_keys}")
|
| 95 |
#print(f"params keys = {self.params_keys}")
|
| 96 |
|
| 97 |
if self.idx == "random":
|
perlmutter_diffusion.sbatch
CHANGED
|
@@ -2,10 +2,10 @@
|
|
| 2 |
#SBATCH -A m4717
|
| 3 |
#SBATCH -J diffusion
|
| 4 |
#SBATCH -C gpu&hbm80g
|
| 5 |
-
#SBATCH -q
|
| 6 |
-
#SBATCH -
|
| 7 |
-
#SBATCH --gpus-per-node=
|
| 8 |
-
#SBATCH -t
|
| 9 |
#SBATCH --ntasks-per-node=1
|
| 10 |
#SBATCH -oReport-%j
|
| 11 |
#SBATCH --mail-type=BEGIN,END,FAIL
|
|
@@ -25,27 +25,21 @@ MASTER_PORT=$((10000 + RANDOM % 10000)) #12355
|
|
| 25 |
#export OMP_NUM_THREADS=1
|
| 26 |
export MASTER_ADDR=$MASTER_ADDR
|
| 27 |
export MASTER_PORT=$MASTER_PORT
|
| 28 |
-
|
| 29 |
-
#echo $MASTER_ADDR
|
| 30 |
-
#echo $MASTER_PORT
|
| 31 |
-
#nc -zv $MASTER_ADDR $MASTER_PORT
|
| 32 |
-
|
| 33 |
-
#export NCCL_DEBUG=INFO
|
| 34 |
-
#export NCCL_DEBUG_SUBSYS=ALL
|
| 35 |
cat $0
|
| 36 |
-
#nvidia-smi
|
| 37 |
|
| 38 |
srun python diffusion.py \
|
| 39 |
-
--num_image
|
| 40 |
--batch_size 2 \
|
| 41 |
-
--n_epoch
|
| 42 |
-
--channel_mult
|
| 43 |
-
--num_new_img_per_gpu
|
| 44 |
-
--max_num_img_per_gpu
|
| 45 |
--gradient_accumulation_steps 1 \
|
| 46 |
--autocast 1 \
|
| 47 |
--use_checkpoint 1 \
|
| 48 |
-
--
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
date
|
|
|
|
| 2 |
#SBATCH -A m4717
|
| 3 |
#SBATCH -J diffusion
|
| 4 |
#SBATCH -C gpu&hbm80g
|
| 5 |
+
#SBATCH -q regular #shared
|
| 6 |
+
#SBATCH -N4
|
| 7 |
+
#SBATCH --gpus-per-node=4
|
| 8 |
+
#SBATCH -t 48:00:00
|
| 9 |
#SBATCH --ntasks-per-node=1
|
| 10 |
#SBATCH -oReport-%j
|
| 11 |
#SBATCH --mail-type=BEGIN,END,FAIL
|
|
|
|
| 25 |
#export OMP_NUM_THREADS=1
|
| 26 |
export MASTER_ADDR=$MASTER_ADDR
|
| 27 |
export MASTER_PORT=$MASTER_PORT
|
| 28 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
cat $0
|
|
|
|
| 30 |
|
| 31 |
srun python diffusion.py \
|
| 32 |
+
--num_image 1600 \
|
| 33 |
--batch_size 2 \
|
| 34 |
+
--n_epoch 40 \
|
| 35 |
+
--channel_mult 1 1 2 2 4 4 \
|
| 36 |
+
--num_new_img_per_gpu 4 \
|
| 37 |
+
--max_num_img_per_gpu 2 \
|
| 38 |
--gradient_accumulation_steps 1 \
|
| 39 |
--autocast 1 \
|
| 40 |
--use_checkpoint 1 \
|
| 41 |
+
--dropout 0.1 \
|
| 42 |
+
--train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
|
| 43 |
+
#--resume ./outputs/model-N1280-device_count4-node5-epoch24-13133235 \
|
| 44 |
|
| 45 |
date
|
quantify_results.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a974cc33812b8a680d4704e46268d54a185da33abeab18704d8100f69369b692
|
| 3 |
+
size 16920502
|