Spaces:
Sleeping
Sleeping
Anirudh Balaraman commited on
Commit ·
6f43d62
1
Parent(s): c769d20
fix pytest
Browse files- Makefile +1 -1
- run_cspca.py +12 -19
- run_pirads.py +4 -8
- src/data/data_loader.py +35 -2
Makefile
CHANGED
|
@@ -27,4 +27,4 @@ clean:
|
|
| 27 |
|
| 28 |
# Updated 'check' to clean before running (optional)
|
| 29 |
# This ensures you are testing from a "blank slate"
|
| 30 |
-
check: format lint typecheck
|
|
|
|
| 27 |
|
| 28 |
# Updated 'check' to clean before running (optional)
|
| 29 |
# This ensures you are testing from a "blank slate"
|
| 30 |
+
check: format lint typecheck clean
|
run_cspca.py
CHANGED
|
@@ -21,9 +21,9 @@ def main_worker(args):
|
|
| 21 |
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 22 |
|
| 23 |
if args.mode == "train":
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
mil_model = mil_model.to(args.device)
|
| 28 |
|
| 29 |
model_dir = os.path.join(args.logdir, "models")
|
|
@@ -66,11 +66,8 @@ def main_worker(args):
|
|
| 66 |
|
| 67 |
|
| 68 |
cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 72 |
-
cspca_model.load_state_dict(checkpt["state_dict"])
|
| 73 |
-
|
| 74 |
cspca_model = cspca_model.to(args.device)
|
| 75 |
if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
|
| 76 |
auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
|
|
@@ -195,15 +192,14 @@ if __name__ == "__main__":
|
|
| 195 |
if args.dataset_json is None:
|
| 196 |
logging.error("Dataset path not provided. Quitting.")
|
| 197 |
sys.exit(1)
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
sys.exit(1)
|
| 205 |
-
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 206 |
|
|
|
|
| 207 |
if args.device == torch.device("cuda"):
|
| 208 |
torch.backends.cudnn.benchmark = True
|
| 209 |
|
|
@@ -218,6 +214,3 @@ if __name__ == "__main__":
|
|
| 218 |
args.tile_count = 5
|
| 219 |
|
| 220 |
main_worker(args)
|
| 221 |
-
|
| 222 |
-
if args.dry_run:
|
| 223 |
-
shutil.rmtree(args.logdir)
|
|
|
|
| 21 |
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 22 |
|
| 23 |
if args.mode == "train":
|
| 24 |
+
|
| 25 |
+
checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
|
| 26 |
+
mil_model.load_state_dict(checkpoint["state_dict"])
|
| 27 |
mil_model = mil_model.to(args.device)
|
| 28 |
|
| 29 |
model_dir = os.path.join(args.logdir, "models")
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
|
| 69 |
+
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 70 |
+
cspca_model.load_state_dict(checkpt["state_dict"])
|
|
|
|
|
|
|
|
|
|
| 71 |
cspca_model = cspca_model.to(args.device)
|
| 72 |
if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
|
| 73 |
auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
|
|
|
|
| 192 |
if args.dataset_json is None:
|
| 193 |
logging.error("Dataset path not provided. Quitting.")
|
| 194 |
sys.exit(1)
|
| 195 |
+
if args.checkpoint_pirads is None and args.mode == "train":
|
| 196 |
+
logging.error("PI-RADS checkpoint path not provided. Quitting.")
|
| 197 |
+
sys.exit(1)
|
| 198 |
+
elif args.checkpoint_cspca is None and args.mode == "test":
|
| 199 |
+
logging.error("csPCa checkpoint path not provided. Quitting.")
|
| 200 |
+
sys.exit(1)
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 203 |
if args.device == torch.device("cuda"):
|
| 204 |
torch.backends.cudnn.benchmark = True
|
| 205 |
|
|
|
|
| 214 |
args.tile_count = 5
|
| 215 |
|
| 216 |
main_worker(args)
|
|
|
|
|
|
|
|
|
run_pirads.py
CHANGED
|
@@ -26,7 +26,7 @@ def main_worker(args):
|
|
| 26 |
model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
|
| 27 |
start_epoch = 0
|
| 28 |
best_acc = 0.0
|
| 29 |
-
if args.checkpoint is not None
|
| 30 |
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
| 31 |
model.load_state_dict(checkpoint["state_dict"])
|
| 32 |
|
|
@@ -285,10 +285,9 @@ if __name__ == "__main__":
|
|
| 285 |
if args.dataset_json is None:
|
| 286 |
logging.error("Dataset JSON file not provided. Quitting.")
|
| 287 |
sys.exit(1)
|
| 288 |
-
if
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
sys.exit(1)
|
| 292 |
|
| 293 |
if args.dry_run:
|
| 294 |
logging.info("Dry run mode enabled.")
|
|
@@ -320,6 +319,3 @@ if __name__ == "__main__":
|
|
| 320 |
main_worker(args)
|
| 321 |
|
| 322 |
wandb.finish()
|
| 323 |
-
|
| 324 |
-
if args.dry_run:
|
| 325 |
-
shutil.rmtree(args.logdir)
|
|
|
|
| 26 |
model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
|
| 27 |
start_epoch = 0
|
| 28 |
best_acc = 0.0
|
| 29 |
+
if args.checkpoint is not None:
|
| 30 |
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
| 31 |
model.load_state_dict(checkpoint["state_dict"])
|
| 32 |
|
|
|
|
| 285 |
if args.dataset_json is None:
|
| 286 |
logging.error("Dataset JSON file not provided. Quitting.")
|
| 287 |
sys.exit(1)
|
| 288 |
+
if args.checkpoint is None and args.mode == "test":
|
| 289 |
+
logging.error("Model checkpoint path not provided. Quitting.")
|
| 290 |
+
sys.exit(1)
|
|
|
|
| 291 |
|
| 292 |
if args.dry_run:
|
| 293 |
logging.info("Dry run mode enabled.")
|
|
|
|
| 319 |
main_worker(args)
|
| 320 |
|
| 321 |
wandb.finish()
|
|
|
|
|
|
|
|
|
src/data/data_loader.py
CHANGED
|
@@ -26,6 +26,29 @@ from .custom_transforms import (
|
|
| 26 |
NormalizeIntensity_customd,
|
| 27 |
)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def list_data_collate(batch: list):
|
| 31 |
"""
|
|
@@ -107,13 +130,23 @@ def data_transform(args: argparse.Namespace) -> Transform:
|
|
| 107 |
def get_dataloader(
|
| 108 |
args: argparse.Namespace, split: Literal["train", "test"]
|
| 109 |
) -> torch.utils.data.DataLoader:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
data_list = load_decathlon_datalist(
|
| 111 |
data_list_file_path=args.dataset_json,
|
| 112 |
data_list_key=split,
|
| 113 |
base_dir=args.data_root,
|
| 114 |
)
|
| 115 |
-
if args.dry_run:
|
| 116 |
-
data_list = data_list[:2] # Use only 8 samples for dry run
|
| 117 |
cache_dir_ = os.path.join(args.logdir, "cache")
|
| 118 |
os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
|
| 119 |
transform = data_transform(args)
|
|
|
|
| 26 |
NormalizeIntensity_customd,
|
| 27 |
)
|
| 28 |
|
| 29 |
+
class DummyMILDataset(torch.utils.data.Dataset):
|
| 30 |
+
def __init__(self, args, num_samples=8):
|
| 31 |
+
self.num_samples = num_samples
|
| 32 |
+
self.args = args
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return self.num_samples
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, index):
|
| 38 |
+
# Simulate the output of your 'data_transform'
|
| 39 |
+
# A list of dictionaries, one for each 'tile_count' (patch)
|
| 40 |
+
bag = []
|
| 41 |
+
label_value = float(index % 2)
|
| 42 |
+
for _ in range(self.args.tile_count):
|
| 43 |
+
item = {
|
| 44 |
+
# Shape: (Channels=3, Depth, H, W) based on your Transposed(indices=(0, 3, 1, 2))
|
| 45 |
+
"image": torch.randn(3, self.args.depth, self.args.tile_size, self.args.tile_size),
|
| 46 |
+
"label": torch.tensor(label_value, dtype=torch.float32)
|
| 47 |
+
}
|
| 48 |
+
if self.args.use_heatmap:
|
| 49 |
+
item["final_heatmap"] = torch.randn(1, self.args.depth, self.args.tile_size, self.args.tile_size)
|
| 50 |
+
bag.append(item)
|
| 51 |
+
return bag
|
| 52 |
|
| 53 |
def list_data_collate(batch: list):
|
| 54 |
"""
|
|
|
|
| 130 |
def get_dataloader(
|
| 131 |
args: argparse.Namespace, split: Literal["train", "test"]
|
| 132 |
) -> torch.utils.data.DataLoader:
|
| 133 |
+
|
| 134 |
+
if args.dry_run:
|
| 135 |
+
print(f"🛠️ DRY RUN: Creating synthetic {split} dataloader...")
|
| 136 |
+
dummy_ds = DummyMILDataset(args, num_samples=args.batch_size * 2)
|
| 137 |
+
return torch.utils.data.DataLoader(
|
| 138 |
+
dummy_ds,
|
| 139 |
+
batch_size=args.batch_size,
|
| 140 |
+
collate_fn=list_data_collate, # Uses your custom stacking logic
|
| 141 |
+
num_workers=0 # Keep it simple for dry run
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
data_list = load_decathlon_datalist(
|
| 146 |
data_list_file_path=args.dataset_json,
|
| 147 |
data_list_key=split,
|
| 148 |
base_dir=args.data_root,
|
| 149 |
)
|
|
|
|
|
|
|
| 150 |
cache_dir_ = os.path.join(args.logdir, "cache")
|
| 151 |
os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
|
| 152 |
transform = data_transform(args)
|