0716-2131
Browse files- context_unet.py +1 -1
- diffusion.py +82 -55
- quantify_results.ipynb +26 -7
context_unet.py
CHANGED
|
@@ -330,7 +330,7 @@ class ContextUnet(nn.Module):
|
|
| 330 |
elif image_size == 128:
|
| 331 |
channel_mult = (1, 1, 2, 3, 4)
|
| 332 |
elif image_size == 64:
|
| 333 |
-
channel_mult = (1, 2,
|
| 334 |
elif image_size == 32:
|
| 335 |
channel_mult = (1, 2, 2, 4)
|
| 336 |
elif image_size == 28:
|
|
|
|
| 330 |
elif image_size == 128:
|
| 331 |
channel_mult = (1, 1, 2, 3, 4)
|
| 332 |
elif image_size == 64:
|
| 333 |
+
channel_mult = (1, 2, 4, 4, 4)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
|
| 334 |
elif image_size == 32:
|
| 335 |
channel_mult = (1, 2, 2, 4)
|
| 336 |
elif image_size == 28:
|
diffusion.py
CHANGED
|
@@ -239,7 +239,7 @@ class TrainConfig:
|
|
| 239 |
stride = (2,2) if dim == 2 else (2,2,4)
|
| 240 |
num_image = 1000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 241 |
batch_size = 1#2#50#20#2#100 # 10
|
| 242 |
-
n_epoch = 4# 10#50#20#20#2#5#25 # 120
|
| 243 |
HII_DIM = 64
|
| 244 |
num_redshift = 512#128#64#512#256#256#64#512#128
|
| 245 |
channel = 1
|
|
@@ -508,7 +508,10 @@ class DDPM21CM:
|
|
| 508 |
# for i, from_ranges in self.ranges_dict[type].items():
|
| 509 |
# value[i] = (value[i] - from_ranges[0])/(from_ranges[1]-from_ranges[0]) # normalize
|
| 510 |
# value[i] =
|
| 511 |
-
def rescale(self,
|
|
|
|
|
|
|
|
|
|
| 512 |
if value.ndim == 1:
|
| 513 |
value = value.view(-1,len(value))
|
| 514 |
|
|
@@ -518,20 +521,21 @@ class DDPM21CM:
|
|
| 518 |
value = value * (to[1]-to[0]) + to[0]
|
| 519 |
return value
|
| 520 |
|
| 521 |
-
def sample(self, params:torch.tensor=None,
|
| 522 |
# n_sample = params.shape[0]
|
| 523 |
# file = self.config.resume
|
| 524 |
|
|
|
|
| 525 |
if params is None:
|
| 526 |
-
params = torch.tensor([
|
| 527 |
-
params_backup = params.numpy().copy()
|
| 528 |
-
else:
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
print(f"device {torch.cuda.current_device()} sampling {
|
| 533 |
-
|
| 534 |
-
assert
|
| 535 |
# print("params =", params)
|
| 536 |
# print("params =", params)
|
| 537 |
# print("len(params) =", len(params))
|
|
@@ -557,18 +561,24 @@ class DDPM21CM:
|
|
| 557 |
with torch.no_grad():
|
| 558 |
x_last, x_entire = self.ddpm.sample(
|
| 559 |
nn_model=self.nn_model,
|
| 560 |
-
params=
|
| 561 |
device=self.config.device,
|
| 562 |
guide_w=self.config.guide_w
|
| 563 |
)
|
| 564 |
|
| 565 |
if save:
|
| 566 |
# np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
|
| 567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
if entire:
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
|
|
|
|
|
|
| 572 |
# %%
|
| 573 |
def train(rank, world_size):
|
| 574 |
config = TrainConfig()
|
|
@@ -576,8 +586,8 @@ def train(rank, world_size):
|
|
| 576 |
|
| 577 |
ddp_setup(rank, world_size)
|
| 578 |
|
| 579 |
-
|
| 580 |
-
for i, num_image in enumerate(
|
| 581 |
config.num_image = num_image
|
| 582 |
# config.world_size = world_size
|
| 583 |
|
|
@@ -614,68 +624,85 @@ if __name__ == "__main__":
|
|
| 614 |
|
| 615 |
# %%
|
| 616 |
|
| 617 |
-
def generate_samples(ddpm21cm,
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
ddp_setup(rank, world_size)
|
| 639 |
ddpm21cm = DDPM21CM(config)
|
| 640 |
|
| 641 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
|
| 643 |
# print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
|
| 644 |
-
if rank == 0:
|
| 645 |
-
|
| 646 |
# print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
|
| 647 |
|
| 648 |
dist.destroy_process_group()
|
| 649 |
|
| 650 |
|
| 651 |
-
if __name__ ==
|
| 652 |
-
print(" sampling ".center(100,'-'))
|
| 653 |
world_size = torch.cuda.device_count()
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
|
|
|
|
|
|
|
|
|
| 658 |
|
| 659 |
# print("config = TrainConfig()")
|
| 660 |
config = TrainConfig()
|
| 661 |
config.world_size = world_size
|
| 662 |
# print("config.world_size = world_size")
|
| 663 |
|
| 664 |
-
for num_image in
|
| 665 |
config.num_image = num_image
|
| 666 |
-
config.resume = f"./outputs/model_state-N{num_image}-
|
| 667 |
|
| 668 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 669 |
manager = mp.Manager()
|
| 670 |
return_dict = manager.dict()
|
| 671 |
|
| 672 |
-
mp.spawn(
|
| 673 |
|
| 674 |
# print("---"*30)
|
| 675 |
# print(f"device {torch.cuda.current_device()}, keys = {return_dict.keys()}")
|
| 676 |
-
if "samples" in return_dict:
|
| 677 |
-
|
| 678 |
-
|
| 679 |
|
| 680 |
|
| 681 |
# %%
|
|
|
|
| 239 |
stride = (2,2) if dim == 2 else (2,2,4)
|
| 240 |
num_image = 1000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 241 |
batch_size = 1#2#50#20#2#100 # 10
|
| 242 |
+
n_epoch = 8#4# 10#50#20#20#2#5#25 # 120
|
| 243 |
HII_DIM = 64
|
| 244 |
num_redshift = 512#128#64#512#256#256#64#512#128
|
| 245 |
channel = 1
|
|
|
|
| 508 |
# for i, from_ranges in self.ranges_dict[type].items():
|
| 509 |
# value[i] = (value[i] - from_ranges[0])/(from_ranges[1]-from_ranges[0]) # normalize
|
| 510 |
# value[i] =
|
| 511 |
+
def rescale(self, params, ranges, to: list):
|
| 512 |
+
# value = np.array(params).copy()
|
| 513 |
+
value = params.clone()
|
| 514 |
+
|
| 515 |
if value.ndim == 1:
|
| 516 |
value = value.view(-1,len(value))
|
| 517 |
|
|
|
|
| 521 |
value = value * (to[1]-to[0]) + to[0]
|
| 522 |
return value
|
| 523 |
|
| 524 |
+
def sample(self, params:torch.tensor=None, num_new_img_per_gpu=192, ema=False, entire=False, save=True):
|
| 525 |
# n_sample = params.shape[0]
|
| 526 |
# file = self.config.resume
|
| 527 |
|
| 528 |
+
print(f"device {torch.cuda.current_device()}, sample, params = {params}")
|
| 529 |
if params is None:
|
| 530 |
+
params = torch.tensor([4.4, 131.341])
|
| 531 |
+
# params_backup = params.numpy().copy()
|
| 532 |
+
# else:
|
| 533 |
+
params_backup = params.numpy().copy()
|
| 534 |
+
params_normalized = self.rescale(params, self.ranges_dict['params'], to=[0,1])
|
| 535 |
+
|
| 536 |
+
print(f"device {torch.cuda.current_device()} sampling {num_new_img_per_gpu} images with normalized params = {params_normalized}")
|
| 537 |
+
params_normalized = params_normalized.repeat(num_new_img_per_gpu,1)
|
| 538 |
+
assert params_normalized.dim() == 2, "params_normalized must be a 2D torch.tensor"
|
| 539 |
# print("params =", params)
|
| 540 |
# print("params =", params)
|
| 541 |
# print("len(params) =", len(params))
|
|
|
|
| 561 |
with torch.no_grad():
|
| 562 |
x_last, x_entire = self.ddpm.sample(
|
| 563 |
nn_model=self.nn_model,
|
| 564 |
+
params=params_normalized.to(self.config.device),
|
| 565 |
device=self.config.device,
|
| 566 |
guide_w=self.config.guide_w
|
| 567 |
)
|
| 568 |
|
| 569 |
if save:
|
| 570 |
# np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
|
| 571 |
+
savetime = datetime.datetime.now().strftime("%m%d-%H%M")
|
| 572 |
+
savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{torch.cuda.current_device()}-{savetime}{'ema' if ema else ''}.npy")
|
| 573 |
+
print(f"saving {savename} ...")
|
| 574 |
+
np.save(savename, x_last)
|
| 575 |
+
|
| 576 |
if entire:
|
| 577 |
+
savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{torch.cuda.current_device()}-{savetime}{'ema' if ema else ''}_entire.npy")
|
| 578 |
+
print(f"saving {savename} ...")
|
| 579 |
+
np.save(savename, x_entire)
|
| 580 |
+
# else:
|
| 581 |
+
return x_last
|
| 582 |
# %%
|
| 583 |
def train(rank, world_size):
|
| 584 |
config = TrainConfig()
|
|
|
|
| 586 |
|
| 587 |
ddp_setup(rank, world_size)
|
| 588 |
|
| 589 |
+
num_train_image_list = [10]#[200]#[1600,3200,6400,12800,25600]
|
| 590 |
+
for i, num_image in enumerate(num_train_image_list):
|
| 591 |
config.num_image = num_image
|
| 592 |
# config.world_size = world_size
|
| 593 |
|
|
|
|
| 624 |
|
| 625 |
# %%
|
| 626 |
|
| 627 |
+
# def generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params):
|
| 628 |
+
# # samples = []
|
| 629 |
+
# for _ in range(num_new_img_per_gpu // max_num_img_per_gpu):
|
| 630 |
+
# sample = ddpm21cm.sample(
|
| 631 |
+
# params=params,
|
| 632 |
+
# num_new_img_per_gpu=max_num_img_per_gpu
|
| 633 |
+
# )
|
| 634 |
+
|
| 635 |
+
# print(f"device {torch.cuda.current_device()} generated sample of shape: {sample.shape}")
|
| 636 |
+
|
| 637 |
+
# # samples.append(sample)
|
| 638 |
+
# # ddpm21cm.sample(params=torch.tensor((5.6, 19.037)), num_new_img_per_gpu=max_num_img_per_gpu)
|
| 639 |
+
# # ddpm21cm.sample(params=torch.tensor((4.699, 30)), num_new_img_per_gpu=max_num_img_per_gpu)
|
| 640 |
+
# # ddpm21cm.sample(params=torch.tensor((5.477, 200)), num_new_img_per_gpu=max_num_img_per_gpu)
|
| 641 |
+
# # ddpm21cm.sample(params=torch.tensor((4.8, 131.341)), num_new_img_per_gpu=max_num_img_per_gpu)
|
| 642 |
+
# # samples = np.concatenate(samples, axis=0)
|
| 643 |
+
|
| 644 |
+
# # samples_list = [np.empty_like(samples) for _ in range(world_size)]
|
| 645 |
+
# # dist.all_gather_object(samples_list, samples)
|
| 646 |
+
|
| 647 |
+
# # if rank == 0:
|
| 648 |
+
# # all_samples = np.concatenate(samples_list, axis=0)
|
| 649 |
+
# # return all_samples
|
| 650 |
+
# # else:
|
| 651 |
+
# # return None
|
| 652 |
+
|
| 653 |
+
def generate_samples(rank, world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, params):
|
| 654 |
ddp_setup(rank, world_size)
|
| 655 |
ddpm21cm = DDPM21CM(config)
|
| 656 |
|
| 657 |
+
# generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params)
|
| 658 |
+
|
| 659 |
+
# samples = []
|
| 660 |
+
for _ in range(num_new_img_per_gpu // max_num_img_per_gpu):
|
| 661 |
+
sample = ddpm21cm.sample(
|
| 662 |
+
params=params,
|
| 663 |
+
num_new_img_per_gpu=max_num_img_per_gpu
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
print(f"device {torch.cuda.current_device()} generated sample of shape: {sample.shape}")
|
| 667 |
|
| 668 |
# print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
|
| 669 |
+
# if rank == 0:
|
| 670 |
+
# return_dict['samples'] = samples
|
| 671 |
# print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
|
| 672 |
|
| 673 |
dist.destroy_process_group()
|
| 674 |
|
| 675 |
|
| 676 |
+
if __name__ == "__main__":
|
|
|
|
| 677 |
world_size = torch.cuda.device_count()
|
| 678 |
+
print(f" sampling, world_size = {world_size} ".center(100,'-'))
|
| 679 |
+
# num_train_image_list = [1600,3200,6400,12800,25600]
|
| 680 |
+
num_train_image_list = [2000]
|
| 681 |
+
num_new_img_per_gpu = 8
|
| 682 |
+
max_num_img_per_gpu = 1
|
| 683 |
+
|
| 684 |
+
params = torch.tensor([4.4, 131.341])
|
| 685 |
|
| 686 |
# print("config = TrainConfig()")
|
| 687 |
config = TrainConfig()
|
| 688 |
config.world_size = world_size
|
| 689 |
# print("config.world_size = world_size")
|
| 690 |
|
| 691 |
+
for num_image in num_train_image_list:
|
| 692 |
config.num_image = num_image
|
| 693 |
+
config.resume = f"./outputs/model_state-N{num_image}-epoch3-device0"
|
| 694 |
|
| 695 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 696 |
manager = mp.Manager()
|
| 697 |
return_dict = manager.dict()
|
| 698 |
|
| 699 |
+
mp.spawn(generate_samples, args=(world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, params), nprocs=world_size, join=True)
|
| 700 |
|
| 701 |
# print("---"*30)
|
| 702 |
# print(f"device {torch.cuda.current_device()}, keys = {return_dict.keys()}")
|
| 703 |
+
# if "samples" in return_dict:
|
| 704 |
+
# samples = return_dict["samples"]
|
| 705 |
+
# print(f"device {torch.cuda.current_device()} generated samples shape: {samples.shape}")
|
| 706 |
|
| 707 |
|
| 708 |
# %%
|
quantify_results.ipynb
CHANGED
|
@@ -1971,24 +1971,43 @@
|
|
| 1971 |
},
|
| 1972 |
{
|
| 1973 |
"cell_type": "code",
|
| 1974 |
-
"execution_count":
|
| 1975 |
"metadata": {},
|
| 1976 |
-
"outputs": [
|
| 1977 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1978 |
},
|
| 1979 |
{
|
| 1980 |
"cell_type": "code",
|
| 1981 |
-
"execution_count":
|
| 1982 |
"metadata": {},
|
| 1983 |
"outputs": [],
|
| 1984 |
-
"source": [
|
|
|
|
|
|
|
| 1985 |
},
|
| 1986 |
{
|
| 1987 |
"cell_type": "code",
|
| 1988 |
-
"execution_count":
|
| 1989 |
"metadata": {},
|
| 1990 |
"outputs": [],
|
| 1991 |
-
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1992 |
},
|
| 1993 |
{
|
| 1994 |
"cell_type": "code",
|
|
|
|
| 1971 |
},
|
| 1972 |
{
|
| 1973 |
"cell_type": "code",
|
| 1974 |
+
"execution_count": 6,
|
| 1975 |
"metadata": {},
|
| 1976 |
+
"outputs": [
|
| 1977 |
+
{
|
| 1978 |
+
"name": "stdout",
|
| 1979 |
+
"output_type": "stream",
|
| 1980 |
+
"text": [
|
| 1981 |
+
"(1, 1, 64, 64, 512)\n"
|
| 1982 |
+
]
|
| 1983 |
+
}
|
| 1984 |
+
],
|
| 1985 |
+
"source": [
|
| 1986 |
+
"import numpy as np\n",
|
| 1987 |
+
"data = np.load('/storage/home/hcoda1/3/bxia34/p-jw254-0/ml21cm/outputs/Tvir4.400000095367432-zeta131.34100341796875-N2000-device0-0716-1726.npy')\n",
|
| 1988 |
+
"print(data.shape)"
|
| 1989 |
+
]
|
| 1990 |
},
|
| 1991 |
{
|
| 1992 |
"cell_type": "code",
|
| 1993 |
+
"execution_count": 7,
|
| 1994 |
"metadata": {},
|
| 1995 |
"outputs": [],
|
| 1996 |
+
"source": [
|
| 1997 |
+
"Tb = data[0,0]"
|
| 1998 |
+
]
|
| 1999 |
},
|
| 2000 |
{
|
| 2001 |
"cell_type": "code",
|
| 2002 |
+
"execution_count": 8,
|
| 2003 |
"metadata": {},
|
| 2004 |
"outputs": [],
|
| 2005 |
+
"source": [
|
| 2006 |
+
"import matplotlib.pyplot as plt\n",
|
| 2007 |
+
"for i in range(Tb.shape[-1]):\n",
|
| 2008 |
+
" plt.imshow(Tb[:,:,i])\n",
|
| 2009 |
+
" plt.savefig(f\"Tb{i:03d}.png\")"
|
| 2010 |
+
]
|
| 2011 |
},
|
| 2012 |
{
|
| 2013 |
"cell_type": "code",
|