File size: 5,675 Bytes
457db56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
import os
import sys
import torch as th
from PIL import Image
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from collections import OrderedDict
from tqdm import tqdm

# Add project root to path to allow importing 'guided_diffusion'
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from guided_diffusion import dist_util, logger
from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)

def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist()
    local_rank = int(os.environ["LOCAL_RANK"])
    th.cuda.set_device(local_rank)

    if dist_util.get_rank() == 0:
        os.makedirs(args.output_dir, exist_ok=True)
    logger.configure(dir=args.output_dir)

    logger.log("creating model and diffusion...")
    # Create the CORRECT model structure (with out_channels=1)
    model, diffusion, _, _ = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    
    logger.log(f"loading model from: {args.model_path}")
    state_dict = dist_util.load_state_dict(args.model_path, map_location="cpu")
    
    # --- FIX for loading DDP-trained model with mismatched output channels ---
    # 1. Strip the "module." prefix that DDP adds during training.
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:] # remove `module.`
            new_state_dict[name] = v
        else:
            new_state_dict[k] = v
            
    # 2. Manually adapt the final layer from 2 channels (in checkpoint) to 1 channel (in current model).
    if 'out.2.weight' in new_state_dict and new_state_dict['out.2.weight'].shape[0] == 2:
        if dist_util.get_rank() == 0:
            logger.log("Adapting checkpoint from 2 output channels to 1...")
        new_state_dict['out.2.weight'] = new_state_dict['out.2.weight'][:1, ...] # Take the first channel's weights
        new_state_dict['out.2.bias'] = new_state_dict['out.2.bias'][:1, ...]     # Take the first channel's bias
        
    model.load_state_dict(new_state_dict)
    # --- END OF FIX ---
    
    model.to(dist_util.dev())
    model = DDP(model, device_ids=[local_rank])
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    logger.log("loading data...")
    dataset = CustomLIDCDataset(
        data_root=args.data_dir,
        split="test",
        image_size=args.image_size,
        dataset_type=args.dataset_type,
        split_strategy=args.split_strategy,
    )
    sampler = DistributedSampler(dataset, shuffle=False)
    dataloader = th.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4)
    
    samples_save_dir = os.path.join(args.output_dir, "samples")
    if dist_util.get_rank() == 0:
        os.makedirs(samples_save_dir, exist_ok=True)
    
    logger.log(f"saving samples to {samples_save_dir}")
    
    # Wrap the dataloader with tqdm on the main process (rank 0) for a clean progress bar
    dataloader_tqdm = tqdm(dataloader, desc="Sampling Progress") if dist_util.get_rank() == 0 else dataloader
    
    for batch, _, image_ids in dataloader_tqdm:
        if not image_ids or not image_ids[0]:
            continue
        image_id = str(image_ids[0])
        
        model_input = batch.to(dist_util.dev()).repeat_interleave(args.num_ensemble, dim=0)
        noise_channel = th.randn_like(model_input[:, :1, ...])
        img_and_noise = th.cat((model_input, noise_channel), dim=1)

        sample_fn = diffusion.p_sample_loop_known
        sample_tensor, _, _ = sample_fn(
            model.module, # Use model.module to access the raw model inside the DDP wrapper
            (model_input.shape[0], 1, args.image_size, args.image_size),
            img_and_noise,
            clip_denoised=args.clip_denoised,
            model_kwargs={}
        )

        all_masks = [(s.squeeze().cpu().numpy() > 0.5).astype(np.uint8) for s in sample_tensor]
        
        # All ranks save their own portion of the data
        for j, mask in enumerate(all_masks):
            Image.fromarray(mask * 255).save(os.path.join(samples_save_dir, f"{image_id}_sample_{j:02d}.png"))
        
        # Only rank 0 saves the summary images to avoid file write conflicts
        if dist_util.get_rank() == 0 and all_masks:
            masks_stack = np.stack(all_masks, axis=0)
            avg_map = np.mean(masks_stack, axis=0)
            Image.fromarray((avg_map * 255).astype(np.uint8)).save(os.path.join(samples_save_dir, f"{image_id}_ensemble_avg.png"))
            vote_mask = (np.sum(masks_stack, axis=0) > args.num_ensemble / 2).astype(np.uint8)
            Image.fromarray(vote_mask * 255).save(os.path.join(samples_save_dir, f"{image_id}_majority_vote.png"))

    dist_util.barrier() # Wait for all processes to finish before logging completion
    logger.log("sampling complete.")

def create_argparser():
    defaults = dict(
        data_dir="./data/LIDC",
        dataset_type="lidc",
        split_strategy="all_annotations",
        clip_denoised=True,
        batch_size=1,
        use_ddim=False,
        model_path="",
        output_dir="./results/samples/",
        num_ensemble=16,
    )
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser

if __name__ == "__main__":
    main()