{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 改編ContextUnet及相關代碼,使其首先對二維的情況適用。並於diffusers.Unet2DModel作比較並加以優化。最後再改寫爲3維的情形。\n", "- 經試用diffusers的Unet2DModel,發現loss從0.3降到0.2但仍然很高,説明存在非Unet2DModel的問題可以優化\n", "- 改用diffusers的DDMPScheduler和DDPMPipeline后,loss降低至0.1以下,有時甚至可以低至0.004,可見我的代碼問題主要出在DDPM部分。DDPMScheduler部分比較簡短,似乎沒有問題,所以問題應該在DDPMPipeline裏某一部分代碼是我代碼欠缺的。\n", "- 我在DDPMScheduler部分有一個typo,導致beta_t一直很小,修正后loss從0.2能降低至0.02, 維持在0.1以下\n", "- 用diffusers的DDPMScheduler似乎效果要好一些,loss總是比我的DDPMScheduler要小一點。儅epoch為19時,前者的loss約0.02,後者loss約0.07。而且前者還支持3維圖像的加噪,不如直接用別人的輪子。但我想知道爲什麽我的loss會高一些。\n", "- 我意識到別人的DDPMScheduler在sample函數中沒有兼容輸入參數,所以歸根結底還是需要我的DDPMscheduler。不過我可以先用別人的來debug我的ContextUnet.\n", "- 我需要將我的ContextUnet擴展兼容不同維度的照片,畢竟我本身也需要和原文獻對比完了再拓展到三維的情形\n", "- 我已將我的ContextUnet轉成了2維的模式,與diffusers.Unet2DModel的loss=0.037相比,我的Unet的loss=0.07。同時我的Unet生成的圖像看上去很奇怪,説明我的Unet也有問題。我需要將代碼退回原Unet,並檢查問題所在。\n", "- 我將紅移方向的像素的數量限制在了64.以此比較兩個Unet的差別。經比較:\\\n", "Unet2DModel loss:0.03, 0.0655, 0.05, 0.02, 0.05\\\n", "ContextUnet loss: 0.1, 0.16, 0.1, 0.2186, 0.06\n", "- 我把ContextUnet退回到了原作者的版本,結果loss=0.05,輸出的照片也不錯。我主要的改動是改回了他原用的normalization函數,其中還有個參數swish。有時間我可以研究一下具體是哪裏影響了訓練的結果。另外我發現了要想tensorboard的圖綫獨立美觀,需要把他們放在不同的文件夾下\n", "- 經過驗證,GroupNorm比batchNorm效果要好\n", "- 已擴展爲接受不同維度的情形\n", "- 融合cond, guide_w, drop_out這些參數\n", "- 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題\n", "- 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形\n", "- 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。\n", "- 解決了,問題出在我忘了寫with torch.no_grad():\n", "- 接下來就是生成800個lightcones,與此同時研究如何計算global signal以及power spectrum\n", "- 儅訓練圖片的數量達到5000時,生成的圖片與檢測數據的相似程度很高\n", "- it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.\n", "- the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.\n", "- In addtion, the performance of DDPM can looks better compared to computation-intensive simulations. " ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", "import h5py\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader, Dataset\n", "# from datasets import Dataset\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import random\n", "# from abc import ABC, abstractmethod\n", "import torch.nn.functional as F\n", "import math\n", "# from PIL import Image\n", "import os\n", "from torch.utils.tensorboard import SummaryWriter\n", "import copy\n", "from tqdm.auto import tqdm\n", "# from torchvision import transforms\n", "# from diffusers import UNet2DModel#, UNet3DConditionModel\n", "# from diffusers import DDPMScheduler\n", "from diffusers.utils import make_image_grid\n", "import datetime\n", "from pathlib import Path\n", "from diffusers.optimization import get_cosine_schedule_with_warmup\n", "from accelerate import notebook_launcher, Accelerator\n", "from huggingface_hub import create_repo, upload_folder\n", "from load_h5 import Dataset4h5\n", "\n", "from context_unet import ContextUnet\n", "\n", "from huggingface_hub import notebook_login\n", "\n", "import torch.multiprocessing as mp\n", "from torch.utils.data.distributed import DistributedSampler\n", "from torch.nn.parallel import DistributedDataParallel as DDP\n", "from torch.distributed import init_process_group, destroy_process_group" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "def ddp_setup(rank: int, world_size: int):\n", " \"\"\"\n", " Args:\n", " rank: Unique identifier of each process\n", " world_size: Total number of processes\n", " \"\"\"\n", " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", " os.environ[\"MASTER_PORT\"] = \"12355\"\n", " torch.cuda.set_device(rank)\n", " init_process_group(backend=\"nccl\", rank=rank, world_size=world_size)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9bbf7e9db9ce426d9c59d6f6d8e8df29", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='