File size: 5,675 Bytes
457db56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import argparse
import os
import sys
import torch as th
from PIL import Image
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from collections import OrderedDict
from tqdm import tqdm
# Add project root to path to allow importing 'guided_diffusion'
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from guided_diffusion import dist_util, logger
from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset
from guided_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
def main():
args = create_argparser().parse_args()
dist_util.setup_dist()
local_rank = int(os.environ["LOCAL_RANK"])
th.cuda.set_device(local_rank)
if dist_util.get_rank() == 0:
os.makedirs(args.output_dir, exist_ok=True)
logger.configure(dir=args.output_dir)
logger.log("creating model and diffusion...")
# Create the CORRECT model structure (with out_channels=1)
model, diffusion, _, _ = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
logger.log(f"loading model from: {args.model_path}")
state_dict = dist_util.load_state_dict(args.model_path, map_location="cpu")
# --- FIX for loading DDP-trained model with mismatched output channels ---
# 1. Strip the "module." prefix that DDP adds during training.
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
name = k[7:] # remove `module.`
new_state_dict[name] = v
else:
new_state_dict[k] = v
# 2. Manually adapt the final layer from 2 channels (in checkpoint) to 1 channel (in current model).
if 'out.2.weight' in new_state_dict and new_state_dict['out.2.weight'].shape[0] == 2:
if dist_util.get_rank() == 0:
logger.log("Adapting checkpoint from 2 output channels to 1...")
new_state_dict['out.2.weight'] = new_state_dict['out.2.weight'][:1, ...] # Take the first channel's weights
new_state_dict['out.2.bias'] = new_state_dict['out.2.bias'][:1, ...] # Take the first channel's bias
model.load_state_dict(new_state_dict)
# --- END OF FIX ---
model.to(dist_util.dev())
model = DDP(model, device_ids=[local_rank])
if args.use_fp16:
model.convert_to_fp16()
model.eval()
logger.log("loading data...")
dataset = CustomLIDCDataset(
data_root=args.data_dir,
split="test",
image_size=args.image_size,
dataset_type=args.dataset_type,
split_strategy=args.split_strategy,
)
sampler = DistributedSampler(dataset, shuffle=False)
dataloader = th.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4)
samples_save_dir = os.path.join(args.output_dir, "samples")
if dist_util.get_rank() == 0:
os.makedirs(samples_save_dir, exist_ok=True)
logger.log(f"saving samples to {samples_save_dir}")
# Wrap the dataloader with tqdm on the main process (rank 0) for a clean progress bar
dataloader_tqdm = tqdm(dataloader, desc="Sampling Progress") if dist_util.get_rank() == 0 else dataloader
for batch, _, image_ids in dataloader_tqdm:
if not image_ids or not image_ids[0]:
continue
image_id = str(image_ids[0])
model_input = batch.to(dist_util.dev()).repeat_interleave(args.num_ensemble, dim=0)
noise_channel = th.randn_like(model_input[:, :1, ...])
img_and_noise = th.cat((model_input, noise_channel), dim=1)
sample_fn = diffusion.p_sample_loop_known
sample_tensor, _, _ = sample_fn(
model.module, # Use model.module to access the raw model inside the DDP wrapper
(model_input.shape[0], 1, args.image_size, args.image_size),
img_and_noise,
clip_denoised=args.clip_denoised,
model_kwargs={}
)
all_masks = [(s.squeeze().cpu().numpy() > 0.5).astype(np.uint8) for s in sample_tensor]
# All ranks save their own portion of the data
for j, mask in enumerate(all_masks):
Image.fromarray(mask * 255).save(os.path.join(samples_save_dir, f"{image_id}_sample_{j:02d}.png"))
# Only rank 0 saves the summary images to avoid file write conflicts
if dist_util.get_rank() == 0 and all_masks:
masks_stack = np.stack(all_masks, axis=0)
avg_map = np.mean(masks_stack, axis=0)
Image.fromarray((avg_map * 255).astype(np.uint8)).save(os.path.join(samples_save_dir, f"{image_id}_ensemble_avg.png"))
vote_mask = (np.sum(masks_stack, axis=0) > args.num_ensemble / 2).astype(np.uint8)
Image.fromarray(vote_mask * 255).save(os.path.join(samples_save_dir, f"{image_id}_majority_vote.png"))
dist_util.barrier() # Wait for all processes to finish before logging completion
logger.log("sampling complete.")
def create_argparser():
defaults = dict(
data_dir="./data/LIDC",
dataset_type="lidc",
split_strategy="all_annotations",
clip_denoised=True,
batch_size=1,
use_ddim=False,
model_path="",
output_dir="./results/samples/",
num_ensemble=16,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
|