File size: 4,838 Bytes
457db56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Generate and save an ensemble of image samples from a trained CIMD model.
"""
import argparse
import os
import random
import sys

import numpy as np
import torch
import torch as th
from PIL import Image
from tqdm import tqdm

# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from guided_diffusion import dist_util, logger
from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)

# Set a seed for reproducibility
seed = 42
th.manual_seed(seed)
th.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist()
    logger.configure(dir=args.output_dir)

    logger.log("creating model and diffusion...")
    model, diffusion, _, _ = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )

    logger.log(f"loading model from: {args.model_path}")
    state_dict = dist_util.load_state_dict(args.model_path, map_location="cpu")
    
    # This logic was in your original sample script, keeping it in case it's needed for older checkpoints
    if "train_2023-05-24-16-53-06" in args.model_path:
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)

    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    logger.log("loading data...")
    ds = CustomLIDCDataset(
        data_root=args.data_dir,
        split="test",
        image_size=args.image_size,
        dataset_type=args.dataset_type,
        split_strategy=args.split_strategy,
    )

    datal = th.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=False)
    logger.log(f"found {len(ds)} images for sampling...")

    # Create the directory for samples if it doesn't exist
    samples_save_dir = os.path.join(args.output_dir, "samples")
    os.makedirs(samples_save_dir, exist_ok=True)
    logger.log(f"saving samples to {samples_save_dir}")

    for i, (batch, label_gts, image_ids) in enumerate(
        tqdm(datal, desc="Sampling", total=len(datal))
    ):
        image_id = str(image_ids[0]) if image_ids else f"image_{i:04d}"
        logger.log(f"sampling for {image_id}...")

        # The model expects the input image to be replicated for each ensemble member
        model_input = batch.repeat_interleave(args.num_ensemble, dim=0)

        # Create a random noise channel to start the diffusion process
        noise_channel = th.randn_like(model_input[:, :1, ...])
        
        # The input to the sampling function is (image, noise)
        img_and_noise = th.cat((model_input, noise_channel), dim=1)

        model_kwargs = {}
        sample_fn = (
            diffusion.p_sample_loop_known if not args.use_ddim else diffusion.ddim_sample_loop_known
        )

        # Generate the ensemble of samples
        sample, _, _ = sample_fn(
            model,
            (model_input.shape[0], 1, args.image_size, args.image_size),
            img_and_noise,
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
        )

        # Process and save each sample in the ensemble
        for j in range(args.num_ensemble):
            # Isolate one sample
            single_sample = sample[j].squeeze()

            # Normalize to 0-1 range and threshold to get a binary mask
            single_sample = (single_sample - single_sample.min()) / (single_sample.max() - single_sample.min())
            binary_mask = (single_sample > 0.5).cpu().numpy().astype(np.uint8)
            
            # Convert to image format (0-255)
            mask_image = Image.fromarray(binary_mask * 255)
            
            # Save the image with a clear naming convention
            output_filename = f"{image_id}_sample_{j:02d}.png"
            mask_image.save(os.path.join(samples_save_dir, output_filename))

    logger.log("sampling complete.")

def create_argparser():
    defaults = dict(
        data_dir="/path/to/your/data/test", # e.g., /.../lidc_npy/test
        dataset_type="lidc",
        split_strategy="all_annotations",
        clip_denoised=True,
        batch_size=1, # Process one image at a time
        use_ddim=False,
        model_path="", # IMPORTANT: Path to your trained .pt model
        output_dir="./results/samples/", # Directory to save results and samples
        num_ensemble=16,  # Number of samples to generate per image
    )
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser

if __name__ == "__main__":
    main()