Upload baselines/predict.py with huggingface_hub
Browse files- baselines/predict.py +97 -0
baselines/predict.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prediction script for deterministic baselines.
|
| 3 |
+
Generates N=16 replicated sample predictions for evaluate.py compatibility.
|
| 4 |
+
"""
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
+
from dataset import LIDCTestDataset
|
| 16 |
+
from models import get_model
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
NUM_SAMPLES = 16 # Match paper's N=16
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@torch.no_grad()
|
| 23 |
+
def predict(model, loader, output_dir, device, num_samples=NUM_SAMPLES):
|
| 24 |
+
"""Run inference and save predictions as N replicated samples."""
|
| 25 |
+
model.eval()
|
| 26 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
total = 0
|
| 29 |
+
for images, sample_ids in tqdm(loader, desc="Predicting"):
|
| 30 |
+
images = images.to(device)
|
| 31 |
+
outputs = model(images)
|
| 32 |
+
preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy().astype(np.uint8) * 255
|
| 33 |
+
|
| 34 |
+
for i in range(len(sample_ids)):
|
| 35 |
+
sample_id = sample_ids[i]
|
| 36 |
+
pred_mask = preds[i, 0] # [H, W]
|
| 37 |
+
|
| 38 |
+
# Save N=16 identical copies as samples
|
| 39 |
+
for s in range(num_samples):
|
| 40 |
+
out_path = os.path.join(output_dir, f"{sample_id}_sample_{s}.png")
|
| 41 |
+
Image.fromarray(pred_mask).save(out_path)
|
| 42 |
+
|
| 43 |
+
total += 1
|
| 44 |
+
|
| 45 |
+
print(f"Saved {total} predictions × {num_samples} samples = {total * num_samples} files")
|
| 46 |
+
return total
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def main():
|
| 50 |
+
parser = argparse.ArgumentParser(description="Run prediction for deterministic baselines")
|
| 51 |
+
parser.add_argument("--model", type=str, required=True,
|
| 52 |
+
choices=["unet", "attention_unet", "unetpp", "transunet", "nnunet"],
|
| 53 |
+
help="Model architecture")
|
| 54 |
+
parser.add_argument("--checkpoint", type=str, default=None,
|
| 55 |
+
help="Path to checkpoint (default: checkpoints/{model}_best.pth)")
|
| 56 |
+
parser.add_argument("--test_dir", type=str, default="data/flat_test",
|
| 57 |
+
help="Path to flat test data")
|
| 58 |
+
parser.add_argument("--output_dir", type=str, default=None,
|
| 59 |
+
help="Output directory (default: results/{model})")
|
| 60 |
+
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
|
| 61 |
+
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader workers")
|
| 62 |
+
parser.add_argument("--num_samples", type=int, default=NUM_SAMPLES, help="Number of sample replications")
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
|
| 65 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 66 |
+
print(f"Using device: {device}")
|
| 67 |
+
|
| 68 |
+
# Defaults
|
| 69 |
+
if args.checkpoint is None:
|
| 70 |
+
args.checkpoint = os.path.join("checkpoints", f"{args.model}_best.pth")
|
| 71 |
+
if args.output_dir is None:
|
| 72 |
+
args.output_dir = os.path.join("results", args.model)
|
| 73 |
+
|
| 74 |
+
# Load model
|
| 75 |
+
model = get_model(args.model, in_channels=1, num_classes=1)
|
| 76 |
+
|
| 77 |
+
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
|
| 78 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 79 |
+
model = model.to(device)
|
| 80 |
+
|
| 81 |
+
print(f"Loaded {args.model} from {args.checkpoint} (epoch {checkpoint['epoch']}, val_dice={checkpoint['val_dice']:.4f})")
|
| 82 |
+
|
| 83 |
+
# Create dataset
|
| 84 |
+
test_dataset = LIDCTestDataset(args.test_dir)
|
| 85 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
|
| 86 |
+
shuffle=False, num_workers=args.num_workers, pin_memory=True)
|
| 87 |
+
|
| 88 |
+
# Run prediction
|
| 89 |
+
n = predict(model, test_loader, args.output_dir, device, args.num_samples)
|
| 90 |
+
|
| 91 |
+
print(f"\nPredictions saved to {args.output_dir}")
|
| 92 |
+
print(f"Ready for evaluation:")
|
| 93 |
+
print(f" python evaluate.py --samples_dir {args.output_dir} --gt_dir data/testing --results_file results/{args.model}_eval.csv")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
main()
|