Prostate-Inference / run_cspca.py
Anirudh Balaraman
dry_run fix
bc6f900
import argparse
import logging
import os
import shutil
import sys
from pathlib import Path
import torch
import yaml
from monai.utils import set_determinism
from src.data.data_loader import get_dataloader
from src.model.cspca_model import CSPCAModel
from src.model.mil import MILModel3D
from src.train.train_cspca import train_epoch, val_epoch
from src.utils import get_metrics, save_cspca_checkpoint, setup_logging
def main_worker(args):
mil_model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
if args.mode == "train":
checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
mil_model.load_state_dict(checkpoint["state_dict"])
mil_model = mil_model.to(args.device)
model_dir = os.path.join(args.logdir, "models")
os.makedirs(model_dir, exist_ok=True)
set_determinism(seed=42)
train_loader = get_dataloader(args, split="train")
valid_loader = get_dataloader(args, split="test")
cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
for submodule in [
cspca_model.backbone.net,
cspca_model.backbone.myfc,
cspca_model.backbone.transformer,
]:
for param in submodule.parameters():
param.requires_grad = False
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr
)
old_loss = float("inf")
for epoch in range(args.epochs):
train_loss, train_auc = train_epoch(
cspca_model, train_loader, optimizer, epoch=epoch, args=args
)
logging.info(f"EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}")
val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
logging.info(
f"EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
)
if val_metric["loss"] < old_loss:
old_loss = val_metric["loss"]
save_cspca_checkpoint(cspca_model, val_metric, model_dir)
args.checkpoint_cspca = os.path.join(model_dir, "cspca_model.pth")
if cache_dir_path.exists() and cache_dir_path.is_dir():
shutil.rmtree(cache_dir_path)
cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
cspca_model.load_state_dict(checkpt["state_dict"])
cspca_model = cspca_model.to(args.device)
if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
logging.info(
f"csPCa Model loaded from {args.checkpoint_cspca} with AUC: {auc}, Sensitivity: {sens}, Specificity: {spec} on the test set."
)
else:
logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.")
metrics_dict = {"auc": [], "sensitivity": [], "specificity": []}
for st in list(range(args.num_seeds)):
set_determinism(seed=st)
test_loader = get_dataloader(args, split="test")
test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args)
metrics_dict["auc"].append(test_metric["auc"])
metrics_dict["sensitivity"].append(test_metric["sensitivity"])
metrics_dict["specificity"].append(test_metric["specificity"])
if cache_dir_path.exists() and cache_dir_path.is_dir():
shutil.rmtree(cache_dir_path)
get_metrics(metrics_dict)
def parse_args():
parser = argparse.ArgumentParser(
description="Multiple Instance Learning (MIL) for csPCa risk prediction."
)
parser.add_argument(
"--mode",
type=str,
choices=["train", "test"],
required=True,
help="Operation mode: train or infer",
)
parser.add_argument("--run_name", type=str, default="train_cspca", help="run name for log file")
parser.add_argument("--config", type=str, help="Path to YAML config file")
parser.add_argument("--project_dir", default=None, help="path to project firectory")
parser.add_argument("--data_root", default=None, help="path to root folder of images")
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
parser.add_argument(
"--mil_mode",
default="att_trans",
help="MIL algorithm: choose either att_trans or att_pyramid",
)
parser.add_argument(
"--tile_count",
default=24,
type=int,
help="number of patches (instances) to extract from MRI input",
)
parser.add_argument(
"--tile_size", default=64, type=int, help="size of square patch (instance) in pixels"
)
parser.add_argument(
"--depth", default=3, type=int, help="number of slices in each 3D patch (instance)"
)
parser.add_argument(
"--use_heatmap",
action="store_true",
help="enable weak attention heatmap guided patch generation",
)
parser.add_argument(
"--no_heatmap", dest="use_heatmap", action="store_false", help="disable heatmap"
)
parser.set_defaults(use_heatmap=True)
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
# parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--checkpoint_pirads", default=None, help="Load PI-RADS model")
parser.add_argument(
"--epochs", "--max_epochs", default=30, type=int, help="number of training epochs"
)
parser.add_argument("--batch_size", default=32, type=int, help="number of MRI scans per batch")
parser.add_argument("--optim_lr", default=2e-4, type=float, help="initial learning rate")
# parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
parser.add_argument(
"--val_every",
"--val_interval",
default=1,
type=int,
help="run validation after this number of epochs, default 1 to run every epoch",
)
parser.add_argument("--checkpoint_cspca", default=None, help="load existing checkpoint")
parser.add_argument(
"--num_seeds", default=20, type=int, help="number of seeds to be run to build CI"
)
args = parser.parse_args()
if args.config:
with open(args.config) as config_file:
config = yaml.safe_load(config_file)
args.__dict__.update(config)
return args
if __name__ == "__main__":
args = parse_args()
if args.project_dir is None:
args.project_dir = Path(__file__).resolve().parent # Set project directory
slurm_job_name = os.getenv(
"SLURM_JOB_NAME"
) # If the script is submitted via slurm, job name is the run name
if slurm_job_name:
args.run_name = slurm_job_name
args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
os.makedirs(args.logdir, exist_ok=True)
args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
setup_logging(args.logfile)
logging.info("Argument values:")
for k, v in vars(args).items():
logging.info(f"{k} => {v}")
logging.info("-----------------")
if args.dataset_json is None:
logging.error("Dataset path not provided. Quitting.")
sys.exit(1)
if args.checkpoint_pirads is None and args.mode == "train":
logging.error("PI-RADS checkpoint path not provided. Quitting.")
sys.exit(1)
elif args.checkpoint_cspca is None and args.mode == "test":
logging.error("csPCa checkpoint path not provided. Quitting.")
sys.exit(1)
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.device == torch.device("cuda"):
torch.backends.cudnn.benchmark = True
main_worker(args)