Xsmos commited on
Commit
8a8e86b
·
verified ·
1 Parent(s): a15aedd
Files changed (1) hide show
  1. backup_diffusion.py +823 -0
backup_diffusion.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # ## 改編ContextUnet及相關代碼,使其首先對二維的情況適用。並於diffusers.Unet2DModel作比較並加以優化。最後再改寫爲3維的情形。
3
+ # - 經試用diffusers的Unet2DModel,發現loss從0.3降到0.2但仍然很高,説明存在非Unet2DModel的問題可以優化
4
+ # - 改用diffusers的DDMPScheduler和DDPMPipeline后,loss降低至0.1以下,有時甚至可以低至0.004,可見我的代碼問題主要出在DDPM部分。DDPMScheduler部分比較簡短,似乎沒有問題,所以問題應該在DDPMPipeline裏某一部分代碼是我代碼欠缺的。
5
+ # - 我在DDPMScheduler部分有一個typo,導致beta_t一直很小,修正后loss從0.2能降低至0.02, 維持在0.1以下
6
+ # - 用diffusers的DDPMScheduler似乎效果要好一些,loss總是比我的DDPMScheduler要小一點。儅epoch為19時,前者的loss約0.02,後者loss約0.07。而且前者還支持3維圖像的加噪,不如直接用別人的輪子。但我想知道爲什麽我的loss會高一些。
7
+ # - 我意識到別人的DDPMScheduler在sample函數中沒有兼容輸入參數,所以歸根結底還是需要我的DDPMscheduler。不過我可以先用別人的來debug我的ContextUnet.
8
+ # - 我需要將我的ContextUnet擴展兼容不同維度的照片,畢竟我本身也需要和原文獻對比完了再拓展到三維的情形
9
+ # - 我已將我的ContextUnet轉成了2維的模式,與diffusers.Unet2DModel的loss=0.037相比,我的Unet的loss=0.07。同時我的Unet生成的圖像看上去很奇怪,説明我的Unet也有問題。我需要將代碼退回原Unet,並檢查問題所在。
10
+ # - 我將紅移方向的像素的數量限制在了64.以此比較兩個Unet的差別。經比較:\
11
+ # Unet2DModel loss:0.03, 0.0655, 0.05, 0.02, 0.05\
12
+ # ContextUnet loss: 0.1, 0.16, 0.1, 0.2186, 0.06
13
+ # - 我把ContextUnet退回到了原作者的版本,結果loss=0.05,輸出的照片也不錯。我主要的改動是改回了他原用的normalization函數,其中還有個參數swish。有時間我可以研究一下具體是哪裏影響了訓練的結果。另外我發現了要想tensorboard的圖綫獨立美觀,需要把他們放在不同的文件夾下
14
+ # - 經過驗證,GroupNorm比batchNorm效果要好
15
+ # - 已擴展爲接受不同維度的情形
16
+ # - 融合cond, guide_w, drop_out這些參數
17
+ # - 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題
18
+ # - 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形
19
+ # - 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。
20
+ # - 解決了,問題出在我忘了寫with torch.no_grad():
21
+ # - 接下來就是生成800個lightcones,與此同時研究如何計算global signal以及power spectrum
22
+ # - 儅訓練圖片的數量達到5000時,生成的圖片與檢測數據的相似程度很高
23
+ # - it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.
24
+ # - the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.
25
+ # - In addtion, the performance of DDPM can looks better compared to computation-intensive simulations.
26
+ # 1 GPU, batch_size = 10, num_image = 3200, 50s for each epoch
27
+ # 4 GPU, batch_size = 10, num_image = 3200,
28
+
29
+ # %%
30
+ from dataclasses import dataclass
31
+ import h5py
32
+ import torch
33
+ import torch.nn as nn
34
+ from torch.utils.data import DataLoader, Dataset
35
+ # from datasets import Dataset
36
+ import matplotlib.pyplot as plt
37
+ import numpy as np
38
+ import random
39
+ # from abc import ABC, abstractmethod
40
+ import torch.nn.functional as F
41
+ import math
42
+ # from PIL import Image
43
+ import os
44
+ from torch.utils.tensorboard import SummaryWriter
45
+ import copy
46
+ from tqdm.auto import tqdm
47
+ # from torchvision import transforms
48
+ # from diffusers import UNet2DModel#, UNet3DConditionModel
49
+ # from diffusers import DDPMScheduler
50
+ from diffusers.utils import make_image_grid
51
+ import datetime
52
+ from pathlib import Path
53
+ from diffusers.optimization import get_cosine_schedule_with_warmup
54
+ from accelerate import notebook_launcher, Accelerator
55
+ from huggingface_hub import create_repo, upload_folder
56
+
57
+ from load_h5 import Dataset4h5
58
+ from context_unet import ContextUnet
59
+
60
+ from huggingface_hub import notebook_login
61
+
62
+ import torch.multiprocessing as mp
63
+ from torch.utils.data.distributed import DistributedSampler
64
+ from torch.nn.parallel import DistributedDataParallel as DDP
65
+ from torch.distributed import init_process_group, destroy_process_group
66
+ import torch.distributed as dist
67
+
68
+ # %%
69
+ def ddp_setup(rank: int, world_size: int):
70
+ """
71
+ Args:
72
+ rank: Unique identifier of each process
73
+ world_size: Total number of processes
74
+ """
75
+ os.environ["MASTER_ADDR"] = "localhost"
76
+ os.environ["MASTER_PORT"] = "12355"
77
+ # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ddp_setup, rank =", rank)
78
+ torch.cuda.set_device(rank)
79
+ init_process_group(backend="nccl", rank=rank, world_size=world_size)
80
+
81
+ # %%
82
+ # notebook_login()
83
+
84
+ # %% [markdown]
85
+ # # Add noise:
86
+ #
87
+ # \begin{align*}
88
+ # x_t &\sim \mathcal N\left(\sqrt{1-\beta_t}\ x_{t-1},\ \beta_t \right) \\
89
+ # x_t &\equiv \sqrt{1-\beta_t}\ x_{t-1} + \sqrt{\beta_t}\ \epsilon\\
90
+ # \epsilon &\sim \mathcal N(0,1)\\
91
+ # \alpha_t & \equiv 1 - \beta_t\\
92
+ # & ...\\
93
+ # x_t &= \sqrt{\bar {\alpha_t}} x_0 + \epsilon\ \sqrt{1 - \bar{\alpha_t}}\\
94
+ # \bar {\alpha_t} &\equiv \prod_{i=1}^t \alpha_i\\
95
+ # &= \exp\left({\ln{\prod_{i=1}^t \alpha_i}}\right)\\
96
+ # &= \exp\left({\sum_{i=1}^t\ln{ \alpha_i}}\right)
97
+ # \end{align*}
98
+
99
+ # %%
100
+ class DDPMScheduler(nn.Module):
101
+ def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu', dtype=torch.float32):
102
+ super().__init__()
103
+
104
+ beta_1, beta_T = betas
105
+ assert 0 < beta_1 <= beta_T <= 1, "ensure 0 < beta_1 <= beta_T <= 1"
106
+ self.device = device
107
+ self.num_timesteps = num_timesteps
108
+ self.img_shape = img_shape
109
+ self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1
110
+ self.beta_t = self.beta_t.to(self.device)
111
+
112
+ # self.drop_prob = drop_prob
113
+ # self.cond = cond
114
+ self.alpha_t = 1 - self.beta_t
115
+ # self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
116
+ self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
117
+ # self.use_fp16 = use_fp16
118
+ self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32
119
+
120
+ def add_noise(self, clean_images):
121
+ shape = clean_images.shape
122
+ expand = torch.ones(len(shape)-1, dtype=int)
123
+ # ts_expand = ts.view(ts.shape[0], *expand.tolist())
124
+ # expand = [1 for i in range(len(shape)-1)]
125
+
126
+ noise = torch.randn_like(clean_images).to(self.device)
127
+ ts = torch.randint(0, self.num_timesteps, (shape[0],)).to(self.device)
128
+
129
+ # test_expand = test.view(test.shape[0],*expand)
130
+ # extend_dim = [None for i in range(shape.dim()-1)]
131
+ noisy_images = (
132
+ clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())
133
+ + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())
134
+ )
135
+ # print(x_t.shape)
136
+
137
+ return noisy_images, noise, ts
138
+
139
+ def sample(self, nn_model, params, device, guide_w = 0):
140
+ n_sample = len(params) #params.shape[0]
141
+ # print("params.shape[0], len(params)", params.shape[0], len(params))
142
+ x_i = torch.randn(n_sample, *self.img_shape).to(device)
143
+ # print("x_i.shape =", x_i.shape)
144
+ # print("x_i.shape =", x_i.shape)
145
+ if guide_w != -1:
146
+ c_i = params
147
+ uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)
148
+ # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)
149
+ # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)
150
+ c_i = torch.cat((c_i, uncond_tokens), 0)
151
+
152
+ x_i_entire = [] # keep track of generated steps in case want to plot something
153
+ # print("self.num_timesteps =", self.num_timesteps)
154
+ # for i in range(self.num_timesteps, 0, -1):
155
+ # print(f'sampling!!!')
156
+ pbar_sample = tqdm(total=self.num_timesteps)
157
+ pbar_sample.set_description(f"cuda:{torch.cuda.current_device()} sampling")
158
+ for i in reversed(range(0, self.num_timesteps)):
159
+ # print(f'sampling timestep {i:4d}',end='\r')
160
+ t_is = torch.tensor([i]).to(device)
161
+ t_is = t_is.repeat(n_sample)
162
+
163
+ z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else 0
164
+
165
+ if guide_w == -1:
166
+ # eps = nn_model(x_i, t_is, return_dict=False)[0]
167
+ eps = nn_model(x_i, t_is)
168
+ # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
169
+ else:
170
+ # double batch
171
+ x_i = x_i.repeat(2, *torch.ones(len(self.img_shape), dtype=int).tolist())
172
+ t_is = t_is.repeat(2)
173
+
174
+ # split predictions and compute weighting
175
+ # print("nn_model input shape", x_i.shape, t_is.shape, c_i.shape)
176
+ eps = nn_model(x_i, t_is, c_i)
177
+ eps1 = eps[:n_sample]
178
+ eps2 = eps[n_sample:]
179
+ eps = eps1 + guide_w*(eps1 - eps2)
180
+ # eps = (1+guide_w)*eps1 - guide_w*eps2
181
+ x_i = x_i[:n_sample]
182
+ # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
183
+
184
+ # print("x_i.shape =", x_i.shape)
185
+ x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
186
+
187
+ pbar_sample.update(1)
188
+
189
+ # store only part of the intermediate steps
190
+ # if i%20==0:# or i==0:# or i<8:
191
+ # x_i_entire.append(x_i.detach().cpu().numpy())
192
+ x_i_entire = np.array(x_i_entire)
193
+ x_i = x_i.detach().cpu().numpy()
194
+ return x_i, x_i_entire
195
+
196
+
197
+ # ddpm_scheduler = DDPMScheduler((1e-4,0.02),10)
198
+ # noisy_images, noise, ts = ddpm_scheduler.add_noise(images)
199
+
200
+ # %%
201
+ class EMA:
202
+ def __init__(self, beta):
203
+ super().__init__()
204
+ self.beta = beta
205
+ self.step = 0
206
+
207
+ def update_model_average(self, ma_model, current_model):
208
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
209
+ old_weight, up_weight = ma_params.data, current_params.data
210
+ ma_params.data = self.update_average(old_weight, up_weight)
211
+
212
+ def update_average(self, old, new):
213
+ if old is None:
214
+ return new
215
+ return old * self.beta + (1 - self.beta) * new
216
+
217
+ def step_ema(self, ema_model, model):
218
+ self.update_model_average(ema_model, model)
219
+ self.step += 1
220
+
221
+ def reset_parameters(self, ema_model, model):
222
+ ema_model.load_state_dict(model.state_dict())
223
+
224
+
225
+ # %%
226
+ @dataclass
227
+ class TrainConfig:
228
+ ###########################
229
+ ## hardcoding these here ##
230
+ ###########################
231
+ push_to_hub = True
232
+ hub_model_id = "Xsmos/ml21cm"
233
+ hub_private_repo = False
234
+ dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
235
+ device = "cuda" if torch.cuda.is_available() else 'cpu'
236
+ # device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else 'cpu'
237
+ world_size = 1#torch.cuda.device_count()
238
+ # repeat = 2
239
+
240
+ # dim = 2
241
+ dim = 2
242
+ stride = (2,4) if dim == 2 else (2,2,2)
243
+ num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
244
+ batch_size = 10#50#20#50#1#2#50#20#2#100 # 10
245
+ n_epoch = 50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
246
+ HII_DIM = 64
247
+ num_redshift = 512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
248
+ channel = 1
249
+ img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
250
+
251
+ ranges_dict = dict(
252
+ params = {
253
+ 0: [4, 6], # ION_Tvir_MIN
254
+ 1: [10, 250], # HII_EFF_FACTOR
255
+ },
256
+ images = {
257
+ 0: [0, 80], # brightness_temp
258
+ }
259
+ )
260
+
261
+ num_timesteps = 1000#1000 # 1000, 500; DDPM time steps
262
+ # n_sample = 24 # 64, the number of samples in sampling process
263
+ n_param = 2
264
+ guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance
265
+ drop_prob = 0#0.28 # only takes effect when guide_w != -1
266
+ ema=False # whether to use ema
267
+ ema_rate=0.995
268
+
269
+ # seed = 0
270
+ # save_dir = './outputs/'
271
+
272
+ save_period = n_epoch // 3 #np.infty#.1 # the period of sampling
273
+ # general parameters for the name and logger
274
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
275
+ lrate = 1e-4
276
+ lr_warmup_steps = 0#5#00
277
+ output_dir = "./outputs/"
278
+ save_name = os.path.join(output_dir, 'model_state')
279
+ # save_period = 1 #10 # the period of saving model
280
+ # cond = True # if training using the conditional information
281
+ # lr_decay = False #True# if using the learning rate decay
282
+ resume = False # if resume from the trained checkpoints
283
+ # params_single = torch.tensor([0.2,0.80000023])
284
+ # params = torch.tile(params_single,(n_sample,1)).to(device)
285
+ # params = params
286
+ # data_dir = './data' # data directory
287
+
288
+ use_fp16 = False
289
+ dtype = torch.float16 if use_fp16 else torch.float32
290
+ mixed_precision = "fp16"
291
+ gradient_accumulation_steps = 1
292
+
293
+ # date = datetime.datetime.now().strftime("%m%d-%H%M")
294
+ # run_name = f'{date}' # the unique name of each experiment
295
+
296
+ # config = TrainConfig()
297
+ # print("device =", config.device)
298
+
299
+ # %%
300
+ # import os
301
+ # print(os.cpu_count())
302
+ # print(len(os.sched_getaffinity(0)))
303
+ # import torch
304
+ # data = torch.randn((64,64))
305
+ # print(data.dtype)
306
+
307
+ # %%
308
+ # @dataclass
309
+
310
+ # def check_params_consistency(model, rank, world_size):
311
+ # all_params_consistent = True
312
+ # for name, param in model.named_parameters():
313
+ # if param.requires_grad:
314
+ # param_tensor = param.detach().clone()
315
+ # dist.all_reduce(param_tensor, op=dist.ReduceOp.SUM)
316
+ # param_tensor /= world_size
317
+
318
+ # if not torch.allclose(param_tensor, param.detach()):
319
+ # all_params_consistent = False
320
+ # if rank == 0:
321
+ # print(f"Parameter {name} is not consistent across GPUs.")
322
+ # if rank == 0 and all_params_consistent:
323
+ # print("All model parameters are consistent across GPUs.")
324
+ # return all_params_consistent
325
+
326
+ # def check_gradients_consistency(model, rank, world_size):
327
+ # all_gradients_consistent = True
328
+ # for name, param in model.named_parameters():
329
+ # if param.requires_grad and param.grad is not None:
330
+ # grad_tensor = param.grad.detach().clone()
331
+ # dist.all_reduce(grad_tensor, op=dist.ReduceOp.SUM)
332
+ # grad_tensor /= world_size
333
+
334
+ # if not torch.allclose(grad_tensor, param.grad.detach()):
335
+ # all_gradients_consistent = False
336
+ # if rank == 0:
337
+ # print(f"Gradient {name} is not consistent across GPUs.")
338
+ # if rank == 0 and all_gradients_consistent:
339
+ # print("All model gradients are consistent across GPUs.")
340
+ # return all_gradients_consistent
341
+
342
+ class DDPM21CM:
343
+ def __init__(self, config):
344
+ # print(
345
+ # "torch.cuda.is_available() =", torch.cuda.is_available(),
346
+ # "torch.cuda.device_count() =", torch.cuda.device_count(),
347
+ # "torch.cuda.is_initialized() =", torch.cuda.is_initialized(),
348
+ # "torch.cuda.current_device() =", torch.cuda.current_device()
349
+ # )
350
+ # config = TrainConfig()
351
+ # date = datetime.datetime.now().strftime("%m%d-%H%M")
352
+ config.run_name = datetime.datetime.now().strftime("%m%d-%H%M") # the unique name of each experiment
353
+ self.config = config
354
+ # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)
355
+ # # self.shape_loaded = dataset.images.shape
356
+ # # print("shape_loaded =", self.shape_loaded)
357
+ # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
358
+ # del dataset
359
+ # print("self.ddpm = DDPMScheduler")
360
+ self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, dtype=config.dtype)
361
+
362
+ # print("self.nn_model = ContextUnet")
363
+ # initialize the unet
364
+ self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, dtype=config.dtype)
365
+
366
+ # print("self.nn_model.train()")
367
+ # nn_model = ContextUnet(n_param=1, image_size=28)
368
+ self.nn_model.train()
369
+ # print("self.ddpm.device =", self.ddpm.device)
370
+ self.nn_model.to(self.ddpm.device)
371
+ # print("before, nn_model.device =", self.ddpm.device)
372
+ self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
373
+ # print("after, nn_model.device =", self.ddpm.device)
374
+ # number of parameters to be trained
375
+
376
+ if config.resume and os.path.exists(config.resume):
377
+ # resume_file = os.path.join(config.output_dir, f"{config.resume}")
378
+ # self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
379
+ # print(f"resumed nn_model from {config.resume}")
380
+ self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
381
+ print(f"cuda:{torch.cuda.current_device()} resumed nn_model from {config.resume}")
382
+ else:
383
+ print(f"cuda:{torch.cuda.current_device()} initialized nn_model randomly")
384
+
385
+ self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
386
+ print(f" Number of parameters for nn_model: {self.number_of_params} ".center(120,'-'))
387
+
388
+ # whether to use ema
389
+ if config.ema:
390
+ self.ema = EMA(config.ema_rate)
391
+ if config.resume and os.path.exists(config.resume):
392
+ self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, dtype=config.dtype).to(config.device)
393
+ self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])
394
+ print(f"resumed ema_model from {config.resume}")
395
+ else:
396
+ self.ema_model = copy.deepcopy(self.nn_model).eval().requires_grad_(False)
397
+
398
+ self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate)
399
+ self.lr_scheduler = get_cosine_schedule_with_warmup(
400
+ optimizer=self.optimizer,
401
+ num_warmup_steps=config.lr_warmup_steps,
402
+ num_training_steps=int(config.num_image / config.world_size / config.batch_size * config.n_epoch),
403
+ # num_training_steps=(len(self.dataloader) * config.n_epoch),
404
+ )
405
+
406
+ self.ranges_dict = config.ranges_dict
407
+
408
+ def load(self):
409
+ # rank = torch.cuda.current_device()
410
+ dataset = Dataset4h5(
411
+ self.config.dataset_name,
412
+ num_image=self.config.num_image,
413
+ idx = "random",#'range',
414
+ HII_DIM=self.config.HII_DIM,
415
+ num_redshift=self.config.num_redshift,
416
+ drop_prob=self.config.drop_prob,
417
+ dim=self.config.dim,
418
+ ranges_dict=self.ranges_dict,
419
+ num_workers=len(os.sched_getaffinity(0))//self.config.world_size,
420
+ )
421
+ # self.shape_loaded = dataset.images.shape
422
+ # print("shape_loaded =", self.shape_loaded)
423
+ # print(f"load, current_device() = {torch.cuda.current_device()}")
424
+ self.dataloader = DataLoader(
425
+ dataset=dataset,
426
+ batch_size=self.config.batch_size,
427
+ shuffle=True,#False,
428
+ num_workers=len(os.sched_getaffinity(0))//self.config.world_size,
429
+ pin_memory=True,
430
+ persistent_workers=True,
431
+ # sampler=DistributedSampler(dataset),
432
+ )
433
+
434
+ del dataset
435
+ # self.accelerate(self.config)
436
+ # print("!!!!!!!!!!!!!!!!, self.dataloader.sampler =", self.dataloader.sampler)
437
+ # del dataset
438
+
439
+ # def accelerate(self):
440
+
441
+ def train(self):
442
+ ###################
443
+ ## training loop ##
444
+ ###################
445
+ # plot_unet = True
446
+
447
+ self.load()
448
+ self.accelerator = Accelerator(
449
+ mixed_precision=self.config.mixed_precision,
450
+ gradient_accumulation_steps=self.config.gradient_accumulation_steps,
451
+ log_with="tensorboard",
452
+ project_dir=os.path.join(self.config.output_dir, "logs"),
453
+ # distributed_type="MULTI_GPU",
454
+ )
455
+ # print("!!!!!!!!!!!!!!!!!!!self.accelerator.device:", self.accelerator.device)
456
+ # if self.accelerator.is_main_process:
457
+ if torch.cuda.current_device() == 0:
458
+ if self.config.output_dir is not None:
459
+ os.makedirs(self.config.output_dir, exist_ok=True)
460
+ if self.config.push_to_hub:
461
+ self.repo_id = create_repo(
462
+ repo_id=self.config.hub_model_id or Path(self.config.output_dir).name, exist_ok=True
463
+ ).repo_id
464
+ self.accelerator.init_trackers(f"{self.config.run_name}")
465
+
466
+
467
+ # print("!!!!!!!!!!!!!!!!, before prepare, self.dataloader.sampler =", self.dataloader.sampler)
468
+ self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \
469
+ self.accelerator.prepare(
470
+ self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler
471
+ )
472
+ # self.nn_model, self.optimizer, self.lr_scheduler = \
473
+ # self.accelerator.prepare(
474
+ # self.nn_model, self.optimizer, self.lr_scheduler
475
+ # )
476
+
477
+ # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
478
+ # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
479
+ # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.DistributedSampler =", self.dataloader.DistributedSampler)
480
+
481
+ global_step = 0
482
+ for ep in range(self.config.n_epoch):
483
+ self.ddpm.train()
484
+ # self.dataloader.sampler.set_epoch(ep)
485
+
486
+ pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
487
+ pbar_train.set_description(f"cuda:{torch.cuda.current_device()}, Epoch {ep}")
488
+ for i, (x, c) in enumerate(self.dataloader):
489
+ # print(f"cuda:{torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
490
+ with self.accelerator.accumulate(self.nn_model):
491
+ x = x.to(self.config.device)
492
+ # print("x = x.to(self.config.device), x.dtype =", x.dtype)
493
+ # x = x.to(self.config.dtype)
494
+ # print("x = x.to(self.dtype), x.dtype =", x.dtype)
495
+ xt, noise, ts = self.ddpm.add_noise(x)
496
+
497
+ if self.config.guide_w == -1:
498
+ noise_pred = self.nn_model(xt, ts)
499
+ else:
500
+ c = c.to(self.config.device)
501
+ noise_pred = self.nn_model(xt, ts, c)
502
+
503
+ # print("noise_pred = self.nn_model(xt, ts, c), noise_pred.dtype =", noise_pred.dtype)
504
+
505
+ loss = F.mse_loss(noise, noise_pred)
506
+ self.accelerator.backward(loss)
507
+ self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)
508
+ self.optimizer.step()
509
+ self.lr_scheduler.step()
510
+ self.optimizer.zero_grad()
511
+
512
+ # ema update
513
+ if self.config.ema:
514
+ self.ema.step_ema(self.ema_model, self.nn_model)
515
+
516
+ pbar_train.update(1)
517
+ logs = dict(
518
+ loss=loss.detach().item(),
519
+ lr=self.optimizer.param_groups[0]['lr'],
520
+ step=global_step
521
+ )
522
+ pbar_train.set_postfix(**logs)
523
+
524
+ self.accelerator.log(logs, step=global_step)
525
+ global_step += 1
526
+
527
+ # if ep == config.n_epoch-1 or (ep+1)*config.save_period==1:
528
+ self.save(ep)
529
+ # # 检查参数和梯度的一致性
530
+ # rank = torch.cuda.current_device()
531
+ # params_consistent = check_params_consistency(self.ddpm, rank, self.config.world_size)
532
+ # gradients_consistent = check_gradients_consistency(self.ddpm, rank, self.config.world_size)
533
+ # # 如果任何一致性检查失败,在所有rank上打印警告
534
+ # if not (params_consistent and gradients_consistent):
535
+ # print(f"Rank {rank}: Parameter or gradient inconsistency detected.")
536
+
537
+
538
+ del self.nn_model
539
+ if self.config.ema:
540
+ del self.ema_model
541
+ torch.cuda.empty_cache()
542
+
543
+ def save(self, ep):
544
+ # save model
545
+ # if self.accelerator.is_main_process:
546
+ if torch.cuda.current_device() == 0:
547
+ if ep == self.config.n_epoch-1 or (ep+1) % self.config.save_period == 0:
548
+ self.nn_model.eval()
549
+ with torch.no_grad():
550
+ if self.config.push_to_hub:
551
+ upload_folder(
552
+ repo_id = self.repo_id,
553
+ folder_path = ".",#config.output_dir,
554
+ commit_message = f"{self.config.run_name}",
555
+ ignore_patterns = ["step_*", "epoch_*", "*.npy", "__pycache__"],
556
+ )
557
+ if self.config.save_name:
558
+ model_state = {
559
+ 'epoch': ep,
560
+ 'unet_state_dict': self.nn_model.module.state_dict(),
561
+ # 'ema_unet_state_dict': self.ema_model.state_dict(),
562
+ }
563
+ save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-epoch{ep}"
564
+ torch.save(model_state, save_name)
565
+ print(f'cuda:{torch.cuda.current_device()} saved model at ' + save_name)
566
+ # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
567
+
568
+ # def rescale(self, value, type='params', to_ranges=[0,1]):
569
+ # for i, from_ranges in self.ranges_dict[type].items():
570
+ # value[i] = (value[i] - from_ranges[0])/(from_ranges[1]-from_ranges[0]) # normalize
571
+ # value[i] =
572
+ def rescale(self, params, ranges, to: list):
573
+ # value = np.array(params).copy()
574
+ value = params.clone()
575
+
576
+ if value.ndim == 1:
577
+ value = value.view(-1,len(value))
578
+
579
+ for i in range(np.shape(value)[1]):
580
+ value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0])
581
+ # print(f"i = {i}, value.min = {value[:,i].min()}, value.max = {value[:,i].max()}")
582
+ value = value * (to[1]-to[0]) + to[0]
583
+ return value
584
+
585
+ def sample(self, params:torch.tensor=None, num_new_img_per_gpu=192, ema=False, entire=False, save=True):
586
+ # n_sample = params.shape[0]
587
+ # file = self.config.resume
588
+
589
+ # print(f"cuda:{torch.cuda.current_device()}, sample, params = {params}")
590
+ if params is None:
591
+ params = torch.tensor([4.4, 131.341])
592
+ # params_backup = params.numpy().copy()
593
+ # else:
594
+ params_backup = params.numpy().copy()
595
+ params_normalized = self.rescale(params, self.ranges_dict['params'], to=[0,1])
596
+
597
+ print(f"cuda:{torch.cuda.current_device()} sampling {num_new_img_per_gpu} images with normalized params = {params_normalized}")
598
+ params_normalized = params_normalized.repeat(num_new_img_per_gpu,1)
599
+ assert params_normalized.dim() == 2, "params_normalized must be a 2D torch.tensor"
600
+ # print("params =", params)
601
+ # print("params =", params)
602
+ # print("len(params) =", len(params))
603
+ # model = self.ema_model if ema else self.nn_model
604
+ # del self.ema_model, self.nn
605
+ # params = torch.tile(params, (n_sample,1)).to(device)
606
+
607
+ # nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
608
+ # if ema:
609
+ # self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
610
+ # else:
611
+ # self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
612
+ # print(f"cuda:{torch.cuda.current_device()} resumed nn_model from {file}")
613
+ # nn_model = ContextUnet(n_param=1, image_size=28)
614
+ # nn_model.train()
615
+ # self.nn_model.to(self.ddpm.device)
616
+ self.nn_model.eval()
617
+
618
+ # self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)
619
+ # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f"{config.resume}"))['ema_unet_state_dict'])
620
+ # print(f"resumed ema_model from {config.resume}")
621
+
622
+ with torch.no_grad():
623
+ x_last, x_entire = self.ddpm.sample(
624
+ nn_model=self.nn_model,
625
+ params=params_normalized.to(self.config.device),
626
+ device=self.config.device,
627
+ guide_w=self.config.guide_w
628
+ )
629
+
630
+ if save:
631
+ # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
632
+ savetime = datetime.datetime.now().strftime("%m%d-%H%M")
633
+ savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{torch.cuda.current_device()}-{savetime}{'ema' if ema else ''}.npy")
634
+ print(f"saving {savename} ...")
635
+ np.save(savename, x_last)
636
+
637
+ if entire:
638
+ savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{torch.cuda.current_device()}-{savetime}{'ema' if ema else ''}_entire.npy")
639
+ print(f"saving {savename} ...")
640
+ np.save(savename, x_entire)
641
+ # else:
642
+ return x_last
643
+ # %%
644
+
645
+ num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
646
+
647
+ def train(rank, world_size):
648
+ # print("before ddp_setup")
649
+ ddp_setup(rank, world_size)
650
+ # print("after ddp_setup")
651
+ # print("TrainConfig()")
652
+ config = TrainConfig()
653
+ config.device = f"cuda:{rank}"
654
+ # print("torch.cuda.current_device(), config.device =", torch.cuda.current_device(), config.device)
655
+ config.world_size = world_size
656
+
657
+ #[3200]#[200]#[1600,3200,6400,12800,25600]
658
+ for i, num_image in enumerate(num_train_image_list):
659
+ config.num_image = num_image
660
+ # config.world_size = world_size
661
+ # print("ddpm21cm = DDPM21CM(config)")
662
+ # print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
663
+ ddpm21cm = DDPM21CM(config)
664
+ # print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
665
+ print(f"run_name = {ddpm21cm.config.run_name}")
666
+ ddpm21cm.train()
667
+ destroy_process_group()
668
+
669
+ if __name__ == "__main__":# and False:
670
+ world_size = torch.cuda.device_count()
671
+ print(f" training, world_size = {world_size} ".center(120,'-'))
672
+ # torch.multiprocessing.set_start_method("spawn")
673
+ # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
674
+
675
+ mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
676
+ # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
677
+
678
+
679
+ # %%
680
+
681
+ # def generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params):
682
+ # # samples = []
683
+ # for _ in range(num_new_img_per_gpu // max_num_img_per_gpu):
684
+ # sample = ddpm21cm.sample(
685
+ # params=params,
686
+ # num_new_img_per_gpu=max_num_img_per_gpu
687
+ # )
688
+
689
+ # print(f"cuda:{torch.cuda.current_device()} generated sample of shape: {sample.shape}")
690
+
691
+ # # samples.append(sample)
692
+ # # ddpm21cm.sample(params=torch.tensor((5.6, 19.037)), num_new_img_per_gpu=max_num_img_per_gpu)
693
+ # # ddpm21cm.sample(params=torch.tensor((4.699, 30)), num_new_img_per_gpu=max_num_img_per_gpu)
694
+ # # ddpm21cm.sample(params=torch.tensor((5.477, 200)), num_new_img_per_gpu=max_num_img_per_gpu)
695
+ # # ddpm21cm.sample(params=torch.tensor((4.8, 131.341)), num_new_img_per_gpu=max_num_img_per_gpu)
696
+ # # samples = np.concatenate(samples, axis=0)
697
+
698
+ # # samples_list = [np.empty_like(samples) for _ in range(world_size)]
699
+ # # dist.all_gather_object(samples_list, samples)
700
+
701
+ # # if rank == 0:
702
+ # # all_samples = np.concatenate(samples_list, axis=0)
703
+ # # return all_samples
704
+ # # else:
705
+ # # return None
706
+
707
+ def generate_samples(rank, world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, params):
708
+ ddp_setup(rank, world_size)
709
+ ddpm21cm = DDPM21CM(config)
710
+
711
+ # generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params)
712
+
713
+ # samples = []
714
+ for _ in range(num_new_img_per_gpu // max_num_img_per_gpu):
715
+ sample = ddpm21cm.sample(
716
+ params=params,
717
+ num_new_img_per_gpu=max_num_img_per_gpu
718
+ )
719
+
720
+ print(f"cuda:{torch.cuda.current_device()} generated sample of shape: {sample.shape}")
721
+
722
+ # print(f"cuda:{torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
723
+ # if rank == 0:
724
+ # return_dict['samples'] = samples
725
+ # print(f"cuda:{torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
726
+
727
+ dist.destroy_process_group()
728
+
729
+
730
+ if __name__ == "__main__":
731
+ world_size = torch.cuda.device_count()
732
+ # print(f" sampling, world_size = {world_size} ".center(120,'-'))
733
+ # num_train_image_list = [1600,3200,6400,12800,25600]
734
+ # num_train_image_list = [5000]
735
+ num_new_img_per_gpu = 200
736
+ max_num_img_per_gpu = 20
737
+
738
+ # params = torch.tensor([4.4, 131.341])
739
+
740
+ # print("config = TrainConfig()")
741
+ config = TrainConfig()
742
+ config.world_size = world_size
743
+ # print("config.world_size = world_size")
744
+
745
+ for num_image in num_train_image_list:
746
+ config.num_image = num_image# // world_size
747
+ config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
748
+ # config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
749
+
750
+ # print("ddpm21cm = DDPM21CM(config)")
751
+ manager = mp.Manager()
752
+ return_dict = manager.dict()
753
+
754
+ params_pairs = [
755
+ (4.4, 131.341),
756
+ (5.6, 19.037),
757
+ (4.699, 30),
758
+ (5.477, 200),
759
+ (4.8, 131.341),
760
+ ]
761
+ for params in params_pairs:
762
+ print(f" sampling for {params}, world_size = {world_size} ".center(120,'-'))
763
+ mp.spawn(generate_samples, args=(world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, torch.tensor(params)), nprocs=torch.cuda.device_count(), join=True)
764
+
765
+ # print("---"*30)
766
+ # print(f"cuda:{torch.cuda.current_device()}, keys = {return_dict.keys()}")
767
+ # if "samples" in return_dict:
768
+ # samples = return_dict["samples"]
769
+ # print(f"cuda:{torch.cuda.current_device()} generated samples shape: {samples.shape}")
770
+
771
+
772
+ # %%
773
+ # ls -lth outputs | head
774
+
775
+ # # %%
776
+ # def plot_grid(samples, c=None, row=1, col=2):
777
+ # print("samples.shape =", samples.shape)
778
+ # for j in range(samples.shape[4]):
779
+ # plt.figure(figsize = (12,6), dpi=400)
780
+ # for i in range(len(samples)):
781
+ # plt.subplot(row,col,i+1)
782
+ # plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)
783
+ # plt.xticks([])
784
+ # plt.yticks([])
785
+ # # plt.suptitle(f"ION_Tvir_MIN = {c[0][0]}, HII_EFF_FACTOR = {c[0][1]}")
786
+ # # plt.show()
787
+ # # plt.suptitle('simulations')
788
+ # plt.tight_layout()
789
+ # plt.subplots_adjust(wspace=0, hspace=0)
790
+ # plt.savefig(f"test3D-{j:03d}.png")
791
+ # plt.close()
792
+ # # plt.show()
793
+
794
+ # data = np.load("outputs/Tvir4.400000095367432-zeta131.34100341796875-N1000.npy")
795
+ # # print(data.shape)
796
+ # plot_grid(data)
797
+ # plt.imshow(data)
798
+
799
+ # %%
800
+ # config = TrainConfig()
801
+ # def plot(filename, row=4, col=6):
802
+ # samples = np.load(filename)
803
+ # params = filename.split('guide_w')[-1][:-4]
804
+ # print("plotting", samples.shape, params)
805
+ # plt.figure(figsize = (8,8))
806
+ # for i in range(24):
807
+ # plt.subplot(row,col,i+1)
808
+ # plt.imshow(samples[i,0,:,:], cmap='gray')#, vmin=-1, vmax=1)
809
+ # plt.xticks([])
810
+ # plt.yticks([])
811
+ # # plt.show()
812
+ # plt.suptitle(params)
813
+ # plt.tight_layout()
814
+ # plt.subplots_adjust(wspace=0, hspace=0)
815
+ # plt.show()
816
+ # # plt.savefig('outputs/'+params+'.png')
817
+ # # plt.close()
818
+ # # plt.imshow(images[0,0])
819
+ # # plt.show()
820
+
821
+ # %%
822
+
823
+