|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import argparse
|
|
|
import json
|
|
|
import logging
|
|
|
import os
|
|
|
from datetime import datetime
|
|
|
from pathlib import Path
|
|
|
|
|
|
import monai
|
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
|
from monai.data import DataLoader, partition_dataset
|
|
|
from monai.networks.schedulers import RFlowScheduler
|
|
|
from monai.networks.schedulers.ddpm import DDPMPredictionType
|
|
|
from monai.transforms import Compose
|
|
|
from monai.utils import first
|
|
|
from torch.amp import GradScaler, autocast
|
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
|
|
from .diff_model_setting import initialize_distributed, load_config, setup_logging
|
|
|
from .utils import define_instance
|
|
|
|
|
|
|
|
|
def load_filenames(data_list_path: str) -> list:
|
|
|
"""
|
|
|
Load filenames from the JSON data list.
|
|
|
|
|
|
Args:
|
|
|
data_list_path (str): Path to the JSON data list file.
|
|
|
|
|
|
Returns:
|
|
|
list: List of filenames.
|
|
|
"""
|
|
|
with open(data_list_path, "r") as file:
|
|
|
json_data = json.load(file)
|
|
|
filenames_train = json_data["training"]
|
|
|
return [_item["image"].replace(".nii.gz", "_emb.nii.gz") for _item in filenames_train]
|
|
|
|
|
|
|
|
|
def prepare_data(
|
|
|
train_files: list,
|
|
|
device: torch.device,
|
|
|
cache_rate: float,
|
|
|
num_workers: int = 2,
|
|
|
batch_size: int = 1,
|
|
|
include_body_region: bool = False,
|
|
|
) -> DataLoader:
|
|
|
"""
|
|
|
Prepare training data.
|
|
|
|
|
|
Args:
|
|
|
train_files (list): List of training files.
|
|
|
device (torch.device): Device to use for training.
|
|
|
cache_rate (float): Cache rate for dataset.
|
|
|
num_workers (int): Number of workers for data loading.
|
|
|
batch_size (int): Mini-batch size.
|
|
|
include_body_region (bool): Whether to include body region in data
|
|
|
|
|
|
Returns:
|
|
|
DataLoader: Data loader for training.
|
|
|
"""
|
|
|
|
|
|
def _load_data_from_file(file_path, key):
|
|
|
with open(file_path) as f:
|
|
|
return torch.FloatTensor(json.load(f)[key])
|
|
|
|
|
|
train_transforms_list = [
|
|
|
monai.transforms.LoadImaged(keys=["image"]),
|
|
|
monai.transforms.EnsureChannelFirstd(keys=["image"]),
|
|
|
monai.transforms.Lambdad(keys="spacing", func=lambda x: _load_data_from_file(x, "spacing")),
|
|
|
monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
|
|
|
]
|
|
|
if include_body_region:
|
|
|
train_transforms_list += [
|
|
|
monai.transforms.Lambdad(
|
|
|
keys="top_region_index", func=lambda x: _load_data_from_file(x, "top_region_index")
|
|
|
),
|
|
|
monai.transforms.Lambdad(
|
|
|
keys="bottom_region_index", func=lambda x: _load_data_from_file(x, "bottom_region_index")
|
|
|
),
|
|
|
monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2),
|
|
|
monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2),
|
|
|
]
|
|
|
train_transforms = Compose(train_transforms_list)
|
|
|
|
|
|
train_ds = monai.data.CacheDataset(
|
|
|
data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers
|
|
|
)
|
|
|
|
|
|
return DataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True)
|
|
|
|
|
|
|
|
|
def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> torch.nn.Module:
|
|
|
"""
|
|
|
Load the UNet model.
|
|
|
|
|
|
Args:
|
|
|
args (argparse.Namespace): Configuration arguments.
|
|
|
device (torch.device): Device to load the model on.
|
|
|
logger (logging.Logger): Logger for logging information.
|
|
|
|
|
|
Returns:
|
|
|
torch.nn.Module: Loaded UNet model.
|
|
|
"""
|
|
|
unet = define_instance(args, "diffusion_unet_def").to(device)
|
|
|
unet = torch.nn.SyncBatchNorm.convert_sync_batchnorm(unet)
|
|
|
|
|
|
if dist.is_initialized():
|
|
|
unet = DistributedDataParallel(unet, device_ids=[device], find_unused_parameters=True)
|
|
|
|
|
|
if args.existing_ckpt_filepath is None:
|
|
|
logger.info("Training from scratch.")
|
|
|
else:
|
|
|
checkpoint_unet = torch.load(f"{args.existing_ckpt_filepath}", map_location=device, weights_only=False)
|
|
|
if dist.is_initialized():
|
|
|
unet.module.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True)
|
|
|
else:
|
|
|
unet.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True)
|
|
|
logger.info(f"Pretrained checkpoint {args.existing_ckpt_filepath} loaded.")
|
|
|
|
|
|
return unet
|
|
|
|
|
|
|
|
|
def calculate_scale_factor(train_loader: DataLoader, device: torch.device, logger: logging.Logger) -> torch.Tensor:
|
|
|
"""
|
|
|
Calculate the scaling factor for the dataset.
|
|
|
|
|
|
Args:
|
|
|
train_loader (DataLoader): Data loader for training.
|
|
|
device (torch.device): Device to use for calculation.
|
|
|
logger (logging.Logger): Logger for logging information.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: Calculated scaling factor.
|
|
|
"""
|
|
|
check_data = first(train_loader)
|
|
|
z = check_data["image"].to(device)
|
|
|
scale_factor = 1 / torch.std(z)
|
|
|
logger.info(f"Scaling factor set to {scale_factor}.")
|
|
|
|
|
|
if dist.is_initialized():
|
|
|
dist.barrier()
|
|
|
dist.all_reduce(scale_factor, op=torch.distributed.ReduceOp.AVG)
|
|
|
logger.info(f"scale_factor -> {scale_factor}.")
|
|
|
return scale_factor
|
|
|
|
|
|
|
|
|
def create_optimizer(model: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
|
|
"""
|
|
|
Create optimizer for training.
|
|
|
|
|
|
Args:
|
|
|
model (torch.nn.Module): Model to optimize.
|
|
|
lr (float): Learning rate.
|
|
|
|
|
|
Returns:
|
|
|
torch.optim.Optimizer: Created optimizer.
|
|
|
"""
|
|
|
return torch.optim.Adam(params=model.parameters(), lr=lr)
|
|
|
|
|
|
|
|
|
def create_lr_scheduler(optimizer: torch.optim.Optimizer, total_steps: int) -> torch.optim.lr_scheduler.PolynomialLR:
|
|
|
"""
|
|
|
Create learning rate scheduler.
|
|
|
|
|
|
Args:
|
|
|
optimizer (torch.optim.Optimizer): Optimizer to schedule.
|
|
|
total_steps (int): Total number of training steps.
|
|
|
|
|
|
Returns:
|
|
|
torch.optim.lr_scheduler.PolynomialLR: Created learning rate scheduler.
|
|
|
"""
|
|
|
return torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0)
|
|
|
|
|
|
|
|
|
def train_one_epoch(
|
|
|
epoch: int,
|
|
|
unet: torch.nn.Module,
|
|
|
train_loader: DataLoader,
|
|
|
optimizer: torch.optim.Optimizer,
|
|
|
lr_scheduler: torch.optim.lr_scheduler.PolynomialLR,
|
|
|
loss_pt: torch.nn.L1Loss,
|
|
|
scaler: GradScaler,
|
|
|
scale_factor: torch.Tensor,
|
|
|
noise_scheduler: torch.nn.Module,
|
|
|
num_images_per_batch: int,
|
|
|
num_train_timesteps: int,
|
|
|
device: torch.device,
|
|
|
logger: logging.Logger,
|
|
|
local_rank: int,
|
|
|
amp: bool = True,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Train the model for one epoch.
|
|
|
|
|
|
Args:
|
|
|
epoch (int): Current epoch number.
|
|
|
unet (torch.nn.Module): UNet model.
|
|
|
train_loader (DataLoader): Data loader for training.
|
|
|
optimizer (torch.optim.Optimizer): Optimizer.
|
|
|
lr_scheduler (torch.optim.lr_scheduler.PolynomialLR): Learning rate scheduler.
|
|
|
loss_pt (torch.nn.L1Loss): Loss function.
|
|
|
scaler (GradScaler): Gradient scaler for mixed precision training.
|
|
|
scale_factor (torch.Tensor): Scaling factor.
|
|
|
noise_scheduler (torch.nn.Module): Noise scheduler.
|
|
|
num_images_per_batch (int): Number of images per batch.
|
|
|
num_train_timesteps (int): Number of training timesteps.
|
|
|
device (torch.device): Device to use for training.
|
|
|
logger (logging.Logger): Logger for logging information.
|
|
|
local_rank (int): Local rank for distributed training.
|
|
|
amp (bool): Use automatic mixed precision training.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: Training loss for the epoch.
|
|
|
"""
|
|
|
include_body_region = unet.include_top_region_index_input
|
|
|
include_modality = unet.num_class_embeds is not None
|
|
|
|
|
|
if local_rank == 0:
|
|
|
current_lr = optimizer.param_groups[0]["lr"]
|
|
|
logger.info(f"Epoch {epoch + 1}, lr {current_lr}.")
|
|
|
|
|
|
_iter = 0
|
|
|
loss_torch = torch.zeros(2, dtype=torch.float, device=device)
|
|
|
|
|
|
unet.train()
|
|
|
for train_data in train_loader:
|
|
|
current_lr = optimizer.param_groups[0]["lr"]
|
|
|
|
|
|
_iter += 1
|
|
|
images = train_data["image"].to(device)
|
|
|
images = images * scale_factor
|
|
|
|
|
|
if include_body_region:
|
|
|
top_region_index_tensor = train_data["top_region_index"].to(device)
|
|
|
bottom_region_index_tensor = train_data["bottom_region_index"].to(device)
|
|
|
|
|
|
if include_modality:
|
|
|
modality_tensor = torch.ones((len(images),), dtype=torch.long).to(device)
|
|
|
spacing_tensor = train_data["spacing"].to(device)
|
|
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
|
with autocast("cuda", enabled=amp):
|
|
|
noise = torch.randn_like(images)
|
|
|
|
|
|
if isinstance(noise_scheduler, RFlowScheduler):
|
|
|
timesteps = noise_scheduler.sample_timesteps(images)
|
|
|
else:
|
|
|
timesteps = torch.randint(0, num_train_timesteps, (images.shape[0],), device=images.device).long()
|
|
|
|
|
|
noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
|
|
|
|
|
|
|
|
|
unet_inputs = {
|
|
|
"x": noisy_latent,
|
|
|
"timesteps": timesteps,
|
|
|
"spacing_tensor": spacing_tensor,
|
|
|
}
|
|
|
|
|
|
if include_body_region:
|
|
|
unet_inputs.update(
|
|
|
{
|
|
|
"top_region_index_tensor": top_region_index_tensor,
|
|
|
"bottom_region_index_tensor": bottom_region_index_tensor,
|
|
|
}
|
|
|
)
|
|
|
if include_modality:
|
|
|
unet_inputs.update(
|
|
|
{
|
|
|
"class_labels": modality_tensor,
|
|
|
}
|
|
|
)
|
|
|
model_output = unet(**unet_inputs)
|
|
|
|
|
|
if noise_scheduler.prediction_type == DDPMPredictionType.EPSILON:
|
|
|
|
|
|
model_gt = noise
|
|
|
elif noise_scheduler.prediction_type == DDPMPredictionType.SAMPLE:
|
|
|
|
|
|
model_gt = images
|
|
|
elif noise_scheduler.prediction_type == DDPMPredictionType.V_PREDICTION:
|
|
|
|
|
|
model_gt = images - noise
|
|
|
else:
|
|
|
raise ValueError(
|
|
|
"noise scheduler prediction type has to be chosen from ",
|
|
|
f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]",
|
|
|
)
|
|
|
|
|
|
loss = loss_pt(model_output.float(), model_gt.float())
|
|
|
|
|
|
if amp:
|
|
|
scaler.scale(loss).backward()
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
else:
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
lr_scheduler.step()
|
|
|
|
|
|
loss_torch[0] += loss.item()
|
|
|
loss_torch[1] += 1.0
|
|
|
|
|
|
if local_rank == 0:
|
|
|
logger.info(
|
|
|
"[{0}] epoch {1}, iter {2}/{3}, loss: {4:.4f}, lr: {5:.12f}.".format(
|
|
|
str(datetime.now())[:19], epoch + 1, _iter, len(train_loader), loss.item(), current_lr
|
|
|
)
|
|
|
)
|
|
|
|
|
|
if dist.is_initialized():
|
|
|
dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)
|
|
|
|
|
|
return loss_torch
|
|
|
|
|
|
|
|
|
def save_checkpoint(
|
|
|
epoch: int,
|
|
|
unet: torch.nn.Module,
|
|
|
loss_torch_epoch: float,
|
|
|
num_train_timesteps: int,
|
|
|
scale_factor: torch.Tensor,
|
|
|
ckpt_folder: str,
|
|
|
args: argparse.Namespace,
|
|
|
) -> None:
|
|
|
"""
|
|
|
Save checkpoint.
|
|
|
|
|
|
Args:
|
|
|
epoch (int): Current epoch number.
|
|
|
unet (torch.nn.Module): UNet model.
|
|
|
loss_torch_epoch (float): Training loss for the epoch.
|
|
|
num_train_timesteps (int): Number of training timesteps.
|
|
|
scale_factor (torch.Tensor): Scaling factor.
|
|
|
ckpt_folder (str): Checkpoint folder path.
|
|
|
args (argparse.Namespace): Configuration arguments.
|
|
|
"""
|
|
|
unet_state_dict = unet.module.state_dict() if dist.is_initialized() else unet.state_dict()
|
|
|
torch.save(
|
|
|
{
|
|
|
"epoch": epoch + 1,
|
|
|
"loss": loss_torch_epoch,
|
|
|
"num_train_timesteps": num_train_timesteps,
|
|
|
"scale_factor": scale_factor,
|
|
|
"unet_state_dict": unet_state_dict,
|
|
|
},
|
|
|
f"{ckpt_folder}/{args.model_filename}",
|
|
|
)
|
|
|
|
|
|
|
|
|
def diff_model_train(
|
|
|
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True
|
|
|
) -> None:
|
|
|
"""
|
|
|
Main function to train a diffusion model.
|
|
|
|
|
|
Args:
|
|
|
env_config_path (str): Path to the environment configuration file.
|
|
|
model_config_path (str): Path to the model configuration file.
|
|
|
model_def_path (str): Path to the model definition file.
|
|
|
num_gpus (int): Number of GPUs to use for training.
|
|
|
amp (bool): Use automatic mixed precision training.
|
|
|
"""
|
|
|
args = load_config(env_config_path, model_config_path, model_def_path)
|
|
|
local_rank, world_size, device = initialize_distributed(num_gpus)
|
|
|
logger = setup_logging("training")
|
|
|
|
|
|
logger.info(f"Using {device} of {world_size}")
|
|
|
|
|
|
if local_rank == 0:
|
|
|
logger.info(f"[config] ckpt_folder -> {args.model_dir}.")
|
|
|
logger.info(f"[config] data_root -> {args.embedding_base_dir}.")
|
|
|
logger.info(f"[config] data_list -> {args.json_data_list}.")
|
|
|
logger.info(f"[config] lr -> {args.diffusion_unet_train['lr']}.")
|
|
|
logger.info(f"[config] num_epochs -> {args.diffusion_unet_train['n_epochs']}.")
|
|
|
logger.info(f"[config] num_train_timesteps -> {args.noise_scheduler['num_train_timesteps']}.")
|
|
|
|
|
|
Path(args.model_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
unet = load_unet(args, device, logger)
|
|
|
noise_scheduler = define_instance(args, "noise_scheduler")
|
|
|
include_body_region = unet.include_top_region_index_input
|
|
|
|
|
|
filenames_train = load_filenames(args.json_data_list)
|
|
|
if local_rank == 0:
|
|
|
logger.info(f"num_files_train: {len(filenames_train)}")
|
|
|
|
|
|
train_files = []
|
|
|
for _i in range(len(filenames_train)):
|
|
|
str_img = os.path.join(args.embedding_base_dir, filenames_train[_i])
|
|
|
if not os.path.exists(str_img):
|
|
|
continue
|
|
|
|
|
|
str_info = os.path.join(args.embedding_base_dir, filenames_train[_i]) + ".json"
|
|
|
train_files_i = {"image": str_img, "spacing": str_info}
|
|
|
if include_body_region:
|
|
|
train_files_i["top_region_index"] = str_info
|
|
|
train_files_i["bottom_region_index"] = str_info
|
|
|
train_files.append(train_files_i)
|
|
|
if dist.is_initialized():
|
|
|
train_files = partition_dataset(
|
|
|
data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True
|
|
|
)[local_rank]
|
|
|
|
|
|
train_loader = prepare_data(
|
|
|
train_files,
|
|
|
device,
|
|
|
args.diffusion_unet_train["cache_rate"],
|
|
|
batch_size=args.diffusion_unet_train["batch_size"],
|
|
|
include_body_region=include_body_region,
|
|
|
)
|
|
|
|
|
|
scale_factor = calculate_scale_factor(train_loader, device, logger)
|
|
|
optimizer = create_optimizer(unet, args.diffusion_unet_train["lr"])
|
|
|
|
|
|
total_steps = (args.diffusion_unet_train["n_epochs"] * len(train_loader.dataset)) / args.diffusion_unet_train[
|
|
|
"batch_size"
|
|
|
]
|
|
|
lr_scheduler = create_lr_scheduler(optimizer, total_steps)
|
|
|
loss_pt = torch.nn.L1Loss()
|
|
|
scaler = GradScaler("cuda")
|
|
|
|
|
|
torch.set_float32_matmul_precision("highest")
|
|
|
logger.info("torch.set_float32_matmul_precision -> highest.")
|
|
|
|
|
|
for epoch in range(args.diffusion_unet_train["n_epochs"]):
|
|
|
loss_torch = train_one_epoch(
|
|
|
epoch,
|
|
|
unet,
|
|
|
train_loader,
|
|
|
optimizer,
|
|
|
lr_scheduler,
|
|
|
loss_pt,
|
|
|
scaler,
|
|
|
scale_factor,
|
|
|
noise_scheduler,
|
|
|
args.diffusion_unet_train["batch_size"],
|
|
|
args.noise_scheduler["num_train_timesteps"],
|
|
|
device,
|
|
|
logger,
|
|
|
local_rank,
|
|
|
amp=amp,
|
|
|
)
|
|
|
|
|
|
loss_torch = loss_torch.tolist()
|
|
|
if torch.cuda.device_count() == 1 or local_rank == 0:
|
|
|
loss_torch_epoch = loss_torch[0] / loss_torch[1]
|
|
|
logger.info(f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}.")
|
|
|
|
|
|
save_checkpoint(
|
|
|
epoch,
|
|
|
unet,
|
|
|
loss_torch_epoch,
|
|
|
args.noise_scheduler["num_train_timesteps"],
|
|
|
scale_factor,
|
|
|
args.model_dir,
|
|
|
args,
|
|
|
)
|
|
|
|
|
|
if dist.is_initialized():
|
|
|
dist.destroy_process_group()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="Diffusion Model Training")
|
|
|
parser.add_argument(
|
|
|
"--env_config",
|
|
|
type=str,
|
|
|
default="./configs/environment_maisi_diff_model.json",
|
|
|
help="Path to environment configuration file",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--model_config",
|
|
|
type=str,
|
|
|
default="./configs/config_maisi_diff_model.json",
|
|
|
help="Path to model training/inference configuration",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
|
|
|
)
|
|
|
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
|
|
|
parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp)
|
|
|
|