""" Generate and save an ensemble of image samples from a trained CIMD model. """ import argparse import os import random import sys import numpy as np import torch import torch as th from PIL import Image from tqdm import tqdm # Add project root to path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from guided_diffusion import dist_util, logger from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset from guided_diffusion.script_util import ( model_and_diffusion_defaults, create_model_and_diffusion, add_dict_to_argparser, args_to_dict, ) # Set a seed for reproducibility seed = 42 th.manual_seed(seed) th.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure(dir=args.output_dir) logger.log("creating model and diffusion...") model, diffusion, _, _ = create_model_and_diffusion( **args_to_dict(args, model_and_diffusion_defaults().keys()) ) logger.log(f"loading model from: {args.model_path}") state_dict = dist_util.load_state_dict(args.model_path, map_location="cpu") # This logic was in your original sample script, keeping it in case it's needed for older checkpoints if "train_2023-05-24-16-53-06" in args.model_path: state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.to(dist_util.dev()) if args.use_fp16: model.convert_to_fp16() model.eval() logger.log("loading data...") ds = CustomLIDCDataset( data_root=args.data_dir, split="test", image_size=args.image_size, dataset_type=args.dataset_type, split_strategy=args.split_strategy, ) datal = th.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=False) logger.log(f"found {len(ds)} images for sampling...") # Create the directory for samples if it doesn't exist samples_save_dir = os.path.join(args.output_dir, "samples") os.makedirs(samples_save_dir, exist_ok=True) logger.log(f"saving samples to {samples_save_dir}") for i, (batch, label_gts, image_ids) in enumerate( tqdm(datal, desc="Sampling", total=len(datal)) ): image_id = str(image_ids[0]) if image_ids else f"image_{i:04d}" logger.log(f"sampling for {image_id}...") # The model expects the input image to be replicated for each ensemble member model_input = batch.repeat_interleave(args.num_ensemble, dim=0) # Create a random noise channel to start the diffusion process noise_channel = th.randn_like(model_input[:, :1, ...]) # The input to the sampling function is (image, noise) img_and_noise = th.cat((model_input, noise_channel), dim=1) model_kwargs = {} sample_fn = ( diffusion.p_sample_loop_known if not args.use_ddim else diffusion.ddim_sample_loop_known ) # Generate the ensemble of samples sample, _, _ = sample_fn( model, (model_input.shape[0], 1, args.image_size, args.image_size), img_and_noise, clip_denoised=args.clip_denoised, model_kwargs=model_kwargs, ) # Process and save each sample in the ensemble for j in range(args.num_ensemble): # Isolate one sample single_sample = sample[j].squeeze() # Normalize to 0-1 range and threshold to get a binary mask single_sample = (single_sample - single_sample.min()) / (single_sample.max() - single_sample.min()) binary_mask = (single_sample > 0.5).cpu().numpy().astype(np.uint8) # Convert to image format (0-255) mask_image = Image.fromarray(binary_mask * 255) # Save the image with a clear naming convention output_filename = f"{image_id}_sample_{j:02d}.png" mask_image.save(os.path.join(samples_save_dir, output_filename)) logger.log("sampling complete.") def create_argparser(): defaults = dict( data_dir="/path/to/your/data/test", # e.g., /.../lidc_npy/test dataset_type="lidc", split_strategy="all_annotations", clip_denoised=True, batch_size=1, # Process one image at a time use_ddim=False, model_path="", # IMPORTANT: Path to your trained .pt model output_dir="./results/samples/", # Directory to save results and samples num_ensemble=16, # Number of samples to generate per image ) defaults.update(model_and_diffusion_defaults()) parser = argparse.ArgumentParser() add_dict_to_argparser(parser, defaults) return parser if __name__ == "__main__": main()