import argparse import logging import os import shutil import sys from pathlib import Path import torch import yaml from monai.utils import set_determinism from src.data.data_loader import get_dataloader from src.model.cspca_model import CSPCAModel from src.model.mil import MILModel3D from src.train.train_cspca import train_epoch, val_epoch from src.utils import get_metrics, save_cspca_checkpoint, setup_logging def main_worker(args): mil_model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode) cache_dir_path = Path(os.path.join(args.logdir, "cache")) if args.mode == "train": checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu") mil_model.load_state_dict(checkpoint["state_dict"]) mil_model = mil_model.to(args.device) model_dir = os.path.join(args.logdir, "models") os.makedirs(model_dir, exist_ok=True) set_determinism(seed=42) train_loader = get_dataloader(args, split="train") valid_loader = get_dataloader(args, split="test") cspca_model = CSPCAModel(backbone=mil_model).to(args.device) for submodule in [ cspca_model.backbone.net, cspca_model.backbone.myfc, cspca_model.backbone.transformer, ]: for param in submodule.parameters(): param.requires_grad = False optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr ) old_loss = float("inf") for epoch in range(args.epochs): train_loss, train_auc = train_epoch( cspca_model, train_loader, optimizer, epoch=epoch, args=args ) logging.info(f"EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}") val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args) logging.info( f"EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}" ) if val_metric["loss"] < old_loss: old_loss = val_metric["loss"] save_cspca_checkpoint(cspca_model, val_metric, model_dir) args.checkpoint_cspca = os.path.join(model_dir, "cspca_model.pth") if cache_dir_path.exists() and cache_dir_path.is_dir(): shutil.rmtree(cache_dir_path) cspca_model = CSPCAModel(backbone=mil_model).to(args.device) checkpt = torch.load(args.checkpoint_cspca, map_location="cpu") cspca_model.load_state_dict(checkpt["state_dict"]) cspca_model = cspca_model.to(args.device) if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt: auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"] logging.info( f"csPCa Model loaded from {args.checkpoint_cspca} with AUC: {auc}, Sensitivity: {sens}, Specificity: {spec} on the test set." ) else: logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.") metrics_dict = {"auc": [], "sensitivity": [], "specificity": []} for st in list(range(args.num_seeds)): set_determinism(seed=st) test_loader = get_dataloader(args, split="test") test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args) metrics_dict["auc"].append(test_metric["auc"]) metrics_dict["sensitivity"].append(test_metric["sensitivity"]) metrics_dict["specificity"].append(test_metric["specificity"]) if cache_dir_path.exists() and cache_dir_path.is_dir(): shutil.rmtree(cache_dir_path) get_metrics(metrics_dict) def parse_args(): parser = argparse.ArgumentParser( description="Multiple Instance Learning (MIL) for csPCa risk prediction." ) parser.add_argument( "--mode", type=str, choices=["train", "test"], required=True, help="Operation mode: train or infer", ) parser.add_argument("--run_name", type=str, default="train_cspca", help="run name for log file") parser.add_argument("--config", type=str, help="Path to YAML config file") parser.add_argument("--project_dir", default=None, help="path to project firectory") parser.add_argument("--data_root", default=None, help="path to root folder of images") parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file") parser.add_argument("--num_classes", default=4, type=int, help="number of output classes") parser.add_argument( "--mil_mode", default="att_trans", help="MIL algorithm: choose either att_trans or att_pyramid", ) parser.add_argument( "--tile_count", default=24, type=int, help="number of patches (instances) to extract from MRI input", ) parser.add_argument( "--tile_size", default=64, type=int, help="size of square patch (instance) in pixels" ) parser.add_argument( "--depth", default=3, type=int, help="number of slices in each 3D patch (instance)" ) parser.add_argument( "--use_heatmap", action="store_true", help="enable weak attention heatmap guided patch generation", ) parser.add_argument( "--no_heatmap", dest="use_heatmap", action="store_false", help="disable heatmap" ) parser.set_defaults(use_heatmap=True) parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading") # parser.add_argument("--dry-run", action="store_true") parser.add_argument("--checkpoint_pirads", default=None, help="Load PI-RADS model") parser.add_argument( "--epochs", "--max_epochs", default=30, type=int, help="number of training epochs" ) parser.add_argument("--batch_size", default=32, type=int, help="number of MRI scans per batch") parser.add_argument("--optim_lr", default=2e-4, type=float, help="initial learning rate") # parser.add_argument("--amp", action="store_true", help="use AMP, recommended") parser.add_argument( "--val_every", "--val_interval", default=1, type=int, help="run validation after this number of epochs, default 1 to run every epoch", ) parser.add_argument("--checkpoint_cspca", default=None, help="load existing checkpoint") parser.add_argument( "--num_seeds", default=20, type=int, help="number of seeds to be run to build CI" ) args = parser.parse_args() if args.config: with open(args.config) as config_file: config = yaml.safe_load(config_file) args.__dict__.update(config) return args if __name__ == "__main__": args = parse_args() if args.project_dir is None: args.project_dir = Path(__file__).resolve().parent # Set project directory slurm_job_name = os.getenv( "SLURM_JOB_NAME" ) # If the script is submitted via slurm, job name is the run name if slurm_job_name: args.run_name = slurm_job_name args.logdir = os.path.join(args.project_dir, "logs", args.run_name) os.makedirs(args.logdir, exist_ok=True) args.logfile = os.path.join(args.logdir, f"{args.run_name}.log") setup_logging(args.logfile) logging.info("Argument values:") for k, v in vars(args).items(): logging.info(f"{k} => {v}") logging.info("-----------------") if args.dataset_json is None: logging.error("Dataset path not provided. Quitting.") sys.exit(1) if args.checkpoint_pirads is None and args.mode == "train": logging.error("PI-RADS checkpoint path not provided. Quitting.") sys.exit(1) elif args.checkpoint_cspca is None and args.mode == "test": logging.error("csPCa checkpoint path not provided. Quitting.") sys.exit(1) args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.device == torch.device("cuda"): torch.backends.cudnn.benchmark = True main_worker(args)