File size: 3,689 Bytes
f290261 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | """
Unified sampling script for baseline models.
Generates N segmentation samples per test image, saved as PNGs
in the same format as CIMD for evaluation.
Usage:
python sample.py --model prob_unet \
--checkpoint ./results/prob_unet/best_model.pt \
--data_dir /workspace/multiannotator_dataset \
--num_samples 16 \
--output_dir ./results/prob_unet/samples_n16
"""
import argparse
import os
import random
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from dataset import MultiAnnotatorDataset
from models import MODEL_REGISTRY
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def main():
parser = argparse.ArgumentParser(description="Sample from baseline models")
parser.add_argument("--model", type=str, required=True,
choices=list(MODEL_REGISTRY.keys()))
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to trained model checkpoint")
parser.add_argument("--data_dir", type=str,
default="/workspace/multiannotator_dataset")
parser.add_argument("--output_dir", type=str, required=True,
help="Directory to save sample masks")
parser.add_argument("--num_samples", type=int, default=16)
parser.add_argument("--image_size", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--latent_dim", type=int, default=6)
parser.add_argument("--base_ch", type=int, default=32)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ---- Model ----
model_cls = MODEL_REGISTRY[args.model]
if args.model in ("phiseg", "gen_prob_unet"):
model = model_cls(in_ch=1, num_classes=1, latent_dim=2,
base_ch=args.base_ch)
else:
model = model_cls(in_ch=1, num_classes=1, latent_dim=args.latent_dim,
base_ch=args.base_ch)
# Load checkpoint
ckpt = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
model = model.to(device)
model.eval()
print(f"Loaded checkpoint from {args.checkpoint}")
# ---- Dataset ----
test_ds = MultiAnnotatorDataset(
data_root=args.data_dir, split="test", image_size=args.image_size)
test_loader = torch.utils.data.DataLoader(
test_ds, batch_size=args.batch_size, shuffle=False)
print(f"Test set: {len(test_ds)} images")
print(f"Sampling {args.num_samples} masks per image...")
samples_dir = os.path.join(args.output_dir, "samples")
os.makedirs(samples_dir, exist_ok=True)
for images, gt_masks, image_ids in tqdm(test_loader, desc="Sampling"):
images = images.to(device)
# Sample N masks: returns [B, N, 1, H, W]
with torch.no_grad():
samples = model.sample(images, num_samples=args.num_samples)
for b in range(images.size(0)):
image_id = image_ids[b]
for j in range(args.num_samples):
mask = samples[b, j, 0].cpu().numpy().astype(np.uint8) * 255
mask_img = Image.fromarray(mask)
fname = f"{image_id}_sample_{j:02d}.png"
mask_img.save(os.path.join(samples_dir, fname))
print(f"Sampling complete. Saved to {samples_dir}")
if __name__ == "__main__":
main()
|