Spaces:
Sleeping
Sleeping
Anirudh Balaraman commited on
Commit ·
c769d20
1
Parent(s): 99c1404
fix dry run
Browse files- run_cspca.py +15 -10
- run_pirads.py +5 -4
run_cspca.py
CHANGED
|
@@ -21,8 +21,9 @@ def main_worker(args):
|
|
| 21 |
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 22 |
|
| 23 |
if args.mode == "train":
|
| 24 |
-
|
| 25 |
-
|
|
|
|
| 26 |
mil_model = mil_model.to(args.device)
|
| 27 |
|
| 28 |
model_dir = os.path.join(args.logdir, "models")
|
|
@@ -65,8 +66,11 @@ def main_worker(args):
|
|
| 65 |
|
| 66 |
|
| 67 |
cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
cspca_model = cspca_model.to(args.device)
|
| 71 |
if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
|
| 72 |
auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
|
|
@@ -191,12 +195,13 @@ if __name__ == "__main__":
|
|
| 191 |
if args.dataset_json is None:
|
| 192 |
logging.error("Dataset path not provided. Quitting.")
|
| 193 |
sys.exit(1)
|
| 194 |
-
if
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
| 200 |
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 201 |
|
| 202 |
if args.device == torch.device("cuda"):
|
|
|
|
| 21 |
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 22 |
|
| 23 |
if args.mode == "train":
|
| 24 |
+
if not args.dry_run:
|
| 25 |
+
checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
|
| 26 |
+
mil_model.load_state_dict(checkpoint["state_dict"])
|
| 27 |
mil_model = mil_model.to(args.device)
|
| 28 |
|
| 29 |
model_dir = os.path.join(args.logdir, "models")
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
|
| 69 |
+
|
| 70 |
+
if not args.dry_run:
|
| 71 |
+
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 72 |
+
cspca_model.load_state_dict(checkpt["state_dict"])
|
| 73 |
+
|
| 74 |
cspca_model = cspca_model.to(args.device)
|
| 75 |
if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
|
| 76 |
auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
|
|
|
|
| 195 |
if args.dataset_json is None:
|
| 196 |
logging.error("Dataset path not provided. Quitting.")
|
| 197 |
sys.exit(1)
|
| 198 |
+
if not args.dry_run:
|
| 199 |
+
if args.checkpoint_pirads is None and args.mode == "train":
|
| 200 |
+
logging.error("PI-RADS checkpoint path not provided. Quitting.")
|
| 201 |
+
sys.exit(1)
|
| 202 |
+
elif args.checkpoint_cspca is None and args.mode == "test":
|
| 203 |
+
logging.error("csPCa checkpoint path not provided. Quitting.")
|
| 204 |
+
sys.exit(1)
|
| 205 |
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 206 |
|
| 207 |
if args.device == torch.device("cuda"):
|
run_pirads.py
CHANGED
|
@@ -26,7 +26,7 @@ def main_worker(args):
|
|
| 26 |
model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
|
| 27 |
start_epoch = 0
|
| 28 |
best_acc = 0.0
|
| 29 |
-
if args.checkpoint is not None:
|
| 30 |
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
| 31 |
model.load_state_dict(checkpoint["state_dict"])
|
| 32 |
|
|
@@ -285,9 +285,10 @@ if __name__ == "__main__":
|
|
| 285 |
if args.dataset_json is None:
|
| 286 |
logging.error("Dataset JSON file not provided. Quitting.")
|
| 287 |
sys.exit(1)
|
| 288 |
-
if
|
| 289 |
-
|
| 290 |
-
|
|
|
|
| 291 |
|
| 292 |
if args.dry_run:
|
| 293 |
logging.info("Dry run mode enabled.")
|
|
|
|
| 26 |
model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
|
| 27 |
start_epoch = 0
|
| 28 |
best_acc = 0.0
|
| 29 |
+
if args.checkpoint is not None and not args.dry_run:
|
| 30 |
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
| 31 |
model.load_state_dict(checkpoint["state_dict"])
|
| 32 |
|
|
|
|
| 285 |
if args.dataset_json is None:
|
| 286 |
logging.error("Dataset JSON file not provided. Quitting.")
|
| 287 |
sys.exit(1)
|
| 288 |
+
if not args.dry_run:
|
| 289 |
+
if args.checkpoint is None and args.mode == "test":
|
| 290 |
+
logging.error("Model checkpoint path not provided. Quitting.")
|
| 291 |
+
sys.exit(1)
|
| 292 |
|
| 293 |
if args.dry_run:
|
| 294 |
logging.info("Dry run mode enabled.")
|