Anirudh Balaraman commited on
Commit
c769d20
·
1 Parent(s): 99c1404

fix dry run

Browse files
Files changed (2) hide show
  1. run_cspca.py +15 -10
  2. run_pirads.py +5 -4
run_cspca.py CHANGED
@@ -21,8 +21,9 @@ def main_worker(args):
21
  cache_dir_path = Path(os.path.join(args.logdir, "cache"))
22
 
23
  if args.mode == "train":
24
- checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
25
- mil_model.load_state_dict(checkpoint["state_dict"])
 
26
  mil_model = mil_model.to(args.device)
27
 
28
  model_dir = os.path.join(args.logdir, "models")
@@ -65,8 +66,11 @@ def main_worker(args):
65
 
66
 
67
  cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
68
- checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
69
- cspca_model.load_state_dict(checkpt["state_dict"])
 
 
 
70
  cspca_model = cspca_model.to(args.device)
71
  if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
72
  auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
@@ -191,12 +195,13 @@ if __name__ == "__main__":
191
  if args.dataset_json is None:
192
  logging.error("Dataset path not provided. Quitting.")
193
  sys.exit(1)
194
- if args.checkpoint_pirads is None and args.mode == "train":
195
- logging.error("PI-RADS checkpoint path not provided. Quitting.")
196
- sys.exit(1)
197
- elif args.checkpoint_cspca is None and args.mode == "test":
198
- logging.error("csPCa checkpoint path not provided. Quitting.")
199
- sys.exit(1)
 
200
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
201
 
202
  if args.device == torch.device("cuda"):
 
21
  cache_dir_path = Path(os.path.join(args.logdir, "cache"))
22
 
23
  if args.mode == "train":
24
+ if not args.dry_run:
25
+ checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
26
+ mil_model.load_state_dict(checkpoint["state_dict"])
27
  mil_model = mil_model.to(args.device)
28
 
29
  model_dir = os.path.join(args.logdir, "models")
 
66
 
67
 
68
  cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
69
+
70
+ if not args.dry_run:
71
+ checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
72
+ cspca_model.load_state_dict(checkpt["state_dict"])
73
+
74
  cspca_model = cspca_model.to(args.device)
75
  if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
76
  auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
 
195
  if args.dataset_json is None:
196
  logging.error("Dataset path not provided. Quitting.")
197
  sys.exit(1)
198
+ if not args.dry_run:
199
+ if args.checkpoint_pirads is None and args.mode == "train":
200
+ logging.error("PI-RADS checkpoint path not provided. Quitting.")
201
+ sys.exit(1)
202
+ elif args.checkpoint_cspca is None and args.mode == "test":
203
+ logging.error("csPCa checkpoint path not provided. Quitting.")
204
+ sys.exit(1)
205
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
206
 
207
  if args.device == torch.device("cuda"):
run_pirads.py CHANGED
@@ -26,7 +26,7 @@ def main_worker(args):
26
  model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
27
  start_epoch = 0
28
  best_acc = 0.0
29
- if args.checkpoint is not None:
30
  checkpoint = torch.load(args.checkpoint, map_location="cpu")
31
  model.load_state_dict(checkpoint["state_dict"])
32
 
@@ -285,9 +285,10 @@ if __name__ == "__main__":
285
  if args.dataset_json is None:
286
  logging.error("Dataset JSON file not provided. Quitting.")
287
  sys.exit(1)
288
- if args.checkpoint is None and args.mode == "test":
289
- logging.error("Model checkpoint path not provided. Quitting.")
290
- sys.exit(1)
 
291
 
292
  if args.dry_run:
293
  logging.info("Dry run mode enabled.")
 
26
  model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
27
  start_epoch = 0
28
  best_acc = 0.0
29
+ if args.checkpoint is not None and not args.dry_run:
30
  checkpoint = torch.load(args.checkpoint, map_location="cpu")
31
  model.load_state_dict(checkpoint["state_dict"])
32
 
 
285
  if args.dataset_json is None:
286
  logging.error("Dataset JSON file not provided. Quitting.")
287
  sys.exit(1)
288
+ if not args.dry_run:
289
+ if args.checkpoint is None and args.mode == "test":
290
+ logging.error("Model checkpoint path not provided. Quitting.")
291
+ sys.exit(1)
292
 
293
  if args.dry_run:
294
  logging.info("Dry run mode enabled.")