{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 改編ContextUnet及相關代碼,使其首先對二維的情況適用。並於diffusers.Unet2DModel作比較並加以優化。最後再改寫爲3維的情形。\n", "- 經試用diffusers的Unet2DModel,發現loss從0.3降到0.2但仍然很高,説明存在非Unet2DModel的問題可以優化\n", "- 改用diffusers的DDMPScheduler和DDPMPipeline后,loss降低至0.1以下,有時甚至可以低至0.004,可見我的代碼問題主要出在DDPM部分。DDPMScheduler部分比較簡短,似乎沒有問題,所以問題應該在DDPMPipeline裏某一部分代碼是我代碼欠缺的。\n", "- 我在DDPMScheduler部分有一個typo,導致beta_t一直很小,修正后loss從0.2能降低至0.02, 維持在0.1以下\n", "- 用diffusers的DDPMScheduler似乎效果要好一些,loss總是比我的DDPMScheduler要小一點。儅epoch為19時,前者的loss約0.02,後者loss約0.07。而且前者還支持3維圖像的加噪,不如直接用別人的輪子。但我想知道爲什麽我的loss會高一些。\n", "- 我意識到別人的DDPMScheduler在sample函數中沒有兼容輸入參數,所以歸根結底還是需要我的DDPMscheduler。不過我可以先用別人的來debug我的ContextUnet.\n", "- 我需要將我的ContextUnet擴展兼容不同維度的照片,畢竟我本身也需要和原文獻對比完了再拓展到三維的情形\n", "- 我已將我的ContextUnet轉成了2維的模式,與diffusers.Unet2DModel的loss=0.037相比,我的Unet的loss=0.07。同時我的Unet生成的圖像看上去很奇怪,説明我的Unet也有問題。我需要將代碼退回原Unet,並檢查問題所在。\n", "- 我將紅移方向的像素的數量限制在了64.以此比較兩個Unet的差別。經比較:\\\n", "Unet2DModel loss:0.03, 0.0655, 0.05, 0.02, 0.05\\\n", "ContextUnet loss: 0.1, 0.16, 0.1, 0.2186, 0.06\n", "- 我把ContextUnet退回到了原作者的版本,結果loss=0.05,輸出的照片也不錯。我主要的改動是改回了他原用的normalization函數,其中還有個參數swish。有時間我可以研究一下具體是哪裏影響了訓練的結果。另外我發現了要想tensorboard的圖綫獨立美觀,需要把他們放在不同的文件夾下\n", "- 經過驗證,GroupNorm比batchNorm效果要好\n", "- 已擴展爲接受不同維度的情形\n", "- 融合cond, guide_w, drop_out這些參數\n", "- 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題\n", "- 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形\n", "- 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。\n", "- 解決了,問題出在我忘了寫with torch.no_grad():\n", "- 接下來就是生成800個lightcones,與此同時研究如何計算global signal以及power spectrum\n", "- 儅訓練圖片的數量達到5000時,生成的圖片與檢測數據的相似程度很高\n", "- it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.\n", "- the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.\n", "- In addtion, the performance of DDPM can looks better compared to computation-intensive simulations. " ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", "import h5py\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader, Dataset\n", "# from datasets import Dataset\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import random\n", "# from abc import ABC, abstractmethod\n", "import torch.nn.functional as F\n", "import math\n", "# from PIL import Image\n", "import os\n", "from torch.utils.tensorboard import SummaryWriter\n", "import copy\n", "from tqdm.auto import tqdm\n", "# from torchvision import transforms\n", "# from diffusers import UNet2DModel#, UNet3DConditionModel\n", "# from diffusers import DDPMScheduler\n", "from diffusers.utils import make_image_grid\n", "import datetime\n", "from pathlib import Path\n", "from diffusers.optimization import get_cosine_schedule_with_warmup\n", "from accelerate import notebook_launcher, Accelerator\n", "from huggingface_hub import create_repo, upload_folder\n", "from load_h5 import Dataset4h5\n", "\n", "from context_unet import ContextUnet\n", "\n", "from huggingface_hub import notebook_login\n", "\n", "import torch.multiprocessing as mp\n", "from torch.utils.data.distributed import DistributedSampler\n", "from torch.nn.parallel import DistributedDataParallel as DDP\n", "from torch.distributed import init_process_group, destroy_process_group" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "def ddp_setup(rank: int, world_size: int):\n", " \"\"\"\n", " Args:\n", " rank: Unique identifier of each process\n", " world_size: Total number of processes\n", " \"\"\"\n", " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", " os.environ[\"MASTER_PORT\"] = \"12355\"\n", " torch.cuda.set_device(rank)\n", " init_process_group(backend=\"nccl\", rank=rank, world_size=world_size)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9bbf7e9db9ce426d9c59d6f6d8e8df29", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='
0 else 0\n", "\n", " if guide_w == -1:\n", " # eps = nn_model(x_i, t_is, return_dict=False)[0]\n", " eps = nn_model(x_i, t_is)\n", " # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n", " else:\n", " # double batch\n", " x_i = x_i.repeat(2, *torch.ones(len(self.img_shape), dtype=int).tolist())\n", " t_is = t_is.repeat(2)\n", "\n", " # split predictions and compute weighting\n", " # print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\n", " eps = nn_model(x_i, t_is, c_i)\n", " eps1 = eps[:n_sample]\n", " eps2 = eps[n_sample:]\n", " eps = eps1 + guide_w*(eps1 - eps2)\n", " # eps = (1+guide_w)*eps1 - guide_w*eps2\n", " x_i = x_i[:n_sample]\n", " # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n", " \n", " # print(\"x_i.shape =\", x_i.shape)\n", " x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n", " \n", " pbar_sample.update(1)\n", " # pbar_sample.set_postfix(step=i)\n", " \n", " # print(\"x_i.shape =\", x_i.shape)\n", " # store only part of the intermediate steps\n", " if i%20==0:# or i==0:# or i<8:\n", " x_i_entire.append(x_i.detach().cpu().numpy())\n", " x_i = x_i.detach().cpu().numpy()\n", " x_i_entire = np.array(x_i_entire)\n", " return x_i, x_i_entire\n", "\n", "\n", "# ddpm_scheduler = DDPMScheduler((1e-4,0.02),10)\n", "# noisy_images, noise, ts = ddpm_scheduler.add_noise(images)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class EMA:\n", " def __init__(self, beta):\n", " super().__init__()\n", " self.beta = beta\n", " self.step = 0\n", "\n", " def update_model_average(self, ma_model, current_model):\n", " for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):\n", " old_weight, up_weight = ma_params.data, current_params.data\n", " ma_params.data = self.update_average(old_weight, up_weight)\n", "\n", " def update_average(self, old, new):\n", " if old is None:\n", " return new\n", " return old * self.beta + (1 - self.beta) * new\n", "\n", " def step_ema(self, ema_model, model):\n", " self.update_model_average(ema_model, model)\n", " self.step += 1\n", "\n", " def reset_parameters(self, ema_model, model):\n", " ema_model.load_state_dict(model.state_dict())\n", " " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class TrainConfig:\n", " ###########################\n", " ## hardcoding these here ##\n", " ###########################\n", " push_to_hub = True\n", " hub_model_id = \"Xsmos/ml21cm\"\n", " hub_private_repo = False\n", " dataset_name = \"/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5\"\n", " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", " # world_size = torch.cuda.device_count()\n", " # repeat = 2\n", "\n", " # dim = 2\n", " dim = 3\n", " stride = (2,2) if dim == 2 else (2,2,1)\n", " num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n", " batch_size = 2#2#50#20#2#100 # 10\n", " n_epoch = 10#50#20#20#2#5#25 # 120\n", " HII_DIM = 28#64\n", " num_redshift = 4#128#64#512#256#256#64#512#128\n", " channel = 1\n", " img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n", "\n", " ranges_dict = dict(\n", " params = {\n", " 0: [4, 6], # ION_Tvir_MIN\n", " 1: [10, 250], # HII_EFF_FACTOR\n", " },\n", " images = {\n", " 0: [0, 80], # brightness_temp\n", " }\n", " )\n", "\n", " num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n", " # n_sample = 24 # 64, the number of samples in sampling process\n", " n_param = 2\n", " guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n", " drop_prob = 0#0.28 # only takes effect when guide_w != -1\n", " ema=True # whether to use ema\n", " ema_rate=0.995\n", "\n", " # seed = 0\n", " # save_dir = './outputs/'\n", "\n", " save_freq = 0#.1 # the period of sampling\n", " # general parameters for the name and logger \n", " # device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " lrate = 1e-4\n", " lr_warmup_steps = 0#5#00\n", " output_dir = \"./outputs/\"\n", " save_name = os.path.join(output_dir, 'model_state')\n", " # save_freq = 1 #10 # the period of saving model\n", " # cond = True # if training using the conditional information\n", " # lr_decay = False #True# if using the learning rate decay\n", " resume = save_name # if resume from the trained checkpoints\n", " # params_single = torch.tensor([0.2,0.80000023])\n", " # params = torch.tile(params_single,(n_sample,1)).to(device)\n", " # params = params\n", " # data_dir = './data' # data directory\n", "\n", "\n", " mixed_precision = \"fp16\"\n", " gradient_accumulation_steps = 1\n", "\n", " # date = datetime.datetime.now().strftime(\"%m%d-%H%M\")\n", " # run_name = f'{date}' # the unique name of each experiment\n", "\n", "# config = TrainConfig()\n", "# print(\"device =\", config.device)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# import os\n", "# print(os.cpu_count())\n", "# print(len(os.sched_getaffinity(0)))\n", "# import torch\n", "# data = torch.randn((64,64))\n", "# print(data.dtype)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# @dataclass\n", "class DDPM21CM:\n", " def __init__(self, config):\n", " # config = TrainConfig()\n", " # date = datetime.datetime.now().strftime(\"%m%d-%H%M\")\n", " config.run_name = datetime.datetime.now().strftime(\"%m%d-%H%M\") # the unique name of each experiment\n", " self.config = config\n", " # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n", " # # self.shape_loaded = dataset.images.shape\n", " # # print(\"shape_loaded =\", self.shape_loaded)\n", " # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)\n", " # del dataset\n", " self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n", "\n", " # initialize the unet\n", " self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)\n", "\n", " if config.resume and os.path.exists(config.resume):\n", " # resume_file = os.path.join(config.output_dir, f\"{config.resume}\")\n", " self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])\n", " print(f\"resumed nn_model from {config.resume}\")\n", " # nn_model = ContextUnet(n_param=1, image_size=28)\n", " self.nn_model.train()\n", " self.nn_model.to(self.ddpm.device)\n", " # print(\"nn_model.device =\", ddpm.device)\n", " # number of parameters to be trained\n", " self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())\n", " print(f\"Number of parameters for nn_model: {self.number_of_params}\")\n", "\n", " # whether to use ema\n", " if config.ema:\n", " self.ema = EMA(config.ema_rate)\n", " if config.resume and os.path.exists(config.resume):\n", " self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n", " self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])\n", " print(f\"resumed ema_model from {config.resume}\")\n", " else:\n", " self.ema_model = copy.deepcopy(self.nn_model).eval().requires_grad_(False)\n", "\n", " self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate)\n", " self.lr_scheduler = get_cosine_schedule_with_warmup(\n", " optimizer=self.optimizer,\n", " num_warmup_steps=config.lr_warmup_steps,\n", " num_training_steps=(int(config.num_image/config.batch_size) * config.n_epoch),\n", " # num_training_steps=(len(self.dataloader) * config.n_epoch),\n", " )\n", "\n", " self.ranges_dict = config.ranges_dict\n", "\n", " def load(self):\n", " dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)\n", " # self.shape_loaded = dataset.images.shape\n", " # print(\"shape_loaded =\", self.shape_loaded)\n", " self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=len(os.sched_getaffinity(0)), pin_memory=True)\n", " # del dataset\n", " # self.accelerate(self.config)\n", " del dataset\n", "\n", " # def accelerate(self):\n", "\n", " def train(self):\n", " ################### \n", " ## training loop ##\n", " ###################\n", " # plot_unet = True\n", "\n", " self.load()\n", " self.accelerator = Accelerator(\n", " mixed_precision=self.config.mixed_precision,\n", " gradient_accumulation_steps=self.config.gradient_accumulation_steps,\n", " log_with=\"tensorboard\",\n", " project_dir=os.path.join(self.config.output_dir, \"logs\"),\n", " )\n", " print(\"self.accelerator.is_main_process:\", self.accelerator.is_main_process)\n", " if self.accelerator.is_main_process:\n", " if self.config.output_dir is not None:\n", " os.makedirs(self.config.output_dir, exist_ok=True)\n", " if self.config.push_to_hub:\n", " self.repo_id = create_repo(\n", " repo_id=self.config.hub_model_id or Path(self.config.output_dir).name, exist_ok=True\n", " ).repo_id\n", " self.accelerator.init_trackers(f\"{self.config.run_name}\")\n", "\n", " self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \\\n", " self.accelerator.prepare(\n", " self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler\n", " )\n", " \n", " global_step = 0\n", " for ep in range(self.config.n_epoch):\n", " self.ddpm.train()\n", "\n", " pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)\n", " pbar_train.set_description(f\"Epoch {ep}\")\n", " for i, (x, c) in enumerate(self.dataloader):\n", " with self.accelerator.accumulate(self.nn_model):\n", " x = x.to(self.config.device)\n", " xt, noise, ts = self.ddpm.add_noise(x)\n", " \n", " if self.config.guide_w == -1:\n", " noise_pred = self.nn_model(xt, ts)\n", " else:\n", " c = c.to(self.config.device)\n", " noise_pred = self.nn_model(xt, ts, c)\n", " \n", " loss = F.mse_loss(noise, noise_pred)\n", " self.accelerator.backward(loss)\n", " self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)\n", " self.optimizer.step()\n", " self.lr_scheduler.step()\n", " self.optimizer.zero_grad()\n", "\n", " # ema update\n", " if self.config.ema:\n", " self.ema.step_ema(self.ema_model, self.nn_model)\n", "\n", " pbar_train.update(1)\n", " logs = dict(\n", " loss=loss.detach().item(),\n", " lr=self.optimizer.param_groups[0]['lr'],\n", " step=global_step\n", " )\n", " pbar_train.set_postfix(**logs)\n", "\n", " self.accelerator.log(logs, step=global_step)\n", " global_step += 1\n", "\n", " # if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:\n", " self.save(ep)\n", "\n", " del self.nn_model\n", " if self.config.ema:\n", " del self.ema_model\n", " torch.cuda.empty_cache()\n", "\n", " def save(self, ep):\n", " # save model\n", " if self.accelerator.is_main_process:\n", " if ep == self.config.n_epoch-1 or (ep+1)*self.config.save_freq==1:\n", " self.nn_model.eval()\n", " with torch.no_grad():\n", " if self.config.push_to_hub:\n", " upload_folder(\n", " repo_id = self.repo_id,\n", " folder_path = \".\",#config.output_dir,\n", " commit_message = f\"{self.config.run_name}\",\n", " ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\", \"__pycache__\"],\n", " )\n", " if self.config.save_name:\n", " model_state = {\n", " 'epoch': ep,\n", " 'unet_state_dict': self.nn_model.state_dict(),\n", " 'ema_unet_state_dict': self.ema_model.state_dict(),\n", " }\n", " torch.save(model_state, self.config.save_name+f\"-N{self.config.num_image}\")\n", " print('saved model at ' + self.config.save_name+f\"-N{self.config.num_image}\")\n", " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n", "\n", " # def rescale(self, value, type='params', to_ranges=[0,1]):\n", " # for i, from_ranges in self.ranges_dict[type].items():\n", " # value[i] = (value[i] - from_ranges[0])/(from_ranges[1]-from_ranges[0]) # normalize\n", " # value[i] = \n", " def rescale(self, value, ranges, to: list):\n", " if value.ndim == 1:\n", " value = value.view(-1,len(value))\n", " \n", " for i in range(np.shape(value)[1]):\n", " value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0])\n", " # print(f\"i = {i}, value.min = {value[:,i].min()}, value.max = {value[:,i].max()}\")\n", " value = value * (to[1]-to[0]) + to[0]\n", " return value \n", "\n", " def sample(self, file, params:torch.tensor=None, repeat=192, ema=False, entire=False):\n", " # n_sample = params.shape[0]\n", " \n", " if params is None:\n", " params = torch.tensor([0.20000000000000018, 0.5055875000000001])\n", " params_backup = params.numpy().copy()\n", " else:\n", " params_backup = params.numpy().copy()\n", " params = self.rescale(params, self.ranges_dict['params'], to=[0,1])\n", "\n", " print(f\"sampling {repeat} images with normalized params = {params}\")\n", " params = params.repeat(repeat,1)\n", " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n", " # print(\"params =\", params)\n", " # print(\"params =\", params)\n", " # print(\"len(params) =\", len(params))\n", " # model = self.ema_model if ema else self.nn_model\n", " # del self.ema_model, self.nn\n", " # params = torch.tile(params, (n_sample,1)).to(device)\n", "\n", " nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)\n", " if ema:\n", " nn_model.load_state_dict(torch.load(file)['ema_unet_state_dict'])\n", " else:\n", " nn_model.load_state_dict(torch.load(file)['unet_state_dict'])\n", " print(f\"nn_model resumed from {file}\")\n", " # nn_model = ContextUnet(n_param=1, image_size=28)\n", " # nn_model.train()\n", " nn_model.to(self.ddpm.device)\n", " nn_model.eval()\n", "\n", " # self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n", " # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n", " # print(f\"resumed ema_model from {config.resume}\")\n", "\n", " with torch.no_grad():\n", " x_last, x_entire = self.ddpm.sample(\n", " nn_model=nn_model, \n", " params=params.to(self.config.device), \n", " device=self.config.device, \n", " guide_w=self.config.guide_w\n", " )\n", "\n", " # np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else ''}.npy\"), x_last)\n", " np.save(os.path.join(self.config.output_dir, f\"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}.npy\"), x_last)\n", "\n", " if entire:\n", " np.save(os.path.join(self.config.output_dir, f\"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}_entire.npy\"), x_last)\n", "# print(\"device =\", config.device)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"\", line 1, in \n", " File \"/storage/home/hcoda1/3/bxia34/.conda/envs/diffusers/lib/python3.9/multiprocessing/spawn.py\", line 116, in spawn_main\n", " exitcode = _main(fd, parent_sentinel)\n", " File \"/storage/home/hcoda1/3/bxia34/.conda/envs/diffusers/lib/python3.9/multiprocessing/spawn.py\", line 126, in _main\n", " self = reduction.pickle.load(from_parent)\n", "AttributeError: Can't get attribute 'single_main' on \n" ] }, { "ename": "ProcessExitedException", "evalue": "process 0 terminated with exit code 1", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mProcessExitedException\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[12], line 21\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m__name__\u001b[39m \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m__main__\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 17\u001b[0m \u001b[39m# torch.multiprocessing.set_start_method(\"spawn\")\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[39m# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\u001b[39;00m\n\u001b[1;32m 19\u001b[0m world_size \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\u001b[39m#torch.cuda.device_count()\u001b[39;00m\n\u001b[0;32m---> 21\u001b[0m mp\u001b[39m.\u001b[39;49mspawn(single_main, args\u001b[39m=\u001b[39;49m(world_size,), nprocs\u001b[39m=\u001b[39;49mworld_size)\n\u001b[1;32m 22\u001b[0m \u001b[39m# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')\u001b[39;00m\n", "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:240\u001b[0m, in \u001b[0;36mspawn\u001b[0;34m(fn, args, nprocs, join, daemon, start_method)\u001b[0m\n\u001b[1;32m 236\u001b[0m msg \u001b[39m=\u001b[39m (\u001b[39m'\u001b[39m\u001b[39mThis method only supports start_method=spawn (got: \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m).\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[1;32m 237\u001b[0m \u001b[39m'\u001b[39m\u001b[39mTo use a different start_method use:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[1;32m 238\u001b[0m \u001b[39m'\u001b[39m\u001b[39m torch.multiprocessing.start_processes(...)\u001b[39m\u001b[39m'\u001b[39m \u001b[39m%\u001b[39m start_method)\n\u001b[1;32m 239\u001b[0m warnings\u001b[39m.\u001b[39mwarn(msg)\n\u001b[0;32m--> 240\u001b[0m \u001b[39mreturn\u001b[39;00m start_processes(fn, args, nprocs, join, daemon, start_method\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mspawn\u001b[39;49m\u001b[39m'\u001b[39;49m)\n", "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:198\u001b[0m, in \u001b[0;36mstart_processes\u001b[0;34m(fn, args, nprocs, join, daemon, start_method)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[39mreturn\u001b[39;00m context\n\u001b[1;32m 197\u001b[0m \u001b[39m# Loop on join until it returns True or raises an exception.\u001b[39;00m\n\u001b[0;32m--> 198\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mnot\u001b[39;00m context\u001b[39m.\u001b[39;49mjoin():\n\u001b[1;32m 199\u001b[0m \u001b[39mpass\u001b[39;00m\n", "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:149\u001b[0m, in \u001b[0;36mProcessContext.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[39mraise\u001b[39;00m ProcessExitedException(\n\u001b[1;32m 141\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mprocess \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with signal \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m\n\u001b[1;32m 142\u001b[0m (error_index, name),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 146\u001b[0m signal_name\u001b[39m=\u001b[39mname\n\u001b[1;32m 147\u001b[0m )\n\u001b[1;32m 148\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 149\u001b[0m \u001b[39mraise\u001b[39;00m ProcessExitedException(\n\u001b[1;32m 150\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mprocess \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with exit code \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m\n\u001b[1;32m 151\u001b[0m (error_index, exitcode),\n\u001b[1;32m 152\u001b[0m error_index\u001b[39m=\u001b[39merror_index,\n\u001b[1;32m 153\u001b[0m error_pid\u001b[39m=\u001b[39mfailed_process\u001b[39m.\u001b[39mpid,\n\u001b[1;32m 154\u001b[0m exit_code\u001b[39m=\u001b[39mexitcode\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 157\u001b[0m original_trace \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39merror_queues[error_index]\u001b[39m.\u001b[39mget()\n\u001b[1;32m 158\u001b[0m msg \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\n\u001b[39;00m\u001b[39m-- Process \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with the following error:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m error_index\n", "\u001b[0;31mProcessExitedException\u001b[0m: process 0 terminated with exit code 1" ] } ], "source": [ "def single_main(rank, world_size):\n", " config = TrainConfig()\n", " ddp_setup(rank, world_size)\n", " \n", " num_image_list = [100]#[200]#[1600,3200,6400,12800,25600]\n", " for i, num_image in enumerate(num_image_list):\n", " config.num_image = num_image\n", " # config.world_size = world_size\n", " \n", " ddpm21cm = DDPM21CM(config)\n", " print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n", " print(f\"run_name = {ddpm21cm.config.run_name}\")\n", " ddpm21cm.train()\n", "\n", " \n", "if __name__ == \"__main__\":\n", " # torch.multiprocessing.set_start_method(\"spawn\")\n", " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n", " world_size = 1#torch.cuda.device_count()\n", "\n", " mp.spawn(single_main, args=(world_size,), nprocs=world_size)\n", " # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# torch.cuda.set_device(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "2\n", "['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__path__', '__file__', '__cached__', '__builtins__', '__annotations__', 'contextlib', 'os', 'torch', 'Device', 'traceback', 'warnings', 'threading', 'List', 'Optional', 'Tuple', 'Union', 'Any', '_utils', '_get_device_index', '_dummy_type', 'classproperty', 'graphs', 'CUDAGraph', 'graph_pool_handle', 'graph', 'make_graphed_callables', 'is_current_stream_capturing', 'streams', 'ExternalStream', 'Stream', 'Event', '_device', '_cudart', '_initialized', '_tls', '_initialization_lock', '_queued_calls', '_is_in_bad_fork', '_device_t', '_LazySeedTracker', '_lazy_seed_tracker', '_CudaDeviceProperties', 'has_magma', 'has_half', 'default_generators', 'is_available', 'is_bf16_supported', '_sleep', '_check_capability', '_check_cubins', 'is_initialized', '_lazy_call', 'DeferredCudaCallError', 'init', '_lazy_init', 'cudart', 'cudaStatus', 'CudaError', 'check_error', 'device', 'device_of', 'set_device', 'get_device_name', 'get_device_capability', 'get_device_properties', 'can_device_access_peer', 'StreamContext', 'stream', 'set_stream', 'device_count', 'get_arch_list', 'get_gencode_flags', 'current_device', 'synchronize', 'ipc_collect', 'current_stream', 'default_stream', 'current_blas_handle', 'set_sync_debug_mode', 'get_sync_debug_mode', 'memory_usage', 'utilization', 'memory', 'caching_allocator_alloc', 'caching_allocator_delete', 'set_per_process_memory_fraction', 'empty_cache', 'memory_stats', 'memory_stats_as_nested_dict', 'reset_accumulated_memory_stats', 'reset_peak_memory_stats', 'reset_max_memory_allocated', 'reset_max_memory_cached', 'memory_allocated', 'max_memory_allocated', 'memory_reserved', 'max_memory_reserved', 'memory_cached', 'max_memory_cached', 'memory_snapshot', 'memory_summary', 'list_gpu_processes', 'mem_get_info', 'random', 'get_rng_state', 'get_rng_state_all', 'set_rng_state', 'set_rng_state_all', 'manual_seed', 'manual_seed_all', 'seed', 'seed_all', 'initial_seed', '_lazy_new', '_CudaBase', 'ByteStorage', 'DoubleStorage', 'FloatStorage', 'HalfStorage', 'LongStorage', 'IntStorage', 'ShortStorage', 'CharStorage', 'BoolStorage', 'BFloat16Storage', 'ComplexDoubleStorage', 'ComplexFloatStorage', 'sparse', 'profiler', 'nvtx', 'amp', 'jiterator', 'ByteTensor', 'CharTensor', 'DoubleTensor', 'FloatTensor', 'IntTensor', 'LongTensor', 'ShortTensor', 'HalfTensor', 'BoolTensor', 'BFloat16Tensor', 'nccl', '_get_device_properties']\n" ] } ], "source": [ "print(torch.cuda.is_available())\n", "print(torch.cuda.device_count())\n", "print(torch.cuda.__dir__())" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "\n", "Quadro RTX 6000\n", "0\n", "(7, 5)\n", "_CudaDeviceProperties(name='Quadro RTX 6000', major=7, minor=5, total_memory=24212MB, multi_processor_count=72)\n" ] } ], "source": [ "print(torch.cuda.is_initialized())\n", "print(torch.cuda.device)\n", "print(torch.cuda.get_device_name())\n", "print(torch.cuda.current_device())\n", "print(torch.cuda.get_device_capability())\n", "print(torch.cuda.get_device_properties(torch.cuda.device))\n", "# print('here')\n", "# print(torch.cuda.memory_usage())\n", "# print(torch.cuda.utilization())\n", "# print(torch.cuda.memory())\n", "# print('here')\n", "# print(torch.cuda.memory_summary())" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Sampling" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if __name__ == \"__main__\":\n", " # num_image_list = [1600,3200,6400,12800,25600]\n", " num_image_list = [1000]\n", " # num_image_list = [3200,6400,12800,25600]\n", " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n", " repeat = 2\n", " config = TrainConfig()\n", " for i, num_image in enumerate(num_image_list):\n", " config.num_image = num_image\n", " ddpm21cm = DDPM21CM(config)\n", "\n", " ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor([4.4, 131.341]), repeat=repeat)\n", "\n", " # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((5.6, 19.037)), repeat=repeat)\n", "\n", " # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((4.699, 30)), repeat=repeat)\n", "\n", " # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((5.477, 200)), repeat=repeat)\n", "\n", " # ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((4.8, 131.341)), repeat=repeat)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# ls -lth outputs | head" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_grid(samples, c=None, row=1, col=2):\n", " print(\"samples.shape =\", samples.shape)\n", " for j in range(samples.shape[4]):\n", " plt.figure(figsize = (12,6), dpi=400)\n", " for i in range(len(samples)):\n", " plt.subplot(row,col,i+1)\n", " plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)\n", " plt.xticks([])\n", " plt.yticks([])\n", " # plt.suptitle(f\"ION_Tvir_MIN = {c[0][0]}, HII_EFF_FACTOR = {c[0][1]}\")\n", " # plt.show()\n", " # plt.suptitle('simulations')\n", " plt.tight_layout()\n", " plt.subplots_adjust(wspace=0, hspace=0)\n", " plt.savefig(f\"test3D-{j:03d}.png\")\n", " plt.close()\n", " # plt.show()\n", " \n", "data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\")\n", "# print(data.shape)\n", "plot_grid(data)\n", "# plt.imshow(data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# config = TrainConfig()\n", "# def plot(filename, row=4, col=6):\n", "# samples = np.load(filename)\n", "# params = filename.split('guide_w')[-1][:-4]\n", "# print(\"plotting\", samples.shape, params)\n", "# plt.figure(figsize = (8,8))\n", "# for i in range(24):\n", "# plt.subplot(row,col,i+1)\n", "# plt.imshow(samples[i,0,:,:], cmap='gray')#, vmin=-1, vmax=1)\n", "# plt.xticks([])\n", "# plt.yticks([])\n", "# # plt.show()\n", "# plt.suptitle(params)\n", "# plt.tight_layout()\n", "# plt.subplots_adjust(wspace=0, hspace=0) \n", "# plt.show()\n", "# # plt.savefig('outputs/'+params+'.png')\n", "# # plt.close()\n", "# # plt.imshow(images[0,0])\n", "# # plt.show()" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# import torch\n", "# print(torch.__version__)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# import torch\n", "# import os\n", "\n", "# def compare_models(num_gpus):\n", "# model_states = []\n", " \n", "# for gpu_id in range(num_gpus):\n", "# filename = f\"outputs/model_state-N40-device{gpu_id}\"\n", "# if os.path.exists(filename):\n", "# state_dict = torch.load(filename, map_location='cpu')\n", "# model_states.append(state_dict)\n", "# print(filename)\n", "# else:\n", "# print(f\"File {filename} not found!\")\n", "# return False\n", " \n", "# # Compare all model state_dicts\n", "# print(\"len(model_states) =\", len(model_states))\n", "# base_state = model_states[0]\n", "# for state in model_states[1:]:\n", "# for key in base_state.keys():\n", "# # print(key, base_state[key], state[key])\n", "# print(key)\n", "# print(\"epoch\", base_state['epoch'], state['epoch'])\n", "\n", "# print(base_state['unet_state_dict'].keys())\n", "# for key in base_state['unet_state_dict']:\n", "# # print(key)\n", "# if not torch.equal(base_state['unet_state_dict'][key], state['unet_state_dict'][key]):\n", "# print(\"different\")\n", "# return \n", "# # else:\n", "# print(\"exactly same\")\n", "\n", "# # if key == 'epoch':\n", "# # print(base_state[key], state[key])\n", "# # else:\n", "# # print(base_state[key], state[key])\n", "# # if not torch.equal(base_state[key], state[key]):\n", "# # # if not (base_state[key] == state[key]):\n", "# # print(f\"Mismatch found in parameter {key}\")\n", "# # return False\n", " \n", "# # print(\"All models are identical!\")\n", "# # return True\n", "\n", "# if __name__ == \"__main__\":\n", "# # epoch_to_check = 0 # specify the epoch you want to check\n", "# num_gpus = torch.cuda.device_count() # specify the number of GPUs used in training\n", "# compare_models(num_gpus)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "test = np.random.normal(0,1,(800,1,64,64,512))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "12.5" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(test.itemsize*test.size) / 1024/1024/1024" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "del test" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }