0718-2225
Browse files- backup_diffusion.ipynb +953 -0
- diffusion.py +7 -18
backup_diffusion.ipynb
ADDED
|
@@ -0,0 +1,953 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"attachments": {},
|
| 5 |
+
"cell_type": "markdown",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"## 改編ContextUnet及相關代碼,使其首先對二維的情況適用。並於diffusers.Unet2DModel作比較並加以優化。最後再改寫爲3維的情形。\n",
|
| 9 |
+
"- 經試用diffusers的Unet2DModel,發現loss從0.3降到0.2但仍然很高,説明存在非Unet2DModel的問題可以優化\n",
|
| 10 |
+
"- 改用diffusers的DDMPScheduler和DDPMPipeline后,loss降低至0.1以下,有時甚至可以低至0.004,可見我的代碼問題主要出在DDPM部分。DDPMScheduler部分比較簡短,似乎沒有問題,所以問題應該在DDPMPipeline裏某一部分代碼是我代碼欠缺的。\n",
|
| 11 |
+
"- 我在DDPMScheduler部分有一個typo,導致beta_t一直很小,修正后loss從0.2能降低至0.02, 維持在0.1以下\n",
|
| 12 |
+
"- 用diffusers的DDPMScheduler似乎效果要好一些,loss總是比我的DDPMScheduler要小一點。儅epoch為19時,前者的loss約0.02,後者loss約0.07。而且前者還支持3維圖像的加噪,不如直接用別人的輪子。但我想知道爲什麽我的loss會高一些。\n",
|
| 13 |
+
"- 我意識到別人的DDPMScheduler在sample函數中沒有兼容輸入參數,所以歸根結底還是需要我的DDPMscheduler。不過我可以先用別人的來debug我的ContextUnet.\n",
|
| 14 |
+
"- 我需要將我的ContextUnet擴展兼容不同維度的照片,畢竟我本身也需要和原文獻對比完了再拓展到三維的情形\n",
|
| 15 |
+
"- 我已將我的ContextUnet轉成了2維的模式,與diffusers.Unet2DModel的loss=0.037相比,我的Unet的loss=0.07。同時我的Unet生成的圖像看上去很奇怪,説明我的Unet也有問題。我需要將代碼退回原Unet,並檢查問題所在。\n",
|
| 16 |
+
"- 我將紅移方向的像素的數量限制在了64.以此比較兩個Unet的差別。經比較:\\\n",
|
| 17 |
+
"Unet2DModel loss:0.03, 0.0655, 0.05, 0.02, 0.05\\\n",
|
| 18 |
+
"ContextUnet loss: 0.1, 0.16, 0.1, 0.2186, 0.06\n",
|
| 19 |
+
"- 我把ContextUnet退回到了原作者的版本,結果loss=0.05,輸出的照片也不錯。我主要的改動是改回了他原用的normalization函數,其中還有個參數swish。有時間我可以研究一下具體是哪裏影響了訓練的結果。另外我發現了要想tensorboard的圖綫獨立美觀,需要把他們放在不同的文件夾下\n",
|
| 20 |
+
"- 經過驗證,GroupNorm比batchNorm效果要好\n",
|
| 21 |
+
"- 已擴展爲接受不同維度的情形\n",
|
| 22 |
+
"- 融合cond, guide_w, drop_out這些參數\n",
|
| 23 |
+
"- 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題\n",
|
| 24 |
+
"- 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形\n",
|
| 25 |
+
"- 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。\n",
|
| 26 |
+
"- 解決了,問題出在我忘了寫with torch.no_grad():\n",
|
| 27 |
+
"- 接下來就是生成800個lightcones,與此同時研究如何計算global signal以及power spectrum\n",
|
| 28 |
+
"- 儅訓練圖片的數量達到5000時,生成的圖片與檢測數據的相似程度很高\n",
|
| 29 |
+
"- it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.\n",
|
| 30 |
+
"- the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.\n",
|
| 31 |
+
"- In addtion, the performance of DDPM can looks better compared to computation-intensive simulations. "
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": 31,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"from dataclasses import dataclass\n",
|
| 41 |
+
"import h5py\n",
|
| 42 |
+
"import torch\n",
|
| 43 |
+
"import torch.nn as nn\n",
|
| 44 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
| 45 |
+
"# from datasets import Dataset\n",
|
| 46 |
+
"import matplotlib.pyplot as plt\n",
|
| 47 |
+
"import numpy as np\n",
|
| 48 |
+
"import random\n",
|
| 49 |
+
"# from abc import ABC, abstractmethod\n",
|
| 50 |
+
"import torch.nn.functional as F\n",
|
| 51 |
+
"import math\n",
|
| 52 |
+
"# from PIL import Image\n",
|
| 53 |
+
"import os\n",
|
| 54 |
+
"from torch.utils.tensorboard import SummaryWriter\n",
|
| 55 |
+
"import copy\n",
|
| 56 |
+
"from tqdm.auto import tqdm\n",
|
| 57 |
+
"# from torchvision import transforms\n",
|
| 58 |
+
"# from diffusers import UNet2DModel#, UNet3DConditionModel\n",
|
| 59 |
+
"# from diffusers import DDPMScheduler\n",
|
| 60 |
+
"from diffusers.utils import make_image_grid\n",
|
| 61 |
+
"import datetime\n",
|
| 62 |
+
"from pathlib import Path\n",
|
| 63 |
+
"from diffusers.optimization import get_cosine_schedule_with_warmup\n",
|
| 64 |
+
"from accelerate import notebook_launcher, Accelerator\n",
|
| 65 |
+
"from huggingface_hub import create_repo, upload_folder\n",
|
| 66 |
+
"from load_h5 import Dataset4h5\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"from context_unet import ContextUnet\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"from huggingface_hub import notebook_login\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"import torch.multiprocessing as mp\n",
|
| 73 |
+
"from torch.utils.data.distributed import DistributedSampler\n",
|
| 74 |
+
"from torch.nn.parallel import DistributedDataParallel as DDP\n",
|
| 75 |
+
"from torch.distributed import init_process_group, destroy_process_group"
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": 32,
|
| 81 |
+
"metadata": {},
|
| 82 |
+
"outputs": [],
|
| 83 |
+
"source": [
|
| 84 |
+
"def ddp_setup(rank: int, world_size: int):\n",
|
| 85 |
+
" \"\"\"\n",
|
| 86 |
+
" Args:\n",
|
| 87 |
+
" rank: Unique identifier of each process\n",
|
| 88 |
+
" world_size: Total number of processes\n",
|
| 89 |
+
" \"\"\"\n",
|
| 90 |
+
" os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
|
| 91 |
+
" os.environ[\"MASTER_PORT\"] = \"12355\"\n",
|
| 92 |
+
" torch.cuda.set_device(rank)\n",
|
| 93 |
+
" init_process_group(backend=\"nccl\", rank=rank, world_size=world_size)"
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"cell_type": "code",
|
| 98 |
+
"execution_count": 34,
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"outputs": [
|
| 101 |
+
{
|
| 102 |
+
"data": {
|
| 103 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 104 |
+
"model_id": "9bbf7e9db9ce426d9c59d6f6d8e8df29",
|
| 105 |
+
"version_major": 2,
|
| 106 |
+
"version_minor": 0
|
| 107 |
+
},
|
| 108 |
+
"text/plain": [
|
| 109 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"output_type": "display_data"
|
| 114 |
+
}
|
| 115 |
+
],
|
| 116 |
+
"source": [
|
| 117 |
+
"notebook_login()"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"attachments": {},
|
| 122 |
+
"cell_type": "markdown",
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"source": [
|
| 125 |
+
"# Add noise:\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"\\begin{align*}\n",
|
| 128 |
+
"x_t &\\sim \\mathcal N\\left(\\sqrt{1-\\beta_t}\\ x_{t-1},\\ \\beta_t \\right) \\\\\n",
|
| 129 |
+
"x_t &\\equiv \\sqrt{1-\\beta_t}\\ x_{t-1} + \\sqrt{\\beta_t}\\ \\epsilon\\\\\n",
|
| 130 |
+
"\\epsilon &\\sim \\mathcal N(0,1)\\\\\n",
|
| 131 |
+
"\\alpha_t & \\equiv 1 - \\beta_t\\\\\n",
|
| 132 |
+
"& ...\\\\\n",
|
| 133 |
+
"x_t &= \\sqrt{\\bar {\\alpha_t}} x_0 + \\epsilon\\ \\sqrt{1 - \\bar{\\alpha_t}}\\\\\n",
|
| 134 |
+
"\\bar {\\alpha_t} &\\equiv \\prod_{i=1}^t \\alpha_i\\\\\n",
|
| 135 |
+
"&= \\exp\\left({\\ln{\\prod_{i=1}^t \\alpha_i}}\\right)\\\\\n",
|
| 136 |
+
"&= \\exp\\left({\\sum_{i=1}^t\\ln{ \\alpha_i}}\\right)\n",
|
| 137 |
+
"\\end{align*}"
|
| 138 |
+
]
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"cell_type": "code",
|
| 142 |
+
"execution_count": 4,
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"outputs": [],
|
| 145 |
+
"source": [
|
| 146 |
+
"class DDPMScheduler(nn.Module):\n",
|
| 147 |
+
" def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu'):\n",
|
| 148 |
+
" super().__init__()\n",
|
| 149 |
+
" \n",
|
| 150 |
+
" beta_1, beta_T = betas\n",
|
| 151 |
+
" assert 0 < beta_1 <= beta_T <= 1, \"ensure 0 < beta_1 <= beta_T <= 1\"\n",
|
| 152 |
+
" self.device = device\n",
|
| 153 |
+
" self.num_timesteps = num_timesteps\n",
|
| 154 |
+
" self.img_shape = img_shape\n",
|
| 155 |
+
" self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1\n",
|
| 156 |
+
" self.beta_t = self.beta_t.to(self.device)\n",
|
| 157 |
+
"\n",
|
| 158 |
+
" # self.drop_prob = drop_prob\n",
|
| 159 |
+
" # self.cond = cond\n",
|
| 160 |
+
" self.alpha_t = 1 - self.beta_t\n",
|
| 161 |
+
" # self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))\n",
|
| 162 |
+
" self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)\n",
|
| 163 |
+
"\n",
|
| 164 |
+
" def add_noise(self, clean_images):\n",
|
| 165 |
+
" shape = clean_images.shape\n",
|
| 166 |
+
" expand = torch.ones(len(shape)-1, dtype=int)\n",
|
| 167 |
+
" # ts_expand = ts.view(ts.shape[0], *expand.tolist())\n",
|
| 168 |
+
" # expand = [1 for i in range(len(shape)-1)]\n",
|
| 169 |
+
"\n",
|
| 170 |
+
" noise = torch.randn_like(clean_images).to(self.device)\n",
|
| 171 |
+
" ts = torch.randint(0, self.num_timesteps, (shape[0],)).to(self.device)\n",
|
| 172 |
+
" \n",
|
| 173 |
+
" # test_expand = test.view(test.shape[0],*expand)\n",
|
| 174 |
+
" # extend_dim = [None for i in range(shape.dim()-1)]\n",
|
| 175 |
+
" noisy_images = (\n",
|
| 176 |
+
" clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
|
| 177 |
+
" + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
|
| 178 |
+
" )\n",
|
| 179 |
+
" # print(x_t.shape)\n",
|
| 180 |
+
"\n",
|
| 181 |
+
" return noisy_images, noise, ts\n",
|
| 182 |
+
"\n",
|
| 183 |
+
" def sample(self, nn_model, params, device, guide_w = 0):\n",
|
| 184 |
+
" n_sample = len(params) #params.shape[0]\n",
|
| 185 |
+
" # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
|
| 186 |
+
" x_i = torch.randn(n_sample, *self.img_shape).to(device)\n",
|
| 187 |
+
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 188 |
+
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 189 |
+
" if guide_w != -1:\n",
|
| 190 |
+
" c_i = params\n",
|
| 191 |
+
" uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)\n",
|
| 192 |
+
" # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)\n",
|
| 193 |
+
" # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)\n",
|
| 194 |
+
" c_i = torch.cat((c_i, uncond_tokens), 0)\n",
|
| 195 |
+
"\n",
|
| 196 |
+
" x_i_entire = [] # keep track of generated steps in case want to plot something\n",
|
| 197 |
+
" # print(\"self.num_timesteps =\", self.num_timesteps)\n",
|
| 198 |
+
" # for i in range(self.num_timesteps, 0, -1):\n",
|
| 199 |
+
" # print(f'sampling!!!')\n",
|
| 200 |
+
" pbar_sample = tqdm(total=self.num_timesteps)\n",
|
| 201 |
+
" pbar_sample.set_description(\"Sampling\")\n",
|
| 202 |
+
" for i in reversed(range(0, self.num_timesteps)):\n",
|
| 203 |
+
" # print(f'sampling timestep {i:4d}',end='\\r')\n",
|
| 204 |
+
" t_is = torch.tensor([i]).to(device)\n",
|
| 205 |
+
" t_is = t_is.repeat(n_sample)\n",
|
| 206 |
+
"\n",
|
| 207 |
+
" z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else 0\n",
|
| 208 |
+
"\n",
|
| 209 |
+
" if guide_w == -1:\n",
|
| 210 |
+
" # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
|
| 211 |
+
" eps = nn_model(x_i, t_is)\n",
|
| 212 |
+
" # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
|
| 213 |
+
" else:\n",
|
| 214 |
+
" # double batch\n",
|
| 215 |
+
" x_i = x_i.repeat(2, *torch.ones(len(self.img_shape), dtype=int).tolist())\n",
|
| 216 |
+
" t_is = t_is.repeat(2)\n",
|
| 217 |
+
"\n",
|
| 218 |
+
" # split predictions and compute weighting\n",
|
| 219 |
+
" # print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\n",
|
| 220 |
+
" eps = nn_model(x_i, t_is, c_i)\n",
|
| 221 |
+
" eps1 = eps[:n_sample]\n",
|
| 222 |
+
" eps2 = eps[n_sample:]\n",
|
| 223 |
+
" eps = eps1 + guide_w*(eps1 - eps2)\n",
|
| 224 |
+
" # eps = (1+guide_w)*eps1 - guide_w*eps2\n",
|
| 225 |
+
" x_i = x_i[:n_sample]\n",
|
| 226 |
+
" # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
|
| 227 |
+
" \n",
|
| 228 |
+
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 229 |
+
" x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
|
| 230 |
+
" \n",
|
| 231 |
+
" pbar_sample.update(1)\n",
|
| 232 |
+
" # pbar_sample.set_postfix(step=i)\n",
|
| 233 |
+
" \n",
|
| 234 |
+
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 235 |
+
" # store only part of the intermediate steps\n",
|
| 236 |
+
" if i%20==0:# or i==0:# or i<8:\n",
|
| 237 |
+
" x_i_entire.append(x_i.detach().cpu().numpy())\n",
|
| 238 |
+
" x_i = x_i.detach().cpu().numpy()\n",
|
| 239 |
+
" x_i_entire = np.array(x_i_entire)\n",
|
| 240 |
+
" return x_i, x_i_entire\n",
|
| 241 |
+
"\n",
|
| 242 |
+
"\n",
|
| 243 |
+
"# ddpm_scheduler = DDPMScheduler((1e-4,0.02),10)\n",
|
| 244 |
+
"# noisy_images, noise, ts = ddpm_scheduler.add_noise(images)"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"cell_type": "code",
|
| 249 |
+
"execution_count": 5,
|
| 250 |
+
"metadata": {},
|
| 251 |
+
"outputs": [],
|
| 252 |
+
"source": [
|
| 253 |
+
"class EMA:\n",
|
| 254 |
+
" def __init__(self, beta):\n",
|
| 255 |
+
" super().__init__()\n",
|
| 256 |
+
" self.beta = beta\n",
|
| 257 |
+
" self.step = 0\n",
|
| 258 |
+
"\n",
|
| 259 |
+
" def update_model_average(self, ma_model, current_model):\n",
|
| 260 |
+
" for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):\n",
|
| 261 |
+
" old_weight, up_weight = ma_params.data, current_params.data\n",
|
| 262 |
+
" ma_params.data = self.update_average(old_weight, up_weight)\n",
|
| 263 |
+
"\n",
|
| 264 |
+
" def update_average(self, old, new):\n",
|
| 265 |
+
" if old is None:\n",
|
| 266 |
+
" return new\n",
|
| 267 |
+
" return old * self.beta + (1 - self.beta) * new\n",
|
| 268 |
+
"\n",
|
| 269 |
+
" def step_ema(self, ema_model, model):\n",
|
| 270 |
+
" self.update_model_average(ema_model, model)\n",
|
| 271 |
+
" self.step += 1\n",
|
| 272 |
+
"\n",
|
| 273 |
+
" def reset_parameters(self, ema_model, model):\n",
|
| 274 |
+
" ema_model.load_state_dict(model.state_dict())\n",
|
| 275 |
+
" "
|
| 276 |
+
]
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"cell_type": "code",
|
| 280 |
+
"execution_count": 6,
|
| 281 |
+
"metadata": {},
|
| 282 |
+
"outputs": [],
|
| 283 |
+
"source": [
|
| 284 |
+
"@dataclass\n",
|
| 285 |
+
"class TrainConfig:\n",
|
| 286 |
+
" ###########################\n",
|
| 287 |
+
" ## hardcoding these here ##\n",
|
| 288 |
+
" ###########################\n",
|
| 289 |
+
" push_to_hub = True\n",
|
| 290 |
+
" hub_model_id = \"Xsmos/ml21cm\"\n",
|
| 291 |
+
" hub_private_repo = False\n",
|
| 292 |
+
" dataset_name = \"/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5\"\n",
|
| 293 |
+
" device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 294 |
+
" # world_size = torch.cuda.device_count()\n",
|
| 295 |
+
" # repeat = 2\n",
|
| 296 |
+
"\n",
|
| 297 |
+
" # dim = 2\n",
|
| 298 |
+
" dim = 3\n",
|
| 299 |
+
" stride = (2,2) if dim == 2 else (2,2,1)\n",
|
| 300 |
+
" num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
|
| 301 |
+
" batch_size = 2#2#50#20#2#100 # 10\n",
|
| 302 |
+
" n_epoch = 10#50#20#20#2#5#25 # 120\n",
|
| 303 |
+
" HII_DIM = 28#64\n",
|
| 304 |
+
" num_redshift = 4#128#64#512#256#256#64#512#128\n",
|
| 305 |
+
" channel = 1\n",
|
| 306 |
+
" img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
|
| 307 |
+
"\n",
|
| 308 |
+
" ranges_dict = dict(\n",
|
| 309 |
+
" params = {\n",
|
| 310 |
+
" 0: [4, 6], # ION_Tvir_MIN\n",
|
| 311 |
+
" 1: [10, 250], # HII_EFF_FACTOR\n",
|
| 312 |
+
" },\n",
|
| 313 |
+
" images = {\n",
|
| 314 |
+
" 0: [0, 80], # brightness_temp\n",
|
| 315 |
+
" }\n",
|
| 316 |
+
" )\n",
|
| 317 |
+
"\n",
|
| 318 |
+
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 319 |
+
" # n_sample = 24 # 64, the number of samples in sampling process\n",
|
| 320 |
+
" n_param = 2\n",
|
| 321 |
+
" guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
|
| 322 |
+
" drop_prob = 0#0.28 # only takes effect when guide_w != -1\n",
|
| 323 |
+
" ema=True # whether to use ema\n",
|
| 324 |
+
" ema_rate=0.995\n",
|
| 325 |
+
"\n",
|
| 326 |
+
" # seed = 0\n",
|
| 327 |
+
" # save_dir = './outputs/'\n",
|
| 328 |
+
"\n",
|
| 329 |
+
" save_freq = 0#.1 # the period of sampling\n",
|
| 330 |
+
" # general parameters for the name and logger \n",
|
| 331 |
+
" # device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 332 |
+
" lrate = 1e-4\n",
|
| 333 |
+
" lr_warmup_steps = 0#5#00\n",
|
| 334 |
+
" output_dir = \"./outputs/\"\n",
|
| 335 |
+
" save_name = os.path.join(output_dir, 'model_state')\n",
|
| 336 |
+
" # save_freq = 1 #10 # the period of saving model\n",
|
| 337 |
+
" # cond = True # if training using the conditional information\n",
|
| 338 |
+
" # lr_decay = False #True# if using the learning rate decay\n",
|
| 339 |
+
" resume = save_name # if resume from the trained checkpoints\n",
|
| 340 |
+
" # params_single = torch.tensor([0.2,0.80000023])\n",
|
| 341 |
+
" # params = torch.tile(params_single,(n_sample,1)).to(device)\n",
|
| 342 |
+
" # params = params\n",
|
| 343 |
+
" # data_dir = './data' # data directory\n",
|
| 344 |
+
"\n",
|
| 345 |
+
"\n",
|
| 346 |
+
" mixed_precision = \"fp16\"\n",
|
| 347 |
+
" gradient_accumulation_steps = 1\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" # date = datetime.datetime.now().strftime(\"%m%d-%H%M\")\n",
|
| 350 |
+
" # run_name = f'{date}' # the unique name of each experiment\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"# config = TrainConfig()\n",
|
| 353 |
+
"# print(\"device =\", config.device)"
|
| 354 |
+
]
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"cell_type": "code",
|
| 358 |
+
"execution_count": 7,
|
| 359 |
+
"metadata": {},
|
| 360 |
+
"outputs": [],
|
| 361 |
+
"source": [
|
| 362 |
+
"# import os\n",
|
| 363 |
+
"# print(os.cpu_count())\n",
|
| 364 |
+
"# print(len(os.sched_getaffinity(0)))\n",
|
| 365 |
+
"# import torch\n",
|
| 366 |
+
"# data = torch.randn((64,64))\n",
|
| 367 |
+
"# print(data.dtype)"
|
| 368 |
+
]
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"cell_type": "code",
|
| 372 |
+
"execution_count": 8,
|
| 373 |
+
"metadata": {},
|
| 374 |
+
"outputs": [],
|
| 375 |
+
"source": [
|
| 376 |
+
"# @dataclass\n",
|
| 377 |
+
"class DDPM21CM:\n",
|
| 378 |
+
" def __init__(self, config):\n",
|
| 379 |
+
" # config = TrainConfig()\n",
|
| 380 |
+
" # date = datetime.datetime.now().strftime(\"%m%d-%H%M\")\n",
|
| 381 |
+
" config.run_name = datetime.datetime.now().strftime(\"%m%d-%H%M\") # the unique name of each experiment\n",
|
| 382 |
+
" self.config = config\n",
|
| 383 |
+
" # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
|
| 384 |
+
" # # self.shape_loaded = dataset.images.shape\n",
|
| 385 |
+
" # # print(\"shape_loaded =\", self.shape_loaded)\n",
|
| 386 |
+
" # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)\n",
|
| 387 |
+
" # del dataset\n",
|
| 388 |
+
" self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n",
|
| 389 |
+
"\n",
|
| 390 |
+
" # initialize the unet\n",
|
| 391 |
+
" self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)\n",
|
| 392 |
+
"\n",
|
| 393 |
+
" if config.resume and os.path.exists(config.resume):\n",
|
| 394 |
+
" # resume_file = os.path.join(config.output_dir, f\"{config.resume}\")\n",
|
| 395 |
+
" self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])\n",
|
| 396 |
+
" print(f\"resumed nn_model from {config.resume}\")\n",
|
| 397 |
+
" # nn_model = ContextUnet(n_param=1, image_size=28)\n",
|
| 398 |
+
" self.nn_model.train()\n",
|
| 399 |
+
" self.nn_model.to(self.ddpm.device)\n",
|
| 400 |
+
" # print(\"nn_model.device =\", ddpm.device)\n",
|
| 401 |
+
" # number of parameters to be trained\n",
|
| 402 |
+
" self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())\n",
|
| 403 |
+
" print(f\"Number of parameters for nn_model: {self.number_of_params}\")\n",
|
| 404 |
+
"\n",
|
| 405 |
+
" # whether to use ema\n",
|
| 406 |
+
" if config.ema:\n",
|
| 407 |
+
" self.ema = EMA(config.ema_rate)\n",
|
| 408 |
+
" if config.resume and os.path.exists(config.resume):\n",
|
| 409 |
+
" self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n",
|
| 410 |
+
" self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])\n",
|
| 411 |
+
" print(f\"resumed ema_model from {config.resume}\")\n",
|
| 412 |
+
" else:\n",
|
| 413 |
+
" self.ema_model = copy.deepcopy(self.nn_model).eval().requires_grad_(False)\n",
|
| 414 |
+
"\n",
|
| 415 |
+
" self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate)\n",
|
| 416 |
+
" self.lr_scheduler = get_cosine_schedule_with_warmup(\n",
|
| 417 |
+
" optimizer=self.optimizer,\n",
|
| 418 |
+
" num_warmup_steps=config.lr_warmup_steps,\n",
|
| 419 |
+
" num_training_steps=(int(config.num_image/config.batch_size) * config.n_epoch),\n",
|
| 420 |
+
" # num_training_steps=(len(self.dataloader) * config.n_epoch),\n",
|
| 421 |
+
" )\n",
|
| 422 |
+
"\n",
|
| 423 |
+
" self.ranges_dict = config.ranges_dict\n",
|
| 424 |
+
"\n",
|
| 425 |
+
" def load(self):\n",
|
| 426 |
+
" dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)\n",
|
| 427 |
+
" # self.shape_loaded = dataset.images.shape\n",
|
| 428 |
+
" # print(\"shape_loaded =\", self.shape_loaded)\n",
|
| 429 |
+
" self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=len(os.sched_getaffinity(0)), pin_memory=True)\n",
|
| 430 |
+
" # del dataset\n",
|
| 431 |
+
" # self.accelerate(self.config)\n",
|
| 432 |
+
" del dataset\n",
|
| 433 |
+
"\n",
|
| 434 |
+
" # def accelerate(self):\n",
|
| 435 |
+
"\n",
|
| 436 |
+
" def train(self):\n",
|
| 437 |
+
" ################### \n",
|
| 438 |
+
" ## training loop ##\n",
|
| 439 |
+
" ###################\n",
|
| 440 |
+
" # plot_unet = True\n",
|
| 441 |
+
"\n",
|
| 442 |
+
" self.load()\n",
|
| 443 |
+
" self.accelerator = Accelerator(\n",
|
| 444 |
+
" mixed_precision=self.config.mixed_precision,\n",
|
| 445 |
+
" gradient_accumulation_steps=self.config.gradient_accumulation_steps,\n",
|
| 446 |
+
" log_with=\"tensorboard\",\n",
|
| 447 |
+
" project_dir=os.path.join(self.config.output_dir, \"logs\"),\n",
|
| 448 |
+
" )\n",
|
| 449 |
+
" print(\"self.accelerator.is_main_process:\", self.accelerator.is_main_process)\n",
|
| 450 |
+
" if self.accelerator.is_main_process:\n",
|
| 451 |
+
" if self.config.output_dir is not None:\n",
|
| 452 |
+
" os.makedirs(self.config.output_dir, exist_ok=True)\n",
|
| 453 |
+
" if self.config.push_to_hub:\n",
|
| 454 |
+
" self.repo_id = create_repo(\n",
|
| 455 |
+
" repo_id=self.config.hub_model_id or Path(self.config.output_dir).name, exist_ok=True\n",
|
| 456 |
+
" ).repo_id\n",
|
| 457 |
+
" self.accelerator.init_trackers(f\"{self.config.run_name}\")\n",
|
| 458 |
+
"\n",
|
| 459 |
+
" self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \\\n",
|
| 460 |
+
" self.accelerator.prepare(\n",
|
| 461 |
+
" self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler\n",
|
| 462 |
+
" )\n",
|
| 463 |
+
" \n",
|
| 464 |
+
" global_step = 0\n",
|
| 465 |
+
" for ep in range(self.config.n_epoch):\n",
|
| 466 |
+
" self.ddpm.train()\n",
|
| 467 |
+
"\n",
|
| 468 |
+
" pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)\n",
|
| 469 |
+
" pbar_train.set_description(f\"Epoch {ep}\")\n",
|
| 470 |
+
" for i, (x, c) in enumerate(self.dataloader):\n",
|
| 471 |
+
" with self.accelerator.accumulate(self.nn_model):\n",
|
| 472 |
+
" x = x.to(self.config.device)\n",
|
| 473 |
+
" xt, noise, ts = self.ddpm.add_noise(x)\n",
|
| 474 |
+
" \n",
|
| 475 |
+
" if self.config.guide_w == -1:\n",
|
| 476 |
+
" noise_pred = self.nn_model(xt, ts)\n",
|
| 477 |
+
" else:\n",
|
| 478 |
+
" c = c.to(self.config.device)\n",
|
| 479 |
+
" noise_pred = self.nn_model(xt, ts, c)\n",
|
| 480 |
+
" \n",
|
| 481 |
+
" loss = F.mse_loss(noise, noise_pred)\n",
|
| 482 |
+
" self.accelerator.backward(loss)\n",
|
| 483 |
+
" self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)\n",
|
| 484 |
+
" self.optimizer.step()\n",
|
| 485 |
+
" self.lr_scheduler.step()\n",
|
| 486 |
+
" self.optimizer.zero_grad()\n",
|
| 487 |
+
"\n",
|
| 488 |
+
" # ema update\n",
|
| 489 |
+
" if self.config.ema:\n",
|
| 490 |
+
" self.ema.step_ema(self.ema_model, self.nn_model)\n",
|
| 491 |
+
"\n",
|
| 492 |
+
" pbar_train.update(1)\n",
|
| 493 |
+
" logs = dict(\n",
|
| 494 |
+
" loss=loss.detach().item(),\n",
|
| 495 |
+
" lr=self.optimizer.param_groups[0]['lr'],\n",
|
| 496 |
+
" step=global_step\n",
|
| 497 |
+
" )\n",
|
| 498 |
+
" pbar_train.set_postfix(**logs)\n",
|
| 499 |
+
"\n",
|
| 500 |
+
" self.accelerator.log(logs, step=global_step)\n",
|
| 501 |
+
" global_step += 1\n",
|
| 502 |
+
"\n",
|
| 503 |
+
" # if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:\n",
|
| 504 |
+
" self.save(ep)\n",
|
| 505 |
+
"\n",
|
| 506 |
+
" del self.nn_model\n",
|
| 507 |
+
" if self.config.ema:\n",
|
| 508 |
+
" del self.ema_model\n",
|
| 509 |
+
" torch.cuda.empty_cache()\n",
|
| 510 |
+
"\n",
|
| 511 |
+
" def save(self, ep):\n",
|
| 512 |
+
" # save model\n",
|
| 513 |
+
" if self.accelerator.is_main_process:\n",
|
| 514 |
+
" if ep == self.config.n_epoch-1 or (ep+1)*self.config.save_freq==1:\n",
|
| 515 |
+
" self.nn_model.eval()\n",
|
| 516 |
+
" with torch.no_grad():\n",
|
| 517 |
+
" if self.config.push_to_hub:\n",
|
| 518 |
+
" upload_folder(\n",
|
| 519 |
+
" repo_id = self.repo_id,\n",
|
| 520 |
+
" folder_path = \".\",#config.output_dir,\n",
|
| 521 |
+
" commit_message = f\"{self.config.run_name}\",\n",
|
| 522 |
+
" ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\", \"__pycache__\"],\n",
|
| 523 |
+
" )\n",
|
| 524 |
+
" if self.config.save_name:\n",
|
| 525 |
+
" model_state = {\n",
|
| 526 |
+
" 'epoch': ep,\n",
|
| 527 |
+
" 'unet_state_dict': self.nn_model.state_dict(),\n",
|
| 528 |
+
" 'ema_unet_state_dict': self.ema_model.state_dict(),\n",
|
| 529 |
+
" }\n",
|
| 530 |
+
" torch.save(model_state, self.config.save_name+f\"-N{self.config.num_image}\")\n",
|
| 531 |
+
" print('saved model at ' + self.config.save_name+f\"-N{self.config.num_image}\")\n",
|
| 532 |
+
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 533 |
+
"\n",
|
| 534 |
+
" # def rescale(self, value, type='params', to_ranges=[0,1]):\n",
|
| 535 |
+
" # for i, from_ranges in self.ranges_dict[type].items():\n",
|
| 536 |
+
" # value[i] = (value[i] - from_ranges[0])/(from_ranges[1]-from_ranges[0]) # normalize\n",
|
| 537 |
+
" # value[i] = \n",
|
| 538 |
+
" def rescale(self, value, ranges, to: list):\n",
|
| 539 |
+
" if value.ndim == 1:\n",
|
| 540 |
+
" value = value.view(-1,len(value))\n",
|
| 541 |
+
" \n",
|
| 542 |
+
" for i in range(np.shape(value)[1]):\n",
|
| 543 |
+
" value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0])\n",
|
| 544 |
+
" # print(f\"i = {i}, value.min = {value[:,i].min()}, value.max = {value[:,i].max()}\")\n",
|
| 545 |
+
" value = value * (to[1]-to[0]) + to[0]\n",
|
| 546 |
+
" return value \n",
|
| 547 |
+
"\n",
|
| 548 |
+
" def sample(self, file, params:torch.tensor=None, repeat=192, ema=False, entire=False):\n",
|
| 549 |
+
" # n_sample = params.shape[0]\n",
|
| 550 |
+
" \n",
|
| 551 |
+
" if params is None:\n",
|
| 552 |
+
" params = torch.tensor([0.20000000000000018, 0.5055875000000001])\n",
|
| 553 |
+
" params_backup = params.numpy().copy()\n",
|
| 554 |
+
" else:\n",
|
| 555 |
+
" params_backup = params.numpy().copy()\n",
|
| 556 |
+
" params = self.rescale(params, self.ranges_dict['params'], to=[0,1])\n",
|
| 557 |
+
"\n",
|
| 558 |
+
" print(f\"sampling {repeat} images with normalized params = {params}\")\n",
|
| 559 |
+
" params = params.repeat(repeat,1)\n",
|
| 560 |
+
" assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
|
| 561 |
+
" # print(\"params =\", params)\n",
|
| 562 |
+
" # print(\"params =\", params)\n",
|
| 563 |
+
" # print(\"len(params) =\", len(params))\n",
|
| 564 |
+
" # model = self.ema_model if ema else self.nn_model\n",
|
| 565 |
+
" # del self.ema_model, self.nn\n",
|
| 566 |
+
" # params = torch.tile(params, (n_sample,1)).to(device)\n",
|
| 567 |
+
"\n",
|
| 568 |
+
" nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)\n",
|
| 569 |
+
" if ema:\n",
|
| 570 |
+
" nn_model.load_state_dict(torch.load(file)['ema_unet_state_dict'])\n",
|
| 571 |
+
" else:\n",
|
| 572 |
+
" nn_model.load_state_dict(torch.load(file)['unet_state_dict'])\n",
|
| 573 |
+
" print(f\"nn_model resumed from {file}\")\n",
|
| 574 |
+
" # nn_model = ContextUnet(n_param=1, image_size=28)\n",
|
| 575 |
+
" # nn_model.train()\n",
|
| 576 |
+
" nn_model.to(self.ddpm.device)\n",
|
| 577 |
+
" nn_model.eval()\n",
|
| 578 |
+
"\n",
|
| 579 |
+
" # self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n",
|
| 580 |
+
" # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
|
| 581 |
+
" # print(f\"resumed ema_model from {config.resume}\")\n",
|
| 582 |
+
"\n",
|
| 583 |
+
" with torch.no_grad():\n",
|
| 584 |
+
" x_last, x_entire = self.ddpm.sample(\n",
|
| 585 |
+
" nn_model=nn_model, \n",
|
| 586 |
+
" params=params.to(self.config.device), \n",
|
| 587 |
+
" device=self.config.device, \n",
|
| 588 |
+
" guide_w=self.config.guide_w\n",
|
| 589 |
+
" )\n",
|
| 590 |
+
"\n",
|
| 591 |
+
" # np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else ''}.npy\"), x_last)\n",
|
| 592 |
+
" np.save(os.path.join(self.config.output_dir, f\"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}.npy\"), x_last)\n",
|
| 593 |
+
"\n",
|
| 594 |
+
" if entire:\n",
|
| 595 |
+
" np.save(os.path.join(self.config.output_dir, f\"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}_entire.npy\"), x_last)\n",
|
| 596 |
+
"# print(\"device =\", config.device)"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"cell_type": "code",
|
| 601 |
+
"execution_count": 12,
|
| 602 |
+
"metadata": {},
|
| 603 |
+
"outputs": [
|
| 604 |
+
{
|
| 605 |
+
"name": "stderr",
|
| 606 |
+
"output_type": "stream",
|
| 607 |
+
"text": [
|
| 608 |
+
"Traceback (most recent call last):\n",
|
| 609 |
+
" File \"<string>\", line 1, in <module>\n",
|
| 610 |
+
" File \"/storage/home/hcoda1/3/bxia34/.conda/envs/diffusers/lib/python3.9/multiprocessing/spawn.py\", line 116, in spawn_main\n",
|
| 611 |
+
" exitcode = _main(fd, parent_sentinel)\n",
|
| 612 |
+
" File \"/storage/home/hcoda1/3/bxia34/.conda/envs/diffusers/lib/python3.9/multiprocessing/spawn.py\", line 126, in _main\n",
|
| 613 |
+
" self = reduction.pickle.load(from_parent)\n",
|
| 614 |
+
"AttributeError: Can't get attribute 'single_main' on <module '__main__' (built-in)>\n"
|
| 615 |
+
]
|
| 616 |
+
},
|
| 617 |
+
{
|
| 618 |
+
"ename": "ProcessExitedException",
|
| 619 |
+
"evalue": "process 0 terminated with exit code 1",
|
| 620 |
+
"output_type": "error",
|
| 621 |
+
"traceback": [
|
| 622 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 623 |
+
"\u001b[0;31mProcessExitedException\u001b[0m Traceback (most recent call last)",
|
| 624 |
+
"Cell \u001b[0;32mIn[12], line 21\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m__name__\u001b[39m \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m__main__\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 17\u001b[0m \u001b[39m# torch.multiprocessing.set_start_method(\"spawn\")\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[39m# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\u001b[39;00m\n\u001b[1;32m 19\u001b[0m world_size \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\u001b[39m#torch.cuda.device_count()\u001b[39;00m\n\u001b[0;32m---> 21\u001b[0m mp\u001b[39m.\u001b[39;49mspawn(single_main, args\u001b[39m=\u001b[39;49m(world_size,), nprocs\u001b[39m=\u001b[39;49mworld_size)\n\u001b[1;32m 22\u001b[0m \u001b[39m# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')\u001b[39;00m\n",
|
| 625 |
+
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:240\u001b[0m, in \u001b[0;36mspawn\u001b[0;34m(fn, args, nprocs, join, daemon, start_method)\u001b[0m\n\u001b[1;32m 236\u001b[0m msg \u001b[39m=\u001b[39m (\u001b[39m'\u001b[39m\u001b[39mThis method only supports start_method=spawn (got: \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m).\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[1;32m 237\u001b[0m \u001b[39m'\u001b[39m\u001b[39mTo use a different start_method use:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[1;32m 238\u001b[0m \u001b[39m'\u001b[39m\u001b[39m torch.multiprocessing.start_processes(...)\u001b[39m\u001b[39m'\u001b[39m \u001b[39m%\u001b[39m start_method)\n\u001b[1;32m 239\u001b[0m warnings\u001b[39m.\u001b[39mwarn(msg)\n\u001b[0;32m--> 240\u001b[0m \u001b[39mreturn\u001b[39;00m start_processes(fn, args, nprocs, join, daemon, start_method\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mspawn\u001b[39;49m\u001b[39m'\u001b[39;49m)\n",
|
| 626 |
+
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:198\u001b[0m, in \u001b[0;36mstart_processes\u001b[0;34m(fn, args, nprocs, join, daemon, start_method)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[39mreturn\u001b[39;00m context\n\u001b[1;32m 197\u001b[0m \u001b[39m# Loop on join until it returns True or raises an exception.\u001b[39;00m\n\u001b[0;32m--> 198\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mnot\u001b[39;00m context\u001b[39m.\u001b[39;49mjoin():\n\u001b[1;32m 199\u001b[0m \u001b[39mpass\u001b[39;00m\n",
|
| 627 |
+
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:149\u001b[0m, in \u001b[0;36mProcessContext.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[39mraise\u001b[39;00m ProcessExitedException(\n\u001b[1;32m 141\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mprocess \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with signal \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m\n\u001b[1;32m 142\u001b[0m (error_index, name),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 146\u001b[0m signal_name\u001b[39m=\u001b[39mname\n\u001b[1;32m 147\u001b[0m )\n\u001b[1;32m 148\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 149\u001b[0m \u001b[39mraise\u001b[39;00m ProcessExitedException(\n\u001b[1;32m 150\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mprocess \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with exit code \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m\n\u001b[1;32m 151\u001b[0m (error_index, exitcode),\n\u001b[1;32m 152\u001b[0m error_index\u001b[39m=\u001b[39merror_index,\n\u001b[1;32m 153\u001b[0m error_pid\u001b[39m=\u001b[39mfailed_process\u001b[39m.\u001b[39mpid,\n\u001b[1;32m 154\u001b[0m exit_code\u001b[39m=\u001b[39mexitcode\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 157\u001b[0m original_trace \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39merror_queues[error_index]\u001b[39m.\u001b[39mget()\n\u001b[1;32m 158\u001b[0m msg \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\n\u001b[39;00m\u001b[39m-- Process \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with the following error:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m error_index\n",
|
| 628 |
+
"\u001b[0;31mProcessExitedException\u001b[0m: process 0 terminated with exit code 1"
|
| 629 |
+
]
|
| 630 |
+
}
|
| 631 |
+
],
|
| 632 |
+
"source": [
|
| 633 |
+
"def single_main(rank, world_size):\n",
|
| 634 |
+
" config = TrainConfig()\n",
|
| 635 |
+
" ddp_setup(rank, world_size)\n",
|
| 636 |
+
" \n",
|
| 637 |
+
" num_image_list = [100]#[200]#[1600,3200,6400,12800,25600]\n",
|
| 638 |
+
" for i, num_image in enumerate(num_image_list):\n",
|
| 639 |
+
" config.num_image = num_image\n",
|
| 640 |
+
" # config.world_size = world_size\n",
|
| 641 |
+
" \n",
|
| 642 |
+
" ddpm21cm = DDPM21CM(config)\n",
|
| 643 |
+
" print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
|
| 644 |
+
" print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
|
| 645 |
+
" ddpm21cm.train()\n",
|
| 646 |
+
"\n",
|
| 647 |
+
" \n",
|
| 648 |
+
"if __name__ == \"__main__\":\n",
|
| 649 |
+
" # torch.multiprocessing.set_start_method(\"spawn\")\n",
|
| 650 |
+
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
| 651 |
+
" world_size = 1#torch.cuda.device_count()\n",
|
| 652 |
+
"\n",
|
| 653 |
+
" mp.spawn(single_main, args=(world_size,), nprocs=world_size)\n",
|
| 654 |
+
" # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
|
| 655 |
+
]
|
| 656 |
+
},
|
| 657 |
+
{
|
| 658 |
+
"cell_type": "code",
|
| 659 |
+
"execution_count": null,
|
| 660 |
+
"metadata": {},
|
| 661 |
+
"outputs": [],
|
| 662 |
+
"source": [
|
| 663 |
+
"# torch.cuda.set_device(0)"
|
| 664 |
+
]
|
| 665 |
+
},
|
| 666 |
+
{
|
| 667 |
+
"cell_type": "code",
|
| 668 |
+
"execution_count": null,
|
| 669 |
+
"metadata": {},
|
| 670 |
+
"outputs": [
|
| 671 |
+
{
|
| 672 |
+
"name": "stdout",
|
| 673 |
+
"output_type": "stream",
|
| 674 |
+
"text": [
|
| 675 |
+
"True\n",
|
| 676 |
+
"2\n",
|
| 677 |
+
"['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__path__', '__file__', '__cached__', '__builtins__', '__annotations__', 'contextlib', 'os', 'torch', 'Device', 'traceback', 'warnings', 'threading', 'List', 'Optional', 'Tuple', 'Union', 'Any', '_utils', '_get_device_index', '_dummy_type', 'classproperty', 'graphs', 'CUDAGraph', 'graph_pool_handle', 'graph', 'make_graphed_callables', 'is_current_stream_capturing', 'streams', 'ExternalStream', 'Stream', 'Event', '_device', '_cudart', '_initialized', '_tls', '_initialization_lock', '_queued_calls', '_is_in_bad_fork', '_device_t', '_LazySeedTracker', '_lazy_seed_tracker', '_CudaDeviceProperties', 'has_magma', 'has_half', 'default_generators', 'is_available', 'is_bf16_supported', '_sleep', '_check_capability', '_check_cubins', 'is_initialized', '_lazy_call', 'DeferredCudaCallError', 'init', '_lazy_init', 'cudart', 'cudaStatus', 'CudaError', 'check_error', 'device', 'device_of', 'set_device', 'get_device_name', 'get_device_capability', 'get_device_properties', 'can_device_access_peer', 'StreamContext', 'stream', 'set_stream', 'device_count', 'get_arch_list', 'get_gencode_flags', 'current_device', 'synchronize', 'ipc_collect', 'current_stream', 'default_stream', 'current_blas_handle', 'set_sync_debug_mode', 'get_sync_debug_mode', 'memory_usage', 'utilization', 'memory', 'caching_allocator_alloc', 'caching_allocator_delete', 'set_per_process_memory_fraction', 'empty_cache', 'memory_stats', 'memory_stats_as_nested_dict', 'reset_accumulated_memory_stats', 'reset_peak_memory_stats', 'reset_max_memory_allocated', 'reset_max_memory_cached', 'memory_allocated', 'max_memory_allocated', 'memory_reserved', 'max_memory_reserved', 'memory_cached', 'max_memory_cached', 'memory_snapshot', 'memory_summary', 'list_gpu_processes', 'mem_get_info', 'random', 'get_rng_state', 'get_rng_state_all', 'set_rng_state', 'set_rng_state_all', 'manual_seed', 'manual_seed_all', 'seed', 'seed_all', 'initial_seed', '_lazy_new', '_CudaBase', 'ByteStorage', 'DoubleStorage', 'FloatStorage', 'HalfStorage', 'LongStorage', 'IntStorage', 'ShortStorage', 'CharStorage', 'BoolStorage', 'BFloat16Storage', 'ComplexDoubleStorage', 'ComplexFloatStorage', 'sparse', 'profiler', 'nvtx', 'amp', 'jiterator', 'ByteTensor', 'CharTensor', 'DoubleTensor', 'FloatTensor', 'IntTensor', 'LongTensor', 'ShortTensor', 'HalfTensor', 'BoolTensor', 'BFloat16Tensor', 'nccl', '_get_device_properties']\n"
|
| 678 |
+
]
|
| 679 |
+
}
|
| 680 |
+
],
|
| 681 |
+
"source": [
|
| 682 |
+
"print(torch.cuda.is_available())\n",
|
| 683 |
+
"print(torch.cuda.device_count())\n",
|
| 684 |
+
"print(torch.cuda.__dir__())"
|
| 685 |
+
]
|
| 686 |
+
},
|
| 687 |
+
{
|
| 688 |
+
"cell_type": "code",
|
| 689 |
+
"execution_count": 17,
|
| 690 |
+
"metadata": {},
|
| 691 |
+
"outputs": [
|
| 692 |
+
{
|
| 693 |
+
"name": "stdout",
|
| 694 |
+
"output_type": "stream",
|
| 695 |
+
"text": [
|
| 696 |
+
"True\n",
|
| 697 |
+
"<class 'torch.cuda.device'>\n",
|
| 698 |
+
"Quadro RTX 6000\n",
|
| 699 |
+
"0\n",
|
| 700 |
+
"(7, 5)\n",
|
| 701 |
+
"_CudaDeviceProperties(name='Quadro RTX 6000', major=7, minor=5, total_memory=24212MB, multi_processor_count=72)\n"
|
| 702 |
+
]
|
| 703 |
+
}
|
| 704 |
+
],
|
| 705 |
+
"source": [
|
| 706 |
+
"print(torch.cuda.is_initialized())\n",
|
| 707 |
+
"print(torch.cuda.device)\n",
|
| 708 |
+
"print(torch.cuda.get_device_name())\n",
|
| 709 |
+
"print(torch.cuda.current_device())\n",
|
| 710 |
+
"print(torch.cuda.get_device_capability())\n",
|
| 711 |
+
"print(torch.cuda.get_device_properties(torch.cuda.device))\n",
|
| 712 |
+
"# print('here')\n",
|
| 713 |
+
"# print(torch.cuda.memory_usage())\n",
|
| 714 |
+
"# print(torch.cuda.utilization())\n",
|
| 715 |
+
"# print(torch.cuda.memory())\n",
|
| 716 |
+
"# print('here')\n",
|
| 717 |
+
"# print(torch.cuda.memory_summary())"
|
| 718 |
+
]
|
| 719 |
+
},
|
| 720 |
+
{
|
| 721 |
+
"attachments": {},
|
| 722 |
+
"cell_type": "markdown",
|
| 723 |
+
"metadata": {},
|
| 724 |
+
"source": [
|
| 725 |
+
"# Sampling"
|
| 726 |
+
]
|
| 727 |
+
},
|
| 728 |
+
{
|
| 729 |
+
"cell_type": "code",
|
| 730 |
+
"execution_count": null,
|
| 731 |
+
"metadata": {},
|
| 732 |
+
"outputs": [],
|
| 733 |
+
"source": [
|
| 734 |
+
"if __name__ == \"__main__\":\n",
|
| 735 |
+
" # num_image_list = [1600,3200,6400,12800,25600]\n",
|
| 736 |
+
" num_image_list = [1000]\n",
|
| 737 |
+
" # num_image_list = [3200,6400,12800,25600]\n",
|
| 738 |
+
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
| 739 |
+
" repeat = 2\n",
|
| 740 |
+
" config = TrainConfig()\n",
|
| 741 |
+
" for i, num_image in enumerate(num_image_list):\n",
|
| 742 |
+
" config.num_image = num_image\n",
|
| 743 |
+
" ddpm21cm = DDPM21CM(config)\n",
|
| 744 |
+
"\n",
|
| 745 |
+
" ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor([4.4, 131.341]), repeat=repeat)\n",
|
| 746 |
+
"\n",
|
| 747 |
+
" # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((5.6, 19.037)), repeat=repeat)\n",
|
| 748 |
+
"\n",
|
| 749 |
+
" # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((4.699, 30)), repeat=repeat)\n",
|
| 750 |
+
"\n",
|
| 751 |
+
" # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((5.477, 200)), repeat=repeat)\n",
|
| 752 |
+
"\n",
|
| 753 |
+
" # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((4.8, 131.341)), repeat=repeat)"
|
| 754 |
+
]
|
| 755 |
+
},
|
| 756 |
+
{
|
| 757 |
+
"cell_type": "code",
|
| 758 |
+
"execution_count": 13,
|
| 759 |
+
"metadata": {},
|
| 760 |
+
"outputs": [],
|
| 761 |
+
"source": [
|
| 762 |
+
"# ls -lth outputs | head"
|
| 763 |
+
]
|
| 764 |
+
},
|
| 765 |
+
{
|
| 766 |
+
"cell_type": "code",
|
| 767 |
+
"execution_count": null,
|
| 768 |
+
"metadata": {},
|
| 769 |
+
"outputs": [],
|
| 770 |
+
"source": [
|
| 771 |
+
"def plot_grid(samples, c=None, row=1, col=2):\n",
|
| 772 |
+
" print(\"samples.shape =\", samples.shape)\n",
|
| 773 |
+
" for j in range(samples.shape[4]):\n",
|
| 774 |
+
" plt.figure(figsize = (12,6), dpi=400)\n",
|
| 775 |
+
" for i in range(len(samples)):\n",
|
| 776 |
+
" plt.subplot(row,col,i+1)\n",
|
| 777 |
+
" plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)\n",
|
| 778 |
+
" plt.xticks([])\n",
|
| 779 |
+
" plt.yticks([])\n",
|
| 780 |
+
" # plt.suptitle(f\"ION_Tvir_MIN = {c[0][0]}, HII_EFF_FACTOR = {c[0][1]}\")\n",
|
| 781 |
+
" # plt.show()\n",
|
| 782 |
+
" # plt.suptitle('simulations')\n",
|
| 783 |
+
" plt.tight_layout()\n",
|
| 784 |
+
" plt.subplots_adjust(wspace=0, hspace=0)\n",
|
| 785 |
+
" plt.savefig(f\"test3D-{j:03d}.png\")\n",
|
| 786 |
+
" plt.close()\n",
|
| 787 |
+
" # plt.show()\n",
|
| 788 |
+
" \n",
|
| 789 |
+
"data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\")\n",
|
| 790 |
+
"# print(data.shape)\n",
|
| 791 |
+
"plot_grid(data)\n",
|
| 792 |
+
"# plt.imshow(data)"
|
| 793 |
+
]
|
| 794 |
+
},
|
| 795 |
+
{
|
| 796 |
+
"cell_type": "code",
|
| 797 |
+
"execution_count": null,
|
| 798 |
+
"metadata": {},
|
| 799 |
+
"outputs": [],
|
| 800 |
+
"source": [
|
| 801 |
+
"# config = TrainConfig()\n",
|
| 802 |
+
"# def plot(filename, row=4, col=6):\n",
|
| 803 |
+
"# samples = np.load(filename)\n",
|
| 804 |
+
"# params = filename.split('guide_w')[-1][:-4]\n",
|
| 805 |
+
"# print(\"plotting\", samples.shape, params)\n",
|
| 806 |
+
"# plt.figure(figsize = (8,8))\n",
|
| 807 |
+
"# for i in range(24):\n",
|
| 808 |
+
"# plt.subplot(row,col,i+1)\n",
|
| 809 |
+
"# plt.imshow(samples[i,0,:,:], cmap='gray')#, vmin=-1, vmax=1)\n",
|
| 810 |
+
"# plt.xticks([])\n",
|
| 811 |
+
"# plt.yticks([])\n",
|
| 812 |
+
"# # plt.show()\n",
|
| 813 |
+
"# plt.suptitle(params)\n",
|
| 814 |
+
"# plt.tight_layout()\n",
|
| 815 |
+
"# plt.subplots_adjust(wspace=0, hspace=0) \n",
|
| 816 |
+
"# plt.show()\n",
|
| 817 |
+
"# # plt.savefig('outputs/'+params+'.png')\n",
|
| 818 |
+
"# # plt.close()\n",
|
| 819 |
+
"# # plt.imshow(images[0,0])\n",
|
| 820 |
+
"# # plt.show()"
|
| 821 |
+
]
|
| 822 |
+
},
|
| 823 |
+
{
|
| 824 |
+
"cell_type": "code",
|
| 825 |
+
"execution_count": 1,
|
| 826 |
+
"metadata": {},
|
| 827 |
+
"outputs": [],
|
| 828 |
+
"source": [
|
| 829 |
+
"# import torch\n",
|
| 830 |
+
"# print(torch.__version__)"
|
| 831 |
+
]
|
| 832 |
+
},
|
| 833 |
+
{
|
| 834 |
+
"cell_type": "code",
|
| 835 |
+
"execution_count": 9,
|
| 836 |
+
"metadata": {},
|
| 837 |
+
"outputs": [],
|
| 838 |
+
"source": [
|
| 839 |
+
"# import torch\n",
|
| 840 |
+
"# import os\n",
|
| 841 |
+
"\n",
|
| 842 |
+
"# def compare_models(num_gpus):\n",
|
| 843 |
+
"# model_states = []\n",
|
| 844 |
+
" \n",
|
| 845 |
+
"# for gpu_id in range(num_gpus):\n",
|
| 846 |
+
"# filename = f\"outputs/model_state-N40-device{gpu_id}\"\n",
|
| 847 |
+
"# if os.path.exists(filename):\n",
|
| 848 |
+
"# state_dict = torch.load(filename, map_location='cpu')\n",
|
| 849 |
+
"# model_states.append(state_dict)\n",
|
| 850 |
+
"# print(filename)\n",
|
| 851 |
+
"# else:\n",
|
| 852 |
+
"# print(f\"File {filename} not found!\")\n",
|
| 853 |
+
"# return False\n",
|
| 854 |
+
" \n",
|
| 855 |
+
"# # Compare all model state_dicts\n",
|
| 856 |
+
"# print(\"len(model_states) =\", len(model_states))\n",
|
| 857 |
+
"# base_state = model_states[0]\n",
|
| 858 |
+
"# for state in model_states[1:]:\n",
|
| 859 |
+
"# for key in base_state.keys():\n",
|
| 860 |
+
"# # print(key, base_state[key], state[key])\n",
|
| 861 |
+
"# print(key)\n",
|
| 862 |
+
"# print(\"epoch\", base_state['epoch'], state['epoch'])\n",
|
| 863 |
+
"\n",
|
| 864 |
+
"# print(base_state['unet_state_dict'].keys())\n",
|
| 865 |
+
"# for key in base_state['unet_state_dict']:\n",
|
| 866 |
+
"# # print(key)\n",
|
| 867 |
+
"# if not torch.equal(base_state['unet_state_dict'][key], state['unet_state_dict'][key]):\n",
|
| 868 |
+
"# print(\"different\")\n",
|
| 869 |
+
"# return \n",
|
| 870 |
+
"# # else:\n",
|
| 871 |
+
"# print(\"exactly same\")\n",
|
| 872 |
+
"\n",
|
| 873 |
+
"# # if key == 'epoch':\n",
|
| 874 |
+
"# # print(base_state[key], state[key])\n",
|
| 875 |
+
"# # else:\n",
|
| 876 |
+
"# # print(base_state[key], state[key])\n",
|
| 877 |
+
"# # if not torch.equal(base_state[key], state[key]):\n",
|
| 878 |
+
"# # # if not (base_state[key] == state[key]):\n",
|
| 879 |
+
"# # print(f\"Mismatch found in parameter {key}\")\n",
|
| 880 |
+
"# # return False\n",
|
| 881 |
+
" \n",
|
| 882 |
+
"# # print(\"All models are identical!\")\n",
|
| 883 |
+
"# # return True\n",
|
| 884 |
+
"\n",
|
| 885 |
+
"# if __name__ == \"__main__\":\n",
|
| 886 |
+
"# # epoch_to_check = 0 # specify the epoch you want to check\n",
|
| 887 |
+
"# num_gpus = torch.cuda.device_count() # specify the number of GPUs used in training\n",
|
| 888 |
+
"# compare_models(num_gpus)"
|
| 889 |
+
]
|
| 890 |
+
},
|
| 891 |
+
{
|
| 892 |
+
"cell_type": "code",
|
| 893 |
+
"execution_count": 6,
|
| 894 |
+
"metadata": {},
|
| 895 |
+
"outputs": [],
|
| 896 |
+
"source": [
|
| 897 |
+
"import numpy as np\n",
|
| 898 |
+
"test = np.random.normal(0,1,(800,1,64,64,512))"
|
| 899 |
+
]
|
| 900 |
+
},
|
| 901 |
+
{
|
| 902 |
+
"cell_type": "code",
|
| 903 |
+
"execution_count": 7,
|
| 904 |
+
"metadata": {},
|
| 905 |
+
"outputs": [
|
| 906 |
+
{
|
| 907 |
+
"data": {
|
| 908 |
+
"text/plain": [
|
| 909 |
+
"12.5"
|
| 910 |
+
]
|
| 911 |
+
},
|
| 912 |
+
"execution_count": 7,
|
| 913 |
+
"metadata": {},
|
| 914 |
+
"output_type": "execute_result"
|
| 915 |
+
}
|
| 916 |
+
],
|
| 917 |
+
"source": [
|
| 918 |
+
"(test.itemsize*test.size) / 1024/1024/1024"
|
| 919 |
+
]
|
| 920 |
+
},
|
| 921 |
+
{
|
| 922 |
+
"cell_type": "code",
|
| 923 |
+
"execution_count": 8,
|
| 924 |
+
"metadata": {},
|
| 925 |
+
"outputs": [],
|
| 926 |
+
"source": [
|
| 927 |
+
"del test"
|
| 928 |
+
]
|
| 929 |
+
}
|
| 930 |
+
],
|
| 931 |
+
"metadata": {
|
| 932 |
+
"kernelspec": {
|
| 933 |
+
"display_name": "base",
|
| 934 |
+
"language": "python",
|
| 935 |
+
"name": "python3"
|
| 936 |
+
},
|
| 937 |
+
"language_info": {
|
| 938 |
+
"codemirror_mode": {
|
| 939 |
+
"name": "ipython",
|
| 940 |
+
"version": 3
|
| 941 |
+
},
|
| 942 |
+
"file_extension": ".py",
|
| 943 |
+
"mimetype": "text/x-python",
|
| 944 |
+
"name": "python",
|
| 945 |
+
"nbconvert_exporter": "python",
|
| 946 |
+
"pygments_lexer": "ipython3",
|
| 947 |
+
"version": "3.9.19"
|
| 948 |
+
},
|
| 949 |
+
"orig_nbformat": 4
|
| 950 |
+
},
|
| 951 |
+
"nbformat": 4,
|
| 952 |
+
"nbformat_minor": 2
|
| 953 |
+
}
|
diffusion.py
CHANGED
|
@@ -239,7 +239,7 @@ class TrainConfig:
|
|
| 239 |
stride = (2,2) if dim == 2 else (2,2,4)
|
| 240 |
num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 241 |
batch_size = 10#1#2#50#20#2#100 # 10
|
| 242 |
-
n_epoch = 5#4# 10#50#20#20#2#5#25 # 120
|
| 243 |
HII_DIM = 64
|
| 244 |
num_redshift = 512#128#64#512#256#256#64#512#128
|
| 245 |
channel = 1
|
|
@@ -266,7 +266,7 @@ class TrainConfig:
|
|
| 266 |
# seed = 0
|
| 267 |
# save_dir = './outputs/'
|
| 268 |
|
| 269 |
-
save_period =
|
| 270 |
# general parameters for the name and logger
|
| 271 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 272 |
lrate = 1e-4
|
|
@@ -580,13 +580,16 @@ class DDPM21CM:
|
|
| 580 |
# else:
|
| 581 |
return x_last
|
| 582 |
# %%
|
|
|
|
|
|
|
|
|
|
| 583 |
def train(rank, world_size):
|
| 584 |
config = TrainConfig()
|
| 585 |
config.world_size = world_size
|
| 586 |
|
| 587 |
ddp_setup(rank, world_size)
|
| 588 |
|
| 589 |
-
|
| 590 |
for i, num_image in enumerate(num_train_image_list):
|
| 591 |
config.num_image = num_image
|
| 592 |
# config.world_size = world_size
|
|
@@ -607,20 +610,6 @@ if __name__ == "__main__":
|
|
| 607 |
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
| 608 |
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|
| 609 |
|
| 610 |
-
# %%
|
| 611 |
-
# print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
|
| 612 |
-
# print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
|
| 613 |
-
# print("torch.cuda.current_device() =", torch.cuda.current_device())
|
| 614 |
-
# print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())
|
| 615 |
-
# print("torch.cuda.get_device_properties(torch.cuda.device) =", torch.cuda.get_device_properties(torch.cuda.device))
|
| 616 |
-
# print('here')
|
| 617 |
-
# print(torch.cuda.memory_usage())
|
| 618 |
-
# print(torch.cuda.utilization())
|
| 619 |
-
# print(torch.cuda.memory())
|
| 620 |
-
# print('here')
|
| 621 |
-
# print(torch.cuda.memory_summary())
|
| 622 |
-
# %% [markdown]
|
| 623 |
-
# # Sampling
|
| 624 |
|
| 625 |
# %%
|
| 626 |
|
|
@@ -677,7 +666,7 @@ if __name__ == "__main__":
|
|
| 677 |
world_size = torch.cuda.device_count()
|
| 678 |
print(f" sampling, world_size = {world_size} ".center(100,'-'))
|
| 679 |
# num_train_image_list = [1600,3200,6400,12800,25600]
|
| 680 |
-
num_train_image_list = [
|
| 681 |
num_new_img_per_gpu = 40
|
| 682 |
max_num_img_per_gpu = 20
|
| 683 |
|
|
|
|
| 239 |
stride = (2,2) if dim == 2 else (2,2,4)
|
| 240 |
num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 241 |
batch_size = 10#1#2#50#20#2#100 # 10
|
| 242 |
+
n_epoch = 10#5#4# 10#50#20#20#2#5#25 # 120
|
| 243 |
HII_DIM = 64
|
| 244 |
num_redshift = 512#128#64#512#256#256#64#512#128
|
| 245 |
channel = 1
|
|
|
|
| 266 |
# seed = 0
|
| 267 |
# save_dir = './outputs/'
|
| 268 |
|
| 269 |
+
save_period = np.infty#.1 # the period of sampling
|
| 270 |
# general parameters for the name and logger
|
| 271 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 272 |
lrate = 1e-4
|
|
|
|
| 580 |
# else:
|
| 581 |
return x_last
|
| 582 |
# %%
|
| 583 |
+
|
| 584 |
+
num_train_image_list = [5000]
|
| 585 |
+
|
| 586 |
def train(rank, world_size):
|
| 587 |
config = TrainConfig()
|
| 588 |
config.world_size = world_size
|
| 589 |
|
| 590 |
ddp_setup(rank, world_size)
|
| 591 |
|
| 592 |
+
#[3200]#[200]#[1600,3200,6400,12800,25600]
|
| 593 |
for i, num_image in enumerate(num_train_image_list):
|
| 594 |
config.num_image = num_image
|
| 595 |
# config.world_size = world_size
|
|
|
|
| 610 |
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
| 611 |
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|
| 612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
|
| 614 |
# %%
|
| 615 |
|
|
|
|
| 666 |
world_size = torch.cuda.device_count()
|
| 667 |
print(f" sampling, world_size = {world_size} ".center(100,'-'))
|
| 668 |
# num_train_image_list = [1600,3200,6400,12800,25600]
|
| 669 |
+
# num_train_image_list = [5000]
|
| 670 |
num_new_img_per_gpu = 40
|
| 671 |
max_num_img_per_gpu = 20
|
| 672 |
|