0721-1845
Browse files- diffusion.py +25 -15
- load_h5.py +35 -11
- quantify_results.ipynb +0 -0
diffusion.py
CHANGED
|
@@ -241,9 +241,9 @@ class TrainConfig:
|
|
| 241 |
stride = (2,2) if dim == 2 else (2,2,2)
|
| 242 |
num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 243 |
batch_size = 50#1#2#50#20#2#100 # 10
|
| 244 |
-
n_epoch = 100#30#120#5#4# 10#50#20#20#2#5#25 # 120
|
| 245 |
HII_DIM = 64
|
| 246 |
-
num_redshift = 64#512#128#64#512#256#256#64#512#128
|
| 247 |
channel = 1
|
| 248 |
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
|
| 249 |
|
|
@@ -366,18 +366,28 @@ class DDPM21CM:
|
|
| 366 |
self.ranges_dict = config.ranges_dict
|
| 367 |
|
| 368 |
def load(self):
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
# self.shape_loaded = dataset.images.shape
|
| 371 |
# print("shape_loaded =", self.shape_loaded)
|
| 372 |
# print(f"load, current_device() = {torch.cuda.current_device()}")
|
| 373 |
self.dataloader = DataLoader(
|
| 374 |
dataset=dataset,
|
| 375 |
batch_size=self.config.batch_size,
|
| 376 |
-
shuffle=False,
|
| 377 |
-
num_workers=
|
| 378 |
pin_memory=True,
|
| 379 |
persistent_workers=True,
|
| 380 |
-
sampler=DistributedSampler(dataset),
|
| 381 |
)
|
| 382 |
|
| 383 |
del dataset
|
|
@@ -414,14 +424,14 @@ class DDPM21CM:
|
|
| 414 |
|
| 415 |
|
| 416 |
# print("!!!!!!!!!!!!!!!!, before prepare, self.dataloader.sampler =", self.dataloader.sampler)
|
| 417 |
-
|
| 418 |
-
# self.accelerator.prepare(
|
| 419 |
-
# self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler
|
| 420 |
-
# )
|
| 421 |
-
self.nn_model, self.optimizer, self.lr_scheduler = \
|
| 422 |
self.accelerator.prepare(
|
| 423 |
-
self.nn_model, self.optimizer, self.lr_scheduler
|
| 424 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
# print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
|
| 427 |
# print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
|
|
@@ -430,7 +440,7 @@ class DDPM21CM:
|
|
| 430 |
global_step = 0
|
| 431 |
for ep in range(self.config.n_epoch):
|
| 432 |
self.ddpm.train()
|
| 433 |
-
self.dataloader.sampler.set_epoch(ep)
|
| 434 |
|
| 435 |
pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
|
| 436 |
pbar_train.set_description(f"device {torch.cuda.current_device()}, Epoch {ep}")
|
|
@@ -527,7 +537,7 @@ class DDPM21CM:
|
|
| 527 |
# n_sample = params.shape[0]
|
| 528 |
# file = self.config.resume
|
| 529 |
|
| 530 |
-
print(f"device {torch.cuda.current_device()}, sample, params = {params}")
|
| 531 |
if params is None:
|
| 532 |
params = torch.tensor([4.4, 131.341])
|
| 533 |
# params_backup = params.numpy().copy()
|
|
@@ -583,7 +593,7 @@ class DDPM21CM:
|
|
| 583 |
return x_last
|
| 584 |
# %%
|
| 585 |
|
| 586 |
-
num_train_image_list = [8000]
|
| 587 |
|
| 588 |
def train(rank, world_size):
|
| 589 |
config = TrainConfig()
|
|
|
|
| 241 |
stride = (2,2) if dim == 2 else (2,2,2)
|
| 242 |
num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 243 |
batch_size = 50#1#2#50#20#2#100 # 10
|
| 244 |
+
n_epoch = 50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
|
| 245 |
HII_DIM = 64
|
| 246 |
+
num_redshift = 64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
|
| 247 |
channel = 1
|
| 248 |
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
|
| 249 |
|
|
|
|
| 366 |
self.ranges_dict = config.ranges_dict
|
| 367 |
|
| 368 |
def load(self):
|
| 369 |
+
# rank = torch.cuda.current_device()
|
| 370 |
+
dataset = Dataset4h5(
|
| 371 |
+
self.config.dataset_name,
|
| 372 |
+
num_image=self.config.num_image,
|
| 373 |
+
idx = 'random',
|
| 374 |
+
HII_DIM=self.config.HII_DIM,
|
| 375 |
+
num_redshift=self.config.num_redshift,
|
| 376 |
+
drop_prob=self.config.drop_prob,
|
| 377 |
+
dim=self.config.dim,
|
| 378 |
+
ranges_dict=self.ranges_dict
|
| 379 |
+
)
|
| 380 |
# self.shape_loaded = dataset.images.shape
|
| 381 |
# print("shape_loaded =", self.shape_loaded)
|
| 382 |
# print(f"load, current_device() = {torch.cuda.current_device()}")
|
| 383 |
self.dataloader = DataLoader(
|
| 384 |
dataset=dataset,
|
| 385 |
batch_size=self.config.batch_size,
|
| 386 |
+
shuffle=True,#False,
|
| 387 |
+
num_workers=len(os.sched_getaffinity(0)),
|
| 388 |
pin_memory=True,
|
| 389 |
persistent_workers=True,
|
| 390 |
+
# sampler=DistributedSampler(dataset),
|
| 391 |
)
|
| 392 |
|
| 393 |
del dataset
|
|
|
|
| 424 |
|
| 425 |
|
| 426 |
# print("!!!!!!!!!!!!!!!!, before prepare, self.dataloader.sampler =", self.dataloader.sampler)
|
| 427 |
+
self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
self.accelerator.prepare(
|
| 429 |
+
self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler
|
| 430 |
)
|
| 431 |
+
# self.nn_model, self.optimizer, self.lr_scheduler = \
|
| 432 |
+
# self.accelerator.prepare(
|
| 433 |
+
# self.nn_model, self.optimizer, self.lr_scheduler
|
| 434 |
+
# )
|
| 435 |
|
| 436 |
# print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
|
| 437 |
# print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
|
|
|
|
| 440 |
global_step = 0
|
| 441 |
for ep in range(self.config.n_epoch):
|
| 442 |
self.ddpm.train()
|
| 443 |
+
# self.dataloader.sampler.set_epoch(ep)
|
| 444 |
|
| 445 |
pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
|
| 446 |
pbar_train.set_description(f"device {torch.cuda.current_device()}, Epoch {ep}")
|
|
|
|
| 537 |
# n_sample = params.shape[0]
|
| 538 |
# file = self.config.resume
|
| 539 |
|
| 540 |
+
# print(f"device {torch.cuda.current_device()}, sample, params = {params}")
|
| 541 |
if params is None:
|
| 542 |
params = torch.tensor([4.4, 131.341])
|
| 543 |
# params_backup = params.numpy().copy()
|
|
|
|
| 593 |
return x_last
|
| 594 |
# %%
|
| 595 |
|
| 596 |
+
num_train_image_list = [200]#[8000]
|
| 597 |
|
| 598 |
def train(rank, world_size):
|
| 599 |
config = TrainConfig()
|
load_h5.py
CHANGED
|
@@ -26,14 +26,28 @@ import datetime
|
|
| 26 |
# from huggingface_hub import create_repo, upload_folder
|
| 27 |
|
| 28 |
class Dataset4h5(Dataset):
|
| 29 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
super().__init__()
|
| 31 |
|
| 32 |
self.dir_name = dir_name
|
| 33 |
self.num_image = num_image
|
| 34 |
-
self.field = field
|
| 35 |
-
self.shuffle = shuffle
|
| 36 |
self.idx = idx
|
|
|
|
|
|
|
| 37 |
self.num_redshift = num_redshift
|
| 38 |
self.HII_DIM = HII_DIM
|
| 39 |
self.drop_prob = drop_prob
|
|
@@ -81,14 +95,24 @@ class Dataset4h5(Dataset):
|
|
| 81 |
self.params_keys = list(f['params']['keys'])
|
| 82 |
print(f"params keys = {self.params_keys}")
|
| 83 |
|
| 84 |
-
if self.idx is None:
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
else:
|
| 93 |
print(f"loading {len(self.idx)} images with idx = {self.idx}")
|
| 94 |
|
|
|
|
| 26 |
# from huggingface_hub import create_repo, upload_folder
|
| 27 |
|
| 28 |
class Dataset4h5(Dataset):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
dir_name,
|
| 32 |
+
num_image=10,
|
| 33 |
+
field='brightness_temp',
|
| 34 |
+
idx=None,
|
| 35 |
+
num_redshift=512,
|
| 36 |
+
HII_DIM=64,
|
| 37 |
+
rescale=True,
|
| 38 |
+
drop_prob = 0,
|
| 39 |
+
dim=2,
|
| 40 |
+
transform=True,
|
| 41 |
+
ranges_dict=None,
|
| 42 |
+
# shuffle=False,
|
| 43 |
+
):
|
| 44 |
super().__init__()
|
| 45 |
|
| 46 |
self.dir_name = dir_name
|
| 47 |
self.num_image = num_image
|
|
|
|
|
|
|
| 48 |
self.idx = idx
|
| 49 |
+
self.field = field
|
| 50 |
+
# self.shuffle = shuffle
|
| 51 |
self.num_redshift = num_redshift
|
| 52 |
self.HII_DIM = HII_DIM
|
| 53 |
self.drop_prob = drop_prob
|
|
|
|
| 95 |
self.params_keys = list(f['params']['keys'])
|
| 96 |
print(f"params keys = {self.params_keys}")
|
| 97 |
|
| 98 |
+
# if self.idx is None:
|
| 99 |
+
# if self.shuffle:
|
| 100 |
+
# self.idx = np.sort(random.sample(range(max_num_image), self.num_image))
|
| 101 |
+
# print(f"loading {self.num_image} images randomly")
|
| 102 |
+
# # print(self.idx)
|
| 103 |
+
# else:
|
| 104 |
+
# self.idx = range(self.num_image)
|
| 105 |
+
# print(f"loading {len(self.idx)} images with idx = {self.idx}")
|
| 106 |
+
if self.idx == "random":
|
| 107 |
+
self.idx = np.sort(random.sample(range(max_num_image), self.num_image))
|
| 108 |
+
print(f"loading {self.num_image} images randomly with idx = {self.idx}")
|
| 109 |
+
# print(self.idx)
|
| 110 |
+
elif self.idx == "range":
|
| 111 |
+
rank = torch.cuda.current_device()
|
| 112 |
+
self.idx = range(
|
| 113 |
+
rank*self.num_image, (rank+1)*self.num_image
|
| 114 |
+
)
|
| 115 |
+
print(f"loading {len(self.idx)} images with idx = {self.idx}")
|
| 116 |
else:
|
| 117 |
print(f"loading {len(self.idx)} images with idx = {self.idx}")
|
| 118 |
|
quantify_results.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|