File size: 3,689 Bytes
f290261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Unified sampling script for baseline models.

Generates N segmentation samples per test image, saved as PNGs
in the same format as CIMD for evaluation.

Usage:
    python sample.py --model prob_unet \
        --checkpoint ./results/prob_unet/best_model.pt \
        --data_dir /workspace/multiannotator_dataset \
        --num_samples 16 \
        --output_dir ./results/prob_unet/samples_n16
"""
import argparse
import os
import random

import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

from dataset import MultiAnnotatorDataset
from models import MODEL_REGISTRY


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def main():
    parser = argparse.ArgumentParser(description="Sample from baseline models")
    parser.add_argument("--model", type=str, required=True,
                        choices=list(MODEL_REGISTRY.keys()))
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="Path to trained model checkpoint")
    parser.add_argument("--data_dir", type=str,
                        default="/workspace/multiannotator_dataset")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="Directory to save sample masks")
    parser.add_argument("--num_samples", type=int, default=16)
    parser.add_argument("--image_size", type=int, default=128)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--latent_dim", type=int, default=6)
    parser.add_argument("--base_ch", type=int, default=32)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    set_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # ---- Model ----
    model_cls = MODEL_REGISTRY[args.model]
    if args.model in ("phiseg", "gen_prob_unet"):
        model = model_cls(in_ch=1, num_classes=1, latent_dim=2,
                          base_ch=args.base_ch)
    else:
        model = model_cls(in_ch=1, num_classes=1, latent_dim=args.latent_dim,
                          base_ch=args.base_ch)

    # Load checkpoint
    ckpt = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(ckpt["model_state_dict"])
    model = model.to(device)
    model.eval()
    print(f"Loaded checkpoint from {args.checkpoint}")

    # ---- Dataset ----
    test_ds = MultiAnnotatorDataset(
        data_root=args.data_dir, split="test", image_size=args.image_size)
    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=args.batch_size, shuffle=False)

    print(f"Test set: {len(test_ds)} images")
    print(f"Sampling {args.num_samples} masks per image...")

    samples_dir = os.path.join(args.output_dir, "samples")
    os.makedirs(samples_dir, exist_ok=True)

    for images, gt_masks, image_ids in tqdm(test_loader, desc="Sampling"):
        images = images.to(device)

        # Sample N masks: returns [B, N, 1, H, W]
        with torch.no_grad():
            samples = model.sample(images, num_samples=args.num_samples)

        for b in range(images.size(0)):
            image_id = image_ids[b]
            for j in range(args.num_samples):
                mask = samples[b, j, 0].cpu().numpy().astype(np.uint8) * 255
                mask_img = Image.fromarray(mask)
                fname = f"{image_id}_sample_{j:02d}.png"
                mask_img.save(os.path.join(samples_dir, fname))

    print(f"Sampling complete. Saved to {samples_dir}")


if __name__ == "__main__":
    main()