NoMAISI / scripts /infer_test_controlnet.py
ft42's picture
Upload 63 files
599a397 verified
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import os
import sys
from datetime import datetime
import torch
import torch.distributed as dist
from monai.data import MetaTensor, decollate_batch
from monai.networks.utils import copy_model_state
from monai.transforms import SaveImage
from monai.utils import RankFilter
from .sample import check_input, ldm_conditional_sample_one_image
from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp, prepare_maisi_controlnet_infer_dataloader
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser(description="maisi.controlnet.infer")
parser.add_argument(
"-e",
"--environment-file",
default="./configs/environment_maisi_controlnet_train.json",
help="environment json file that stores environment path",
)
parser.add_argument(
"-c",
"--config-file",
default="./configs/config_maisi.json",
help="config json file that stores network hyper-parameters",
)
parser.add_argument(
"-t",
"--training-config",
default="./configs/config_maisi_controlnet_train.json",
help="config json file that stores training hyper-parameters",
)
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
args = parser.parse_args()
# Step 0: configuration
logger = logging.getLogger("maisi.controlnet.infer")
# whether to use distributed data parallel
use_ddp = args.gpus > 1
if use_ddp:
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = setup_ddp(rank, world_size)
logger.addFilter(RankFilter())
else:
rank = 0
world_size = 1
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
logger.info(f"World_size: {world_size}")
with open(args.environment_file, "r") as env_file:
env_dict = json.load(env_file)
with open(args.config_file, "r") as config_file:
config_dict = json.load(config_file)
with open(args.training_config, "r") as training_config_file:
training_config_dict = json.load(training_config_file)
for k, v in env_dict.items():
setattr(args, k, v)
for k, v in config_dict.items():
setattr(args, k, v)
for k, v in training_config_dict.items():
setattr(args, k, v)
# Step 1: set data loader
val_loader = prepare_maisi_controlnet_infer_dataloader(
json_data_list=args.json_data_list,
data_base_dir=args.data_base_dir,
batch_size=args.controlnet_train["batch_size"],
cache_rate=args.controlnet_train["cache_rate"],
rank=rank,
world_size=world_size,)
# Step 2: define AE, diffusion model and controlnet
# define AE
autoencoder = define_instance(args, "autoencoder_def").to(device)
# load trained autoencoder model
if args.trained_autoencoder_path is not None:
if not os.path.exists(args.trained_autoencoder_path):
raise ValueError("Please download the autoencoder checkpoint.")
autoencoder_ckpt = torch.load(args.trained_autoencoder_path, weights_only=True)
autoencoder.load_state_dict(autoencoder_ckpt)
logger.info(f"Load trained diffusion model from {args.trained_autoencoder_path}.")
else:
logger.info("trained autoencoder model is not loaded.")
# define diffusion Model
unet = define_instance(args, "diffusion_unet_def").to(device)
include_body_region = unet.include_top_region_index_input
include_modality = unet.num_class_embeds is not None
# load trained diffusion model
if args.trained_diffusion_path is not None:
if not os.path.exists(args.trained_diffusion_path):
raise ValueError("Please download the trained diffusion unet checkpoint.")
diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False)
unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"])
# load scale factor from diffusion model checkpoint
scale_factor = diffusion_model_ckpt["scale_factor"]
logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.")
logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.")
else:
logger.info("trained diffusion model is not loaded.")
scale_factor = 1.0
logger.info(f"set scale_factor -> {scale_factor}.")
# define ControlNet
controlnet = define_instance(args, "controlnet_def").to(device)
# copy weights from the DM to the controlnet
copy_model_state(controlnet, unet.state_dict())
# load trained controlnet model if it is provided
if args.trained_controlnet_path is not None:
if not os.path.exists(args.trained_controlnet_path):
raise ValueError("Please download the trained ControlNet checkpoint.")
controlnet.load_state_dict(
torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"]
)
logger.info(f"load trained controlnet model from {args.trained_controlnet_path}")
else:
logger.info("trained controlnet is not loaded.")
noise_scheduler = define_instance(args, "noise_scheduler")
# Step 3: inference
autoencoder.eval()
controlnet.eval()
unet.eval()
for batch in val_loader:
# get label mask
labels = batch["label"].to(device)
# get corresponding conditions
if include_body_region:
top_region_index_tensor = batch["top_region_index"].to(device)
bottom_region_index_tensor = batch["bottom_region_index"].to(device)
else:
top_region_index_tensor = None
bottom_region_index_tensor = None
spacing_tensor = batch["spacing"].to(device)
modality_tensor = args.controlnet_infer["modality"] * torch.ones((len(labels),), dtype=torch.long).to(device)
out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist())
# get target dimension
dim = batch["dim"]
output_size = (dim[0].item(), dim[1].item(), dim[2].item())
latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4)
# check if output_size and out_spacing are valid.
check_input(None, None, None, output_size, out_spacing, None)
# generate a single synthetic image using a latent diffusion model with controlnet.
synthetic_images, _ = ldm_conditional_sample_one_image(
autoencoder=autoencoder,
diffusion_unet=unet,
controlnet=controlnet,
noise_scheduler=noise_scheduler,
scale_factor=scale_factor,
device=device,
combine_label_or=labels,
top_region_index_tensor=top_region_index_tensor,
bottom_region_index_tensor=bottom_region_index_tensor,
spacing_tensor=spacing_tensor,
modality_tensor=modality_tensor,
latent_shape=latent_shape,
output_size=output_size,
noise_factor=1.0,
num_inference_steps=args.controlnet_infer["num_inference_steps"],
autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"],
autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"],
)
# save image/label pairs
labels = decollate_batch(batch)[0]["label"]
real_object_name = labels.meta.get("filename_or_obj", "default_name.nii.gz")
labels.meta["filename_or_obj"] = real_object_name
output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
synthetic_images = MetaTensor(synthetic_images.squeeze(0), meta=labels.meta)
img_saver = SaveImage(
output_dir=args.output_dir,
output_postfix="image",
separate_folder=False,
)
img_saver(synthetic_images)
label_saver = SaveImage(
output_dir=args.output_dir,
output_postfix="label",
separate_folder=False,
)
label_saver(labels)
if use_ddp:
dist.destroy_process_group()
if __name__ == "__main__":
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
main()