import argparse import os import sys import torch as th from PIL import Image import numpy as np from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from collections import OrderedDict from tqdm import tqdm # Add project root to path to allow importing 'guided_diffusion' sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from guided_diffusion import dist_util, logger from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset from guided_diffusion.script_util import ( model_and_diffusion_defaults, create_model_and_diffusion, add_dict_to_argparser, args_to_dict, ) def main(): args = create_argparser().parse_args() dist_util.setup_dist() local_rank = int(os.environ["LOCAL_RANK"]) th.cuda.set_device(local_rank) if dist_util.get_rank() == 0: os.makedirs(args.output_dir, exist_ok=True) logger.configure(dir=args.output_dir) logger.log("creating model and diffusion...") # Create the CORRECT model structure (with out_channels=1) model, diffusion, _, _ = create_model_and_diffusion( **args_to_dict(args, model_and_diffusion_defaults().keys()) ) logger.log(f"loading model from: {args.model_path}") state_dict = dist_util.load_state_dict(args.model_path, map_location="cpu") # --- FIX for loading DDP-trained model with mismatched output channels --- # 1. Strip the "module." prefix that DDP adds during training. new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith('module.'): name = k[7:] # remove `module.` new_state_dict[name] = v else: new_state_dict[k] = v # 2. Manually adapt the final layer from 2 channels (in checkpoint) to 1 channel (in current model). if 'out.2.weight' in new_state_dict and new_state_dict['out.2.weight'].shape[0] == 2: if dist_util.get_rank() == 0: logger.log("Adapting checkpoint from 2 output channels to 1...") new_state_dict['out.2.weight'] = new_state_dict['out.2.weight'][:1, ...] # Take the first channel's weights new_state_dict['out.2.bias'] = new_state_dict['out.2.bias'][:1, ...] # Take the first channel's bias model.load_state_dict(new_state_dict) # --- END OF FIX --- model.to(dist_util.dev()) model = DDP(model, device_ids=[local_rank]) if args.use_fp16: model.convert_to_fp16() model.eval() logger.log("loading data...") dataset = CustomLIDCDataset( data_root=args.data_dir, split="test", image_size=args.image_size, dataset_type=args.dataset_type, split_strategy=args.split_strategy, ) sampler = DistributedSampler(dataset, shuffle=False) dataloader = th.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4) samples_save_dir = os.path.join(args.output_dir, "samples") if dist_util.get_rank() == 0: os.makedirs(samples_save_dir, exist_ok=True) logger.log(f"saving samples to {samples_save_dir}") # Wrap the dataloader with tqdm on the main process (rank 0) for a clean progress bar dataloader_tqdm = tqdm(dataloader, desc="Sampling Progress") if dist_util.get_rank() == 0 else dataloader for batch, _, image_ids in dataloader_tqdm: if not image_ids or not image_ids[0]: continue image_id = str(image_ids[0]) model_input = batch.to(dist_util.dev()).repeat_interleave(args.num_ensemble, dim=0) noise_channel = th.randn_like(model_input[:, :1, ...]) img_and_noise = th.cat((model_input, noise_channel), dim=1) sample_fn = diffusion.p_sample_loop_known sample_tensor, _, _ = sample_fn( model.module, # Use model.module to access the raw model inside the DDP wrapper (model_input.shape[0], 1, args.image_size, args.image_size), img_and_noise, clip_denoised=args.clip_denoised, model_kwargs={} ) all_masks = [(s.squeeze().cpu().numpy() > 0.5).astype(np.uint8) for s in sample_tensor] # All ranks save their own portion of the data for j, mask in enumerate(all_masks): Image.fromarray(mask * 255).save(os.path.join(samples_save_dir, f"{image_id}_sample_{j:02d}.png")) # Only rank 0 saves the summary images to avoid file write conflicts if dist_util.get_rank() == 0 and all_masks: masks_stack = np.stack(all_masks, axis=0) avg_map = np.mean(masks_stack, axis=0) Image.fromarray((avg_map * 255).astype(np.uint8)).save(os.path.join(samples_save_dir, f"{image_id}_ensemble_avg.png")) vote_mask = (np.sum(masks_stack, axis=0) > args.num_ensemble / 2).astype(np.uint8) Image.fromarray(vote_mask * 255).save(os.path.join(samples_save_dir, f"{image_id}_majority_vote.png")) dist_util.barrier() # Wait for all processes to finish before logging completion logger.log("sampling complete.") def create_argparser(): defaults = dict( data_dir="./data/LIDC", dataset_type="lidc", split_strategy="all_annotations", clip_denoised=True, batch_size=1, use_ddim=False, model_path="", output_dir="./results/samples/", num_ensemble=16, ) defaults.update(model_and_diffusion_defaults()) parser = argparse.ArgumentParser() add_dict_to_argparser(parser, defaults) return parser if __name__ == "__main__": main()