Ambiguous_BaselineRuns / cimd /scripts /segmentation_sample.py
siddharthdhara17's picture
Upload folder using huggingface_hub
457db56 verified
"""
Generate and save an ensemble of image samples from a trained CIMD model.
"""
import argparse
import os
import random
import sys
import numpy as np
import torch
import torch as th
from PIL import Image
from tqdm import tqdm
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from guided_diffusion import dist_util, logger
from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset
from guided_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
# Set a seed for reproducibility
seed = 42
th.manual_seed(seed)
th.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def main():
args = create_argparser().parse_args()
dist_util.setup_dist()
logger.configure(dir=args.output_dir)
logger.log("creating model and diffusion...")
model, diffusion, _, _ = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
logger.log(f"loading model from: {args.model_path}")
state_dict = dist_util.load_state_dict(args.model_path, map_location="cpu")
# This logic was in your original sample script, keeping it in case it's needed for older checkpoints
if "train_2023-05-24-16-53-06" in args.model_path:
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(dist_util.dev())
if args.use_fp16:
model.convert_to_fp16()
model.eval()
logger.log("loading data...")
ds = CustomLIDCDataset(
data_root=args.data_dir,
split="test",
image_size=args.image_size,
dataset_type=args.dataset_type,
split_strategy=args.split_strategy,
)
datal = th.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=False)
logger.log(f"found {len(ds)} images for sampling...")
# Create the directory for samples if it doesn't exist
samples_save_dir = os.path.join(args.output_dir, "samples")
os.makedirs(samples_save_dir, exist_ok=True)
logger.log(f"saving samples to {samples_save_dir}")
for i, (batch, label_gts, image_ids) in enumerate(
tqdm(datal, desc="Sampling", total=len(datal))
):
image_id = str(image_ids[0]) if image_ids else f"image_{i:04d}"
logger.log(f"sampling for {image_id}...")
# The model expects the input image to be replicated for each ensemble member
model_input = batch.repeat_interleave(args.num_ensemble, dim=0)
# Create a random noise channel to start the diffusion process
noise_channel = th.randn_like(model_input[:, :1, ...])
# The input to the sampling function is (image, noise)
img_and_noise = th.cat((model_input, noise_channel), dim=1)
model_kwargs = {}
sample_fn = (
diffusion.p_sample_loop_known if not args.use_ddim else diffusion.ddim_sample_loop_known
)
# Generate the ensemble of samples
sample, _, _ = sample_fn(
model,
(model_input.shape[0], 1, args.image_size, args.image_size),
img_and_noise,
clip_denoised=args.clip_denoised,
model_kwargs=model_kwargs,
)
# Process and save each sample in the ensemble
for j in range(args.num_ensemble):
# Isolate one sample
single_sample = sample[j].squeeze()
# Normalize to 0-1 range and threshold to get a binary mask
single_sample = (single_sample - single_sample.min()) / (single_sample.max() - single_sample.min())
binary_mask = (single_sample > 0.5).cpu().numpy().astype(np.uint8)
# Convert to image format (0-255)
mask_image = Image.fromarray(binary_mask * 255)
# Save the image with a clear naming convention
output_filename = f"{image_id}_sample_{j:02d}.png"
mask_image.save(os.path.join(samples_save_dir, output_filename))
logger.log("sampling complete.")
def create_argparser():
defaults = dict(
data_dir="/path/to/your/data/test", # e.g., /.../lidc_npy/test
dataset_type="lidc",
split_strategy="all_annotations",
clip_denoised=True,
batch_size=1, # Process one image at a time
use_ddim=False,
model_path="", # IMPORTANT: Path to your trained .pt model
output_dir="./results/samples/", # Directory to save results and samples
num_ensemble=16, # Number of samples to generate per image
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()