| """ |
| Unified sampling script for baseline models. |
| |
| Generates N segmentation samples per test image, saved as PNGs |
| in the same format as CIMD for evaluation. |
| |
| Usage: |
| python sample.py --model prob_unet \ |
| --checkpoint ./results/prob_unet/best_model.pt \ |
| --data_dir /workspace/multiannotator_dataset \ |
| --num_samples 16 \ |
| --output_dir ./results/prob_unet/samples_n16 |
| """ |
| import argparse |
| import os |
| import random |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| from dataset import MultiAnnotatorDataset |
| from models import MODEL_REGISTRY |
|
|
|
|
| def set_seed(seed=42): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Sample from baseline models") |
| parser.add_argument("--model", type=str, required=True, |
| choices=list(MODEL_REGISTRY.keys())) |
| parser.add_argument("--checkpoint", type=str, required=True, |
| help="Path to trained model checkpoint") |
| parser.add_argument("--data_dir", type=str, |
| default="/workspace/multiannotator_dataset") |
| parser.add_argument("--output_dir", type=str, required=True, |
| help="Directory to save sample masks") |
| parser.add_argument("--num_samples", type=int, default=16) |
| parser.add_argument("--image_size", type=int, default=128) |
| parser.add_argument("--batch_size", type=int, default=1) |
| parser.add_argument("--latent_dim", type=int, default=6) |
| parser.add_argument("--base_ch", type=int, default=32) |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
|
|
| set_seed(args.seed) |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| |
| model_cls = MODEL_REGISTRY[args.model] |
| if args.model in ("phiseg", "gen_prob_unet"): |
| model = model_cls(in_ch=1, num_classes=1, latent_dim=2, |
| base_ch=args.base_ch) |
| else: |
| model = model_cls(in_ch=1, num_classes=1, latent_dim=args.latent_dim, |
| base_ch=args.base_ch) |
|
|
| |
| ckpt = torch.load(args.checkpoint, map_location=device) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model = model.to(device) |
| model.eval() |
| print(f"Loaded checkpoint from {args.checkpoint}") |
|
|
| |
| test_ds = MultiAnnotatorDataset( |
| data_root=args.data_dir, split="test", image_size=args.image_size) |
| test_loader = torch.utils.data.DataLoader( |
| test_ds, batch_size=args.batch_size, shuffle=False) |
|
|
| print(f"Test set: {len(test_ds)} images") |
| print(f"Sampling {args.num_samples} masks per image...") |
|
|
| samples_dir = os.path.join(args.output_dir, "samples") |
| os.makedirs(samples_dir, exist_ok=True) |
|
|
| for images, gt_masks, image_ids in tqdm(test_loader, desc="Sampling"): |
| images = images.to(device) |
|
|
| |
| with torch.no_grad(): |
| samples = model.sample(images, num_samples=args.num_samples) |
|
|
| for b in range(images.size(0)): |
| image_id = image_ids[b] |
| for j in range(args.num_samples): |
| mask = samples[b, j, 0].cpu().numpy().astype(np.uint8) * 255 |
| mask_img = Image.fromarray(mask) |
| fname = f"{image_id}_sample_{j:02d}.png" |
| mask_img.save(os.path.join(samples_dir, fname)) |
|
|
| print(f"Sampling complete. Saved to {samples_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|