NoMAISI / scripts /inference.py
ft42's picture
Upload 63 files
599a397 verified
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# # MAISI Inference Script
import argparse
import json
import logging
import os
import sys
import tempfile
import monai
import torch
from monai.apps import download_url
from monai.config import print_config
from monai.transforms import LoadImage, Orientation
from monai.utils import set_determinism
from scripts.sample import LDMSampler, check_input
from scripts.utils import define_instance
from scripts.utils_plot import find_label_center_loc, get_xyz_plot, show_image
def main():
parser = argparse.ArgumentParser(description="maisi.controlnet.training")
parser.add_argument(
"-e",
"--environment-file",
default="./configs/environment.json",
help="environment json file that stores environment path",
)
parser.add_argument(
"-c",
"--config-file",
default="./configs/config_maisi.json",
help="config json file that stores network hyper-parameters",
)
parser.add_argument(
"-i",
"--inference-file",
default="./configs/config_infer.json",
help="config json file that stores inference hyper-parameters",
)
parser.add_argument(
"-x",
"--extra-config-file",
default=None,
help="config json file that stores inference extra parameters",
)
parser.add_argument(
"-s",
"--random-seed",
default=None,
help="random seed, can be None or int",
)
parser.add_argument(
"--version",
default="maisi3d-rflow",
type=str,
help="maisi_version, choose from ['maisi3d-ddpm', 'maisi3d-rflow']",
)
args = parser.parse_args()
# Step 0: configuration
logger = logging.getLogger("maisi.inference")
maisi_version = args.version
# ## Set deterministic training for reproducibility
if args.random_seed is not None:
set_determinism(seed=args.random_seed)
# ## Setup data directory
# You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.
# This allows you to save results and reuse downloads.
# If not specified a temporary directory will be used.
directory = os.environ.get("MONAI_DATA_DIRECTORY")
if directory is not None:
os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
# TODO: remove the `files` after the files are uploaded to the NGC
files = [
{
"path": "models/autoencoder_epoch273.pt",
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials"
"/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt",
},
{
"path": "models/mask_generation_autoencoder.pt",
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
"/tutorials/mask_generation_autoencoder.pt",
},
{
"path": "models/mask_generation_diffusion_unet.pt",
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
"/tutorials/model_zoo/model_maisi_mask_generation_diffusion_unet_v2.pt",
},
{
"path": "configs/all_anatomy_size_condtions.json",
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/all_anatomy_size_condtions.json",
},
{
"path": "datasets/all_masks_flexible_size_and_spacing_4000.zip",
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
"/tutorials/all_masks_flexible_size_and_spacing_4000.zip",
},
]
if maisi_version == "maisi3d-ddpm":
files += [
{
"path": "models/diff_unet_3d_ddpm.pt",
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo"
"/model_maisi_input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1_alternative.pt",
},
{
"path": "models/controlnet_3d_ddpm.pt",
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo"
"/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt",
},
{
"path": "configs/candidate_masks_flexible_size_and_spacing_3000.json",
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
"/tutorials/candidate_masks_flexible_size_and_spacing_3000.json",
},
]
elif maisi_version == "maisi3d-rflow":
files += [
{
"path": "models/diff_unet_3d_rflow.pt",
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/"
"diff_unet_ckpt_rflow_epoch19350.pt",
},
{
"path": "models/controlnet_3d_rflow.pt",
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/"
"controlnet_rflow_epoch60.pt",
},
{
"path": "configs/candidate_masks_flexible_size_and_spacing_4000.json",
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
"/tutorials/candidate_masks_flexible_size_and_spacing_4000.json",
},
]
else:
raise ValueError(
f"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}."
)
for file in files:
file["path"] = file["path"] if "datasets/" not in file["path"] else os.path.join(root_dir, file["path"])
download_url(url=file["url"], filepath=file["path"])
# ## Read in environment setting, including data directory, model directory, and output directory
# The information for data directory, model directory, and output directory are saved in ./configs/environment.json
env_dict = json.load(open(args.environment_file, "r"))
for k, v in env_dict.items():
# Update the path to the downloaded dataset in MONAI_DATA_DIRECTORY
val = v if "datasets/" not in v else os.path.join(root_dir, v)
setattr(args, k, val)
print(f"{k}: {val}")
print("Global config variables have been loaded.")
# ## Read in configuration setting, including network definition, body region and anatomy to generate, etc.
#
# The information for the inference input, like body region and anatomy to generate, is stored in "./configs/config_infer.json".
# Please refer to README.md for the details.
config_dict = json.load(open(args.config_file, "r"))
for k, v in config_dict.items():
setattr(args, k, v)
# check the format of inference inputs
config_infer_dict = json.load(open(args.inference_file, "r"))
# override num_split if asked
if "autoencoder_tp_num_splits" in config_infer_dict:
args.autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"]
args.mask_generation_autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"]
for k, v in config_infer_dict.items():
setattr(args, k, v)
print(f"{k}: {v}")
#
# ## Read in optional extra configuration setting - typically acceleration options (TRT)
#
#
if args.extra_config_file is not None:
extra_config_dict = json.load(open(args.extra_config_file, "r"))
for k, v in extra_config_dict.items():
setattr(args, k, v)
print(f"{k}: {v}")
check_input(
args.body_region,
args.anatomy_list,
args.label_dict_json,
args.output_size,
args.spacing,
args.controllable_anatomy_size,
)
latent_shape = [args.latent_channels, args.output_size[0] // 4, args.output_size[1] // 4, args.output_size[2] // 4]
print("Network definition and inference inputs have been loaded.")
# ## Initialize networks and noise scheduler, then load the trained model weights.
# The networks and noise scheduler are defined in `config_file`. We will read them in and load the model weights.
noise_scheduler = define_instance(args, "noise_scheduler")
mask_generation_noise_scheduler = define_instance(args, "mask_generation_noise_scheduler")
device = torch.device("cuda")
autoencoder = define_instance(args, "autoencoder").to(device)
checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True)
autoencoder.load_state_dict(checkpoint_autoencoder)
diffusion_unet = define_instance(args, "diffusion_unet").to(device)
checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path, weights_only=False)
diffusion_unet.load_state_dict(checkpoint_diffusion_unet["unet_state_dict"], strict=True)
scale_factor = checkpoint_diffusion_unet["scale_factor"].to(device)
controlnet = define_instance(args, "controlnet").to(device)
checkpoint_controlnet = torch.load(args.trained_controlnet_path, weights_only=False)
monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict())
controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True)
mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder").to(device)
checkpoint_mask_generation_autoencoder = torch.load(
args.trained_mask_generation_autoencoder_path, weights_only=True
)
mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)
mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion").to(device)
checkpoint_mask_generation_diffusion_unet = torch.load(
args.trained_mask_generation_diffusion_path, weights_only=False
)
mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet["unet_state_dict"])
mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet["scale_factor"]
print("All the trained model weights have been loaded.")
# ## Define the LDM Sampler, which contains functions that will perform the inference.
ldm_sampler = LDMSampler(
args.body_region,
args.anatomy_list,
args.all_mask_files_json,
args.all_anatomy_size_conditions_json,
args.all_mask_files_base_dir,
args.label_dict_json,
args.label_dict_remap_json,
autoencoder,
diffusion_unet,
controlnet,
noise_scheduler,
scale_factor,
mask_generation_autoencoder,
mask_generation_diffusion_unet,
mask_generation_scale_factor,
mask_generation_noise_scheduler,
device,
latent_shape,
args.mask_generation_latent_shape,
args.output_size,
args.output_dir,
args.controllable_anatomy_size,
image_output_ext=args.image_output_ext,
label_output_ext=args.label_output_ext,
spacing=args.spacing,
modality=args.modality,
num_inference_steps=args.num_inference_steps,
mask_generation_num_inference_steps=args.mask_generation_num_inference_steps,
random_seed=args.random_seed,
autoencoder_sliding_window_infer_size=args.autoencoder_sliding_window_infer_size,
autoencoder_sliding_window_infer_overlap=args.autoencoder_sliding_window_infer_overlap,
)
print(f"The generated image/mask pairs will be saved in {args.output_dir}.")
output_filenames = ldm_sampler.sample_multiple_images(args.num_output_samples)
print("MAISI image/mask generation finished")
if __name__ == "__main__":
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
torch.cuda.reset_peak_memory_stats()
main()
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3) # Convert to GB
print(f"Peak GPU memory usage: {peak_memory_gb:.2f} GB")