| """ |
| Prediction script for deterministic baselines. |
| Generates N=16 replicated sample predictions for evaluate.py compatibility. |
| """ |
| import argparse |
| import os |
| import sys |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
| from dataset import LIDCTestDataset |
| from models import get_model |
|
|
|
|
| NUM_SAMPLES = 16 |
|
|
|
|
| @torch.no_grad() |
| def predict(model, loader, output_dir, device, num_samples=NUM_SAMPLES): |
| """Run inference and save predictions as N replicated samples.""" |
| model.eval() |
| os.makedirs(output_dir, exist_ok=True) |
| |
| total = 0 |
| for images, sample_ids in tqdm(loader, desc="Predicting"): |
| images = images.to(device) |
| outputs = model(images) |
| preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy().astype(np.uint8) * 255 |
| |
| for i in range(len(sample_ids)): |
| sample_id = sample_ids[i] |
| pred_mask = preds[i, 0] |
| |
| |
| for s in range(num_samples): |
| out_path = os.path.join(output_dir, f"{sample_id}_sample_{s}.png") |
| Image.fromarray(pred_mask).save(out_path) |
| |
| total += 1 |
| |
| print(f"Saved {total} predictions × {num_samples} samples = {total * num_samples} files") |
| return total |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Run prediction for deterministic baselines") |
| parser.add_argument("--model", type=str, required=True, |
| choices=["unet", "attention_unet", "unetpp", "transunet", "nnunet"], |
| help="Model architecture") |
| parser.add_argument("--checkpoint", type=str, default=None, |
| help="Path to checkpoint (default: checkpoints/{model}_best.pth)") |
| parser.add_argument("--test_dir", type=str, default="data/flat_test", |
| help="Path to flat test data") |
| parser.add_argument("--output_dir", type=str, default=None, |
| help="Output directory (default: results/{model})") |
| parser.add_argument("--batch_size", type=int, default=64, help="Batch size") |
| parser.add_argument("--num_workers", type=int, default=4, help="DataLoader workers") |
| parser.add_argument("--num_samples", type=int, default=NUM_SAMPLES, help="Number of sample replications") |
| args = parser.parse_args() |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| |
| if args.checkpoint is None: |
| args.checkpoint = os.path.join("checkpoints", f"{args.model}_best.pth") |
| if args.output_dir is None: |
| args.output_dir = os.path.join("results", args.model) |
| |
| |
| model = get_model(args.model, in_channels=1, num_classes=1) |
| |
| checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model = model.to(device) |
| |
| print(f"Loaded {args.model} from {args.checkpoint} (epoch {checkpoint['epoch']}, val_dice={checkpoint['val_dice']:.4f})") |
| |
| |
| test_dataset = LIDCTestDataset(args.test_dir) |
| test_loader = DataLoader(test_dataset, batch_size=args.batch_size, |
| shuffle=False, num_workers=args.num_workers, pin_memory=True) |
| |
| |
| n = predict(model, test_loader, args.output_dir, device, args.num_samples) |
| |
| print(f"\nPredictions saved to {args.output_dir}") |
| print(f"Ready for evaluation:") |
| print(f" python evaluate.py --samples_dir {args.output_dir} --gt_dir data/testing --results_file results/{args.model}_eval.csv") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|