File size: 3,838 Bytes
4f926db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Prediction script for deterministic baselines.
Generates N=16 replicated sample predictions for evaluate.py compatibility.
"""
import argparse
import os
import sys
import numpy as np
import torch
from torch.utils.data import DataLoader
from PIL import Image
from tqdm import tqdm

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from dataset import LIDCTestDataset
from models import get_model


NUM_SAMPLES = 16  # Match paper's N=16


@torch.no_grad()
def predict(model, loader, output_dir, device, num_samples=NUM_SAMPLES):
    """Run inference and save predictions as N replicated samples."""
    model.eval()
    os.makedirs(output_dir, exist_ok=True)
    
    total = 0
    for images, sample_ids in tqdm(loader, desc="Predicting"):
        images = images.to(device)
        outputs = model(images)
        preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy().astype(np.uint8) * 255
        
        for i in range(len(sample_ids)):
            sample_id = sample_ids[i]
            pred_mask = preds[i, 0]  # [H, W]
            
            # Save N=16 identical copies as samples
            for s in range(num_samples):
                out_path = os.path.join(output_dir, f"{sample_id}_sample_{s}.png")
                Image.fromarray(pred_mask).save(out_path)
            
            total += 1
    
    print(f"Saved {total} predictions × {num_samples} samples = {total * num_samples} files")
    return total


def main():
    parser = argparse.ArgumentParser(description="Run prediction for deterministic baselines")
    parser.add_argument("--model", type=str, required=True,
                        choices=["unet", "attention_unet", "unetpp", "transunet", "nnunet"],
                        help="Model architecture")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Path to checkpoint (default: checkpoints/{model}_best.pth)")
    parser.add_argument("--test_dir", type=str, default="data/flat_test",
                        help="Path to flat test data")
    parser.add_argument("--output_dir", type=str, default=None,
                        help="Output directory (default: results/{model})")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--num_workers", type=int, default=4, help="DataLoader workers")
    parser.add_argument("--num_samples", type=int, default=NUM_SAMPLES, help="Number of sample replications")
    args = parser.parse_args()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Defaults
    if args.checkpoint is None:
        args.checkpoint = os.path.join("checkpoints", f"{args.model}_best.pth")
    if args.output_dir is None:
        args.output_dir = os.path.join("results", args.model)
    
    # Load model
    model = get_model(args.model, in_channels=1, num_classes=1)
    
    checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    
    print(f"Loaded {args.model} from {args.checkpoint} (epoch {checkpoint['epoch']}, val_dice={checkpoint['val_dice']:.4f})")
    
    # Create dataset
    test_dataset = LIDCTestDataset(args.test_dir)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                             shuffle=False, num_workers=args.num_workers, pin_memory=True)
    
    # Run prediction
    n = predict(model, test_loader, args.output_dir, device, args.num_samples)
    
    print(f"\nPredictions saved to {args.output_dir}")
    print(f"Ready for evaluation:")
    print(f"  python evaluate.py --samples_dir {args.output_dir} --gt_dir data/testing --results_file results/{args.model}_eval.csv")


if __name__ == "__main__":
    main()