import argparse import logging import os import shutil import sys import time from pathlib import Path import numpy as np import torch import wandb import yaml from monai.utils import set_determinism from torch.utils.tensorboard import SummaryWriter from src.data.data_loader import get_dataloader from src.model.mil import MILModel3D from src.train.train_pirads import train_epoch, val_epoch from src.utils import save_pirads_checkpoint, setup_logging def main_worker(args): if args.device == torch.device("cuda"): torch.backends.cudnn.benchmark = True model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode) start_epoch = 0 best_acc = 0.0 if args.checkpoint is not None: checkpoint = torch.load(args.checkpoint, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) if "epoch" in checkpoint: start_epoch = checkpoint["epoch"] if "best_acc" in checkpoint: best_acc = checkpoint["best_acc"] logging.info( "=> loaded checkpoint %s (epoch %d) (bestacc %f)", args.checkpoint, start_epoch, best_acc, ) cache_dir_ = os.path.join(args.logdir, "cache") model.to(args.device) params = model.parameters() if args.mode == "train": train_loader = get_dataloader(args, split="train") valid_loader = get_dataloader(args, split="test") logging.info( f"Dataset training: {len(train_loader.dataset)}, test: {len(valid_loader.dataset)}" ) if args.mil_mode in ["att_trans", "att_trans_pyramid"]: params = [ { "params": list(model.attention.parameters()) + list(model.myfc.parameters()) + list(model.net.parameters()) }, {"params": list(model.transformer.parameters()), "lr": 6e-5, "weight_decay": 0.1}, ] optimizer = torch.optim.AdamW(params, lr=args.optim_lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs, eta_min=0 ) scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp) if args.logdir is not None: writer = SummaryWriter(log_dir=args.logdir) logging.info(f"Writing Tensorboard logs to {writer.log_dir}") else: writer = None # RUN TRAINING n_epochs = args.epochs val_loss_min = float("inf") epochs_no_improve = 0 for epoch in range(start_epoch, n_epochs): logging.info(f"{time.ctime()} | Epoch: {epoch}") epoch_time = time.time() train_loss, train_acc, train_att_loss, batch_norm = train_epoch( model, train_loader, optimizer, scaler=scaler, epoch=epoch, args=args ) logging.info( "Final training %d/%d loss: %.4f attention loss: %.4f acc: %.4f time %.2fs", epoch, n_epochs - 1, train_loss, train_att_loss, train_acc, time.time() - epoch_time, ) if writer is not None: writer.add_scalar("train_loss", train_loss, epoch) writer.add_scalar("train_attention_loss", train_att_loss, epoch) writer.add_scalar("train_acc", train_acc, epoch) wandb.log( { "Train Loss": train_loss, "Train Accuracy": train_acc, "Train Attention Loss": train_att_loss, "Batch Norm": batch_norm, }, step=epoch, ) model_new_best = False val_acc = 0 if (epoch + 1) % args.val_every == 0: epoch_time = time.time() val_loss, val_acc, qwk = val_epoch(model, valid_loader, epoch=epoch, args=args) logging.info( "Final test %d/%d loss: %.4f acc: %.4f qwk: %.4f time %.2fs", epoch, n_epochs - 1, val_loss, val_acc, qwk, time.time() - epoch_time, ) if writer is not None: writer.add_scalar("test_loss", val_loss, epoch) writer.add_scalar("test_acc", val_acc, epoch) writer.add_scalar("test_qwk", qwk, epoch) # val_acc = qwk wandb.log( {"Test Loss": val_loss, "Test Accuracy": val_acc, "Cohen Kappa": qwk}, step=epoch, ) if val_loss < val_loss_min: logging.info("Loss (%.6f --> %.6f)", val_loss_min, val_loss) val_loss_min = val_loss model_new_best = True if args.logdir is not None: save_pirads_checkpoint( model, epoch, args, best_acc=val_acc, filename=f"model_{epoch}.pt" ) if model_new_best: logging.info("Copying to model.pt new best model") shutil.copyfile( os.path.join(args.logdir, f"model_{epoch}.pt"), os.path.join(args.logdir, "model.pt"), ) epochs_no_improve = 0 else: epochs_no_improve += 1 if epochs_no_improve == args.early_stop: logging.info("Early stopping!") break scheduler.step() logging.info("ALL DONE") elif args.mode == "test": kappa_list = [] for seed in list(range(args.num_seeds)): set_determinism(seed=seed) valid_loader = get_dataloader(args, split=args.mode) logging.info("test:", str(len(valid_loader.dataset))) val_loss, val_acc, qwk = val_epoch(model, valid_loader, epoch=0, args=args) kappa_list.append(qwk) logging.info(f"Seed {seed}, QWK: {qwk}") if os.path.exists(cache_dir_): logging.info(f"Removing cache directory {cache_dir_}") shutil.rmtree(cache_dir_) logging.info(f"Mean QWK over {args.num_seeds} seeds: {np.mean(kappa_list)}") if os.path.exists(cache_dir_): logging.info(f"Removing cache directory {cache_dir_}") shutil.rmtree(cache_dir_) def parse_args(): parser = argparse.ArgumentParser( description="Multiple Instance Learning (MIL) for PIRADS Classification." ) parser.add_argument( "--mode", type=str, choices=["train", "test"], required=True, help="operation mode: train or infer", ) parser.add_argument( "--wandb", action="store_true", help="Add this flag to enable WandB logging" ) parser.add_argument( "--project_name", type=str, default="Classification_prostate", help="WandB project name" ) parser.add_argument( "--run_name", type=str, default="train_pirads", help="run name for WandB logging" ) parser.add_argument("--config", type=str, help="path to YAML config file") parser.add_argument("--project_dir", default=None, help="path to project firectory") parser.add_argument("--data_root", default=None, help="path to root folder of images") parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file") parser.add_argument("--num_classes", default=4, type=int, help="number of output classes") parser.add_argument( "--mil_mode", default="att_trans", help="MIL algorithm: choose either att_trans or att_pyramid", ) parser.add_argument( "--tile_count", default=24, type=int, help="number of patches (instances) to extract from MRI input", ) parser.add_argument( "--tile_size", default=64, type=int, help="size of square patch (instance) in pixels" ) parser.add_argument( "--depth", default=3, type=int, help="number of slices in each 3D patch (instance)" ) parser.add_argument( "--use_heatmap", action="store_true", help="enable weak attention heatmap guided patch generation", ) parser.add_argument( "--no_heatmap", dest="use_heatmap", action="store_false", help="disable heatmap" ) parser.set_defaults(use_heatmap=True) parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading") parser.add_argument("--checkpoint", default=None, help="load existing checkpoint") parser.add_argument( "--epochs", "--max_epochs", default=50, type=int, help="number of training epochs" ) parser.add_argument("--early_stop", default=40, type=int, help="early stopping criteria") parser.add_argument("--batch_size", default=4, type=int, help="number of MRI scans per batch") parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate") parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay") parser.add_argument("--amp", action="store_true", help="use AMP, recommended") parser.add_argument( "--val_every", "--val_interval", default=1, type=int, help="run validation after this number of epochs, default 1 to run every epoch", ) args = parser.parse_args() if args.config: with open(args.config) as config_file: config = yaml.safe_load(config_file) args.__dict__.update(config) return args if __name__ == "__main__": args = parse_args() if args.project_dir is None: args.project_dir = Path(__file__).resolve().parent # Set project directory slurm_job_name = os.getenv( "SLURM_JOB_NAME" ) # If the script is submitted via slurm, job name is the run name if slurm_job_name: args.run_name = slurm_job_name args.logdir = os.path.join(args.project_dir, "logs", args.run_name) os.makedirs(args.logdir, exist_ok=True) args.logfile = os.path.join(args.logdir, f"{args.run_name}.log") setup_logging(args.logfile) logging.info("Argument values:") for k, v in vars(args).items(): logging.info(f"{k} => {v}") logging.info("-----------------") args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.device == torch.device("cpu"): args.amp = False if args.dataset_json is None: logging.error("Dataset JSON file not provided. Quitting.") sys.exit(1) if args.checkpoint is None and args.mode == "test": logging.error("Model checkpoint path not provided. Quitting.") sys.exit(1) mode_wandb = "online" if args.wandb and args.mode != "test" else "disabled" config_wandb = { "learning_rate": args.optim_lr, "batch_size": args.batch_size, "epochs": args.epochs, "patch size": args.tile_size, "patch count": args.tile_count, } wandb.init( project=args.project_name, name=args.run_name, dir=os.path.join(args.logdir, "wandb"), config=config_wandb, mode=mode_wandb, ) main_worker(args) wandb.finish()