siddharthdhara17 commited on
Commit
4f926db
·
verified ·
1 Parent(s): 04fefb0

Upload baselines/predict.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. baselines/predict.py +97 -0
baselines/predict.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prediction script for deterministic baselines.
3
+ Generates N=16 replicated sample predictions for evaluate.py compatibility.
4
+ """
5
+ import argparse
6
+ import os
7
+ import sys
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+
14
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
15
+ from dataset import LIDCTestDataset
16
+ from models import get_model
17
+
18
+
19
+ NUM_SAMPLES = 16 # Match paper's N=16
20
+
21
+
22
+ @torch.no_grad()
23
+ def predict(model, loader, output_dir, device, num_samples=NUM_SAMPLES):
24
+ """Run inference and save predictions as N replicated samples."""
25
+ model.eval()
26
+ os.makedirs(output_dir, exist_ok=True)
27
+
28
+ total = 0
29
+ for images, sample_ids in tqdm(loader, desc="Predicting"):
30
+ images = images.to(device)
31
+ outputs = model(images)
32
+ preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy().astype(np.uint8) * 255
33
+
34
+ for i in range(len(sample_ids)):
35
+ sample_id = sample_ids[i]
36
+ pred_mask = preds[i, 0] # [H, W]
37
+
38
+ # Save N=16 identical copies as samples
39
+ for s in range(num_samples):
40
+ out_path = os.path.join(output_dir, f"{sample_id}_sample_{s}.png")
41
+ Image.fromarray(pred_mask).save(out_path)
42
+
43
+ total += 1
44
+
45
+ print(f"Saved {total} predictions × {num_samples} samples = {total * num_samples} files")
46
+ return total
47
+
48
+
49
+ def main():
50
+ parser = argparse.ArgumentParser(description="Run prediction for deterministic baselines")
51
+ parser.add_argument("--model", type=str, required=True,
52
+ choices=["unet", "attention_unet", "unetpp", "transunet", "nnunet"],
53
+ help="Model architecture")
54
+ parser.add_argument("--checkpoint", type=str, default=None,
55
+ help="Path to checkpoint (default: checkpoints/{model}_best.pth)")
56
+ parser.add_argument("--test_dir", type=str, default="data/flat_test",
57
+ help="Path to flat test data")
58
+ parser.add_argument("--output_dir", type=str, default=None,
59
+ help="Output directory (default: results/{model})")
60
+ parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
61
+ parser.add_argument("--num_workers", type=int, default=4, help="DataLoader workers")
62
+ parser.add_argument("--num_samples", type=int, default=NUM_SAMPLES, help="Number of sample replications")
63
+ args = parser.parse_args()
64
+
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ print(f"Using device: {device}")
67
+
68
+ # Defaults
69
+ if args.checkpoint is None:
70
+ args.checkpoint = os.path.join("checkpoints", f"{args.model}_best.pth")
71
+ if args.output_dir is None:
72
+ args.output_dir = os.path.join("results", args.model)
73
+
74
+ # Load model
75
+ model = get_model(args.model, in_channels=1, num_classes=1)
76
+
77
+ checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
78
+ model.load_state_dict(checkpoint["model_state_dict"])
79
+ model = model.to(device)
80
+
81
+ print(f"Loaded {args.model} from {args.checkpoint} (epoch {checkpoint['epoch']}, val_dice={checkpoint['val_dice']:.4f})")
82
+
83
+ # Create dataset
84
+ test_dataset = LIDCTestDataset(args.test_dir)
85
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
86
+ shuffle=False, num_workers=args.num_workers, pin_memory=True)
87
+
88
+ # Run prediction
89
+ n = predict(model, test_loader, args.output_dir, device, args.num_samples)
90
+
91
+ print(f"\nPredictions saved to {args.output_dir}")
92
+ print(f"Ready for evaluation:")
93
+ print(f" python evaluate.py --samples_dir {args.output_dir} --gt_dir data/testing --results_file results/{args.model}_eval.csv")
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()