|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import argparse
|
|
|
import logging
|
|
|
import os
|
|
|
import random
|
|
|
from datetime import datetime
|
|
|
|
|
|
import nibabel as nib
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
|
from monai.inferers import sliding_window_inference
|
|
|
from monai.inferers.inferer import SlidingWindowInferer
|
|
|
from monai.networks.schedulers import RFlowScheduler
|
|
|
from monai.utils import set_determinism
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from .diff_model_setting import initialize_distributed, load_config, setup_logging
|
|
|
from .sample import ReconModel, check_input
|
|
|
from .utils import define_instance, dynamic_infer
|
|
|
|
|
|
|
|
|
def set_random_seed(seed: int) -> int:
|
|
|
"""
|
|
|
Set random seed for reproducibility.
|
|
|
|
|
|
Args:
|
|
|
seed (int): Random seed.
|
|
|
|
|
|
Returns:
|
|
|
int: Set random seed.
|
|
|
"""
|
|
|
random_seed = random.randint(0, 99999) if seed is None else seed
|
|
|
set_determinism(random_seed)
|
|
|
return random_seed
|
|
|
|
|
|
|
|
|
def load_models(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> tuple:
|
|
|
"""
|
|
|
Load the autoencoder and UNet models.
|
|
|
|
|
|
Args:
|
|
|
args (argparse.Namespace): Configuration arguments.
|
|
|
device (torch.device): Device to load models on.
|
|
|
logger (logging.Logger): Logger for logging information.
|
|
|
|
|
|
Returns:
|
|
|
tuple: Loaded autoencoder, UNet model, and scale factor.
|
|
|
"""
|
|
|
autoencoder = define_instance(args, "autoencoder_def").to(device)
|
|
|
try:
|
|
|
checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True)
|
|
|
autoencoder.load_state_dict(checkpoint_autoencoder)
|
|
|
except Exception:
|
|
|
logger.error("The trained_autoencoder_path does not exist!")
|
|
|
|
|
|
unet = define_instance(args, "diffusion_unet_def").to(device)
|
|
|
checkpoint = torch.load(f"{args.model_dir}/{args.model_filename}", map_location=device, weights_only=False)
|
|
|
unet.load_state_dict(checkpoint["unet_state_dict"], strict=True)
|
|
|
logger.info(f"checkpoints {args.model_dir}/{args.model_filename} loaded.")
|
|
|
|
|
|
scale_factor = checkpoint["scale_factor"]
|
|
|
logger.info(f"scale_factor -> {scale_factor}.")
|
|
|
|
|
|
return autoencoder, unet, scale_factor
|
|
|
|
|
|
|
|
|
def prepare_tensors(args: argparse.Namespace, device: torch.device) -> tuple:
|
|
|
"""
|
|
|
Prepare necessary tensors for inference.
|
|
|
|
|
|
Args:
|
|
|
args (argparse.Namespace): Configuration arguments.
|
|
|
device (torch.device): Device to load tensors on.
|
|
|
|
|
|
Returns:
|
|
|
tuple: Prepared top_region_index_tensor, bottom_region_index_tensor, and spacing_tensor.
|
|
|
"""
|
|
|
top_region_index_tensor = np.array(args.diffusion_unet_inference["top_region_index"]).astype(float) * 1e2
|
|
|
bottom_region_index_tensor = np.array(args.diffusion_unet_inference["bottom_region_index"]).astype(float) * 1e2
|
|
|
spacing_tensor = np.array(args.diffusion_unet_inference["spacing"]).astype(float) * 1e2
|
|
|
|
|
|
top_region_index_tensor = torch.from_numpy(top_region_index_tensor[np.newaxis, :]).half().to(device)
|
|
|
bottom_region_index_tensor = torch.from_numpy(bottom_region_index_tensor[np.newaxis, :]).half().to(device)
|
|
|
spacing_tensor = torch.from_numpy(spacing_tensor[np.newaxis, :]).half().to(device)
|
|
|
modality_tensor = args.diffusion_unet_inference["modality"] * torch.ones(
|
|
|
(len(spacing_tensor)), dtype=torch.long
|
|
|
).to(device)
|
|
|
|
|
|
return top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor
|
|
|
|
|
|
|
|
|
def run_inference(
|
|
|
args: argparse.Namespace,
|
|
|
device: torch.device,
|
|
|
autoencoder: torch.nn.Module,
|
|
|
unet: torch.nn.Module,
|
|
|
scale_factor: float,
|
|
|
top_region_index_tensor: torch.Tensor,
|
|
|
bottom_region_index_tensor: torch.Tensor,
|
|
|
spacing_tensor: torch.Tensor,
|
|
|
modality_tensor: torch.Tensor,
|
|
|
output_size: tuple,
|
|
|
divisor: int,
|
|
|
logger: logging.Logger,
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Run the inference to generate synthetic images.
|
|
|
|
|
|
Args:
|
|
|
args (argparse.Namespace): Configuration arguments.
|
|
|
device (torch.device): Device to run inference on.
|
|
|
autoencoder (torch.nn.Module): Autoencoder model.
|
|
|
unet (torch.nn.Module): UNet model.
|
|
|
scale_factor (float): Scale factor for the model.
|
|
|
top_region_index_tensor (torch.Tensor): Top region index tensor.
|
|
|
bottom_region_index_tensor (torch.Tensor): Bottom region index tensor.
|
|
|
spacing_tensor (torch.Tensor): Spacing tensor.
|
|
|
modality_tensor (torch.Tensor): Modality tensor.
|
|
|
output_size (tuple): Output size of the synthetic image.
|
|
|
divisor (int): Divisor for downsample level.
|
|
|
logger (logging.Logger): Logger for logging information.
|
|
|
|
|
|
Returns:
|
|
|
np.ndarray: Generated synthetic image data.
|
|
|
"""
|
|
|
include_body_region = unet.include_top_region_index_input
|
|
|
include_modality = unet.num_class_embeds is not None
|
|
|
|
|
|
noise = torch.randn(
|
|
|
(
|
|
|
1,
|
|
|
args.latent_channels,
|
|
|
output_size[0] // divisor,
|
|
|
output_size[1] // divisor,
|
|
|
output_size[2] // divisor,
|
|
|
),
|
|
|
device=device,
|
|
|
)
|
|
|
logger.info(f"noise: {noise.device}, {noise.dtype}, {type(noise)}")
|
|
|
|
|
|
image = noise
|
|
|
noise_scheduler = define_instance(args, "noise_scheduler")
|
|
|
if isinstance(noise_scheduler, RFlowScheduler):
|
|
|
noise_scheduler.set_timesteps(
|
|
|
num_inference_steps=args.diffusion_unet_inference["num_inference_steps"],
|
|
|
input_img_size_numel=torch.prod(torch.tensor(noise.shape[2:])),
|
|
|
)
|
|
|
else:
|
|
|
noise_scheduler.set_timesteps(num_inference_steps=args.diffusion_unet_inference["num_inference_steps"])
|
|
|
|
|
|
recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device)
|
|
|
autoencoder.eval()
|
|
|
unet.eval()
|
|
|
|
|
|
all_timesteps = noise_scheduler.timesteps
|
|
|
all_next_timesteps = torch.cat((all_timesteps[1:], torch.tensor([0], dtype=all_timesteps.dtype)))
|
|
|
progress_bar = tqdm(
|
|
|
zip(all_timesteps, all_next_timesteps),
|
|
|
total=min(len(all_timesteps), len(all_next_timesteps)),
|
|
|
)
|
|
|
with torch.amp.autocast("cuda", enabled=True):
|
|
|
for t, next_t in progress_bar:
|
|
|
|
|
|
unet_inputs = {
|
|
|
"x": image,
|
|
|
"timesteps": torch.Tensor((t,)).to(device),
|
|
|
"spacing_tensor": spacing_tensor,
|
|
|
}
|
|
|
|
|
|
|
|
|
if include_body_region:
|
|
|
unet_inputs.update(
|
|
|
{
|
|
|
"top_region_index_tensor": top_region_index_tensor,
|
|
|
"bottom_region_index_tensor": bottom_region_index_tensor,
|
|
|
}
|
|
|
)
|
|
|
|
|
|
if include_modality:
|
|
|
unet_inputs.update(
|
|
|
{
|
|
|
"class_labels": modality_tensor,
|
|
|
}
|
|
|
)
|
|
|
model_output = unet(**unet_inputs)
|
|
|
if not isinstance(noise_scheduler, RFlowScheduler):
|
|
|
image, _ = noise_scheduler.step(model_output, t, image)
|
|
|
else:
|
|
|
image, _ = noise_scheduler.step(model_output, t, image, next_t)
|
|
|
|
|
|
inferer = SlidingWindowInferer(
|
|
|
roi_size=[80, 80, 80],
|
|
|
sw_batch_size=1,
|
|
|
progress=True,
|
|
|
mode="gaussian",
|
|
|
overlap=0.4,
|
|
|
sw_device=device,
|
|
|
device=device,
|
|
|
)
|
|
|
synthetic_images = dynamic_infer(inferer, recon_model, image)
|
|
|
data = synthetic_images.squeeze().cpu().detach().numpy()
|
|
|
a_min, a_max, b_min, b_max = -1000, 1000, 0, 1
|
|
|
data = (data - b_min) / (b_max - b_min) * (a_max - a_min) + a_min
|
|
|
data = np.clip(data, a_min, a_max)
|
|
|
return np.int16(data)
|
|
|
|
|
|
|
|
|
def save_image(
|
|
|
data: np.ndarray,
|
|
|
output_size: tuple,
|
|
|
out_spacing: tuple,
|
|
|
output_path: str,
|
|
|
logger: logging.Logger,
|
|
|
) -> None:
|
|
|
"""
|
|
|
Save the generated synthetic image to a file.
|
|
|
|
|
|
Args:
|
|
|
data (np.ndarray): Synthetic image data.
|
|
|
output_size (tuple): Output size of the image.
|
|
|
out_spacing (tuple): Spacing of the output image.
|
|
|
output_path (str): Path to save the output image.
|
|
|
logger (logging.Logger): Logger for logging information.
|
|
|
"""
|
|
|
out_affine = np.eye(4)
|
|
|
for i in range(3):
|
|
|
out_affine[i, i] = out_spacing[i]
|
|
|
|
|
|
new_image = nib.Nifti1Image(data, affine=out_affine)
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
nib.save(new_image, output_path)
|
|
|
logger.info(f"Saved {output_path}.")
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
|
|
|
"""
|
|
|
Main function to run the diffusion model inference.
|
|
|
|
|
|
Args:
|
|
|
env_config_path (str): Path to the environment configuration file.
|
|
|
model_config_path (str): Path to the model configuration file.
|
|
|
model_def_path (str): Path to the model definition file.
|
|
|
"""
|
|
|
args = load_config(env_config_path, model_config_path, model_def_path)
|
|
|
local_rank, world_size, device = initialize_distributed(num_gpus)
|
|
|
logger = setup_logging("inference")
|
|
|
random_seed = set_random_seed(
|
|
|
args.diffusion_unet_inference["random_seed"] + local_rank
|
|
|
if args.diffusion_unet_inference["random_seed"]
|
|
|
else None
|
|
|
)
|
|
|
logger.info(f"Using {device} of {world_size} with random seed: {random_seed}")
|
|
|
|
|
|
output_size = tuple(args.diffusion_unet_inference["dim"])
|
|
|
out_spacing = tuple(args.diffusion_unet_inference["spacing"])
|
|
|
output_prefix = args.output_prefix
|
|
|
ckpt_filepath = f"{args.model_dir}/{args.model_filename}"
|
|
|
|
|
|
if local_rank == 0:
|
|
|
logger.info(f"[config] ckpt_filepath -> {ckpt_filepath}.")
|
|
|
logger.info(f"[config] random_seed -> {random_seed}.")
|
|
|
logger.info(f"[config] output_prefix -> {output_prefix}.")
|
|
|
logger.info(f"[config] output_size -> {output_size}.")
|
|
|
logger.info(f"[config] out_spacing -> {out_spacing}.")
|
|
|
|
|
|
check_input(None, None, None, output_size, out_spacing, None)
|
|
|
|
|
|
autoencoder, unet, scale_factor = load_models(args, device, logger)
|
|
|
num_downsample_level = max(
|
|
|
1,
|
|
|
(
|
|
|
len(args.diffusion_unet_def["num_channels"])
|
|
|
if isinstance(args.diffusion_unet_def["num_channels"], list)
|
|
|
else len(args.diffusion_unet_def["attention_levels"])
|
|
|
),
|
|
|
)
|
|
|
divisor = 2 ** (num_downsample_level - 2)
|
|
|
logger.info(f"num_downsample_level -> {num_downsample_level}, divisor -> {divisor}.")
|
|
|
|
|
|
top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor = prepare_tensors(args, device)
|
|
|
data = run_inference(
|
|
|
args,
|
|
|
device,
|
|
|
autoencoder,
|
|
|
unet,
|
|
|
scale_factor,
|
|
|
top_region_index_tensor,
|
|
|
bottom_region_index_tensor,
|
|
|
spacing_tensor,
|
|
|
modality_tensor,
|
|
|
output_size,
|
|
|
divisor,
|
|
|
logger,
|
|
|
)
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
|
output_path = "{0}/{1}_seed{2}_size{3:d}x{4:d}x{5:d}_spacing{6:.2f}x{7:.2f}x{8:.2f}_{9}_rank{10}.nii.gz".format(
|
|
|
args.output_dir,
|
|
|
output_prefix,
|
|
|
random_seed,
|
|
|
output_size[0],
|
|
|
output_size[1],
|
|
|
output_size[2],
|
|
|
out_spacing[0],
|
|
|
out_spacing[1],
|
|
|
out_spacing[2],
|
|
|
timestamp,
|
|
|
local_rank,
|
|
|
)
|
|
|
save_image(data, output_size, out_spacing, output_path, logger)
|
|
|
|
|
|
if dist.is_initialized():
|
|
|
dist.destroy_process_group()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="Diffusion Model Inference")
|
|
|
parser.add_argument(
|
|
|
"--env_config",
|
|
|
type=str,
|
|
|
default="./configs/environment_maisi_diff_model_train.json",
|
|
|
help="Path to environment configuration file",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--model_config",
|
|
|
type=str,
|
|
|
default="./configs/config_maisi_diff_model_train.json",
|
|
|
help="Path to model training/inference configuration",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--model_def",
|
|
|
type=str,
|
|
|
default="./configs/config_maisi.json",
|
|
|
help="Path to model definition file",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--num_gpus",
|
|
|
type=int,
|
|
|
default=1,
|
|
|
help="Number of GPUs to use for distributed inference",
|
|
|
)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
diff_model_infer(args.env_config, args.model_config, args.model_def, args.num_gpus)
|
|
|
|