|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
import json
|
|
|
import logging
|
|
|
import os
|
|
|
import sys
|
|
|
import tempfile
|
|
|
|
|
|
import monai
|
|
|
import torch
|
|
|
from monai.apps import download_url
|
|
|
from monai.config import print_config
|
|
|
from monai.transforms import LoadImage, Orientation
|
|
|
from monai.utils import set_determinism
|
|
|
|
|
|
from scripts.sample import LDMSampler, check_input
|
|
|
from scripts.utils import define_instance
|
|
|
from scripts.utils_plot import find_label_center_loc, get_xyz_plot, show_image
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="maisi.controlnet.training")
|
|
|
parser.add_argument(
|
|
|
"-e",
|
|
|
"--environment-file",
|
|
|
default="./configs/environment.json",
|
|
|
help="environment json file that stores environment path",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"-c",
|
|
|
"--config-file",
|
|
|
default="./configs/config_maisi.json",
|
|
|
help="config json file that stores network hyper-parameters",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"-i",
|
|
|
"--inference-file",
|
|
|
default="./configs/config_infer.json",
|
|
|
help="config json file that stores inference hyper-parameters",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"-x",
|
|
|
"--extra-config-file",
|
|
|
default=None,
|
|
|
help="config json file that stores inference extra parameters",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"-s",
|
|
|
"--random-seed",
|
|
|
default=None,
|
|
|
help="random seed, can be None or int",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--version",
|
|
|
default="maisi3d-rflow",
|
|
|
type=str,
|
|
|
help="maisi_version, choose from ['maisi3d-ddpm', 'maisi3d-rflow']",
|
|
|
)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
logger = logging.getLogger("maisi.inference")
|
|
|
|
|
|
maisi_version = args.version
|
|
|
|
|
|
|
|
|
if args.random_seed is not None:
|
|
|
set_determinism(seed=args.random_seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
directory = os.environ.get("MONAI_DATA_DIRECTORY")
|
|
|
if directory is not None:
|
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
root_dir = tempfile.mkdtemp() if directory is None else directory
|
|
|
print(root_dir)
|
|
|
|
|
|
|
|
|
files = [
|
|
|
{
|
|
|
"path": "models/autoencoder_epoch273.pt",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials"
|
|
|
"/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt",
|
|
|
},
|
|
|
{
|
|
|
"path": "models/mask_generation_autoencoder.pt",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
|
|
|
"/tutorials/mask_generation_autoencoder.pt",
|
|
|
},
|
|
|
{
|
|
|
"path": "models/mask_generation_diffusion_unet.pt",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
|
|
|
"/tutorials/model_zoo/model_maisi_mask_generation_diffusion_unet_v2.pt",
|
|
|
},
|
|
|
{
|
|
|
"path": "configs/all_anatomy_size_condtions.json",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/all_anatomy_size_condtions.json",
|
|
|
},
|
|
|
{
|
|
|
"path": "datasets/all_masks_flexible_size_and_spacing_4000.zip",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
|
|
|
"/tutorials/all_masks_flexible_size_and_spacing_4000.zip",
|
|
|
},
|
|
|
]
|
|
|
|
|
|
if maisi_version == "maisi3d-ddpm":
|
|
|
files += [
|
|
|
{
|
|
|
"path": "models/diff_unet_3d_ddpm.pt",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo"
|
|
|
"/model_maisi_input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1_alternative.pt",
|
|
|
},
|
|
|
{
|
|
|
"path": "models/controlnet_3d_ddpm.pt",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo"
|
|
|
"/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt",
|
|
|
},
|
|
|
{
|
|
|
"path": "configs/candidate_masks_flexible_size_and_spacing_3000.json",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
|
|
|
"/tutorials/candidate_masks_flexible_size_and_spacing_3000.json",
|
|
|
},
|
|
|
]
|
|
|
elif maisi_version == "maisi3d-rflow":
|
|
|
files += [
|
|
|
{
|
|
|
"path": "models/diff_unet_3d_rflow.pt",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/"
|
|
|
"diff_unet_ckpt_rflow_epoch19350.pt",
|
|
|
},
|
|
|
{
|
|
|
"path": "models/controlnet_3d_rflow.pt",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/"
|
|
|
"controlnet_rflow_epoch60.pt",
|
|
|
},
|
|
|
{
|
|
|
"path": "configs/candidate_masks_flexible_size_and_spacing_4000.json",
|
|
|
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
|
|
|
"/tutorials/candidate_masks_flexible_size_and_spacing_4000.json",
|
|
|
},
|
|
|
]
|
|
|
else:
|
|
|
raise ValueError(
|
|
|
f"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}."
|
|
|
)
|
|
|
|
|
|
for file in files:
|
|
|
file["path"] = file["path"] if "datasets/" not in file["path"] else os.path.join(root_dir, file["path"])
|
|
|
download_url(url=file["url"], filepath=file["path"])
|
|
|
|
|
|
|
|
|
|
|
|
env_dict = json.load(open(args.environment_file, "r"))
|
|
|
for k, v in env_dict.items():
|
|
|
|
|
|
val = v if "datasets/" not in v else os.path.join(root_dir, v)
|
|
|
setattr(args, k, val)
|
|
|
print(f"{k}: {val}")
|
|
|
print("Global config variables have been loaded.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config_dict = json.load(open(args.config_file, "r"))
|
|
|
for k, v in config_dict.items():
|
|
|
setattr(args, k, v)
|
|
|
|
|
|
|
|
|
config_infer_dict = json.load(open(args.inference_file, "r"))
|
|
|
|
|
|
if "autoencoder_tp_num_splits" in config_infer_dict:
|
|
|
args.autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"]
|
|
|
args.mask_generation_autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"]
|
|
|
for k, v in config_infer_dict.items():
|
|
|
setattr(args, k, v)
|
|
|
print(f"{k}: {v}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.extra_config_file is not None:
|
|
|
extra_config_dict = json.load(open(args.extra_config_file, "r"))
|
|
|
for k, v in extra_config_dict.items():
|
|
|
setattr(args, k, v)
|
|
|
print(f"{k}: {v}")
|
|
|
|
|
|
check_input(
|
|
|
args.body_region,
|
|
|
args.anatomy_list,
|
|
|
args.label_dict_json,
|
|
|
args.output_size,
|
|
|
args.spacing,
|
|
|
args.controllable_anatomy_size,
|
|
|
)
|
|
|
latent_shape = [args.latent_channels, args.output_size[0] // 4, args.output_size[1] // 4, args.output_size[2] // 4]
|
|
|
print("Network definition and inference inputs have been loaded.")
|
|
|
|
|
|
|
|
|
|
|
|
noise_scheduler = define_instance(args, "noise_scheduler")
|
|
|
mask_generation_noise_scheduler = define_instance(args, "mask_generation_noise_scheduler")
|
|
|
|
|
|
device = torch.device("cuda")
|
|
|
|
|
|
autoencoder = define_instance(args, "autoencoder").to(device)
|
|
|
checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True)
|
|
|
autoencoder.load_state_dict(checkpoint_autoencoder)
|
|
|
|
|
|
diffusion_unet = define_instance(args, "diffusion_unet").to(device)
|
|
|
checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path, weights_only=False)
|
|
|
diffusion_unet.load_state_dict(checkpoint_diffusion_unet["unet_state_dict"], strict=True)
|
|
|
scale_factor = checkpoint_diffusion_unet["scale_factor"].to(device)
|
|
|
|
|
|
controlnet = define_instance(args, "controlnet").to(device)
|
|
|
checkpoint_controlnet = torch.load(args.trained_controlnet_path, weights_only=False)
|
|
|
monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict())
|
|
|
controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True)
|
|
|
|
|
|
mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder").to(device)
|
|
|
checkpoint_mask_generation_autoencoder = torch.load(
|
|
|
args.trained_mask_generation_autoencoder_path, weights_only=True
|
|
|
)
|
|
|
mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)
|
|
|
|
|
|
mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion").to(device)
|
|
|
checkpoint_mask_generation_diffusion_unet = torch.load(
|
|
|
args.trained_mask_generation_diffusion_path, weights_only=False
|
|
|
)
|
|
|
mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet["unet_state_dict"])
|
|
|
mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet["scale_factor"]
|
|
|
|
|
|
print("All the trained model weights have been loaded.")
|
|
|
|
|
|
|
|
|
ldm_sampler = LDMSampler(
|
|
|
args.body_region,
|
|
|
args.anatomy_list,
|
|
|
args.all_mask_files_json,
|
|
|
args.all_anatomy_size_conditions_json,
|
|
|
args.all_mask_files_base_dir,
|
|
|
args.label_dict_json,
|
|
|
args.label_dict_remap_json,
|
|
|
autoencoder,
|
|
|
diffusion_unet,
|
|
|
controlnet,
|
|
|
noise_scheduler,
|
|
|
scale_factor,
|
|
|
mask_generation_autoencoder,
|
|
|
mask_generation_diffusion_unet,
|
|
|
mask_generation_scale_factor,
|
|
|
mask_generation_noise_scheduler,
|
|
|
device,
|
|
|
latent_shape,
|
|
|
args.mask_generation_latent_shape,
|
|
|
args.output_size,
|
|
|
args.output_dir,
|
|
|
args.controllable_anatomy_size,
|
|
|
image_output_ext=args.image_output_ext,
|
|
|
label_output_ext=args.label_output_ext,
|
|
|
spacing=args.spacing,
|
|
|
modality=args.modality,
|
|
|
num_inference_steps=args.num_inference_steps,
|
|
|
mask_generation_num_inference_steps=args.mask_generation_num_inference_steps,
|
|
|
random_seed=args.random_seed,
|
|
|
autoencoder_sliding_window_infer_size=args.autoencoder_sliding_window_infer_size,
|
|
|
autoencoder_sliding_window_infer_overlap=args.autoencoder_sliding_window_infer_overlap,
|
|
|
)
|
|
|
|
|
|
print(f"The generated image/mask pairs will be saved in {args.output_dir}.")
|
|
|
output_filenames = ldm_sampler.sample_multiple_images(args.num_output_samples)
|
|
|
print("MAISI image/mask generation finished")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
logging.basicConfig(
|
|
|
stream=sys.stdout,
|
|
|
level=logging.INFO,
|
|
|
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
|
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
|
)
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
main()
|
|
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
|
|
print(f"Peak GPU memory usage: {peak_memory_gb:.2f} GB")
|
|
|
|