| """ |
| Generate and save an ensemble of image samples from a trained CIMD model. |
| """ |
| import argparse |
| import os |
| import random |
| import sys |
|
|
| import numpy as np |
| import torch |
| import torch as th |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
| from guided_diffusion import dist_util, logger |
| from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset |
| from guided_diffusion.script_util import ( |
| model_and_diffusion_defaults, |
| create_model_and_diffusion, |
| add_dict_to_argparser, |
| args_to_dict, |
| ) |
|
|
| |
| seed = 42 |
| th.manual_seed(seed) |
| th.cuda.manual_seed_all(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| def main(): |
| args = create_argparser().parse_args() |
|
|
| dist_util.setup_dist() |
| logger.configure(dir=args.output_dir) |
|
|
| logger.log("creating model and diffusion...") |
| model, diffusion, _, _ = create_model_and_diffusion( |
| **args_to_dict(args, model_and_diffusion_defaults().keys()) |
| ) |
|
|
| logger.log(f"loading model from: {args.model_path}") |
| state_dict = dist_util.load_state_dict(args.model_path, map_location="cpu") |
| |
| |
| if "train_2023-05-24-16-53-06" in args.model_path: |
| state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
| model.load_state_dict(state_dict) |
|
|
| model.to(dist_util.dev()) |
| if args.use_fp16: |
| model.convert_to_fp16() |
| model.eval() |
|
|
| logger.log("loading data...") |
| ds = CustomLIDCDataset( |
| data_root=args.data_dir, |
| split="test", |
| image_size=args.image_size, |
| dataset_type=args.dataset_type, |
| split_strategy=args.split_strategy, |
| ) |
|
|
| datal = th.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=False) |
| logger.log(f"found {len(ds)} images for sampling...") |
|
|
| |
| samples_save_dir = os.path.join(args.output_dir, "samples") |
| os.makedirs(samples_save_dir, exist_ok=True) |
| logger.log(f"saving samples to {samples_save_dir}") |
|
|
| for i, (batch, label_gts, image_ids) in enumerate( |
| tqdm(datal, desc="Sampling", total=len(datal)) |
| ): |
| image_id = str(image_ids[0]) if image_ids else f"image_{i:04d}" |
| logger.log(f"sampling for {image_id}...") |
|
|
| |
| model_input = batch.repeat_interleave(args.num_ensemble, dim=0) |
|
|
| |
| noise_channel = th.randn_like(model_input[:, :1, ...]) |
| |
| |
| img_and_noise = th.cat((model_input, noise_channel), dim=1) |
|
|
| model_kwargs = {} |
| sample_fn = ( |
| diffusion.p_sample_loop_known if not args.use_ddim else diffusion.ddim_sample_loop_known |
| ) |
|
|
| |
| sample, _, _ = sample_fn( |
| model, |
| (model_input.shape[0], 1, args.image_size, args.image_size), |
| img_and_noise, |
| clip_denoised=args.clip_denoised, |
| model_kwargs=model_kwargs, |
| ) |
|
|
| |
| for j in range(args.num_ensemble): |
| |
| single_sample = sample[j].squeeze() |
|
|
| |
| single_sample = (single_sample - single_sample.min()) / (single_sample.max() - single_sample.min()) |
| binary_mask = (single_sample > 0.5).cpu().numpy().astype(np.uint8) |
| |
| |
| mask_image = Image.fromarray(binary_mask * 255) |
| |
| |
| output_filename = f"{image_id}_sample_{j:02d}.png" |
| mask_image.save(os.path.join(samples_save_dir, output_filename)) |
|
|
| logger.log("sampling complete.") |
|
|
| def create_argparser(): |
| defaults = dict( |
| data_dir="/path/to/your/data/test", |
| dataset_type="lidc", |
| split_strategy="all_annotations", |
| clip_denoised=True, |
| batch_size=1, |
| use_ddim=False, |
| model_path="", |
| output_dir="./results/samples/", |
| num_ensemble=16, |
| ) |
| defaults.update(model_and_diffusion_defaults()) |
| parser = argparse.ArgumentParser() |
| add_dict_to_argparser(parser, defaults) |
| return parser |
|
|
| if __name__ == "__main__": |
| main() |
|
|