File size: 37,895 Bytes
8a8e86b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
# %% [markdown]
# ## 改編ContextUnet及相關代碼,使其首先對二維的情況適用。並於diffusers.Unet2DModel作比較並加以優化。最後再改寫爲3維的情形。
# - 經試用diffusers的Unet2DModel,發現loss從0.3降到0.2但仍然很高,説明存在非Unet2DModel的問題可以優化
# - 改用diffusers的DDMPScheduler和DDPMPipeline后,loss降低至0.1以下,有時甚至可以低至0.004,可見我的代碼問題主要出在DDPM部分。DDPMScheduler部分比較簡短,似乎沒有問題,所以問題應該在DDPMPipeline裏某一部分代碼是我代碼欠缺的。
# - 我在DDPMScheduler部分有一個typo,導致beta_t一直很小,修正后loss從0.2能降低至0.02, 維持在0.1以下
# - 用diffusers的DDPMScheduler似乎效果要好一些,loss總是比我的DDPMScheduler要小一點。儅epoch為19時,前者的loss約0.02,後者loss約0.07。而且前者還支持3維圖像的加噪,不如直接用別人的輪子。但我想知道爲什麽我的loss會高一些。
# - 我意識到別人的DDPMScheduler在sample函數中沒有兼容輸入參數,所以歸根結底還是需要我的DDPMscheduler。不過我可以先用別人的來debug我的ContextUnet.
# - 我需要將我的ContextUnet擴展兼容不同維度的照片,畢竟我本身也需要和原文獻對比完了再拓展到三維的情形
# - 我已將我的ContextUnet轉成了2維的模式,與diffusers.Unet2DModel的loss=0.037相比,我的Unet的loss=0.07。同時我的Unet生成的圖像看上去很奇怪,説明我的Unet也有問題。我需要將代碼退回原Unet,並檢查問題所在。
# - 我將紅移方向的像素的數量限制在了64.以此比較兩個Unet的差別。經比較:\
# Unet2DModel loss:0.03, 0.0655, 0.05, 0.02, 0.05\
# ContextUnet loss: 0.1, 0.16, 0.1, 0.2186, 0.06
# - 我把ContextUnet退回到了原作者的版本,結果loss=0.05,輸出的照片也不錯。我主要的改動是改回了他原用的normalization函數,其中還有個參數swish。有時間我可以研究一下具體是哪裏影響了訓練的結果。另外我發現了要想tensorboard的圖綫獨立美觀,需要把他們放在不同的文件夾下
# - 經過驗證,GroupNorm比batchNorm效果要好
# - 已擴展爲接受不同維度的情形
# - 融合cond, guide_w, drop_out這些參數
# - 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題
# - 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形
# - 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。
# - 解決了,問題出在我忘了寫with torch.no_grad():
# - 接下來就是生成800個lightcones,與此同時研究如何計算global signal以及power spectrum
# - 儅訓練圖片的數量達到5000時,生成的圖片與檢測數據的相似程度很高
# - it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.
# - the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.
# - In addtion, the performance of DDPM can looks better compared to computation-intensive simulations. 
# 1 GPU, batch_size = 10, num_image = 3200, 50s for each epoch
# 4 GPU, batch_size = 10, num_image = 3200, 

# %%
from dataclasses import dataclass
import h5py
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
# from datasets import Dataset
import matplotlib.pyplot as plt
import numpy as np
import random
# from abc import ABC, abstractmethod
import torch.nn.functional as F
import math
# from PIL import Image
import os
from torch.utils.tensorboard import SummaryWriter
import copy
from tqdm.auto import tqdm
# from torchvision import transforms
# from diffusers import UNet2DModel#, UNet3DConditionModel
# from diffusers import DDPMScheduler
from diffusers.utils import make_image_grid
import datetime
from pathlib import Path
from diffusers.optimization import get_cosine_schedule_with_warmup
from accelerate import notebook_launcher, Accelerator
from huggingface_hub import create_repo, upload_folder

from load_h5 import Dataset4h5
from context_unet import ContextUnet

from huggingface_hub import notebook_login

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch.distributed as dist

# %%
def ddp_setup(rank: int, world_size: int):
  """
  Args:
      rank: Unique identifier of each process
     world_size: Total number of processes
  """
  os.environ["MASTER_ADDR"] = "localhost"
  os.environ["MASTER_PORT"] = "12355"
#   print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ddp_setup, rank =", rank)
  torch.cuda.set_device(rank)
  init_process_group(backend="nccl", rank=rank, world_size=world_size)

# %%
# notebook_login()

# %% [markdown]
# # Add noise:
# 
# \begin{align*}
# x_t &\sim \mathcal N\left(\sqrt{1-\beta_t}\ x_{t-1},\ \beta_t \right) \\
# x_t &\equiv \sqrt{1-\beta_t}\ x_{t-1} + \sqrt{\beta_t}\ \epsilon\\
# \epsilon &\sim \mathcal N(0,1)\\
# \alpha_t & \equiv 1 - \beta_t\\
# & ...\\
# x_t &= \sqrt{\bar {\alpha_t}} x_0 + \epsilon\ \sqrt{1 - \bar{\alpha_t}}\\
# \bar {\alpha_t} &\equiv \prod_{i=1}^t \alpha_i\\
# &= \exp\left({\ln{\prod_{i=1}^t \alpha_i}}\right)\\
# &= \exp\left({\sum_{i=1}^t\ln{ \alpha_i}}\right)
# \end{align*}

# %%
class DDPMScheduler(nn.Module):
    def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu', dtype=torch.float32):
        super().__init__()
        
        beta_1, beta_T = betas
        assert 0 < beta_1 <= beta_T <= 1, "ensure 0 < beta_1 <= beta_T <= 1"
        self.device = device
        self.num_timesteps = num_timesteps
        self.img_shape = img_shape
        self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1
        self.beta_t = self.beta_t.to(self.device)

        # self.drop_prob = drop_prob
        # self.cond = cond
        self.alpha_t = 1 - self.beta_t
        # self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
        self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
        # self.use_fp16 = use_fp16
        self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32

    def add_noise(self, clean_images):
        shape = clean_images.shape
        expand = torch.ones(len(shape)-1, dtype=int)
        # ts_expand = ts.view(ts.shape[0], *expand.tolist())
        # expand = [1 for i in range(len(shape)-1)]

        noise = torch.randn_like(clean_images).to(self.device)
        ts = torch.randint(0, self.num_timesteps, (shape[0],)).to(self.device)
                
        # test_expand = test.view(test.shape[0],*expand)
        # extend_dim = [None for i in range(shape.dim()-1)]
        noisy_images = (
            clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())
            + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())
            )
        # print(x_t.shape)

        return noisy_images, noise, ts

    def sample(self, nn_model, params, device, guide_w = 0):
        n_sample = len(params) #params.shape[0]
        # print("params.shape[0], len(params)", params.shape[0], len(params))
        x_i = torch.randn(n_sample, *self.img_shape).to(device)
        # print("x_i.shape =", x_i.shape)
        # print("x_i.shape =", x_i.shape)
        if guide_w != -1:
            c_i = params
            uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)
            # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)
            # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)
            c_i = torch.cat((c_i, uncond_tokens), 0)

        x_i_entire = [] # keep track of generated steps in case want to plot something
        # print("self.num_timesteps =", self.num_timesteps)
        # for i in range(self.num_timesteps, 0, -1):
        # print(f'sampling!!!')
        pbar_sample = tqdm(total=self.num_timesteps)
        pbar_sample.set_description(f"cuda:{torch.cuda.current_device()} sampling")
        for i in reversed(range(0, self.num_timesteps)):
            # print(f'sampling timestep {i:4d}',end='\r')
            t_is = torch.tensor([i]).to(device)
            t_is = t_is.repeat(n_sample)

            z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else 0

            if guide_w == -1:
                # eps = nn_model(x_i, t_is, return_dict=False)[0]
                eps = nn_model(x_i, t_is)
                # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
            else:
                # double batch
                x_i = x_i.repeat(2, *torch.ones(len(self.img_shape), dtype=int).tolist())
                t_is = t_is.repeat(2)

                # split predictions and compute weighting
                # print("nn_model input shape", x_i.shape, t_is.shape, c_i.shape)
                eps = nn_model(x_i, t_is, c_i)
                eps1 = eps[:n_sample]
                eps2 = eps[n_sample:]
                eps = eps1 + guide_w*(eps1 - eps2)
                # eps = (1+guide_w)*eps1 - guide_w*eps2
                x_i = x_i[:n_sample]
                # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
            
            # print("x_i.shape =", x_i.shape)
            x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
            
            pbar_sample.update(1)
            
            # store only part of the intermediate steps
            # if i%20==0:# or i==0:# or i<8:
            #     x_i_entire.append(x_i.detach().cpu().numpy())
        x_i_entire = np.array(x_i_entire)
        x_i = x_i.detach().cpu().numpy()
        return x_i, x_i_entire


# ddpm_scheduler = DDPMScheduler((1e-4,0.02),10)
# noisy_images, noise, ts = ddpm_scheduler.add_noise(images)

# %%
class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model):
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())
        

# %%
@dataclass
class TrainConfig:
    ###########################
    ## hardcoding these here ##
    ###########################
    push_to_hub = True 
    hub_model_id = "Xsmos/ml21cm"
    hub_private_repo = False
    dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
    device = "cuda" if torch.cuda.is_available() else 'cpu'
    # device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else 'cpu'
    world_size = 1#torch.cuda.device_count()
    # repeat = 2

    # dim = 2
    dim = 2
    stride = (2,4) if dim == 2 else (2,2,2)
    num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
    batch_size = 10#50#20#50#1#2#50#20#2#100 # 10
    n_epoch = 50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
    HII_DIM = 64
    num_redshift = 512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
    channel = 1
    img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)

    ranges_dict = dict(
        params = {
            0: [4, 6], # ION_Tvir_MIN
            1: [10, 250], # HII_EFF_FACTOR
            },
        images = {
            0: [0, 80], # brightness_temp
            }
        )

    num_timesteps = 1000#1000 # 1000, 500; DDPM time steps
    # n_sample = 24 # 64, the number of samples in sampling process
    n_param = 2
    guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance
    drop_prob = 0#0.28 # only takes effect when guide_w != -1
    ema=False # whether to use ema
    ema_rate=0.995

    # seed = 0
    # save_dir = './outputs/'

    save_period = n_epoch // 3 #np.infty#.1 # the period of sampling
    # general parameters for the name and logger    
    # device = "cuda" if torch.cuda.is_available() else "cpu"
    lrate = 1e-4
    lr_warmup_steps = 0#5#00
    output_dir = "./outputs/"
    save_name = os.path.join(output_dir, 'model_state')
    # save_period = 1 #10 # the period of saving model
    # cond = True # if training using the conditional information
    # lr_decay = False #True# if using the learning rate decay
    resume = False # if resume from the trained checkpoints
    # params_single = torch.tensor([0.2,0.80000023])
    # params = torch.tile(params_single,(n_sample,1)).to(device)
    # params =  params
    # data_dir = './data' # data directory

    use_fp16 = False
    dtype = torch.float16 if use_fp16 else torch.float32
    mixed_precision = "fp16"
    gradient_accumulation_steps = 1

    # date = datetime.datetime.now().strftime("%m%d-%H%M")
    # run_name = f'{date}' # the unique name of each experiment

# config = TrainConfig()
# print("device =", config.device)

# %%
# import os
# print(os.cpu_count())
# print(len(os.sched_getaffinity(0)))
# import torch
# data = torch.randn((64,64))
# print(data.dtype)

# %%
# @dataclass

# def check_params_consistency(model, rank, world_size):
#     all_params_consistent = True
#     for name, param in model.named_parameters():
#         if param.requires_grad:
#             param_tensor = param.detach().clone()
#             dist.all_reduce(param_tensor, op=dist.ReduceOp.SUM)
#             param_tensor /= world_size

#             if not torch.allclose(param_tensor, param.detach()):
#                 all_params_consistent = False
#                 if rank == 0:
#                     print(f"Parameter {name} is not consistent across GPUs.")
#     if rank == 0 and all_params_consistent:
#         print("All model parameters are consistent across GPUs.")
#     return all_params_consistent

# def check_gradients_consistency(model, rank, world_size):
#     all_gradients_consistent = True
#     for name, param in model.named_parameters():
#         if param.requires_grad and param.grad is not None:
#             grad_tensor = param.grad.detach().clone()
#             dist.all_reduce(grad_tensor, op=dist.ReduceOp.SUM)
#             grad_tensor /= world_size

#             if not torch.allclose(grad_tensor, param.grad.detach()):
#                 all_gradients_consistent = False
#                 if rank == 0:
#                     print(f"Gradient {name} is not consistent across GPUs.")
#     if rank == 0 and all_gradients_consistent:
#         print("All model gradients are consistent across GPUs.")
#     return all_gradients_consistent

class DDPM21CM:
    def __init__(self, config):
        # print(
        #     "torch.cuda.is_available() =", torch.cuda.is_available(), 
        #     "torch.cuda.device_count() =", torch.cuda.device_count(),
        #     "torch.cuda.is_initialized() =", torch.cuda.is_initialized(),
        #     "torch.cuda.current_device() =", torch.cuda.current_device()
        # )
        # config = TrainConfig()
        # date = datetime.datetime.now().strftime("%m%d-%H%M")
        config.run_name = datetime.datetime.now().strftime("%m%d-%H%M") # the unique name of each experiment
        self.config = config
        # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)
        # # self.shape_loaded = dataset.images.shape
        # # print("shape_loaded =", self.shape_loaded)
        # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
        # del dataset
        # print("self.ddpm = DDPMScheduler")
        self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, dtype=config.dtype)

        # print("self.nn_model = ContextUnet")
        # initialize the unet
        self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, dtype=config.dtype)

        # print("self.nn_model.train()")
        # nn_model = ContextUnet(n_param=1, image_size=28)
        self.nn_model.train()
        # print("self.ddpm.device =", self.ddpm.device)
        self.nn_model.to(self.ddpm.device)
        # print("before, nn_model.device =", self.ddpm.device)
        self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
        # print("after, nn_model.device =", self.ddpm.device)
        # number of parameters to be trained

        if config.resume and os.path.exists(config.resume):
            # resume_file = os.path.join(config.output_dir, f"{config.resume}")
            # self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
            # print(f"resumed nn_model from {config.resume}")
            self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
            print(f"cuda:{torch.cuda.current_device()} resumed nn_model from {config.resume}")
        else:
            print(f"cuda:{torch.cuda.current_device()} initialized nn_model randomly")

        self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
        print(f" Number of parameters for nn_model: {self.number_of_params} ".center(120,'-'))

        # whether to use ema
        if config.ema:
            self.ema = EMA(config.ema_rate)
            if config.resume and os.path.exists(config.resume):
                self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, dtype=config.dtype).to(config.device)
                self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])
                print(f"resumed ema_model from {config.resume}")
            else:
                self.ema_model = copy.deepcopy(self.nn_model).eval().requires_grad_(False)

        self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate)
        self.lr_scheduler = get_cosine_schedule_with_warmup(
            optimizer=self.optimizer,
            num_warmup_steps=config.lr_warmup_steps,
            num_training_steps=int(config.num_image / config.world_size / config.batch_size * config.n_epoch),
            # num_training_steps=(len(self.dataloader) * config.n_epoch),
        )

        self.ranges_dict = config.ranges_dict

    def load(self):
        # rank = torch.cuda.current_device()
        dataset = Dataset4h5(
            self.config.dataset_name, 
            num_image=self.config.num_image,
            idx = "random",#'range',
            HII_DIM=self.config.HII_DIM, 
            num_redshift=self.config.num_redshift,
            drop_prob=self.config.drop_prob, 
            dim=self.config.dim,
            ranges_dict=self.ranges_dict,
            num_workers=len(os.sched_getaffinity(0))//self.config.world_size,
            )
        # self.shape_loaded = dataset.images.shape
        # print("shape_loaded =", self.shape_loaded)
        # print(f"load, current_device() = {torch.cuda.current_device()}")
        self.dataloader = DataLoader(
            dataset=dataset, 
            batch_size=self.config.batch_size, 
            shuffle=True,#False, 
            num_workers=len(os.sched_getaffinity(0))//self.config.world_size,
            pin_memory=True,
            persistent_workers=True,
            # sampler=DistributedSampler(dataset),
            )

        del dataset
        # self.accelerate(self.config)
        # print("!!!!!!!!!!!!!!!!, self.dataloader.sampler =", self.dataloader.sampler)
        # del dataset

    # def accelerate(self):

    def train(self):
        ###################      
        ## training loop ##
        ###################
        # plot_unet = True

        self.load()
        self.accelerator = Accelerator(
            mixed_precision=self.config.mixed_precision,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            log_with="tensorboard",
            project_dir=os.path.join(self.config.output_dir, "logs"),
            # distributed_type="MULTI_GPU",
        )
        # print("!!!!!!!!!!!!!!!!!!!self.accelerator.device:", self.accelerator.device)
        # if self.accelerator.is_main_process:
        if torch.cuda.current_device() == 0:
            if self.config.output_dir is not None:
                os.makedirs(self.config.output_dir, exist_ok=True)
            if self.config.push_to_hub:
                self.repo_id = create_repo(
                    repo_id=self.config.hub_model_id or Path(self.config.output_dir).name, exist_ok=True
                ).repo_id
            self.accelerator.init_trackers(f"{self.config.run_name}")


        # print("!!!!!!!!!!!!!!!!, before prepare, self.dataloader.sampler =", self.dataloader.sampler)
        self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \
            self.accelerator.prepare(
            self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler
            )
        # self.nn_model, self.optimizer, self.lr_scheduler = \
        #     self.accelerator.prepare(
        #     self.nn_model, self.optimizer, self.lr_scheduler
        #     )

        # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
        # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
        # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.DistributedSampler =", self.dataloader.DistributedSampler)

        global_step = 0
        for ep in range(self.config.n_epoch):
            self.ddpm.train()
            # self.dataloader.sampler.set_epoch(ep)

            pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
            pbar_train.set_description(f"cuda:{torch.cuda.current_device()}, Epoch {ep}")
            for i, (x, c) in enumerate(self.dataloader):
                # print(f"cuda:{torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
                with self.accelerator.accumulate(self.nn_model):
                    x = x.to(self.config.device)
                    # print("x = x.to(self.config.device), x.dtype =", x.dtype)
                    # x = x.to(self.config.dtype)
                    # print("x = x.to(self.dtype), x.dtype =", x.dtype)
                    xt, noise, ts = self.ddpm.add_noise(x)
                    
                    if self.config.guide_w == -1:
                        noise_pred = self.nn_model(xt, ts)
                    else:
                        c = c.to(self.config.device)
                        noise_pred = self.nn_model(xt, ts, c)

                    # print("noise_pred = self.nn_model(xt, ts, c), noise_pred.dtype =", noise_pred.dtype)
                    
                    loss = F.mse_loss(noise, noise_pred)
                    self.accelerator.backward(loss)
                    self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)
                    self.optimizer.step()
                    self.lr_scheduler.step()
                    self.optimizer.zero_grad()

                # ema update
                if self.config.ema:
                    self.ema.step_ema(self.ema_model, self.nn_model)

                pbar_train.update(1)
                logs = dict(
                    loss=loss.detach().item(),
                    lr=self.optimizer.param_groups[0]['lr'],
                    step=global_step
                )
                pbar_train.set_postfix(**logs)

                self.accelerator.log(logs, step=global_step)
                global_step += 1

            # if ep == config.n_epoch-1 or (ep+1)*config.save_period==1:
            self.save(ep)
            # # 检查参数和梯度的一致性
            # rank = torch.cuda.current_device()
            # params_consistent = check_params_consistency(self.ddpm, rank, self.config.world_size)
            # gradients_consistent = check_gradients_consistency(self.ddpm, rank, self.config.world_size)
            # # 如果任何一致性检查失败,在所有rank上打印警告
            # if not (params_consistent and gradients_consistent):
            #     print(f"Rank {rank}: Parameter or gradient inconsistency detected.")


        del self.nn_model
        if self.config.ema:
            del self.ema_model
        torch.cuda.empty_cache()

    def save(self, ep):
        # save model
        # if self.accelerator.is_main_process:
        if torch.cuda.current_device() == 0:
            if ep == self.config.n_epoch-1 or (ep+1) % self.config.save_period == 0:
                self.nn_model.eval()
                with torch.no_grad():
                    if self.config.push_to_hub:
                        upload_folder(
                            repo_id = self.repo_id,
                            folder_path = ".",#config.output_dir,
                            commit_message = f"{self.config.run_name}",
                            ignore_patterns = ["step_*", "epoch_*", "*.npy", "__pycache__"],
                            )
                    if self.config.save_name:
                        model_state = {
                            'epoch': ep,
                            'unet_state_dict': self.nn_model.module.state_dict(),
                            # 'ema_unet_state_dict': self.ema_model.state_dict(),
                            }
                        save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-epoch{ep}"
                        torch.save(model_state, save_name)
                        print(f'cuda:{torch.cuda.current_device()} saved model at ' + save_name)
                        # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")

    # def rescale(self, value, type='params', to_ranges=[0,1]):
    #     for i, from_ranges in self.ranges_dict[type].items():
    #         value[i] = (value[i] - from_ranges[0])/(from_ranges[1]-from_ranges[0]) # normalize
    #         value[i] = 
    def rescale(self, params, ranges, to: list):
        # value = np.array(params).copy()
        value = params.clone()

        if value.ndim == 1:
            value = value.view(-1,len(value))
            
        for i in range(np.shape(value)[1]):
            value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0])
            # print(f"i = {i}, value.min = {value[:,i].min()}, value.max = {value[:,i].max()}")
        value = value * (to[1]-to[0]) + to[0]
        return value 

    def sample(self, params:torch.tensor=None, num_new_img_per_gpu=192, ema=False, entire=False, save=True):
        # n_sample = params.shape[0]
        # file = self.config.resume

        # print(f"cuda:{torch.cuda.current_device()}, sample, params = {params}")
        if params is None:
            params = torch.tensor([4.4, 131.341])
            # params_backup = params.numpy().copy()
        # else:
        params_backup = params.numpy().copy()
        params_normalized = self.rescale(params, self.ranges_dict['params'], to=[0,1])

        print(f"cuda:{torch.cuda.current_device()} sampling {num_new_img_per_gpu} images with normalized params = {params_normalized}")
        params_normalized = params_normalized.repeat(num_new_img_per_gpu,1)
        assert params_normalized.dim() == 2, "params_normalized must be a 2D torch.tensor"
        # print("params =", params)
        # print("params =", params)
        # print("len(params) =", len(params))
        # model = self.ema_model if ema else self.nn_model
        # del self.ema_model, self.nn
        # params = torch.tile(params, (n_sample,1)).to(device)

        # nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
        # if ema:
        #     self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
        # else:
        #     self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
        # print(f"cuda:{torch.cuda.current_device()} resumed nn_model from {file}")
        # nn_model = ContextUnet(n_param=1, image_size=28)
        # nn_model.train()
        # self.nn_model.to(self.ddpm.device)
        self.nn_model.eval()

        # self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)
        # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f"{config.resume}"))['ema_unet_state_dict'])
        # print(f"resumed ema_model from {config.resume}")

        with torch.no_grad():
            x_last, x_entire = self.ddpm.sample(
                nn_model=self.nn_model, 
                params=params_normalized.to(self.config.device), 
                device=self.config.device, 
                guide_w=self.config.guide_w
                )

        if save:    
            # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
            savetime = datetime.datetime.now().strftime("%m%d-%H%M")
            savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{torch.cuda.current_device()}-{savetime}{'ema' if ema else ''}.npy")
            print(f"saving {savename} ...")
            np.save(savename, x_last)

            if entire:
                savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{torch.cuda.current_device()}-{savetime}{'ema' if ema else ''}_entire.npy")
                print(f"saving {savename} ...")
                np.save(savename, x_entire)
        # else:
        return x_last
# %%

num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#

def train(rank, world_size):
    # print("before ddp_setup")
    ddp_setup(rank, world_size)
    # print("after ddp_setup")
    # print("TrainConfig()")
    config = TrainConfig()
    config.device = f"cuda:{rank}"
    # print("torch.cuda.current_device(), config.device =", torch.cuda.current_device(), config.device)
    config.world_size = world_size
    
    #[3200]#[200]#[1600,3200,6400,12800,25600]
    for i, num_image in enumerate(num_train_image_list):
        config.num_image = num_image
        # config.world_size = world_size
        # print("ddpm21cm = DDPM21CM(config)")
        # print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
        ddpm21cm = DDPM21CM(config)
        # print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
        print(f"run_name = {ddpm21cm.config.run_name}")
        ddpm21cm.train()
        destroy_process_group()

if __name__ == "__main__":# and False:
    world_size = torch.cuda.device_count()
    print(f" training, world_size = {world_size} ".center(120,'-'))
    # torch.multiprocessing.set_start_method("spawn")
    # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)

    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
    # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')


# %%

# def generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params):
#     # samples = []
#     for _ in range(num_new_img_per_gpu // max_num_img_per_gpu):    
#         sample = ddpm21cm.sample(
#             params=params, 
#             num_new_img_per_gpu=max_num_img_per_gpu
#             )

#         print(f"cuda:{torch.cuda.current_device()} generated sample of shape: {sample.shape}")

#         # samples.append(sample)
#         # ddpm21cm.sample(params=torch.tensor((5.6, 19.037)), num_new_img_per_gpu=max_num_img_per_gpu)
#         # ddpm21cm.sample(params=torch.tensor((4.699, 30)), num_new_img_per_gpu=max_num_img_per_gpu)
#         # ddpm21cm.sample(params=torch.tensor((5.477, 200)), num_new_img_per_gpu=max_num_img_per_gpu)
#         # ddpm21cm.sample(params=torch.tensor((4.8, 131.341)), num_new_img_per_gpu=max_num_img_per_gpu)
#     # samples = np.concatenate(samples, axis=0)

#     # samples_list = [np.empty_like(samples) for _ in range(world_size)]
#     # dist.all_gather_object(samples_list, samples)

#     # if rank == 0:
#     #     all_samples = np.concatenate(samples_list, axis=0)
#     #     return all_samples
#     # else:
#     #     return None

def generate_samples(rank, world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, params):
    ddp_setup(rank, world_size)
    ddpm21cm = DDPM21CM(config)

    # generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params)

    # samples = []
    for _ in range(num_new_img_per_gpu // max_num_img_per_gpu):    
        sample = ddpm21cm.sample(
            params=params, 
            num_new_img_per_gpu=max_num_img_per_gpu
            )
            
        print(f"cuda:{torch.cuda.current_device()} generated sample of shape: {sample.shape}")

    # print(f"cuda:{torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
    # if rank == 0:
    #     return_dict['samples'] = samples
    # print(f"cuda:{torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")

    dist.destroy_process_group()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    # print(f" sampling, world_size = {world_size} ".center(120,'-'))
    # num_train_image_list = [1600,3200,6400,12800,25600]
    # num_train_image_list = [5000]
    num_new_img_per_gpu = 200
    max_num_img_per_gpu = 20

    # params = torch.tensor([4.4, 131.341])

    # print("config = TrainConfig()")
    config = TrainConfig()
    config.world_size = world_size
    # print("config.world_size = world_size")

    for num_image in num_train_image_list:
        config.num_image = num_image# // world_size
        config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
        # config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"

        # print("ddpm21cm = DDPM21CM(config)")
        manager = mp.Manager()
        return_dict = manager.dict()

        params_pairs = [
            (4.4, 131.341),
            (5.6, 19.037),
            (4.699, 30),
            (5.477, 200),
            (4.8, 131.341),
        ]
        for params in params_pairs:
            print(f" sampling for {params}, world_size = {world_size} ".center(120,'-'))
            mp.spawn(generate_samples, args=(world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, torch.tensor(params)), nprocs=torch.cuda.device_count(), join=True)

        # print("---"*30)
        # print(f"cuda:{torch.cuda.current_device()}, keys = {return_dict.keys()}")
        # if "samples" in return_dict:
        #     samples = return_dict["samples"]
        #     print(f"cuda:{torch.cuda.current_device()} generated samples shape: {samples.shape}")


# %%
# ls -lth outputs | head

# # %%
# def plot_grid(samples, c=None, row=1, col=2):
#     print("samples.shape =", samples.shape)
#     for j in range(samples.shape[4]):
#         plt.figure(figsize = (12,6), dpi=400)
#         for i in range(len(samples)):
#             plt.subplot(row,col,i+1)
#             plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)
#             plt.xticks([])
#             plt.yticks([])
#         # plt.suptitle(f"ION_Tvir_MIN = {c[0][0]}, HII_EFF_FACTOR = {c[0][1]}")
#             # plt.show()
#         # plt.suptitle('simulations')
#         plt.tight_layout()
#         plt.subplots_adjust(wspace=0, hspace=0)
#         plt.savefig(f"test3D-{j:03d}.png")
#         plt.close()
#         # plt.show()
    
# data = np.load("outputs/Tvir4.400000095367432-zeta131.34100341796875-N1000.npy")
# # print(data.shape)
# plot_grid(data)
# plt.imshow(data)

# %%
# config = TrainConfig()
# def plot(filename, row=4, col=6):
#     samples = np.load(filename)
#     params = filename.split('guide_w')[-1][:-4]
#     print("plotting", samples.shape, params)
#     plt.figure(figsize = (8,8))
#     for i in range(24):
#         plt.subplot(row,col,i+1)
#         plt.imshow(samples[i,0,:,:], cmap='gray')#, vmin=-1, vmax=1)
#         plt.xticks([])
#         plt.yticks([])
#         # plt.show()
#     plt.suptitle(params)
#     plt.tight_layout()
#     plt.subplots_adjust(wspace=0, hspace=0) 
#     plt.show()
#     # plt.savefig('outputs/'+params+'.png')
#     # plt.close()
#     # plt.imshow(images[0,0])
#     # plt.show()

# %%