|
|
|
|
|
import logging |
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import random |
|
|
|
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
import os |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
import copy |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
|
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from huggingface_hub import create_repo, upload_folder |
|
|
|
|
|
from load_h5 import Dataset4h5 |
|
|
from context_unet import ContextUnet |
|
|
|
|
|
from huggingface_hub import notebook_login |
|
|
|
|
|
import torch.multiprocessing as mp |
|
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.distributed import init_process_group, destroy_process_group |
|
|
import torch.distributed as dist |
|
|
|
|
|
import argparse |
|
|
import socket |
|
|
import sys |
|
|
from datetime import timedelta |
|
|
from time import time |
|
|
|
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
from random import getrandbits |
|
|
|
|
|
import subprocess |
|
|
|
|
|
|
|
|
def ddp_setup(rank: int, world_size: int, master_addr, master_port): |
|
|
""" |
|
|
Args: |
|
|
rank: Unique identifier of each process |
|
|
world_size: Total number of processes |
|
|
""" |
|
|
|
|
|
|
|
|
os.environ["MASTER_ADDR"] = master_addr |
|
|
os.environ["MASTER_PORT"] = master_port |
|
|
|
|
|
init_process_group( |
|
|
backend="nccl", |
|
|
init_method=f"tcp://{master_addr}:{master_port}", |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
timeout=timedelta(minutes=20) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DDPMScheduler(nn.Module): |
|
|
def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu', config=None): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
beta_1, beta_T = betas |
|
|
assert 0 < beta_1 <= beta_T <= 1, "ensure 0 < beta_1 <= beta_T <= 1" |
|
|
self.device = device |
|
|
self.num_timesteps = num_timesteps |
|
|
self.img_shape = img_shape |
|
|
self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) |
|
|
|
|
|
self.beta_t = self.beta_t.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
self.alpha_t = 1 - self.beta_t |
|
|
|
|
|
self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0) |
|
|
|
|
|
self.config = config |
|
|
|
|
|
def add_noise(self, clean_images): |
|
|
shape = clean_images.shape |
|
|
expand = torch.ones(len(shape)-1, dtype=int) |
|
|
|
|
|
|
|
|
|
|
|
noise = torch.randn_like(clean_images).to(self.device) |
|
|
ts = torch.randint(0, self.num_timesteps, (shape[0],)).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
noisy_images = ( |
|
|
clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(shape[0], *expand.tolist()) |
|
|
+ noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(shape[0], *expand.tolist()) |
|
|
) |
|
|
|
|
|
|
|
|
return noisy_images, noise, ts |
|
|
|
|
|
def sample(self, nn_model, params, device, guide_w = 0): |
|
|
n_sample = len(params) |
|
|
|
|
|
x_i = torch.randn(n_sample, *self.img_shape) |
|
|
x_i = x_i.to(device) |
|
|
|
|
|
|
|
|
|
|
|
if guide_w != -1: |
|
|
c_i = params |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_i_entire = [] |
|
|
|
|
|
|
|
|
|
|
|
pbar_sample = tqdm(total=self.num_timesteps, file=sys.stderr, disable=True) |
|
|
pbar_sample.set_description(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} sampling") |
|
|
for i in reversed(range(0, self.num_timesteps)): |
|
|
|
|
|
t_is = torch.tensor([i]).to(device) |
|
|
t_is = t_is.repeat(n_sample) |
|
|
|
|
|
|
|
|
z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else torch.tensor(0.) |
|
|
|
|
|
|
|
|
if guide_w == -1: |
|
|
|
|
|
eps = nn_model(x_i, t_is) |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eps = nn_model(x_i, t_is, c_i) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z |
|
|
|
|
|
|
|
|
pbar_sample.update(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_i_entire = np.array(x_i_entire) |
|
|
x_i = x_i.detach().cpu().numpy() |
|
|
return x_i, x_i_entire |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EMA: |
|
|
def __init__(self, beta): |
|
|
super().__init__() |
|
|
self.beta = beta |
|
|
self.step = 0 |
|
|
|
|
|
def update_model_average(self, ma_model, current_model): |
|
|
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): |
|
|
old_weight, up_weight = ma_params.data, current_params.data |
|
|
ma_params.data = self.update_average(old_weight, up_weight) |
|
|
|
|
|
def update_average(self, old, new): |
|
|
if old is None: |
|
|
return new |
|
|
return old * self.beta + (1 - self.beta) * new |
|
|
|
|
|
def step_ema(self, ema_model, model): |
|
|
self.update_model_average(ema_model, model) |
|
|
self.step += 1 |
|
|
|
|
|
def reset_parameters(self, ema_model, model): |
|
|
ema_model.load_state_dict(model.state_dict()) |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainConfig: |
|
|
|
|
|
|
|
|
|
|
|
push_to_hub = False |
|
|
hub_model_id = "Xsmos/ml21cm" |
|
|
hub_private_repo = False |
|
|
dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5" |
|
|
device = "cuda" if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
world_size = 1 |
|
|
|
|
|
|
|
|
|
|
|
dim = 3 |
|
|
stride = (2,4) if dim == 2 else (2,2,4) |
|
|
num_image = 32 |
|
|
batch_size = 1 |
|
|
n_epoch = 100 |
|
|
HII_DIM = 64 |
|
|
num_redshift = 1024 |
|
|
startat = 0 |
|
|
|
|
|
channel = 1 |
|
|
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift) |
|
|
|
|
|
ranges_dict = dict( |
|
|
params = { |
|
|
0: [4, 6], |
|
|
1: [10, 250], |
|
|
}, |
|
|
images = { |
|
|
0: [-338, 54], |
|
|
} |
|
|
) |
|
|
|
|
|
num_timesteps = 1000 |
|
|
|
|
|
n_param = 2 |
|
|
guide_w = 0 |
|
|
dropout = 0 |
|
|
|
|
|
ema=False |
|
|
ema_rate=0.995 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_period = 10 |
|
|
|
|
|
|
|
|
lrate = 1e-4 |
|
|
lr_warmup_steps = 0 |
|
|
output_dir = "./outputs/" |
|
|
save_name = os.path.join(output_dir, 'model') |
|
|
|
|
|
|
|
|
|
|
|
resume = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradient_accumulation_steps = 1 |
|
|
|
|
|
|
|
|
|
|
|
channel_mult = (1,2,2,2,4) |
|
|
|
|
|
|
|
|
str_len = 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gpu_info(device): |
|
|
total_memory = torch.cuda.get_device_properties(device).total_memory |
|
|
reserved_memory = torch.cuda.memory_reserved(device) |
|
|
allocated_memory = torch.cuda.memory_allocated(device) |
|
|
free_memory = reserved_memory - allocated_memory |
|
|
return { |
|
|
'total': int(total_memory / 1024**2), |
|
|
'used': int(allocated_memory / 1024**2), |
|
|
'free': int(free_memory / 1024**2), |
|
|
} |
|
|
|
|
|
class DDPM21CM: |
|
|
def __init__(self, config): |
|
|
config.run_name = os.environ.get("SLURM_JOB_ID", datetime.now().strftime("%d%H%M%S")) |
|
|
self.config = config |
|
|
self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, config=config,) |
|
|
|
|
|
|
|
|
self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, channel_mult=config.channel_mult, use_checkpoint=config.use_checkpoint, dropout=config.dropout) |
|
|
|
|
|
self.nn_model.train() |
|
|
self.nn_model.to(self.ddpm.device) |
|
|
self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device]) |
|
|
|
|
|
|
|
|
if config.resume and os.path.exists(config.resume): |
|
|
|
|
|
|
|
|
|
|
|
self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict']) |
|
|
|
|
|
print(f"{config.run_name} cuda:{torch.cuda.current_device()}/{self.config.global_rank} resumed nn_model from {config.resume} with {sum(x.numel() for x in self.nn_model.parameters())} parameters, {datetime.now().strftime('%d-%H:%M:%S.%f')}".center(self.config.str_len,'+')) |
|
|
else: |
|
|
print(f"{config.run_name} cuda:{torch.cuda.current_device()}/{self.config.global_rank} initialized nn_model randomly with {sum(x.numel() for x in self.nn_model.parameters())} parameters, {datetime.now().strftime('%d-%H:%M:%S.%f')}".center(self.config.str_len,'+')) |
|
|
|
|
|
|
|
|
if config.ema: |
|
|
self.ema = EMA(config.ema_rate) |
|
|
if config.resume and os.path.exists(config.resume): |
|
|
self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device, dropout=config.dropout) |
|
|
self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict']) |
|
|
print(f"resumed ema_model from {config.resume}") |
|
|
else: |
|
|
self.ema_model = copy.deepcopy(self.nn_model).eval().requires_grad_(False) |
|
|
|
|
|
self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate) |
|
|
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
optimizer = self.optimizer, |
|
|
T_max = int(config.num_image / config.batch_size * config.n_epoch / config.gradient_accumulation_steps), |
|
|
) |
|
|
|
|
|
self.ranges_dict = config.ranges_dict |
|
|
self.scaler = GradScaler() |
|
|
|
|
|
def load(self): |
|
|
dataset = Dataset4h5( |
|
|
self.config.dataset_name, |
|
|
num_image=self.config.num_image, |
|
|
idx = 'range', |
|
|
HII_DIM=self.config.HII_DIM, |
|
|
num_redshift=self.config.num_redshift, |
|
|
startat=self.config.startat, |
|
|
|
|
|
dim=self.config.dim, |
|
|
ranges_dict=self.ranges_dict, |
|
|
num_workers=min(1,len(os.sched_getaffinity(0))//self.config.world_size), |
|
|
str_len = self.config.str_len, |
|
|
) |
|
|
|
|
|
|
|
|
dataloader_start = time() |
|
|
self.dataloader = DataLoader( |
|
|
dataset=dataset, |
|
|
batch_size=self.config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=len(os.sched_getaffinity(0))//self.config.world_size, |
|
|
pin_memory=True, |
|
|
persistent_workers=True, |
|
|
|
|
|
) |
|
|
if len(self.dataloader) % self.config.gradient_accumulation_steps != 0: |
|
|
raise ValueError(f"len(self.dataloader) % self.config.gradient_accumulation_steps = {len(self.dataloader) % self.config.gradient_accumulation_steps} instead of 0. Make sure len(dataloader)={len(self.dataloader)} is dividable by gradient_accumulation_steps={self.config.gradient_accumulation_steps}.") |
|
|
|
|
|
dataloader_end = time() |
|
|
|
|
|
|
|
|
del dataset |
|
|
|
|
|
def transform(self, img, idx=0): |
|
|
|
|
|
flip_xy = [i+2 for i in range(2) if getrandbits(1)] |
|
|
img[idx] = torch.flip(img[idx], dims=flip_xy) |
|
|
|
|
|
if getrandbits(1): |
|
|
img = img.transpose(2,3) |
|
|
|
|
|
|
|
|
|
|
|
return img |
|
|
|
|
|
def train(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.load() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.global_rank == 0: |
|
|
if self.config.output_dir is not None: |
|
|
os.makedirs(self.config.output_dir, exist_ok=True) |
|
|
if self.config.push_to_hub: |
|
|
self.repo_id = create_repo( |
|
|
repo_id=self.config.hub_model_id or Path(self.config.output_dir).name, exist_ok=True |
|
|
).repo_id |
|
|
|
|
|
self.config.logger = SummaryWriter(f"logs/{self.config.run_name}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
|
|
|
|
torch.distributed.barrier() |
|
|
else: |
|
|
print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} torch.distributed.is_initialized False!!!!!!!!!!!!!!!") |
|
|
|
|
|
global_step = 0 |
|
|
for ep in range(self.config.n_epoch): |
|
|
self.ddpm.train() |
|
|
pbar_train = tqdm(total=len(self.dataloader), file=sys.stderr, disable=True) |
|
|
pbar_train.set_description(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} Epoch {ep}") |
|
|
epoch_start = time() |
|
|
for i, (x, c) in enumerate(self.dataloader): |
|
|
if self.config.dim == 3: |
|
|
x = self.transform(x) |
|
|
|
|
|
|
|
|
|
|
|
x = x.to(self.config.device) |
|
|
|
|
|
with autocast(enabled=self.config.autocast): |
|
|
xt, noise, ts = self.ddpm.add_noise(x) |
|
|
|
|
|
if self.config.guide_w == -1: |
|
|
noise_pred = self.nn_model(xt, ts) |
|
|
else: |
|
|
c = c.to(self.config.device) |
|
|
noise_pred = self.nn_model(xt, ts, c) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = F.mse_loss(noise, noise_pred) |
|
|
loss = loss / self.config.gradient_accumulation_steps |
|
|
|
|
|
|
|
|
if torch.isnan(loss).any(): |
|
|
raise ValueError(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} Epoch {ep}, loss: {loss}") |
|
|
|
|
|
|
|
|
self.scaler.scale(loss).backward() |
|
|
|
|
|
|
|
|
if (i+1) % self.config.gradient_accumulation_steps == 0: |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(self.nn_model.parameters(), max_norm=1.0) |
|
|
|
|
|
self.scaler.step(self.optimizer) |
|
|
self.lr_scheduler.step() |
|
|
|
|
|
self.scaler.update() |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
|
if self.config.ema: |
|
|
self.ema.step_ema(self.ema_model, self.nn_model) |
|
|
|
|
|
|
|
|
pbar_train.update(1) |
|
|
|
|
|
logs = dict( |
|
|
loss=loss.detach().item(), |
|
|
lr=self.optimizer.param_groups[0]['lr'], |
|
|
step=global_step |
|
|
) |
|
|
pbar_train.set_postfix(**logs) |
|
|
|
|
|
|
|
|
if self.config.global_rank == 0: |
|
|
self.config.logger.add_scalar("MSE", logs["loss"], global_step = global_step) |
|
|
self.config.logger.add_scalar("learning_rate", logs["lr"], global_step = global_step) |
|
|
global_step += 1 |
|
|
|
|
|
if (i+1) % self.config.gradient_accumulation_steps != 0: |
|
|
print(f"(i+1)%self.config.gradient_accumulation_steps = {(i+1)%self.config.gradient_accumulation_steps}, i = {i}, scg = {self.config.gradient_accumulation_steps}".center(self.config.str_len,'-')) |
|
|
|
|
|
self.save(ep) |
|
|
print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} Epoch{ep}:{i+1}/{len(self.dataloader)} costs {(time()-epoch_start)/60:.2f} min", flush=True) |
|
|
|
|
|
del self.nn_model |
|
|
if self.config.ema: |
|
|
del self.ema_model |
|
|
|
|
|
def save(self, ep): |
|
|
|
|
|
|
|
|
if self.config.global_rank == 0: |
|
|
if ep == self.config.n_epoch-1 or (ep+1) % self.config.save_period == 0: |
|
|
self.nn_model.eval() |
|
|
with torch.no_grad(): |
|
|
if self.config.push_to_hub: |
|
|
upload_folder( |
|
|
repo_id = self.repo_id, |
|
|
folder_path = ".", |
|
|
commit_message = f"{self.config.run_name}", |
|
|
ignore_patterns = ["step_*", "epoch_*", "*.npy", "__pycache__"], |
|
|
) |
|
|
if self.config.save_name: |
|
|
model_state = { |
|
|
'epoch': ep, |
|
|
'unet_state_dict': self.nn_model.module.state_dict(), |
|
|
|
|
|
} |
|
|
save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-node{int(os.environ['SLURM_NNODES'])}-epoch{ep}-{self.config.run_name}" |
|
|
torch.save(model_state, save_name) |
|
|
print(f'cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved model at ' + save_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rescale(self, params, ranges, to: list): |
|
|
|
|
|
value = params.clone() |
|
|
|
|
|
if value.ndim == 1: |
|
|
value = value.view(-1,len(value)) |
|
|
|
|
|
for i in range(np.shape(value)[1]): |
|
|
value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0]) |
|
|
|
|
|
value = value * (to[1]-to[0]) + to[0] |
|
|
return value |
|
|
|
|
|
def sample(self, params:torch.tensor=None, num_new_img_per_gpu=192, ema=False, entire=False, save=True): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if params is None: |
|
|
params = torch.tensor([4.4, 131.341]) |
|
|
|
|
|
|
|
|
params_backup = params.numpy().copy() |
|
|
params_normalized = self.rescale(params, self.ranges_dict['params'], to=[0,1]) |
|
|
|
|
|
print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} sampling {num_new_img_per_gpu} images with normalized params = {params_normalized}, {datetime.now().strftime('%d-%H:%M:%S.%f')}") |
|
|
params_normalized = params_normalized.repeat(num_new_img_per_gpu,1) |
|
|
assert params_normalized.dim() == 2, "params_normalized must be a 2D torch.tensor" |
|
|
|
|
|
|
|
|
self.nn_model.eval() |
|
|
sample_start = time() |
|
|
with torch.no_grad(): |
|
|
with autocast(enabled=self.config.autocast): |
|
|
|
|
|
x_last, x_entire = self.ddpm.sample( |
|
|
nn_model=self.nn_model, |
|
|
params=params_normalized.to(self.config.device), |
|
|
device=self.config.device, |
|
|
guide_w=self.config.guide_w |
|
|
) |
|
|
|
|
|
if save: |
|
|
|
|
|
savetime = datetime.now().strftime("%d%H%M%S") |
|
|
savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]:.3f}-zeta{params_backup[1]:.3f}-N{self.config.num_image}-device{self.config.global_rank}-{os.path.basename(self.config.resume)}-{savetime}{'ema' if ema else ''}.npy") |
|
|
if not os.path.exists(self.config.output_dir): |
|
|
os.makedirs(self.config.output_dir) |
|
|
np.save(savename, x_last) |
|
|
print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved {x_last.shape} to {savename} with {(time()-sample_start)/60:.2f} min", flush=True) |
|
|
|
|
|
if entire: |
|
|
savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]:.3f}-zeta{params_backup[1]:.3f}-N{self.config.num_image}-device{self.config.global_rank}-{os.path.basename(self.config.resume)}-{savetime}{'ema' if ema else ''}_entire.npy") |
|
|
np.save(savename, x_entire) |
|
|
print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved images of shape {x_entire.shape} to {savename}") |
|
|
|
|
|
return x_last |
|
|
|
|
|
|
|
|
|
|
|
def train(rank, world_size, local_world_size, master_addr, master_port, config): |
|
|
global_rank = rank + local_world_size * int(os.environ["SLURM_NODEID"]) |
|
|
ddp_setup(global_rank, world_size, master_addr, master_port) |
|
|
torch.cuda.set_device(rank) |
|
|
|
|
|
|
|
|
|
|
|
config.device = f"cuda:{rank}" |
|
|
config.world_size = local_world_size |
|
|
config.global_rank = global_rank |
|
|
|
|
|
ddpm21cm = DDPM21CM(config) |
|
|
ddpm21cm.train() |
|
|
destroy_process_group() |
|
|
|
|
|
|
|
|
def generate_samples(rank, world_size, local_world_size, master_addr, master_port, config, num_new_img_per_gpu, max_num_img_per_gpu, params): |
|
|
global_rank = rank + local_world_size * int(os.environ["SLURM_NODEID"]) |
|
|
ddp_setup(global_rank, world_size, master_addr, master_port) |
|
|
torch.cuda.set_device(rank) |
|
|
|
|
|
config.device = f"cuda:{rank}" |
|
|
config.world_size = local_world_size |
|
|
config.global_rank = global_rank |
|
|
|
|
|
ddpm21cm = DDPM21CM(config) |
|
|
|
|
|
for _ in range(num_new_img_per_gpu // max_num_img_per_gpu): |
|
|
|
|
|
sample = ddpm21cm.sample( |
|
|
params=params, |
|
|
num_new_img_per_gpu=max_num_img_per_gpu, |
|
|
) |
|
|
|
|
|
if num_new_img_per_gpu % max_num_img_per_gpu: |
|
|
sample_extra = ddpm21cm.sample( |
|
|
params=params, |
|
|
num_new_img_per_gpu=num_new_img_per_gpu % max_num_img_per_gpu, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--train", type=str, required=False, help="whether to train the model", default=False) |
|
|
|
|
|
parser.add_argument("--resume", type=str, required=False, help="filename of the model to resume", default=False) |
|
|
parser.add_argument("--num_new_img_per_gpu", type=int, required=False, default=4) |
|
|
parser.add_argument("--max_num_img_per_gpu", type=int, required=False, default=2) |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, required=False, default=1) |
|
|
parser.add_argument("--num_image", type=int, required=False, default=32) |
|
|
parser.add_argument("--n_epoch", type=int, required=False, default=50) |
|
|
parser.add_argument("--batch_size", type=int, required=False, default=2) |
|
|
parser.add_argument("--channel_mult", type=float, nargs="+", required=False, default=(1,2,2,2,4)) |
|
|
parser.add_argument("--autocast", type=int, required=False, default=False) |
|
|
parser.add_argument("--use_checkpoint", type=int, required=False, default=False) |
|
|
parser.add_argument("--dropout", type=float, required=False, default=0) |
|
|
parser.add_argument("--lrate", type=float, required=False, default=1e-4) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
master_addr = os.environ["MASTER_ADDR"] |
|
|
master_port = os.environ["MASTER_PORT"] |
|
|
local_world_size = torch.cuda.device_count() |
|
|
total_nodes = int(os.environ["SLURM_NNODES"]) |
|
|
world_size = local_world_size * total_nodes |
|
|
|
|
|
config = TrainConfig() |
|
|
config.gradient_accumulation_steps = args.gradient_accumulation_steps |
|
|
config.num_image = args.num_image |
|
|
config.n_epoch = args.n_epoch |
|
|
config.batch_size = args.batch_size |
|
|
config.channel_mult = args.channel_mult |
|
|
config.autocast = bool(args.autocast) |
|
|
config.use_checkpoint = bool(args.use_checkpoint) |
|
|
config.dropout = args.dropout |
|
|
config.lrate = args.lrate |
|
|
|
|
|
|
|
|
if args.train: |
|
|
config.dataset_name = args.train |
|
|
print(f" training, ip = {socket.gethostbyname(socket.gethostname())}, local_world_size = {local_world_size}, world_size = {world_size}, {datetime.now().strftime('%d-%H:%M:%S.%f')} ".center(config.str_len,'#')) |
|
|
mp.spawn( |
|
|
train, |
|
|
args=(world_size, local_world_size, master_addr, master_port, config), |
|
|
nprocs=local_world_size, |
|
|
join=True, |
|
|
) |
|
|
|
|
|
if args.resume: |
|
|
num_new_img_per_gpu = args.num_new_img_per_gpu |
|
|
max_num_img_per_gpu = args.max_num_img_per_gpu |
|
|
|
|
|
|
|
|
|
|
|
config.resume = args.resume |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params_pairs = [ |
|
|
(4.4, 131.341), |
|
|
(5.6, 19.037), |
|
|
(4.699, 30), |
|
|
(5.477, 200), |
|
|
(4.8, 131.341), |
|
|
] |
|
|
|
|
|
for params in params_pairs: |
|
|
print(f"sampling, {params}, ip = {socket.gethostbyname(socket.gethostname())}, local_world_size = {local_world_size}, world_size = {world_size}, {datetime.now().strftime('%d-%H:%M:%S.%f')}".center(config.str_len,'#')) |
|
|
mp.spawn( |
|
|
generate_samples, |
|
|
args=(world_size, local_world_size, master_addr, master_port, config, num_new_img_per_gpu, max_num_img_per_gpu, torch.tensor(params)), |
|
|
nprocs=local_world_size, |
|
|
join=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|