|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
from datetime import datetime |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from monai.data import MetaTensor, decollate_batch |
|
|
from monai.networks.utils import copy_model_state |
|
|
from monai.transforms import SaveImage |
|
|
from monai.utils import RankFilter |
|
|
|
|
|
from .sample import check_input, ldm_conditional_sample_one_image |
|
|
from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp, prepare_maisi_controlnet_infer_dataloader |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="maisi.controlnet.infer") |
|
|
parser.add_argument( |
|
|
"-e", |
|
|
"--environment-file", |
|
|
default="./configs/environment_maisi_controlnet_train.json", |
|
|
help="environment json file that stores environment path", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-c", |
|
|
"--config-file", |
|
|
default="./configs/config_maisi.json", |
|
|
help="config json file that stores network hyper-parameters", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-t", |
|
|
"--training-config", |
|
|
default="./configs/config_maisi_controlnet_train.json", |
|
|
help="config json file that stores training hyper-parameters", |
|
|
) |
|
|
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
logger = logging.getLogger("maisi.controlnet.infer") |
|
|
|
|
|
use_ddp = args.gpus > 1 |
|
|
if use_ddp: |
|
|
rank = int(os.environ["LOCAL_RANK"]) |
|
|
world_size = int(os.environ["WORLD_SIZE"]) |
|
|
device = setup_ddp(rank, world_size) |
|
|
logger.addFilter(RankFilter()) |
|
|
else: |
|
|
rank = 0 |
|
|
world_size = 1 |
|
|
device = torch.device(f"cuda:{rank}") |
|
|
|
|
|
torch.cuda.set_device(device) |
|
|
logger.info(f"Number of GPUs: {torch.cuda.device_count()}") |
|
|
logger.info(f"World_size: {world_size}") |
|
|
|
|
|
with open(args.environment_file, "r") as env_file: |
|
|
env_dict = json.load(env_file) |
|
|
with open(args.config_file, "r") as config_file: |
|
|
config_dict = json.load(config_file) |
|
|
with open(args.training_config, "r") as training_config_file: |
|
|
training_config_dict = json.load(training_config_file) |
|
|
|
|
|
for k, v in env_dict.items(): |
|
|
setattr(args, k, v) |
|
|
for k, v in config_dict.items(): |
|
|
setattr(args, k, v) |
|
|
for k, v in training_config_dict.items(): |
|
|
setattr(args, k, v) |
|
|
|
|
|
|
|
|
val_loader = prepare_maisi_controlnet_infer_dataloader( |
|
|
json_data_list=args.json_data_list, |
|
|
data_base_dir=args.data_base_dir, |
|
|
batch_size=args.controlnet_train["batch_size"], |
|
|
cache_rate=args.controlnet_train["cache_rate"], |
|
|
rank=rank, |
|
|
world_size=world_size,) |
|
|
|
|
|
|
|
|
|
|
|
autoencoder = define_instance(args, "autoencoder_def").to(device) |
|
|
|
|
|
if args.trained_autoencoder_path is not None: |
|
|
if not os.path.exists(args.trained_autoencoder_path): |
|
|
raise ValueError("Please download the autoencoder checkpoint.") |
|
|
autoencoder_ckpt = torch.load(args.trained_autoencoder_path, weights_only=True) |
|
|
autoencoder.load_state_dict(autoencoder_ckpt) |
|
|
logger.info(f"Load trained diffusion model from {args.trained_autoencoder_path}.") |
|
|
else: |
|
|
logger.info("trained autoencoder model is not loaded.") |
|
|
|
|
|
|
|
|
unet = define_instance(args, "diffusion_unet_def").to(device) |
|
|
include_body_region = unet.include_top_region_index_input |
|
|
include_modality = unet.num_class_embeds is not None |
|
|
|
|
|
|
|
|
if args.trained_diffusion_path is not None: |
|
|
if not os.path.exists(args.trained_diffusion_path): |
|
|
raise ValueError("Please download the trained diffusion unet checkpoint.") |
|
|
diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False) |
|
|
unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) |
|
|
|
|
|
scale_factor = diffusion_model_ckpt["scale_factor"] |
|
|
logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") |
|
|
logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") |
|
|
else: |
|
|
logger.info("trained diffusion model is not loaded.") |
|
|
scale_factor = 1.0 |
|
|
logger.info(f"set scale_factor -> {scale_factor}.") |
|
|
|
|
|
|
|
|
controlnet = define_instance(args, "controlnet_def").to(device) |
|
|
|
|
|
copy_model_state(controlnet, unet.state_dict()) |
|
|
|
|
|
if args.trained_controlnet_path is not None: |
|
|
if not os.path.exists(args.trained_controlnet_path): |
|
|
raise ValueError("Please download the trained ControlNet checkpoint.") |
|
|
controlnet.load_state_dict( |
|
|
torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"] |
|
|
) |
|
|
logger.info(f"load trained controlnet model from {args.trained_controlnet_path}") |
|
|
else: |
|
|
logger.info("trained controlnet is not loaded.") |
|
|
|
|
|
noise_scheduler = define_instance(args, "noise_scheduler") |
|
|
|
|
|
|
|
|
autoencoder.eval() |
|
|
controlnet.eval() |
|
|
unet.eval() |
|
|
|
|
|
for batch in val_loader: |
|
|
|
|
|
|
|
|
labels = batch["label"].to(device) |
|
|
|
|
|
if include_body_region: |
|
|
top_region_index_tensor = batch["top_region_index"].to(device) |
|
|
bottom_region_index_tensor = batch["bottom_region_index"].to(device) |
|
|
else: |
|
|
top_region_index_tensor = None |
|
|
bottom_region_index_tensor = None |
|
|
spacing_tensor = batch["spacing"].to(device) |
|
|
modality_tensor = args.controlnet_infer["modality"] * torch.ones((len(labels),), dtype=torch.long).to(device) |
|
|
out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist()) |
|
|
|
|
|
dim = batch["dim"] |
|
|
output_size = (dim[0].item(), dim[1].item(), dim[2].item()) |
|
|
latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4) |
|
|
|
|
|
check_input(None, None, None, output_size, out_spacing, None) |
|
|
|
|
|
synthetic_images, _ = ldm_conditional_sample_one_image( |
|
|
autoencoder=autoencoder, |
|
|
diffusion_unet=unet, |
|
|
controlnet=controlnet, |
|
|
noise_scheduler=noise_scheduler, |
|
|
scale_factor=scale_factor, |
|
|
device=device, |
|
|
combine_label_or=labels, |
|
|
top_region_index_tensor=top_region_index_tensor, |
|
|
bottom_region_index_tensor=bottom_region_index_tensor, |
|
|
spacing_tensor=spacing_tensor, |
|
|
modality_tensor=modality_tensor, |
|
|
latent_shape=latent_shape, |
|
|
output_size=output_size, |
|
|
noise_factor=1.0, |
|
|
num_inference_steps=args.controlnet_infer["num_inference_steps"], |
|
|
autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"], |
|
|
autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"], |
|
|
) |
|
|
|
|
|
labels = decollate_batch(batch)[0]["label"] |
|
|
real_object_name = labels.meta.get("filename_or_obj", "default_name.nii.gz") |
|
|
labels.meta["filename_or_obj"] = real_object_name |
|
|
output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") |
|
|
synthetic_images = MetaTensor(synthetic_images.squeeze(0), meta=labels.meta) |
|
|
img_saver = SaveImage( |
|
|
output_dir=args.output_dir, |
|
|
output_postfix="image", |
|
|
separate_folder=False, |
|
|
) |
|
|
img_saver(synthetic_images) |
|
|
label_saver = SaveImage( |
|
|
output_dir=args.output_dir, |
|
|
output_postfix="label", |
|
|
separate_folder=False, |
|
|
) |
|
|
label_saver(labels) |
|
|
if use_ddp: |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
logging.basicConfig( |
|
|
stream=sys.stdout, |
|
|
level=logging.INFO, |
|
|
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", |
|
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
|
) |
|
|
main() |
|
|
|