""" Prediction script for deterministic baselines. Generates N=16 replicated sample predictions for evaluate.py compatibility. """ import argparse import os import sys import numpy as np import torch from torch.utils.data import DataLoader from PIL import Image from tqdm import tqdm sys.path.append(os.path.dirname(os.path.abspath(__file__))) from dataset import LIDCTestDataset from models import get_model NUM_SAMPLES = 16 # Match paper's N=16 @torch.no_grad() def predict(model, loader, output_dir, device, num_samples=NUM_SAMPLES): """Run inference and save predictions as N replicated samples.""" model.eval() os.makedirs(output_dir, exist_ok=True) total = 0 for images, sample_ids in tqdm(loader, desc="Predicting"): images = images.to(device) outputs = model(images) preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy().astype(np.uint8) * 255 for i in range(len(sample_ids)): sample_id = sample_ids[i] pred_mask = preds[i, 0] # [H, W] # Save N=16 identical copies as samples for s in range(num_samples): out_path = os.path.join(output_dir, f"{sample_id}_sample_{s}.png") Image.fromarray(pred_mask).save(out_path) total += 1 print(f"Saved {total} predictions × {num_samples} samples = {total * num_samples} files") return total def main(): parser = argparse.ArgumentParser(description="Run prediction for deterministic baselines") parser.add_argument("--model", type=str, required=True, choices=["unet", "attention_unet", "unetpp", "transunet", "nnunet"], help="Model architecture") parser.add_argument("--checkpoint", type=str, default=None, help="Path to checkpoint (default: checkpoints/{model}_best.pth)") parser.add_argument("--test_dir", type=str, default="data/flat_test", help="Path to flat test data") parser.add_argument("--output_dir", type=str, default=None, help="Output directory (default: results/{model})") parser.add_argument("--batch_size", type=int, default=64, help="Batch size") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader workers") parser.add_argument("--num_samples", type=int, default=NUM_SAMPLES, help="Number of sample replications") args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Defaults if args.checkpoint is None: args.checkpoint = os.path.join("checkpoints", f"{args.model}_best.pth") if args.output_dir is None: args.output_dir = os.path.join("results", args.model) # Load model model = get_model(args.model, in_channels=1, num_classes=1) checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) model = model.to(device) print(f"Loaded {args.model} from {args.checkpoint} (epoch {checkpoint['epoch']}, val_dice={checkpoint['val_dice']:.4f})") # Create dataset test_dataset = LIDCTestDataset(args.test_dir) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) # Run prediction n = predict(model, test_loader, args.output_dir, device, args.num_samples) print(f"\nPredictions saved to {args.output_dir}") print(f"Ready for evaluation:") print(f" python evaluate.py --samples_dir {args.output_dir} --gt_dir data/testing --results_file results/{args.model}_eval.csv") if __name__ == "__main__": main()