|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
from datetime import timedelta |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn.functional as F |
|
|
from monai.networks.utils import copy_model_state |
|
|
from monai.utils import RankFilter |
|
|
from monai.networks.schedulers import RFlowScheduler |
|
|
|
|
|
from monai.networks.schedulers.ddpm import DDPMPredictionType |
|
|
from torch.amp import GradScaler, autocast |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
from .utils import binarize_labels, define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="maisi.controlnet.training") |
|
|
parser.add_argument( |
|
|
"-e", |
|
|
"--environment-file", |
|
|
default="./configs/environment_maisi_controlnet_train.json", |
|
|
help="environment json file that stores environment path", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-c", |
|
|
"--config-file", |
|
|
default="./configs/config_maisi-ddpm.json", |
|
|
help="config json file that stores network hyper-parameters", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-t", |
|
|
"--training-config", |
|
|
default="./configs/config_maisi_controlnet_train.json", |
|
|
help="config json file that stores training hyper-parameters", |
|
|
) |
|
|
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
logger = logging.getLogger("maisi.controlnet.training") |
|
|
|
|
|
use_ddp = args.gpus > 1 |
|
|
if use_ddp: |
|
|
rank = int(os.environ["LOCAL_RANK"]) |
|
|
world_size = int(os.environ["WORLD_SIZE"]) |
|
|
device = setup_ddp(rank, world_size) |
|
|
logger.addFilter(RankFilter()) |
|
|
else: |
|
|
rank = 0 |
|
|
world_size = 1 |
|
|
device = torch.device(f"cuda:{rank}") |
|
|
|
|
|
torch.cuda.set_device(device) |
|
|
logger.info(f"Number of GPUs: {torch.cuda.device_count()}") |
|
|
logger.info(f"World_size: {world_size}") |
|
|
|
|
|
with open(args.environment_file, "r") as env_file: |
|
|
env_dict = json.load(env_file) |
|
|
with open(args.config_file, "r") as config_file: |
|
|
config_dict = json.load(config_file) |
|
|
with open(args.training_config, "r") as training_config_file: |
|
|
training_config_dict = json.load(training_config_file) |
|
|
|
|
|
for k, v in env_dict.items(): |
|
|
setattr(args, k, v) |
|
|
for k, v in config_dict.items(): |
|
|
setattr(args, k, v) |
|
|
for k, v in training_config_dict.items(): |
|
|
setattr(args, k, v) |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
tensorboard_path = os.path.join(args.tfevent_path, args.exp_name) |
|
|
Path(tensorboard_path).mkdir(parents=True, exist_ok=True) |
|
|
tensorboard_writer = SummaryWriter(tensorboard_path) |
|
|
|
|
|
|
|
|
train_loader, _ = prepare_maisi_controlnet_json_dataloader( |
|
|
json_data_list=args.json_data_list, |
|
|
data_base_dir=args.data_base_dir, |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
batch_size=args.controlnet_train["batch_size"], |
|
|
cache_rate=args.controlnet_train["cache_rate"], |
|
|
fold=args.controlnet_train["fold"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
unet = define_instance(args, "diffusion_unet_def").to(device) |
|
|
include_body_region = unet.include_top_region_index_input |
|
|
include_modality = unet.num_class_embeds is not None |
|
|
|
|
|
|
|
|
if args.trained_diffusion_path is not None: |
|
|
if not os.path.exists(args.trained_diffusion_path): |
|
|
raise ValueError("Please download the trained diffusion unet checkpoint.") |
|
|
diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False) |
|
|
unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) |
|
|
|
|
|
scale_factor = diffusion_model_ckpt["scale_factor"] |
|
|
logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") |
|
|
logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") |
|
|
else: |
|
|
logger.info("trained diffusion model is not loaded.") |
|
|
scale_factor = 1.0 |
|
|
logger.info(f"set scale_factor -> {scale_factor}.") |
|
|
|
|
|
|
|
|
controlnet = define_instance(args, "controlnet_def").to(device) |
|
|
|
|
|
copy_model_state(controlnet, unet.state_dict()) |
|
|
|
|
|
if args.trained_controlnet_path is not None: |
|
|
if not os.path.exists(args.trained_controlnet_path): |
|
|
raise ValueError("Please download the trained ControlNet checkpoint.") |
|
|
controlnet.load_state_dict( |
|
|
torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"] |
|
|
) |
|
|
logger.info(f"load trained controlnet model from {args.trained_controlnet_path}") |
|
|
else: |
|
|
logger.info("train controlnet model from scratch.") |
|
|
|
|
|
for p in unet.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
noise_scheduler = define_instance(args, "noise_scheduler") |
|
|
|
|
|
if use_ddp: |
|
|
controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True) |
|
|
|
|
|
|
|
|
weighted_loss = args.controlnet_train["weighted_loss"] |
|
|
weighted_loss_label = args.controlnet_train["weighted_loss_label"] |
|
|
optimizer = torch.optim.AdamW(params=controlnet.parameters(), lr=args.controlnet_train["lr"]) |
|
|
total_steps = (args.controlnet_train["n_epochs"] * len(train_loader.dataset)) / args.controlnet_train["batch_size"] |
|
|
logger.info(f"total number of training steps: {total_steps}.") |
|
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0) |
|
|
|
|
|
|
|
|
n_epochs = args.controlnet_train["n_epochs"] |
|
|
scaler = GradScaler("cuda") |
|
|
total_step = 0 |
|
|
best_loss = 1e4 |
|
|
|
|
|
if weighted_loss > 1.0: |
|
|
logger.info(f"apply weighted loss = {weighted_loss} on labels: {weighted_loss_label}") |
|
|
|
|
|
controlnet.train() |
|
|
unet.eval() |
|
|
prev_time = time.time() |
|
|
for epoch in range(n_epochs): |
|
|
epoch_loss_ = 0 |
|
|
for step, batch in enumerate(train_loader): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images = batch["image"].to(device) * scale_factor |
|
|
labels = batch["label"].to(device) |
|
|
|
|
|
if include_body_region: |
|
|
top_region_index_tensor = batch["top_region_index"].to(device) |
|
|
bottom_region_index_tensor = batch["bottom_region_index"].to(device) |
|
|
|
|
|
if include_modality: |
|
|
modality_tensor = torch.ones((len(images),), dtype=torch.long).to(device) |
|
|
spacing_tensor = batch["spacing"].to(device) |
|
|
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
with autocast("cuda", enabled=True): |
|
|
|
|
|
noise_shape = list(images.shape) |
|
|
noise = torch.randn(noise_shape, dtype=images.dtype).to(device) |
|
|
|
|
|
|
|
|
controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float() |
|
|
|
|
|
|
|
|
if isinstance(noise_scheduler, RFlowScheduler): |
|
|
timesteps = noise_scheduler.sample_timesteps(images) |
|
|
else: |
|
|
timesteps = torch.randint( |
|
|
0, noise_scheduler.num_train_timesteps, (images.shape[0],), device=images.device |
|
|
).long() |
|
|
|
|
|
|
|
|
noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps) |
|
|
|
|
|
|
|
|
|
|
|
controlnet_inputs = { |
|
|
"x": noisy_latent, |
|
|
"timesteps": timesteps, |
|
|
"controlnet_cond": controlnet_cond, |
|
|
} |
|
|
if include_modality: |
|
|
controlnet_inputs.update( |
|
|
{ |
|
|
"class_labels": modality_tensor, |
|
|
} |
|
|
) |
|
|
down_block_res_samples, mid_block_res_sample = controlnet(**controlnet_inputs) |
|
|
|
|
|
|
|
|
|
|
|
unet_inputs = { |
|
|
"x": noisy_latent, |
|
|
"timesteps": timesteps, |
|
|
"spacing_tensor": spacing_tensor, |
|
|
"down_block_additional_residuals": down_block_res_samples, |
|
|
"mid_block_additional_residual": mid_block_res_sample, |
|
|
} |
|
|
|
|
|
if include_body_region: |
|
|
unet_inputs.update( |
|
|
{ |
|
|
"top_region_index_tensor": top_region_index_tensor, |
|
|
"bottom_region_index_tensor": bottom_region_index_tensor, |
|
|
} |
|
|
) |
|
|
if include_modality: |
|
|
unet_inputs.update( |
|
|
{ |
|
|
"class_labels": modality_tensor, |
|
|
} |
|
|
) |
|
|
model_output = unet(**unet_inputs) |
|
|
|
|
|
if noise_scheduler.prediction_type == DDPMPredictionType.EPSILON: |
|
|
|
|
|
model_gt = noise |
|
|
elif noise_scheduler.prediction_type == DDPMPredictionType.SAMPLE: |
|
|
|
|
|
model_gt = images |
|
|
elif noise_scheduler.prediction_type == DDPMPredictionType.V_PREDICTION: |
|
|
|
|
|
model_gt = images - noise |
|
|
else: |
|
|
raise ValueError( |
|
|
"noise scheduler prediction type has to be chosen from ", |
|
|
f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]", |
|
|
) |
|
|
|
|
|
if weighted_loss > 1.0: |
|
|
weights = torch.ones_like(images).to(images.device) |
|
|
roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(images.device) |
|
|
interpolate_label = F.interpolate(labels, size=images.shape[2:], mode="nearest") |
|
|
|
|
|
for label in weighted_loss_label: |
|
|
roi[interpolate_label == label] = 1 |
|
|
weights[roi.repeat(1, images.shape[1], 1, 1, 1) == 1] = weighted_loss |
|
|
loss = (F.l1_loss(model_output.float(), model_gt.float(), reduction="none") * weights).mean() |
|
|
else: |
|
|
loss = F.l1_loss(model_output.float(), model_gt.float()) |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
lr_scheduler.step() |
|
|
total_step += 1 |
|
|
|
|
|
if rank == 0: |
|
|
|
|
|
tensorboard_writer.add_scalar( |
|
|
"train/train_controlnet_loss_iter", loss.detach().cpu().item(), total_step |
|
|
) |
|
|
batches_done = step + 1 |
|
|
batches_left = len(train_loader) - batches_done |
|
|
time_left = timedelta(seconds=batches_left * (time.time() - prev_time)) |
|
|
prev_time = time.time() |
|
|
logger.info( |
|
|
"\r[Epoch %d/%d] [Batch %d/%d] [LR: %.8f] [loss: %.4f] ETA: %s " |
|
|
% ( |
|
|
epoch + 1, |
|
|
n_epochs, |
|
|
step + 1, |
|
|
len(train_loader), |
|
|
lr_scheduler.get_last_lr()[0], |
|
|
loss.detach().cpu().item(), |
|
|
time_left, |
|
|
) |
|
|
) |
|
|
epoch_loss_ += loss.detach() |
|
|
|
|
|
epoch_loss = epoch_loss_ / (step + 1) |
|
|
|
|
|
if use_ddp: |
|
|
dist.barrier() |
|
|
dist.all_reduce(epoch_loss, op=torch.distributed.ReduceOp.AVG) |
|
|
|
|
|
if rank == 0: |
|
|
tensorboard_writer.add_scalar("train/train_controlnet_loss_epoch", epoch_loss.cpu().item(), total_step) |
|
|
|
|
|
controlnet_state_dict = controlnet.module.state_dict() if world_size > 1 else controlnet.state_dict() |
|
|
torch.save( |
|
|
{ |
|
|
"epoch": epoch + 1, |
|
|
"loss": epoch_loss, |
|
|
"controlnet_state_dict": controlnet_state_dict, |
|
|
}, |
|
|
f"{args.model_dir}/{args.exp_name}_current.pt", |
|
|
) |
|
|
|
|
|
if epoch_loss < best_loss: |
|
|
best_loss = epoch_loss |
|
|
logger.info(f"best loss -> {best_loss}.") |
|
|
torch.save( |
|
|
{ |
|
|
"epoch": epoch + 1, |
|
|
"loss": best_loss, |
|
|
"controlnet_state_dict": controlnet_state_dict, |
|
|
}, |
|
|
f"{args.model_dir}/{args.exp_name}_best.pt", |
|
|
) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
if use_ddp: |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
logging.basicConfig( |
|
|
stream=sys.stdout, |
|
|
level=logging.INFO, |
|
|
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", |
|
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
|
) |
|
|
main() |