Xsmos commited on
Commit
d1e26d8
·
verified ·
1 Parent(s): aeab678

0718-2225

Browse files
Files changed (2) hide show
  1. backup_diffusion.ipynb +953 -0
  2. diffusion.py +7 -18
backup_diffusion.ipynb ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "## 改編ContextUnet及相關代碼,使其首先對二維的情況適用。並於diffusers.Unet2DModel作比較並加以優化。最後再改寫爲3維的情形。\n",
9
+ "- 經試用diffusers的Unet2DModel,發現loss從0.3降到0.2但仍然很高,説明存在非Unet2DModel的問題可以優化\n",
10
+ "- 改用diffusers的DDMPScheduler和DDPMPipeline后,loss降低至0.1以下,有時甚至可以低至0.004,可見我的代碼問題主要出在DDPM部分。DDPMScheduler部分比較簡短,似乎沒有問題,所以問題應該在DDPMPipeline裏某一部分代碼是我代碼欠缺的。\n",
11
+ "- 我在DDPMScheduler部分有一個typo,導致beta_t一直很小,修正后loss從0.2能降低至0.02, 維持在0.1以下\n",
12
+ "- 用diffusers的DDPMScheduler似乎效果要好一些,loss總是比我的DDPMScheduler要小一點。儅epoch為19時,前者的loss約0.02,後者loss約0.07。而且前者還支持3維圖像的加噪,不如直接用別人的輪子。但我想知道爲什麽我的loss會高一些。\n",
13
+ "- 我意識到別人的DDPMScheduler在sample函數中沒有兼容輸入參數,所以歸根結底還是需要我的DDPMscheduler。不過我可以先用別人的來debug我的ContextUnet.\n",
14
+ "- 我需要將我的ContextUnet擴展兼容不同維度的照片,畢竟我本身也需要和原文獻對比完了再拓展到三維的情形\n",
15
+ "- 我已將我的ContextUnet轉成了2維的模式,與diffusers.Unet2DModel的loss=0.037相比,我的Unet的loss=0.07。同時我的Unet生成的圖像看上去很奇怪,説明我的Unet也有問題。我需要將代碼退回原Unet,並檢查問題所在。\n",
16
+ "- 我將紅移方向的像素的數量限制在了64.以此比較兩個Unet的差別。經比較:\\\n",
17
+ "Unet2DModel loss:0.03, 0.0655, 0.05, 0.02, 0.05\\\n",
18
+ "ContextUnet loss: 0.1, 0.16, 0.1, 0.2186, 0.06\n",
19
+ "- 我把ContextUnet退回到了原作者的版本,結果loss=0.05,輸出的照片也不錯。我主要的改動是改回了他原用的normalization函數,其中還有個參數swish。有時間我可以研究一下具體是哪裏影響了訓練的結果。另外我發現了要想tensorboard的圖綫獨立美觀,需要把他們放在不同的文件夾下\n",
20
+ "- 經過驗證,GroupNorm比batchNorm效果要好\n",
21
+ "- 已擴展爲接受不同維度的情形\n",
22
+ "- 融合cond, guide_w, drop_out這些參數\n",
23
+ "- 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題\n",
24
+ "- 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形\n",
25
+ "- 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。\n",
26
+ "- 解決了,問題出在我忘了寫with torch.no_grad():\n",
27
+ "- 接下來就是生成800個lightcones,與此同時研究如何計算global signal以及power spectrum\n",
28
+ "- 儅訓練圖片的數量達到5000時,生成的圖片與檢測數據的相似程度很高\n",
29
+ "- it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.\n",
30
+ "- the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.\n",
31
+ "- In addtion, the performance of DDPM can looks better compared to computation-intensive simulations. "
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 31,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "from dataclasses import dataclass\n",
41
+ "import h5py\n",
42
+ "import torch\n",
43
+ "import torch.nn as nn\n",
44
+ "from torch.utils.data import DataLoader, Dataset\n",
45
+ "# from datasets import Dataset\n",
46
+ "import matplotlib.pyplot as plt\n",
47
+ "import numpy as np\n",
48
+ "import random\n",
49
+ "# from abc import ABC, abstractmethod\n",
50
+ "import torch.nn.functional as F\n",
51
+ "import math\n",
52
+ "# from PIL import Image\n",
53
+ "import os\n",
54
+ "from torch.utils.tensorboard import SummaryWriter\n",
55
+ "import copy\n",
56
+ "from tqdm.auto import tqdm\n",
57
+ "# from torchvision import transforms\n",
58
+ "# from diffusers import UNet2DModel#, UNet3DConditionModel\n",
59
+ "# from diffusers import DDPMScheduler\n",
60
+ "from diffusers.utils import make_image_grid\n",
61
+ "import datetime\n",
62
+ "from pathlib import Path\n",
63
+ "from diffusers.optimization import get_cosine_schedule_with_warmup\n",
64
+ "from accelerate import notebook_launcher, Accelerator\n",
65
+ "from huggingface_hub import create_repo, upload_folder\n",
66
+ "from load_h5 import Dataset4h5\n",
67
+ "\n",
68
+ "from context_unet import ContextUnet\n",
69
+ "\n",
70
+ "from huggingface_hub import notebook_login\n",
71
+ "\n",
72
+ "import torch.multiprocessing as mp\n",
73
+ "from torch.utils.data.distributed import DistributedSampler\n",
74
+ "from torch.nn.parallel import DistributedDataParallel as DDP\n",
75
+ "from torch.distributed import init_process_group, destroy_process_group"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 32,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "def ddp_setup(rank: int, world_size: int):\n",
85
+ " \"\"\"\n",
86
+ " Args:\n",
87
+ " rank: Unique identifier of each process\n",
88
+ " world_size: Total number of processes\n",
89
+ " \"\"\"\n",
90
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
91
+ " os.environ[\"MASTER_PORT\"] = \"12355\"\n",
92
+ " torch.cuda.set_device(rank)\n",
93
+ " init_process_group(backend=\"nccl\", rank=rank, world_size=world_size)"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 34,
99
+ "metadata": {},
100
+ "outputs": [
101
+ {
102
+ "data": {
103
+ "application/vnd.jupyter.widget-view+json": {
104
+ "model_id": "9bbf7e9db9ce426d9c59d6f6d8e8df29",
105
+ "version_major": 2,
106
+ "version_minor": 0
107
+ },
108
+ "text/plain": [
109
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
110
+ ]
111
+ },
112
+ "metadata": {},
113
+ "output_type": "display_data"
114
+ }
115
+ ],
116
+ "source": [
117
+ "notebook_login()"
118
+ ]
119
+ },
120
+ {
121
+ "attachments": {},
122
+ "cell_type": "markdown",
123
+ "metadata": {},
124
+ "source": [
125
+ "# Add noise:\n",
126
+ "\n",
127
+ "\\begin{align*}\n",
128
+ "x_t &\\sim \\mathcal N\\left(\\sqrt{1-\\beta_t}\\ x_{t-1},\\ \\beta_t \\right) \\\\\n",
129
+ "x_t &\\equiv \\sqrt{1-\\beta_t}\\ x_{t-1} + \\sqrt{\\beta_t}\\ \\epsilon\\\\\n",
130
+ "\\epsilon &\\sim \\mathcal N(0,1)\\\\\n",
131
+ "\\alpha_t & \\equiv 1 - \\beta_t\\\\\n",
132
+ "& ...\\\\\n",
133
+ "x_t &= \\sqrt{\\bar {\\alpha_t}} x_0 + \\epsilon\\ \\sqrt{1 - \\bar{\\alpha_t}}\\\\\n",
134
+ "\\bar {\\alpha_t} &\\equiv \\prod_{i=1}^t \\alpha_i\\\\\n",
135
+ "&= \\exp\\left({\\ln{\\prod_{i=1}^t \\alpha_i}}\\right)\\\\\n",
136
+ "&= \\exp\\left({\\sum_{i=1}^t\\ln{ \\alpha_i}}\\right)\n",
137
+ "\\end{align*}"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 4,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "class DDPMScheduler(nn.Module):\n",
147
+ " def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu'):\n",
148
+ " super().__init__()\n",
149
+ " \n",
150
+ " beta_1, beta_T = betas\n",
151
+ " assert 0 < beta_1 <= beta_T <= 1, \"ensure 0 < beta_1 <= beta_T <= 1\"\n",
152
+ " self.device = device\n",
153
+ " self.num_timesteps = num_timesteps\n",
154
+ " self.img_shape = img_shape\n",
155
+ " self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1\n",
156
+ " self.beta_t = self.beta_t.to(self.device)\n",
157
+ "\n",
158
+ " # self.drop_prob = drop_prob\n",
159
+ " # self.cond = cond\n",
160
+ " self.alpha_t = 1 - self.beta_t\n",
161
+ " # self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))\n",
162
+ " self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)\n",
163
+ "\n",
164
+ " def add_noise(self, clean_images):\n",
165
+ " shape = clean_images.shape\n",
166
+ " expand = torch.ones(len(shape)-1, dtype=int)\n",
167
+ " # ts_expand = ts.view(ts.shape[0], *expand.tolist())\n",
168
+ " # expand = [1 for i in range(len(shape)-1)]\n",
169
+ "\n",
170
+ " noise = torch.randn_like(clean_images).to(self.device)\n",
171
+ " ts = torch.randint(0, self.num_timesteps, (shape[0],)).to(self.device)\n",
172
+ " \n",
173
+ " # test_expand = test.view(test.shape[0],*expand)\n",
174
+ " # extend_dim = [None for i in range(shape.dim()-1)]\n",
175
+ " noisy_images = (\n",
176
+ " clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
177
+ " + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
178
+ " )\n",
179
+ " # print(x_t.shape)\n",
180
+ "\n",
181
+ " return noisy_images, noise, ts\n",
182
+ "\n",
183
+ " def sample(self, nn_model, params, device, guide_w = 0):\n",
184
+ " n_sample = len(params) #params.shape[0]\n",
185
+ " # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
186
+ " x_i = torch.randn(n_sample, *self.img_shape).to(device)\n",
187
+ " # print(\"x_i.shape =\", x_i.shape)\n",
188
+ " # print(\"x_i.shape =\", x_i.shape)\n",
189
+ " if guide_w != -1:\n",
190
+ " c_i = params\n",
191
+ " uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)\n",
192
+ " # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)\n",
193
+ " # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)\n",
194
+ " c_i = torch.cat((c_i, uncond_tokens), 0)\n",
195
+ "\n",
196
+ " x_i_entire = [] # keep track of generated steps in case want to plot something\n",
197
+ " # print(\"self.num_timesteps =\", self.num_timesteps)\n",
198
+ " # for i in range(self.num_timesteps, 0, -1):\n",
199
+ " # print(f'sampling!!!')\n",
200
+ " pbar_sample = tqdm(total=self.num_timesteps)\n",
201
+ " pbar_sample.set_description(\"Sampling\")\n",
202
+ " for i in reversed(range(0, self.num_timesteps)):\n",
203
+ " # print(f'sampling timestep {i:4d}',end='\\r')\n",
204
+ " t_is = torch.tensor([i]).to(device)\n",
205
+ " t_is = t_is.repeat(n_sample)\n",
206
+ "\n",
207
+ " z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else 0\n",
208
+ "\n",
209
+ " if guide_w == -1:\n",
210
+ " # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
211
+ " eps = nn_model(x_i, t_is)\n",
212
+ " # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
213
+ " else:\n",
214
+ " # double batch\n",
215
+ " x_i = x_i.repeat(2, *torch.ones(len(self.img_shape), dtype=int).tolist())\n",
216
+ " t_is = t_is.repeat(2)\n",
217
+ "\n",
218
+ " # split predictions and compute weighting\n",
219
+ " # print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\n",
220
+ " eps = nn_model(x_i, t_is, c_i)\n",
221
+ " eps1 = eps[:n_sample]\n",
222
+ " eps2 = eps[n_sample:]\n",
223
+ " eps = eps1 + guide_w*(eps1 - eps2)\n",
224
+ " # eps = (1+guide_w)*eps1 - guide_w*eps2\n",
225
+ " x_i = x_i[:n_sample]\n",
226
+ " # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
227
+ " \n",
228
+ " # print(\"x_i.shape =\", x_i.shape)\n",
229
+ " x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
230
+ " \n",
231
+ " pbar_sample.update(1)\n",
232
+ " # pbar_sample.set_postfix(step=i)\n",
233
+ " \n",
234
+ " # print(\"x_i.shape =\", x_i.shape)\n",
235
+ " # store only part of the intermediate steps\n",
236
+ " if i%20==0:# or i==0:# or i<8:\n",
237
+ " x_i_entire.append(x_i.detach().cpu().numpy())\n",
238
+ " x_i = x_i.detach().cpu().numpy()\n",
239
+ " x_i_entire = np.array(x_i_entire)\n",
240
+ " return x_i, x_i_entire\n",
241
+ "\n",
242
+ "\n",
243
+ "# ddpm_scheduler = DDPMScheduler((1e-4,0.02),10)\n",
244
+ "# noisy_images, noise, ts = ddpm_scheduler.add_noise(images)"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": 5,
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "class EMA:\n",
254
+ " def __init__(self, beta):\n",
255
+ " super().__init__()\n",
256
+ " self.beta = beta\n",
257
+ " self.step = 0\n",
258
+ "\n",
259
+ " def update_model_average(self, ma_model, current_model):\n",
260
+ " for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):\n",
261
+ " old_weight, up_weight = ma_params.data, current_params.data\n",
262
+ " ma_params.data = self.update_average(old_weight, up_weight)\n",
263
+ "\n",
264
+ " def update_average(self, old, new):\n",
265
+ " if old is None:\n",
266
+ " return new\n",
267
+ " return old * self.beta + (1 - self.beta) * new\n",
268
+ "\n",
269
+ " def step_ema(self, ema_model, model):\n",
270
+ " self.update_model_average(ema_model, model)\n",
271
+ " self.step += 1\n",
272
+ "\n",
273
+ " def reset_parameters(self, ema_model, model):\n",
274
+ " ema_model.load_state_dict(model.state_dict())\n",
275
+ " "
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": 6,
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "@dataclass\n",
285
+ "class TrainConfig:\n",
286
+ " ###########################\n",
287
+ " ## hardcoding these here ##\n",
288
+ " ###########################\n",
289
+ " push_to_hub = True\n",
290
+ " hub_model_id = \"Xsmos/ml21cm\"\n",
291
+ " hub_private_repo = False\n",
292
+ " dataset_name = \"/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5\"\n",
293
+ " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
294
+ " # world_size = torch.cuda.device_count()\n",
295
+ " # repeat = 2\n",
296
+ "\n",
297
+ " # dim = 2\n",
298
+ " dim = 3\n",
299
+ " stride = (2,2) if dim == 2 else (2,2,1)\n",
300
+ " num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
301
+ " batch_size = 2#2#50#20#2#100 # 10\n",
302
+ " n_epoch = 10#50#20#20#2#5#25 # 120\n",
303
+ " HII_DIM = 28#64\n",
304
+ " num_redshift = 4#128#64#512#256#256#64#512#128\n",
305
+ " channel = 1\n",
306
+ " img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
307
+ "\n",
308
+ " ranges_dict = dict(\n",
309
+ " params = {\n",
310
+ " 0: [4, 6], # ION_Tvir_MIN\n",
311
+ " 1: [10, 250], # HII_EFF_FACTOR\n",
312
+ " },\n",
313
+ " images = {\n",
314
+ " 0: [0, 80], # brightness_temp\n",
315
+ " }\n",
316
+ " )\n",
317
+ "\n",
318
+ " num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
319
+ " # n_sample = 24 # 64, the number of samples in sampling process\n",
320
+ " n_param = 2\n",
321
+ " guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
322
+ " drop_prob = 0#0.28 # only takes effect when guide_w != -1\n",
323
+ " ema=True # whether to use ema\n",
324
+ " ema_rate=0.995\n",
325
+ "\n",
326
+ " # seed = 0\n",
327
+ " # save_dir = './outputs/'\n",
328
+ "\n",
329
+ " save_freq = 0#.1 # the period of sampling\n",
330
+ " # general parameters for the name and logger \n",
331
+ " # device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
332
+ " lrate = 1e-4\n",
333
+ " lr_warmup_steps = 0#5#00\n",
334
+ " output_dir = \"./outputs/\"\n",
335
+ " save_name = os.path.join(output_dir, 'model_state')\n",
336
+ " # save_freq = 1 #10 # the period of saving model\n",
337
+ " # cond = True # if training using the conditional information\n",
338
+ " # lr_decay = False #True# if using the learning rate decay\n",
339
+ " resume = save_name # if resume from the trained checkpoints\n",
340
+ " # params_single = torch.tensor([0.2,0.80000023])\n",
341
+ " # params = torch.tile(params_single,(n_sample,1)).to(device)\n",
342
+ " # params = params\n",
343
+ " # data_dir = './data' # data directory\n",
344
+ "\n",
345
+ "\n",
346
+ " mixed_precision = \"fp16\"\n",
347
+ " gradient_accumulation_steps = 1\n",
348
+ "\n",
349
+ " # date = datetime.datetime.now().strftime(\"%m%d-%H%M\")\n",
350
+ " # run_name = f'{date}' # the unique name of each experiment\n",
351
+ "\n",
352
+ "# config = TrainConfig()\n",
353
+ "# print(\"device =\", config.device)"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": 7,
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": [
362
+ "# import os\n",
363
+ "# print(os.cpu_count())\n",
364
+ "# print(len(os.sched_getaffinity(0)))\n",
365
+ "# import torch\n",
366
+ "# data = torch.randn((64,64))\n",
367
+ "# print(data.dtype)"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": 8,
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": [
376
+ "# @dataclass\n",
377
+ "class DDPM21CM:\n",
378
+ " def __init__(self, config):\n",
379
+ " # config = TrainConfig()\n",
380
+ " # date = datetime.datetime.now().strftime(\"%m%d-%H%M\")\n",
381
+ " config.run_name = datetime.datetime.now().strftime(\"%m%d-%H%M\") # the unique name of each experiment\n",
382
+ " self.config = config\n",
383
+ " # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
384
+ " # # self.shape_loaded = dataset.images.shape\n",
385
+ " # # print(\"shape_loaded =\", self.shape_loaded)\n",
386
+ " # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)\n",
387
+ " # del dataset\n",
388
+ " self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n",
389
+ "\n",
390
+ " # initialize the unet\n",
391
+ " self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)\n",
392
+ "\n",
393
+ " if config.resume and os.path.exists(config.resume):\n",
394
+ " # resume_file = os.path.join(config.output_dir, f\"{config.resume}\")\n",
395
+ " self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])\n",
396
+ " print(f\"resumed nn_model from {config.resume}\")\n",
397
+ " # nn_model = ContextUnet(n_param=1, image_size=28)\n",
398
+ " self.nn_model.train()\n",
399
+ " self.nn_model.to(self.ddpm.device)\n",
400
+ " # print(\"nn_model.device =\", ddpm.device)\n",
401
+ " # number of parameters to be trained\n",
402
+ " self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())\n",
403
+ " print(f\"Number of parameters for nn_model: {self.number_of_params}\")\n",
404
+ "\n",
405
+ " # whether to use ema\n",
406
+ " if config.ema:\n",
407
+ " self.ema = EMA(config.ema_rate)\n",
408
+ " if config.resume and os.path.exists(config.resume):\n",
409
+ " self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n",
410
+ " self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])\n",
411
+ " print(f\"resumed ema_model from {config.resume}\")\n",
412
+ " else:\n",
413
+ " self.ema_model = copy.deepcopy(self.nn_model).eval().requires_grad_(False)\n",
414
+ "\n",
415
+ " self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate)\n",
416
+ " self.lr_scheduler = get_cosine_schedule_with_warmup(\n",
417
+ " optimizer=self.optimizer,\n",
418
+ " num_warmup_steps=config.lr_warmup_steps,\n",
419
+ " num_training_steps=(int(config.num_image/config.batch_size) * config.n_epoch),\n",
420
+ " # num_training_steps=(len(self.dataloader) * config.n_epoch),\n",
421
+ " )\n",
422
+ "\n",
423
+ " self.ranges_dict = config.ranges_dict\n",
424
+ "\n",
425
+ " def load(self):\n",
426
+ " dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)\n",
427
+ " # self.shape_loaded = dataset.images.shape\n",
428
+ " # print(\"shape_loaded =\", self.shape_loaded)\n",
429
+ " self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=len(os.sched_getaffinity(0)), pin_memory=True)\n",
430
+ " # del dataset\n",
431
+ " # self.accelerate(self.config)\n",
432
+ " del dataset\n",
433
+ "\n",
434
+ " # def accelerate(self):\n",
435
+ "\n",
436
+ " def train(self):\n",
437
+ " ################### \n",
438
+ " ## training loop ##\n",
439
+ " ###################\n",
440
+ " # plot_unet = True\n",
441
+ "\n",
442
+ " self.load()\n",
443
+ " self.accelerator = Accelerator(\n",
444
+ " mixed_precision=self.config.mixed_precision,\n",
445
+ " gradient_accumulation_steps=self.config.gradient_accumulation_steps,\n",
446
+ " log_with=\"tensorboard\",\n",
447
+ " project_dir=os.path.join(self.config.output_dir, \"logs\"),\n",
448
+ " )\n",
449
+ " print(\"self.accelerator.is_main_process:\", self.accelerator.is_main_process)\n",
450
+ " if self.accelerator.is_main_process:\n",
451
+ " if self.config.output_dir is not None:\n",
452
+ " os.makedirs(self.config.output_dir, exist_ok=True)\n",
453
+ " if self.config.push_to_hub:\n",
454
+ " self.repo_id = create_repo(\n",
455
+ " repo_id=self.config.hub_model_id or Path(self.config.output_dir).name, exist_ok=True\n",
456
+ " ).repo_id\n",
457
+ " self.accelerator.init_trackers(f\"{self.config.run_name}\")\n",
458
+ "\n",
459
+ " self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \\\n",
460
+ " self.accelerator.prepare(\n",
461
+ " self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler\n",
462
+ " )\n",
463
+ " \n",
464
+ " global_step = 0\n",
465
+ " for ep in range(self.config.n_epoch):\n",
466
+ " self.ddpm.train()\n",
467
+ "\n",
468
+ " pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)\n",
469
+ " pbar_train.set_description(f\"Epoch {ep}\")\n",
470
+ " for i, (x, c) in enumerate(self.dataloader):\n",
471
+ " with self.accelerator.accumulate(self.nn_model):\n",
472
+ " x = x.to(self.config.device)\n",
473
+ " xt, noise, ts = self.ddpm.add_noise(x)\n",
474
+ " \n",
475
+ " if self.config.guide_w == -1:\n",
476
+ " noise_pred = self.nn_model(xt, ts)\n",
477
+ " else:\n",
478
+ " c = c.to(self.config.device)\n",
479
+ " noise_pred = self.nn_model(xt, ts, c)\n",
480
+ " \n",
481
+ " loss = F.mse_loss(noise, noise_pred)\n",
482
+ " self.accelerator.backward(loss)\n",
483
+ " self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)\n",
484
+ " self.optimizer.step()\n",
485
+ " self.lr_scheduler.step()\n",
486
+ " self.optimizer.zero_grad()\n",
487
+ "\n",
488
+ " # ema update\n",
489
+ " if self.config.ema:\n",
490
+ " self.ema.step_ema(self.ema_model, self.nn_model)\n",
491
+ "\n",
492
+ " pbar_train.update(1)\n",
493
+ " logs = dict(\n",
494
+ " loss=loss.detach().item(),\n",
495
+ " lr=self.optimizer.param_groups[0]['lr'],\n",
496
+ " step=global_step\n",
497
+ " )\n",
498
+ " pbar_train.set_postfix(**logs)\n",
499
+ "\n",
500
+ " self.accelerator.log(logs, step=global_step)\n",
501
+ " global_step += 1\n",
502
+ "\n",
503
+ " # if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:\n",
504
+ " self.save(ep)\n",
505
+ "\n",
506
+ " del self.nn_model\n",
507
+ " if self.config.ema:\n",
508
+ " del self.ema_model\n",
509
+ " torch.cuda.empty_cache()\n",
510
+ "\n",
511
+ " def save(self, ep):\n",
512
+ " # save model\n",
513
+ " if self.accelerator.is_main_process:\n",
514
+ " if ep == self.config.n_epoch-1 or (ep+1)*self.config.save_freq==1:\n",
515
+ " self.nn_model.eval()\n",
516
+ " with torch.no_grad():\n",
517
+ " if self.config.push_to_hub:\n",
518
+ " upload_folder(\n",
519
+ " repo_id = self.repo_id,\n",
520
+ " folder_path = \".\",#config.output_dir,\n",
521
+ " commit_message = f\"{self.config.run_name}\",\n",
522
+ " ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\", \"__pycache__\"],\n",
523
+ " )\n",
524
+ " if self.config.save_name:\n",
525
+ " model_state = {\n",
526
+ " 'epoch': ep,\n",
527
+ " 'unet_state_dict': self.nn_model.state_dict(),\n",
528
+ " 'ema_unet_state_dict': self.ema_model.state_dict(),\n",
529
+ " }\n",
530
+ " torch.save(model_state, self.config.save_name+f\"-N{self.config.num_image}\")\n",
531
+ " print('saved model at ' + self.config.save_name+f\"-N{self.config.num_image}\")\n",
532
+ " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
533
+ "\n",
534
+ " # def rescale(self, value, type='params', to_ranges=[0,1]):\n",
535
+ " # for i, from_ranges in self.ranges_dict[type].items():\n",
536
+ " # value[i] = (value[i] - from_ranges[0])/(from_ranges[1]-from_ranges[0]) # normalize\n",
537
+ " # value[i] = \n",
538
+ " def rescale(self, value, ranges, to: list):\n",
539
+ " if value.ndim == 1:\n",
540
+ " value = value.view(-1,len(value))\n",
541
+ " \n",
542
+ " for i in range(np.shape(value)[1]):\n",
543
+ " value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0])\n",
544
+ " # print(f\"i = {i}, value.min = {value[:,i].min()}, value.max = {value[:,i].max()}\")\n",
545
+ " value = value * (to[1]-to[0]) + to[0]\n",
546
+ " return value \n",
547
+ "\n",
548
+ " def sample(self, file, params:torch.tensor=None, repeat=192, ema=False, entire=False):\n",
549
+ " # n_sample = params.shape[0]\n",
550
+ " \n",
551
+ " if params is None:\n",
552
+ " params = torch.tensor([0.20000000000000018, 0.5055875000000001])\n",
553
+ " params_backup = params.numpy().copy()\n",
554
+ " else:\n",
555
+ " params_backup = params.numpy().copy()\n",
556
+ " params = self.rescale(params, self.ranges_dict['params'], to=[0,1])\n",
557
+ "\n",
558
+ " print(f\"sampling {repeat} images with normalized params = {params}\")\n",
559
+ " params = params.repeat(repeat,1)\n",
560
+ " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
561
+ " # print(\"params =\", params)\n",
562
+ " # print(\"params =\", params)\n",
563
+ " # print(\"len(params) =\", len(params))\n",
564
+ " # model = self.ema_model if ema else self.nn_model\n",
565
+ " # del self.ema_model, self.nn\n",
566
+ " # params = torch.tile(params, (n_sample,1)).to(device)\n",
567
+ "\n",
568
+ " nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)\n",
569
+ " if ema:\n",
570
+ " nn_model.load_state_dict(torch.load(file)['ema_unet_state_dict'])\n",
571
+ " else:\n",
572
+ " nn_model.load_state_dict(torch.load(file)['unet_state_dict'])\n",
573
+ " print(f\"nn_model resumed from {file}\")\n",
574
+ " # nn_model = ContextUnet(n_param=1, image_size=28)\n",
575
+ " # nn_model.train()\n",
576
+ " nn_model.to(self.ddpm.device)\n",
577
+ " nn_model.eval()\n",
578
+ "\n",
579
+ " # self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n",
580
+ " # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
581
+ " # print(f\"resumed ema_model from {config.resume}\")\n",
582
+ "\n",
583
+ " with torch.no_grad():\n",
584
+ " x_last, x_entire = self.ddpm.sample(\n",
585
+ " nn_model=nn_model, \n",
586
+ " params=params.to(self.config.device), \n",
587
+ " device=self.config.device, \n",
588
+ " guide_w=self.config.guide_w\n",
589
+ " )\n",
590
+ "\n",
591
+ " # np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else ''}.npy\"), x_last)\n",
592
+ " np.save(os.path.join(self.config.output_dir, f\"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}.npy\"), x_last)\n",
593
+ "\n",
594
+ " if entire:\n",
595
+ " np.save(os.path.join(self.config.output_dir, f\"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}_entire.npy\"), x_last)\n",
596
+ "# print(\"device =\", config.device)"
597
+ ]
598
+ },
599
+ {
600
+ "cell_type": "code",
601
+ "execution_count": 12,
602
+ "metadata": {},
603
+ "outputs": [
604
+ {
605
+ "name": "stderr",
606
+ "output_type": "stream",
607
+ "text": [
608
+ "Traceback (most recent call last):\n",
609
+ " File \"<string>\", line 1, in <module>\n",
610
+ " File \"/storage/home/hcoda1/3/bxia34/.conda/envs/diffusers/lib/python3.9/multiprocessing/spawn.py\", line 116, in spawn_main\n",
611
+ " exitcode = _main(fd, parent_sentinel)\n",
612
+ " File \"/storage/home/hcoda1/3/bxia34/.conda/envs/diffusers/lib/python3.9/multiprocessing/spawn.py\", line 126, in _main\n",
613
+ " self = reduction.pickle.load(from_parent)\n",
614
+ "AttributeError: Can't get attribute 'single_main' on <module '__main__' (built-in)>\n"
615
+ ]
616
+ },
617
+ {
618
+ "ename": "ProcessExitedException",
619
+ "evalue": "process 0 terminated with exit code 1",
620
+ "output_type": "error",
621
+ "traceback": [
622
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
623
+ "\u001b[0;31mProcessExitedException\u001b[0m Traceback (most recent call last)",
624
+ "Cell \u001b[0;32mIn[12], line 21\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m__name__\u001b[39m \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m__main__\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 17\u001b[0m \u001b[39m# torch.multiprocessing.set_start_method(\"spawn\")\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[39m# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\u001b[39;00m\n\u001b[1;32m 19\u001b[0m world_size \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\u001b[39m#torch.cuda.device_count()\u001b[39;00m\n\u001b[0;32m---> 21\u001b[0m mp\u001b[39m.\u001b[39;49mspawn(single_main, args\u001b[39m=\u001b[39;49m(world_size,), nprocs\u001b[39m=\u001b[39;49mworld_size)\n\u001b[1;32m 22\u001b[0m \u001b[39m# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')\u001b[39;00m\n",
625
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:240\u001b[0m, in \u001b[0;36mspawn\u001b[0;34m(fn, args, nprocs, join, daemon, start_method)\u001b[0m\n\u001b[1;32m 236\u001b[0m msg \u001b[39m=\u001b[39m (\u001b[39m'\u001b[39m\u001b[39mThis method only supports start_method=spawn (got: \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m).\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[1;32m 237\u001b[0m \u001b[39m'\u001b[39m\u001b[39mTo use a different start_method use:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[1;32m 238\u001b[0m \u001b[39m'\u001b[39m\u001b[39m torch.multiprocessing.start_processes(...)\u001b[39m\u001b[39m'\u001b[39m \u001b[39m%\u001b[39m start_method)\n\u001b[1;32m 239\u001b[0m warnings\u001b[39m.\u001b[39mwarn(msg)\n\u001b[0;32m--> 240\u001b[0m \u001b[39mreturn\u001b[39;00m start_processes(fn, args, nprocs, join, daemon, start_method\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mspawn\u001b[39;49m\u001b[39m'\u001b[39;49m)\n",
626
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:198\u001b[0m, in \u001b[0;36mstart_processes\u001b[0;34m(fn, args, nprocs, join, daemon, start_method)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[39mreturn\u001b[39;00m context\n\u001b[1;32m 197\u001b[0m \u001b[39m# Loop on join until it returns True or raises an exception.\u001b[39;00m\n\u001b[0;32m--> 198\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mnot\u001b[39;00m context\u001b[39m.\u001b[39;49mjoin():\n\u001b[1;32m 199\u001b[0m \u001b[39mpass\u001b[39;00m\n",
627
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:149\u001b[0m, in \u001b[0;36mProcessContext.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[39mraise\u001b[39;00m ProcessExitedException(\n\u001b[1;32m 141\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mprocess \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with signal \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m\n\u001b[1;32m 142\u001b[0m (error_index, name),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 146\u001b[0m signal_name\u001b[39m=\u001b[39mname\n\u001b[1;32m 147\u001b[0m )\n\u001b[1;32m 148\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 149\u001b[0m \u001b[39mraise\u001b[39;00m ProcessExitedException(\n\u001b[1;32m 150\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mprocess \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with exit code \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m\n\u001b[1;32m 151\u001b[0m (error_index, exitcode),\n\u001b[1;32m 152\u001b[0m error_index\u001b[39m=\u001b[39merror_index,\n\u001b[1;32m 153\u001b[0m error_pid\u001b[39m=\u001b[39mfailed_process\u001b[39m.\u001b[39mpid,\n\u001b[1;32m 154\u001b[0m exit_code\u001b[39m=\u001b[39mexitcode\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 157\u001b[0m original_trace \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39merror_queues[error_index]\u001b[39m.\u001b[39mget()\n\u001b[1;32m 158\u001b[0m msg \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\n\u001b[39;00m\u001b[39m-- Process \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with the following error:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m error_index\n",
628
+ "\u001b[0;31mProcessExitedException\u001b[0m: process 0 terminated with exit code 1"
629
+ ]
630
+ }
631
+ ],
632
+ "source": [
633
+ "def single_main(rank, world_size):\n",
634
+ " config = TrainConfig()\n",
635
+ " ddp_setup(rank, world_size)\n",
636
+ " \n",
637
+ " num_image_list = [100]#[200]#[1600,3200,6400,12800,25600]\n",
638
+ " for i, num_image in enumerate(num_image_list):\n",
639
+ " config.num_image = num_image\n",
640
+ " # config.world_size = world_size\n",
641
+ " \n",
642
+ " ddpm21cm = DDPM21CM(config)\n",
643
+ " print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
644
+ " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
645
+ " ddpm21cm.train()\n",
646
+ "\n",
647
+ " \n",
648
+ "if __name__ == \"__main__\":\n",
649
+ " # torch.multiprocessing.set_start_method(\"spawn\")\n",
650
+ " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
651
+ " world_size = 1#torch.cuda.device_count()\n",
652
+ "\n",
653
+ " mp.spawn(single_main, args=(world_size,), nprocs=world_size)\n",
654
+ " # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
655
+ ]
656
+ },
657
+ {
658
+ "cell_type": "code",
659
+ "execution_count": null,
660
+ "metadata": {},
661
+ "outputs": [],
662
+ "source": [
663
+ "# torch.cuda.set_device(0)"
664
+ ]
665
+ },
666
+ {
667
+ "cell_type": "code",
668
+ "execution_count": null,
669
+ "metadata": {},
670
+ "outputs": [
671
+ {
672
+ "name": "stdout",
673
+ "output_type": "stream",
674
+ "text": [
675
+ "True\n",
676
+ "2\n",
677
+ "['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__path__', '__file__', '__cached__', '__builtins__', '__annotations__', 'contextlib', 'os', 'torch', 'Device', 'traceback', 'warnings', 'threading', 'List', 'Optional', 'Tuple', 'Union', 'Any', '_utils', '_get_device_index', '_dummy_type', 'classproperty', 'graphs', 'CUDAGraph', 'graph_pool_handle', 'graph', 'make_graphed_callables', 'is_current_stream_capturing', 'streams', 'ExternalStream', 'Stream', 'Event', '_device', '_cudart', '_initialized', '_tls', '_initialization_lock', '_queued_calls', '_is_in_bad_fork', '_device_t', '_LazySeedTracker', '_lazy_seed_tracker', '_CudaDeviceProperties', 'has_magma', 'has_half', 'default_generators', 'is_available', 'is_bf16_supported', '_sleep', '_check_capability', '_check_cubins', 'is_initialized', '_lazy_call', 'DeferredCudaCallError', 'init', '_lazy_init', 'cudart', 'cudaStatus', 'CudaError', 'check_error', 'device', 'device_of', 'set_device', 'get_device_name', 'get_device_capability', 'get_device_properties', 'can_device_access_peer', 'StreamContext', 'stream', 'set_stream', 'device_count', 'get_arch_list', 'get_gencode_flags', 'current_device', 'synchronize', 'ipc_collect', 'current_stream', 'default_stream', 'current_blas_handle', 'set_sync_debug_mode', 'get_sync_debug_mode', 'memory_usage', 'utilization', 'memory', 'caching_allocator_alloc', 'caching_allocator_delete', 'set_per_process_memory_fraction', 'empty_cache', 'memory_stats', 'memory_stats_as_nested_dict', 'reset_accumulated_memory_stats', 'reset_peak_memory_stats', 'reset_max_memory_allocated', 'reset_max_memory_cached', 'memory_allocated', 'max_memory_allocated', 'memory_reserved', 'max_memory_reserved', 'memory_cached', 'max_memory_cached', 'memory_snapshot', 'memory_summary', 'list_gpu_processes', 'mem_get_info', 'random', 'get_rng_state', 'get_rng_state_all', 'set_rng_state', 'set_rng_state_all', 'manual_seed', 'manual_seed_all', 'seed', 'seed_all', 'initial_seed', '_lazy_new', '_CudaBase', 'ByteStorage', 'DoubleStorage', 'FloatStorage', 'HalfStorage', 'LongStorage', 'IntStorage', 'ShortStorage', 'CharStorage', 'BoolStorage', 'BFloat16Storage', 'ComplexDoubleStorage', 'ComplexFloatStorage', 'sparse', 'profiler', 'nvtx', 'amp', 'jiterator', 'ByteTensor', 'CharTensor', 'DoubleTensor', 'FloatTensor', 'IntTensor', 'LongTensor', 'ShortTensor', 'HalfTensor', 'BoolTensor', 'BFloat16Tensor', 'nccl', '_get_device_properties']\n"
678
+ ]
679
+ }
680
+ ],
681
+ "source": [
682
+ "print(torch.cuda.is_available())\n",
683
+ "print(torch.cuda.device_count())\n",
684
+ "print(torch.cuda.__dir__())"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": 17,
690
+ "metadata": {},
691
+ "outputs": [
692
+ {
693
+ "name": "stdout",
694
+ "output_type": "stream",
695
+ "text": [
696
+ "True\n",
697
+ "<class 'torch.cuda.device'>\n",
698
+ "Quadro RTX 6000\n",
699
+ "0\n",
700
+ "(7, 5)\n",
701
+ "_CudaDeviceProperties(name='Quadro RTX 6000', major=7, minor=5, total_memory=24212MB, multi_processor_count=72)\n"
702
+ ]
703
+ }
704
+ ],
705
+ "source": [
706
+ "print(torch.cuda.is_initialized())\n",
707
+ "print(torch.cuda.device)\n",
708
+ "print(torch.cuda.get_device_name())\n",
709
+ "print(torch.cuda.current_device())\n",
710
+ "print(torch.cuda.get_device_capability())\n",
711
+ "print(torch.cuda.get_device_properties(torch.cuda.device))\n",
712
+ "# print('here')\n",
713
+ "# print(torch.cuda.memory_usage())\n",
714
+ "# print(torch.cuda.utilization())\n",
715
+ "# print(torch.cuda.memory())\n",
716
+ "# print('here')\n",
717
+ "# print(torch.cuda.memory_summary())"
718
+ ]
719
+ },
720
+ {
721
+ "attachments": {},
722
+ "cell_type": "markdown",
723
+ "metadata": {},
724
+ "source": [
725
+ "# Sampling"
726
+ ]
727
+ },
728
+ {
729
+ "cell_type": "code",
730
+ "execution_count": null,
731
+ "metadata": {},
732
+ "outputs": [],
733
+ "source": [
734
+ "if __name__ == \"__main__\":\n",
735
+ " # num_image_list = [1600,3200,6400,12800,25600]\n",
736
+ " num_image_list = [1000]\n",
737
+ " # num_image_list = [3200,6400,12800,25600]\n",
738
+ " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
739
+ " repeat = 2\n",
740
+ " config = TrainConfig()\n",
741
+ " for i, num_image in enumerate(num_image_list):\n",
742
+ " config.num_image = num_image\n",
743
+ " ddpm21cm = DDPM21CM(config)\n",
744
+ "\n",
745
+ " ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor([4.4, 131.341]), repeat=repeat)\n",
746
+ "\n",
747
+ " # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((5.6, 19.037)), repeat=repeat)\n",
748
+ "\n",
749
+ " # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((4.699, 30)), repeat=repeat)\n",
750
+ "\n",
751
+ " # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((5.477, 200)), repeat=repeat)\n",
752
+ "\n",
753
+ " # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((4.8, 131.341)), repeat=repeat)"
754
+ ]
755
+ },
756
+ {
757
+ "cell_type": "code",
758
+ "execution_count": 13,
759
+ "metadata": {},
760
+ "outputs": [],
761
+ "source": [
762
+ "# ls -lth outputs | head"
763
+ ]
764
+ },
765
+ {
766
+ "cell_type": "code",
767
+ "execution_count": null,
768
+ "metadata": {},
769
+ "outputs": [],
770
+ "source": [
771
+ "def plot_grid(samples, c=None, row=1, col=2):\n",
772
+ " print(\"samples.shape =\", samples.shape)\n",
773
+ " for j in range(samples.shape[4]):\n",
774
+ " plt.figure(figsize = (12,6), dpi=400)\n",
775
+ " for i in range(len(samples)):\n",
776
+ " plt.subplot(row,col,i+1)\n",
777
+ " plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)\n",
778
+ " plt.xticks([])\n",
779
+ " plt.yticks([])\n",
780
+ " # plt.suptitle(f\"ION_Tvir_MIN = {c[0][0]}, HII_EFF_FACTOR = {c[0][1]}\")\n",
781
+ " # plt.show()\n",
782
+ " # plt.suptitle('simulations')\n",
783
+ " plt.tight_layout()\n",
784
+ " plt.subplots_adjust(wspace=0, hspace=0)\n",
785
+ " plt.savefig(f\"test3D-{j:03d}.png\")\n",
786
+ " plt.close()\n",
787
+ " # plt.show()\n",
788
+ " \n",
789
+ "data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\")\n",
790
+ "# print(data.shape)\n",
791
+ "plot_grid(data)\n",
792
+ "# plt.imshow(data)"
793
+ ]
794
+ },
795
+ {
796
+ "cell_type": "code",
797
+ "execution_count": null,
798
+ "metadata": {},
799
+ "outputs": [],
800
+ "source": [
801
+ "# config = TrainConfig()\n",
802
+ "# def plot(filename, row=4, col=6):\n",
803
+ "# samples = np.load(filename)\n",
804
+ "# params = filename.split('guide_w')[-1][:-4]\n",
805
+ "# print(\"plotting\", samples.shape, params)\n",
806
+ "# plt.figure(figsize = (8,8))\n",
807
+ "# for i in range(24):\n",
808
+ "# plt.subplot(row,col,i+1)\n",
809
+ "# plt.imshow(samples[i,0,:,:], cmap='gray')#, vmin=-1, vmax=1)\n",
810
+ "# plt.xticks([])\n",
811
+ "# plt.yticks([])\n",
812
+ "# # plt.show()\n",
813
+ "# plt.suptitle(params)\n",
814
+ "# plt.tight_layout()\n",
815
+ "# plt.subplots_adjust(wspace=0, hspace=0) \n",
816
+ "# plt.show()\n",
817
+ "# # plt.savefig('outputs/'+params+'.png')\n",
818
+ "# # plt.close()\n",
819
+ "# # plt.imshow(images[0,0])\n",
820
+ "# # plt.show()"
821
+ ]
822
+ },
823
+ {
824
+ "cell_type": "code",
825
+ "execution_count": 1,
826
+ "metadata": {},
827
+ "outputs": [],
828
+ "source": [
829
+ "# import torch\n",
830
+ "# print(torch.__version__)"
831
+ ]
832
+ },
833
+ {
834
+ "cell_type": "code",
835
+ "execution_count": 9,
836
+ "metadata": {},
837
+ "outputs": [],
838
+ "source": [
839
+ "# import torch\n",
840
+ "# import os\n",
841
+ "\n",
842
+ "# def compare_models(num_gpus):\n",
843
+ "# model_states = []\n",
844
+ " \n",
845
+ "# for gpu_id in range(num_gpus):\n",
846
+ "# filename = f\"outputs/model_state-N40-device{gpu_id}\"\n",
847
+ "# if os.path.exists(filename):\n",
848
+ "# state_dict = torch.load(filename, map_location='cpu')\n",
849
+ "# model_states.append(state_dict)\n",
850
+ "# print(filename)\n",
851
+ "# else:\n",
852
+ "# print(f\"File {filename} not found!\")\n",
853
+ "# return False\n",
854
+ " \n",
855
+ "# # Compare all model state_dicts\n",
856
+ "# print(\"len(model_states) =\", len(model_states))\n",
857
+ "# base_state = model_states[0]\n",
858
+ "# for state in model_states[1:]:\n",
859
+ "# for key in base_state.keys():\n",
860
+ "# # print(key, base_state[key], state[key])\n",
861
+ "# print(key)\n",
862
+ "# print(\"epoch\", base_state['epoch'], state['epoch'])\n",
863
+ "\n",
864
+ "# print(base_state['unet_state_dict'].keys())\n",
865
+ "# for key in base_state['unet_state_dict']:\n",
866
+ "# # print(key)\n",
867
+ "# if not torch.equal(base_state['unet_state_dict'][key], state['unet_state_dict'][key]):\n",
868
+ "# print(\"different\")\n",
869
+ "# return \n",
870
+ "# # else:\n",
871
+ "# print(\"exactly same\")\n",
872
+ "\n",
873
+ "# # if key == 'epoch':\n",
874
+ "# # print(base_state[key], state[key])\n",
875
+ "# # else:\n",
876
+ "# # print(base_state[key], state[key])\n",
877
+ "# # if not torch.equal(base_state[key], state[key]):\n",
878
+ "# # # if not (base_state[key] == state[key]):\n",
879
+ "# # print(f\"Mismatch found in parameter {key}\")\n",
880
+ "# # return False\n",
881
+ " \n",
882
+ "# # print(\"All models are identical!\")\n",
883
+ "# # return True\n",
884
+ "\n",
885
+ "# if __name__ == \"__main__\":\n",
886
+ "# # epoch_to_check = 0 # specify the epoch you want to check\n",
887
+ "# num_gpus = torch.cuda.device_count() # specify the number of GPUs used in training\n",
888
+ "# compare_models(num_gpus)"
889
+ ]
890
+ },
891
+ {
892
+ "cell_type": "code",
893
+ "execution_count": 6,
894
+ "metadata": {},
895
+ "outputs": [],
896
+ "source": [
897
+ "import numpy as np\n",
898
+ "test = np.random.normal(0,1,(800,1,64,64,512))"
899
+ ]
900
+ },
901
+ {
902
+ "cell_type": "code",
903
+ "execution_count": 7,
904
+ "metadata": {},
905
+ "outputs": [
906
+ {
907
+ "data": {
908
+ "text/plain": [
909
+ "12.5"
910
+ ]
911
+ },
912
+ "execution_count": 7,
913
+ "metadata": {},
914
+ "output_type": "execute_result"
915
+ }
916
+ ],
917
+ "source": [
918
+ "(test.itemsize*test.size) / 1024/1024/1024"
919
+ ]
920
+ },
921
+ {
922
+ "cell_type": "code",
923
+ "execution_count": 8,
924
+ "metadata": {},
925
+ "outputs": [],
926
+ "source": [
927
+ "del test"
928
+ ]
929
+ }
930
+ ],
931
+ "metadata": {
932
+ "kernelspec": {
933
+ "display_name": "base",
934
+ "language": "python",
935
+ "name": "python3"
936
+ },
937
+ "language_info": {
938
+ "codemirror_mode": {
939
+ "name": "ipython",
940
+ "version": 3
941
+ },
942
+ "file_extension": ".py",
943
+ "mimetype": "text/x-python",
944
+ "name": "python",
945
+ "nbconvert_exporter": "python",
946
+ "pygments_lexer": "ipython3",
947
+ "version": "3.9.19"
948
+ },
949
+ "orig_nbformat": 4
950
+ },
951
+ "nbformat": 4,
952
+ "nbformat_minor": 2
953
+ }
diffusion.py CHANGED
@@ -239,7 +239,7 @@ class TrainConfig:
239
  stride = (2,2) if dim == 2 else (2,2,4)
240
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
241
  batch_size = 10#1#2#50#20#2#100 # 10
242
- n_epoch = 5#4# 10#50#20#20#2#5#25 # 120
243
  HII_DIM = 64
244
  num_redshift = 512#128#64#512#256#256#64#512#128
245
  channel = 1
@@ -266,7 +266,7 @@ class TrainConfig:
266
  # seed = 0
267
  # save_dir = './outputs/'
268
 
269
- save_period = 1#np.infty#.1 # the period of sampling
270
  # general parameters for the name and logger
271
  # device = "cuda" if torch.cuda.is_available() else "cpu"
272
  lrate = 1e-4
@@ -580,13 +580,16 @@ class DDPM21CM:
580
  # else:
581
  return x_last
582
  # %%
 
 
 
583
  def train(rank, world_size):
584
  config = TrainConfig()
585
  config.world_size = world_size
586
 
587
  ddp_setup(rank, world_size)
588
 
589
- num_train_image_list = [200]#[3200]#[200]#[1600,3200,6400,12800,25600]
590
  for i, num_image in enumerate(num_train_image_list):
591
  config.num_image = num_image
592
  # config.world_size = world_size
@@ -607,20 +610,6 @@ if __name__ == "__main__":
607
  mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
608
  # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
609
 
610
- # %%
611
- # print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
612
- # print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
613
- # print("torch.cuda.current_device() =", torch.cuda.current_device())
614
- # print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())
615
- # print("torch.cuda.get_device_properties(torch.cuda.device) =", torch.cuda.get_device_properties(torch.cuda.device))
616
- # print('here')
617
- # print(torch.cuda.memory_usage())
618
- # print(torch.cuda.utilization())
619
- # print(torch.cuda.memory())
620
- # print('here')
621
- # print(torch.cuda.memory_summary())
622
- # %% [markdown]
623
- # # Sampling
624
 
625
  # %%
626
 
@@ -677,7 +666,7 @@ if __name__ == "__main__":
677
  world_size = torch.cuda.device_count()
678
  print(f" sampling, world_size = {world_size} ".center(100,'-'))
679
  # num_train_image_list = [1600,3200,6400,12800,25600]
680
- num_train_image_list = [200]
681
  num_new_img_per_gpu = 40
682
  max_num_img_per_gpu = 20
683
 
 
239
  stride = (2,2) if dim == 2 else (2,2,4)
240
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
241
  batch_size = 10#1#2#50#20#2#100 # 10
242
+ n_epoch = 10#5#4# 10#50#20#20#2#5#25 # 120
243
  HII_DIM = 64
244
  num_redshift = 512#128#64#512#256#256#64#512#128
245
  channel = 1
 
266
  # seed = 0
267
  # save_dir = './outputs/'
268
 
269
+ save_period = np.infty#.1 # the period of sampling
270
  # general parameters for the name and logger
271
  # device = "cuda" if torch.cuda.is_available() else "cpu"
272
  lrate = 1e-4
 
580
  # else:
581
  return x_last
582
  # %%
583
+
584
+ num_train_image_list = [5000]
585
+
586
  def train(rank, world_size):
587
  config = TrainConfig()
588
  config.world_size = world_size
589
 
590
  ddp_setup(rank, world_size)
591
 
592
+ #[3200]#[200]#[1600,3200,6400,12800,25600]
593
  for i, num_image in enumerate(num_train_image_list):
594
  config.num_image = num_image
595
  # config.world_size = world_size
 
610
  mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
611
  # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
 
614
  # %%
615
 
 
666
  world_size = torch.cuda.device_count()
667
  print(f" sampling, world_size = {world_size} ".center(100,'-'))
668
  # num_train_image_list = [1600,3200,6400,12800,25600]
669
+ # num_train_image_list = [5000]
670
  num_new_img_per_gpu = 40
671
  max_num_img_per_gpu = 20
672