File size: 4,838 Bytes
457db56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """
Generate and save an ensemble of image samples from a trained CIMD model.
"""
import argparse
import os
import random
import sys
import numpy as np
import torch
import torch as th
from PIL import Image
from tqdm import tqdm
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from guided_diffusion import dist_util, logger
from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset
from guided_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
# Set a seed for reproducibility
seed = 42
th.manual_seed(seed)
th.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def main():
args = create_argparser().parse_args()
dist_util.setup_dist()
logger.configure(dir=args.output_dir)
logger.log("creating model and diffusion...")
model, diffusion, _, _ = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
logger.log(f"loading model from: {args.model_path}")
state_dict = dist_util.load_state_dict(args.model_path, map_location="cpu")
# This logic was in your original sample script, keeping it in case it's needed for older checkpoints
if "train_2023-05-24-16-53-06" in args.model_path:
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(dist_util.dev())
if args.use_fp16:
model.convert_to_fp16()
model.eval()
logger.log("loading data...")
ds = CustomLIDCDataset(
data_root=args.data_dir,
split="test",
image_size=args.image_size,
dataset_type=args.dataset_type,
split_strategy=args.split_strategy,
)
datal = th.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=False)
logger.log(f"found {len(ds)} images for sampling...")
# Create the directory for samples if it doesn't exist
samples_save_dir = os.path.join(args.output_dir, "samples")
os.makedirs(samples_save_dir, exist_ok=True)
logger.log(f"saving samples to {samples_save_dir}")
for i, (batch, label_gts, image_ids) in enumerate(
tqdm(datal, desc="Sampling", total=len(datal))
):
image_id = str(image_ids[0]) if image_ids else f"image_{i:04d}"
logger.log(f"sampling for {image_id}...")
# The model expects the input image to be replicated for each ensemble member
model_input = batch.repeat_interleave(args.num_ensemble, dim=0)
# Create a random noise channel to start the diffusion process
noise_channel = th.randn_like(model_input[:, :1, ...])
# The input to the sampling function is (image, noise)
img_and_noise = th.cat((model_input, noise_channel), dim=1)
model_kwargs = {}
sample_fn = (
diffusion.p_sample_loop_known if not args.use_ddim else diffusion.ddim_sample_loop_known
)
# Generate the ensemble of samples
sample, _, _ = sample_fn(
model,
(model_input.shape[0], 1, args.image_size, args.image_size),
img_and_noise,
clip_denoised=args.clip_denoised,
model_kwargs=model_kwargs,
)
# Process and save each sample in the ensemble
for j in range(args.num_ensemble):
# Isolate one sample
single_sample = sample[j].squeeze()
# Normalize to 0-1 range and threshold to get a binary mask
single_sample = (single_sample - single_sample.min()) / (single_sample.max() - single_sample.min())
binary_mask = (single_sample > 0.5).cpu().numpy().astype(np.uint8)
# Convert to image format (0-255)
mask_image = Image.fromarray(binary_mask * 255)
# Save the image with a clear naming convention
output_filename = f"{image_id}_sample_{j:02d}.png"
mask_image.save(os.path.join(samples_save_dir, output_filename))
logger.log("sampling complete.")
def create_argparser():
defaults = dict(
data_dir="/path/to/your/data/test", # e.g., /.../lidc_npy/test
dataset_type="lidc",
split_strategy="all_annotations",
clip_denoised=True,
batch_size=1, # Process one image at a time
use_ddim=False,
model_path="", # IMPORTANT: Path to your trained .pt model
output_dir="./results/samples/", # Directory to save results and samples
num_ensemble=16, # Number of samples to generate per image
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
|