""" Unified sampling script for baseline models. Generates N segmentation samples per test image, saved as PNGs in the same format as CIMD for evaluation. Usage: python sample.py --model prob_unet \ --checkpoint ./results/prob_unet/best_model.pt \ --data_dir /workspace/multiannotator_dataset \ --num_samples 16 \ --output_dir ./results/prob_unet/samples_n16 """ import argparse import os import random import numpy as np import torch from PIL import Image from tqdm import tqdm from dataset import MultiAnnotatorDataset from models import MODEL_REGISTRY def set_seed(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def main(): parser = argparse.ArgumentParser(description="Sample from baseline models") parser.add_argument("--model", type=str, required=True, choices=list(MODEL_REGISTRY.keys())) parser.add_argument("--checkpoint", type=str, required=True, help="Path to trained model checkpoint") parser.add_argument("--data_dir", type=str, default="/workspace/multiannotator_dataset") parser.add_argument("--output_dir", type=str, required=True, help="Directory to save sample masks") parser.add_argument("--num_samples", type=int, default=16) parser.add_argument("--image_size", type=int, default=128) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--latent_dim", type=int, default=6) parser.add_argument("--base_ch", type=int, default=32) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() set_seed(args.seed) os.makedirs(args.output_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # ---- Model ---- model_cls = MODEL_REGISTRY[args.model] if args.model in ("phiseg", "gen_prob_unet"): model = model_cls(in_ch=1, num_classes=1, latent_dim=2, base_ch=args.base_ch) else: model = model_cls(in_ch=1, num_classes=1, latent_dim=args.latent_dim, base_ch=args.base_ch) # Load checkpoint ckpt = torch.load(args.checkpoint, map_location=device) model.load_state_dict(ckpt["model_state_dict"]) model = model.to(device) model.eval() print(f"Loaded checkpoint from {args.checkpoint}") # ---- Dataset ---- test_ds = MultiAnnotatorDataset( data_root=args.data_dir, split="test", image_size=args.image_size) test_loader = torch.utils.data.DataLoader( test_ds, batch_size=args.batch_size, shuffle=False) print(f"Test set: {len(test_ds)} images") print(f"Sampling {args.num_samples} masks per image...") samples_dir = os.path.join(args.output_dir, "samples") os.makedirs(samples_dir, exist_ok=True) for images, gt_masks, image_ids in tqdm(test_loader, desc="Sampling"): images = images.to(device) # Sample N masks: returns [B, N, 1, H, W] with torch.no_grad(): samples = model.sample(images, num_samples=args.num_samples) for b in range(images.size(0)): image_id = image_ids[b] for j in range(args.num_samples): mask = samples[b, j, 0].cpu().numpy().astype(np.uint8) * 255 mask_img = Image.fromarray(mask) fname = f"{image_id}_sample_{j:02d}.png" mask_img.save(os.path.join(samples_dir, fname)) print(f"Sampling complete. Saved to {samples_dir}") if __name__ == "__main__": main()