Spaces:
Runtime error
Runtime error
Anirudh Balaraman commited on
Commit ·
95dc457
1
Parent(s): 6f43d62
fix pytest
Browse files- Makefile +2 -5
- run_cspca.py +0 -3
- src/data/data_loader.py +8 -6
- temp.ipynb +243 -0
- tests/__init__.py +0 -0
- tests/test_run.py +142 -112
Makefile
CHANGED
|
@@ -20,11 +20,8 @@ clean:
|
|
| 20 |
@python3 -Bc "import shutil, pathlib; \
|
| 21 |
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('__pycache__')]; \
|
| 22 |
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.ipynb_checkpoints')]; \
|
| 23 |
-
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.monai-cache')];
|
| 24 |
-
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.mypy_cache')]; \
|
| 25 |
-
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.ruff_cache')]; \
|
| 26 |
-
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.pytest_cache')]"
|
| 27 |
|
| 28 |
# Updated 'check' to clean before running (optional)
|
| 29 |
# This ensures you are testing from a "blank slate"
|
| 30 |
-
check: format lint typecheck clean
|
|
|
|
| 20 |
@python3 -Bc "import shutil, pathlib; \
|
| 21 |
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('__pycache__')]; \
|
| 22 |
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.ipynb_checkpoints')]; \
|
| 23 |
+
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.monai-cache')];"
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Updated 'check' to clean before running (optional)
|
| 26 |
# This ensures you are testing from a "blank slate"
|
| 27 |
+
check: format lint typecheck test clean
|
run_cspca.py
CHANGED
|
@@ -21,7 +21,6 @@ def main_worker(args):
|
|
| 21 |
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 22 |
|
| 23 |
if args.mode == "train":
|
| 24 |
-
|
| 25 |
checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
|
| 26 |
mil_model.load_state_dict(checkpoint["state_dict"])
|
| 27 |
mil_model = mil_model.to(args.device)
|
|
@@ -64,7 +63,6 @@ def main_worker(args):
|
|
| 64 |
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 65 |
shutil.rmtree(cache_dir_path)
|
| 66 |
|
| 67 |
-
|
| 68 |
cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
|
| 69 |
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 70 |
cspca_model.load_state_dict(checkpt["state_dict"])
|
|
@@ -92,7 +90,6 @@ def main_worker(args):
|
|
| 92 |
get_metrics(metrics_dict)
|
| 93 |
|
| 94 |
|
| 95 |
-
|
| 96 |
def parse_args():
|
| 97 |
parser = argparse.ArgumentParser(
|
| 98 |
description="Multiple Instance Learning (MIL) for csPCa risk prediction."
|
|
|
|
| 21 |
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 22 |
|
| 23 |
if args.mode == "train":
|
|
|
|
| 24 |
checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
|
| 25 |
mil_model.load_state_dict(checkpoint["state_dict"])
|
| 26 |
mil_model = mil_model.to(args.device)
|
|
|
|
| 63 |
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 64 |
shutil.rmtree(cache_dir_path)
|
| 65 |
|
|
|
|
| 66 |
cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
|
| 67 |
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 68 |
cspca_model.load_state_dict(checkpt["state_dict"])
|
|
|
|
| 90 |
get_metrics(metrics_dict)
|
| 91 |
|
| 92 |
|
|
|
|
| 93 |
def parse_args():
|
| 94 |
parser = argparse.ArgumentParser(
|
| 95 |
description="Multiple Instance Learning (MIL) for csPCa risk prediction."
|
src/data/data_loader.py
CHANGED
|
@@ -26,6 +26,7 @@ from .custom_transforms import (
|
|
| 26 |
NormalizeIntensity_customd,
|
| 27 |
)
|
| 28 |
|
|
|
|
| 29 |
class DummyMILDataset(torch.utils.data.Dataset):
|
| 30 |
def __init__(self, args, num_samples=8):
|
| 31 |
self.num_samples = num_samples
|
|
@@ -43,13 +44,16 @@ class DummyMILDataset(torch.utils.data.Dataset):
|
|
| 43 |
item = {
|
| 44 |
# Shape: (Channels=3, Depth, H, W) based on your Transposed(indices=(0, 3, 1, 2))
|
| 45 |
"image": torch.randn(3, self.args.depth, self.args.tile_size, self.args.tile_size),
|
| 46 |
-
"label": torch.tensor(label_value, dtype=torch.float32)
|
| 47 |
}
|
| 48 |
if self.args.use_heatmap:
|
| 49 |
-
item["final_heatmap"] = torch.randn(
|
|
|
|
|
|
|
| 50 |
bag.append(item)
|
| 51 |
return bag
|
| 52 |
|
|
|
|
| 53 |
def list_data_collate(batch: list):
|
| 54 |
"""
|
| 55 |
Combine instances from a list of dicts into a single dict, by stacking them along first dim
|
|
@@ -130,18 +134,16 @@ def data_transform(args: argparse.Namespace) -> Transform:
|
|
| 130 |
def get_dataloader(
|
| 131 |
args: argparse.Namespace, split: Literal["train", "test"]
|
| 132 |
) -> torch.utils.data.DataLoader:
|
| 133 |
-
|
| 134 |
if args.dry_run:
|
| 135 |
print(f"🛠️ DRY RUN: Creating synthetic {split} dataloader...")
|
| 136 |
dummy_ds = DummyMILDataset(args, num_samples=args.batch_size * 2)
|
| 137 |
return torch.utils.data.DataLoader(
|
| 138 |
dummy_ds,
|
| 139 |
batch_size=args.batch_size,
|
| 140 |
-
collate_fn=list_data_collate,
|
| 141 |
-
num_workers=0
|
| 142 |
)
|
| 143 |
|
| 144 |
-
|
| 145 |
data_list = load_decathlon_datalist(
|
| 146 |
data_list_file_path=args.dataset_json,
|
| 147 |
data_list_key=split,
|
|
|
|
| 26 |
NormalizeIntensity_customd,
|
| 27 |
)
|
| 28 |
|
| 29 |
+
|
| 30 |
class DummyMILDataset(torch.utils.data.Dataset):
|
| 31 |
def __init__(self, args, num_samples=8):
|
| 32 |
self.num_samples = num_samples
|
|
|
|
| 44 |
item = {
|
| 45 |
# Shape: (Channels=3, Depth, H, W) based on your Transposed(indices=(0, 3, 1, 2))
|
| 46 |
"image": torch.randn(3, self.args.depth, self.args.tile_size, self.args.tile_size),
|
| 47 |
+
"label": torch.tensor(label_value, dtype=torch.float32),
|
| 48 |
}
|
| 49 |
if self.args.use_heatmap:
|
| 50 |
+
item["final_heatmap"] = torch.randn(
|
| 51 |
+
1, self.args.depth, self.args.tile_size, self.args.tile_size
|
| 52 |
+
)
|
| 53 |
bag.append(item)
|
| 54 |
return bag
|
| 55 |
|
| 56 |
+
|
| 57 |
def list_data_collate(batch: list):
|
| 58 |
"""
|
| 59 |
Combine instances from a list of dicts into a single dict, by stacking them along first dim
|
|
|
|
| 134 |
def get_dataloader(
|
| 135 |
args: argparse.Namespace, split: Literal["train", "test"]
|
| 136 |
) -> torch.utils.data.DataLoader:
|
|
|
|
| 137 |
if args.dry_run:
|
| 138 |
print(f"🛠️ DRY RUN: Creating synthetic {split} dataloader...")
|
| 139 |
dummy_ds = DummyMILDataset(args, num_samples=args.batch_size * 2)
|
| 140 |
return torch.utils.data.DataLoader(
|
| 141 |
dummy_ds,
|
| 142 |
batch_size=args.batch_size,
|
| 143 |
+
collate_fn=list_data_collate, # Uses your custom stacking logic
|
| 144 |
+
num_workers=0, # Keep it simple for dry run
|
| 145 |
)
|
| 146 |
|
|
|
|
| 147 |
data_list = load_decathlon_datalist(
|
| 148 |
data_list_file_path=args.dataset_json,
|
| 149 |
data_list_key=split,
|
temp.ipynb
CHANGED
|
@@ -77,6 +77,249 @@
|
|
| 77 |
"from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset"
|
| 78 |
]
|
| 79 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
{
|
| 81 |
"cell_type": "code",
|
| 82 |
"execution_count": 2,
|
|
|
|
| 77 |
"from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset"
|
| 78 |
]
|
| 79 |
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": 1,
|
| 83 |
+
"id": "bc433898",
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [],
|
| 86 |
+
"source": [
|
| 87 |
+
"import subprocess\n",
|
| 88 |
+
"import sys\n",
|
| 89 |
+
"from pathlib import Path\n",
|
| 90 |
+
"import torch\n",
|
| 91 |
+
"import pytest\n",
|
| 92 |
+
"import argparse\n",
|
| 93 |
+
"from src.train.train_pirads import get_attention_scores\n"
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"cell_type": "code",
|
| 98 |
+
"execution_count": 8,
|
| 99 |
+
"id": "f1c90aff",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": [
|
| 103 |
+
"batch_size = 2\n",
|
| 104 |
+
"num_patches = 4\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"# Sample 0: Target = 3 (Cancer), Sample 1: Target = 0 (PI-RADS 2)\n",
|
| 107 |
+
"data = torch.randn(batch_size, num_patches, 1, 8, 8)\n",
|
| 108 |
+
"target = torch.tensor([3.0, 0.0])\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"# Create heatmaps: Sample 0 has one \"hot\" patch\n",
|
| 111 |
+
"heatmap = torch.zeros(batch_size, num_patches, 1, 8, 8)\n",
|
| 112 |
+
"heatmap[0, 0] = 10.0 # High attention on patch 0 for the first sample\n",
|
| 113 |
+
"heatmap[0, 3] = 2.0 \n",
|
| 114 |
+
"heatmap[1, 2] = 5.0 # Should be overridden by PI-RADS 2 logic anyway\n"
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "code",
|
| 119 |
+
"execution_count": 25,
|
| 120 |
+
"id": "80cb444f",
|
| 121 |
+
"metadata": {},
|
| 122 |
+
"outputs": [],
|
| 123 |
+
"source": [
|
| 124 |
+
"def mock_args():\n",
|
| 125 |
+
" # Mocking argparse for the device\n",
|
| 126 |
+
" args = argparse.Namespace()\n",
|
| 127 |
+
" args.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 128 |
+
" return args"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"execution_count": 41,
|
| 134 |
+
"id": "6528fd4d",
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"outputs": [
|
| 137 |
+
{
|
| 138 |
+
"ename": "AssertionError",
|
| 139 |
+
"evalue": "",
|
| 140 |
+
"output_type": "error",
|
| 141 |
+
"traceback": [
|
| 142 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 143 |
+
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
|
| 144 |
+
"Cell \u001b[0;32mIn[41], line 23\u001b[0m\n\u001b[1;32m 21\u001b[0m idx \u001b[38;5;241m=\u001b[39m (shuffled_images[\u001b[38;5;241m0\u001b[39m, :, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m5.0\u001b[39m)\u001b[38;5;241m.\u001b[39mnonzero(as_tuple\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# The attention score at that same index should be the maximum\u001b[39;00m\n\u001b[0;32m---> 23\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m att_labels[\u001b[38;5;241m0\u001b[39m, idx] \u001b[38;5;241m==\u001b[39m att_labels[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mmedian()\n",
|
| 145 |
+
"\u001b[0;31mAssertionError\u001b[0m: "
|
| 146 |
+
]
|
| 147 |
+
}
|
| 148 |
+
],
|
| 149 |
+
"source": [
|
| 150 |
+
"num_patches = 10\n",
|
| 151 |
+
"\n",
|
| 152 |
+
"# Distinct data per patch: [0, 1, 2, 3...]\n",
|
| 153 |
+
"data = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()\n",
|
| 154 |
+
"target = torch.tensor([3.0])\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"# Heatmap matches the data indices so we can track the \"label\"\n",
|
| 157 |
+
"heatmap = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args)\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"idx= (shuffled_images[0, :, 0, 0, 0] == 9.0).nonzero(as_tuple=True)[0]\n",
|
| 163 |
+
"# The attention score at that same index should be the maximum\n",
|
| 164 |
+
"assert att_labels[0, idx] == att_labels[0].max()\n",
|
| 165 |
+
"\n",
|
| 166 |
+
"idx = (shuffled_images[0, :, 0, 0, 0] == 0.0).nonzero(as_tuple=True)[0]\n",
|
| 167 |
+
"# The attention score at that same index should be the maximum\n",
|
| 168 |
+
"assert att_labels[0, idx] == att_labels[0].min()\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"idx = (shuffled_images[0, :, 0, 0, 0] == 5.0).nonzero(as_tuple=True)[0]\n",
|
| 171 |
+
"# The attention score at that same index should be the maximum\n",
|
| 172 |
+
"assert att_labels[0, idx] == att_labels[0].median()"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "code",
|
| 177 |
+
"execution_count": 2,
|
| 178 |
+
"id": "90f5acab",
|
| 179 |
+
"metadata": {},
|
| 180 |
+
"outputs": [],
|
| 181 |
+
"source": [
|
| 182 |
+
"import subprocess\n",
|
| 183 |
+
"import sys\n",
|
| 184 |
+
"from pathlib import Path\n",
|
| 185 |
+
"import torch\n",
|
| 186 |
+
"import pytest\n",
|
| 187 |
+
"import argparse\n",
|
| 188 |
+
"from src.train.train_pirads import get_attention_scores\n",
|
| 189 |
+
"import monai\n",
|
| 190 |
+
"from monai.transforms import Transform\n",
|
| 191 |
+
"from src.data.custom_transforms import NormalizeIntensity_custom"
|
| 192 |
+
]
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"cell_type": "code",
|
| 196 |
+
"execution_count": 3,
|
| 197 |
+
"id": "e3a2dc6c",
|
| 198 |
+
"metadata": {},
|
| 199 |
+
"outputs": [],
|
| 200 |
+
"source": [
|
| 201 |
+
"img = torch.zeros((2, 4, 4), dtype=torch.float32)\n",
|
| 202 |
+
"mask = torch.zeros((1, 4, 4), dtype=torch.float32)"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"cell_type": "code",
|
| 207 |
+
"execution_count": 4,
|
| 208 |
+
"id": "98a500df",
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"outputs": [
|
| 211 |
+
{
|
| 212 |
+
"data": {
|
| 213 |
+
"text/plain": [
|
| 214 |
+
"tensor([[[0., 0., 0., 0.],\n",
|
| 215 |
+
" [0., 0., 0., 0.],\n",
|
| 216 |
+
" [0., 0., 0., 0.],\n",
|
| 217 |
+
" [0., 0., 0., 0.]],\n",
|
| 218 |
+
"\n",
|
| 219 |
+
" [[0., 0., 0., 0.],\n",
|
| 220 |
+
" [0., 0., 0., 0.],\n",
|
| 221 |
+
" [0., 0., 0., 0.],\n",
|
| 222 |
+
" [0., 0., 0., 0.]]])"
|
| 223 |
+
]
|
| 224 |
+
},
|
| 225 |
+
"execution_count": 4,
|
| 226 |
+
"metadata": {},
|
| 227 |
+
"output_type": "execute_result"
|
| 228 |
+
}
|
| 229 |
+
],
|
| 230 |
+
"source": [
|
| 231 |
+
"img"
|
| 232 |
+
]
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"cell_type": "code",
|
| 236 |
+
"execution_count": 5,
|
| 237 |
+
"id": "c9974f43",
|
| 238 |
+
"metadata": {},
|
| 239 |
+
"outputs": [],
|
| 240 |
+
"source": [
|
| 241 |
+
"img[0, :, :] = 100.0 # Background\n",
|
| 242 |
+
"img[0, 0, 0] = 10.0 # Masked pixel 1\n",
|
| 243 |
+
"img[0, 0, 1] = 20.0 # Masked pixel 2\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"# --- Channel 1 Setup ---\n",
|
| 246 |
+
"# Inside mask: Values [2, 4]\n",
|
| 247 |
+
"# Outside mask: Value 50\n",
|
| 248 |
+
"img[1, :, :] = 50.0 # Background\n",
|
| 249 |
+
"img[1, 0, 0] = 2.0 # Masked pixel 1\n",
|
| 250 |
+
"img[1, 0, 1] = 4.0 # Masked pixel 2\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"# --- Mask Setup ---\n",
|
| 253 |
+
"# Selects only the top-left two pixels (0,0) and (0,1)\n",
|
| 254 |
+
"mask[0, 0, 0] = 1\n",
|
| 255 |
+
"mask[0, 0, 1] = 1\n"
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"cell_type": "code",
|
| 260 |
+
"execution_count": null,
|
| 261 |
+
"id": "eb910fda",
|
| 262 |
+
"metadata": {},
|
| 263 |
+
"outputs": [],
|
| 264 |
+
"source": [
|
| 265 |
+
"data = torch.rand(1, 10, 10)\n",
|
| 266 |
+
"mask = torch.randint(0, 2, (1, 10, 10)).float()\n",
|
| 267 |
+
"normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True)\n",
|
| 268 |
+
"out = normalizer(data, mask)"
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"cell_type": "code",
|
| 273 |
+
"execution_count": 25,
|
| 274 |
+
"id": "923341a3",
|
| 275 |
+
"metadata": {},
|
| 276 |
+
"outputs": [],
|
| 277 |
+
"source": [
|
| 278 |
+
"masked = data[mask != 0]\n",
|
| 279 |
+
"mean_ = torch.mean(masked.float())\n",
|
| 280 |
+
"std_ = torch.std(masked.float(), unbiased=False)\n",
|
| 281 |
+
"\n",
|
| 282 |
+
"epsilon = 1e-8\n",
|
| 283 |
+
"normalized_data = (data - mean_) / (std_ + epsilon)"
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"cell_type": "code",
|
| 288 |
+
"execution_count": 23,
|
| 289 |
+
"id": "e844cde1",
|
| 290 |
+
"metadata": {},
|
| 291 |
+
"outputs": [
|
| 292 |
+
{
|
| 293 |
+
"data": {
|
| 294 |
+
"text/plain": [
|
| 295 |
+
"tensor([ 1.4106, -0.1975, 0.3907, 1.2870, -0.7974, -1.2061, 0.7028, 1.2778,\n",
|
| 296 |
+
" 0.4667, -0.3361, -0.7842, -1.6296, -1.2037, 1.3582, -0.5648, -0.3055,\n",
|
| 297 |
+
" -0.3313, 0.0328, -1.0675, 0.6328, -0.2215, -1.3372, 0.5165, 1.9302,\n",
|
| 298 |
+
" 0.8875, 0.6793, 0.5553, 0.4335, 0.6390, -1.3707, 1.6053, 1.8626,\n",
|
| 299 |
+
" -0.3923, 0.2319, 0.3911, -0.4683, -1.1255, -1.6464, -0.2123, -0.5415,\n",
|
| 300 |
+
" 0.1401, -0.2822, 1.5019, -0.5117, -1.6047, -0.2322, -1.3080, 0.0130,\n",
|
| 301 |
+
" 1.8028, 0.5602, -1.6317])"
|
| 302 |
+
]
|
| 303 |
+
},
|
| 304 |
+
"execution_count": 23,
|
| 305 |
+
"metadata": {},
|
| 306 |
+
"output_type": "execute_result"
|
| 307 |
+
}
|
| 308 |
+
],
|
| 309 |
+
"source": [
|
| 310 |
+
"masked"
|
| 311 |
+
]
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"cell_type": "code",
|
| 315 |
+
"execution_count": 26,
|
| 316 |
+
"id": "a9a20f58",
|
| 317 |
+
"metadata": {},
|
| 318 |
+
"outputs": [],
|
| 319 |
+
"source": [
|
| 320 |
+
"torch.testing.assert_close(out, normalized_data)"
|
| 321 |
+
]
|
| 322 |
+
},
|
| 323 |
{
|
| 324 |
"cell_type": "code",
|
| 325 |
"execution_count": 2,
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_run.py
CHANGED
|
@@ -1,139 +1,169 @@
|
|
| 1 |
-
import
|
| 2 |
-
import sys
|
| 3 |
-
from pathlib import Path
|
| 4 |
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
# Path to your run_pirads.py script
|
| 12 |
-
repo_root = Path(__file__).parent.parent
|
| 13 |
-
script_path = repo_root / "run_pirads.py"
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
# Make sure the file exists
|
| 19 |
-
assert config_path.exists(), f"Config file not found: {config_path}"
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
str(script_path),
|
| 26 |
-
"--mode",
|
| 27 |
-
"train",
|
| 28 |
-
"--config",
|
| 29 |
-
str(config_path),
|
| 30 |
-
"--dry_run",
|
| 31 |
-
],
|
| 32 |
-
capture_output=True,
|
| 33 |
-
text=True,
|
| 34 |
-
)
|
| 35 |
|
| 36 |
-
#
|
| 37 |
-
|
|
|
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
def test_run_pirads_inference():
|
| 41 |
-
"""
|
| 42 |
-
Test that run_cspca.py runs without crashing using an existing YAML config.
|
| 43 |
-
"""
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
|
|
|
|
| 51 |
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
|
| 55 |
-
|
| 56 |
-
result = subprocess.run(
|
| 57 |
-
[
|
| 58 |
-
sys.executable,
|
| 59 |
-
str(script_path),
|
| 60 |
-
"--mode",
|
| 61 |
-
"test",
|
| 62 |
-
"--config",
|
| 63 |
-
str(config_path),
|
| 64 |
-
"--dry_run",
|
| 65 |
-
],
|
| 66 |
-
capture_output=True,
|
| 67 |
-
text=True,
|
| 68 |
-
)
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
|
|
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
"""
|
| 76 |
-
Test that
|
|
|
|
| 77 |
"""
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
script_path = repo_root / "run_cspca.py"
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
[
|
| 92 |
-
sys.executable,
|
| 93 |
-
str(script_path),
|
| 94 |
-
"--mode",
|
| 95 |
-
"train",
|
| 96 |
-
"--config",
|
| 97 |
-
str(config_path),
|
| 98 |
-
"--dry_run",
|
| 99 |
-
],
|
| 100 |
-
capture_output=True,
|
| 101 |
-
text=True,
|
| 102 |
-
)
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
"""
|
| 110 |
-
Test
|
|
|
|
| 111 |
"""
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
import pytest
|
| 4 |
+
import torch
|
| 5 |
|
| 6 |
+
from src.data.custom_transforms import NormalizeIntensity_custom
|
| 7 |
+
from src.data.data_loader import get_dataloader
|
| 8 |
+
from src.model.cspca_model import CSPCAModel
|
| 9 |
+
from src.model.mil import MILModel3D
|
| 10 |
+
from src.train import train_cspca, train_pirads
|
| 11 |
+
from src.train.train_pirads import get_attention_scores
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
@pytest.fixture
|
| 15 |
+
def mock_args():
|
| 16 |
+
# Mocking argparse for the device
|
| 17 |
+
args = argparse.Namespace()
|
| 18 |
+
args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
return args
|
| 20 |
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
def test_get_attention_scores_logic(mock_args):
|
| 23 |
+
# Setup: 2 samples, 4 patches, images of size 8x8
|
| 24 |
+
batch_size = 2
|
| 25 |
+
num_patches = 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
# Sample 0: Target = 3 (Cancer), Sample 1: Target = 0 (PI-RADS 2)
|
| 28 |
+
data = torch.randn(batch_size, num_patches, 1, 8, 8)
|
| 29 |
+
target = torch.tensor([3.0, 0.0])
|
| 30 |
|
| 31 |
+
# Create heatmaps: Sample 0 has one "hot" patch
|
| 32 |
+
heatmap = torch.zeros(batch_size, num_patches, 1, 8, 8)
|
| 33 |
+
heatmap[0, 0] = 10.0 # High attention on patch 0 for the first sample
|
| 34 |
+
heatmap[1, :] = 5.0 # Should be overridden by PI-RADS 2 logic anyway
|
| 35 |
+
|
| 36 |
+
att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args)
|
| 37 |
+
|
| 38 |
+
# --- TEST 1: Normalization ---
|
| 39 |
+
sums = att_labels.sum(dim=1)
|
| 40 |
+
torch.testing.assert_close(sums, torch.ones(batch_size).to(mock_args.device))
|
| 41 |
+
|
| 42 |
+
# --- TEST 2: PI-RADS 2 Uniformity ---
|
| 43 |
+
pirads_2_scores = att_labels[1]
|
| 44 |
+
expected_uniform = torch.ones(num_patches).to(mock_args.device) / num_patches
|
| 45 |
+
torch.testing.assert_close(pirads_2_scores, expected_uniform)
|
| 46 |
+
|
| 47 |
+
# --- TEST 4: Output Shapes ---
|
| 48 |
+
assert att_labels.shape == (batch_size, num_patches)
|
| 49 |
+
assert shuffled_images.shape == data.shape
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
def test_shuffling_consistency(mock_args):
|
| 53 |
+
# Verify that the image and label are shuffled with the SAME permutation
|
| 54 |
+
num_patches = 10
|
| 55 |
|
| 56 |
+
# Distinct data per patch: [0, 1, 2, 3...]
|
| 57 |
+
data = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()
|
| 58 |
+
target = torch.tensor([3.0])
|
| 59 |
|
| 60 |
+
# Heatmap matches the data indices so we can track the "label"
|
| 61 |
+
heatmap = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()
|
| 62 |
|
| 63 |
+
att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
idx = (shuffled_images[0, :, 0, 0, 0] == 9.0).nonzero(as_tuple=True)[0]
|
| 66 |
+
# The attention score at that same index should be the maximum
|
| 67 |
+
assert att_labels[0, idx] == att_labels[0].max()
|
| 68 |
|
| 69 |
+
idx = (shuffled_images[0, :, 0, 0, 0] == 0.0).nonzero(as_tuple=True)[0]
|
| 70 |
+
# The attention score at that same index should be the minimum
|
| 71 |
+
assert att_labels[0, idx] == att_labels[0].min()
|
| 72 |
|
| 73 |
+
shuffled_images = shuffled_images.cpu().squeeze() # Shape [10]
|
| 74 |
+
att_labels = att_labels.cpu().squeeze() # Shape [10]
|
| 75 |
+
|
| 76 |
+
sorted_vals, original_indices = torch.sort(shuffled_images)
|
| 77 |
+
sorted_labels = att_labels[original_indices]
|
| 78 |
+
|
| 79 |
+
for i in range(len(sorted_labels) - 1):
|
| 80 |
+
assert sorted_labels[i] <= sorted_labels[i + 1], (
|
| 81 |
+
f"Alignment broken at index {i}: Image val {sorted_vals[i]} has higher label than {sorted_vals[i + 1]}"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_normalize_intensity_custom_masked_stats():
|
| 86 |
"""
|
| 87 |
+
Test that statistics (mean/std) are calculated ONLY from the masked region,
|
| 88 |
+
but applied to the whole image.
|
| 89 |
"""
|
| 90 |
|
| 91 |
+
img = torch.zeros((2, 4, 4), dtype=torch.float32)
|
| 92 |
+
mask = torch.zeros((1, 4, 4), dtype=torch.float32)
|
|
|
|
| 93 |
|
| 94 |
+
img[0, :, :] = 100.0
|
| 95 |
+
img[0, 0, 0] = 10.0
|
| 96 |
+
img[0, 0, 1] = 20.0
|
| 97 |
|
| 98 |
+
img[1, :, :] = 50.0
|
| 99 |
+
img[1, 0, 0] = 2.0
|
| 100 |
+
img[1, 0, 1] = 4.0
|
| 101 |
|
| 102 |
+
mask[0, 0, 0] = 1
|
| 103 |
+
mask[0, 0, 1] = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True)
|
| 106 |
+
out = normalizer(img, mask)
|
| 107 |
|
| 108 |
+
assert torch.isclose(out[0, 0, 0], torch.tensor(-1.0)), "Ch0 masked value 1 incorrect"
|
| 109 |
+
assert torch.isclose(out[0, 0, 1], torch.tensor(1.0)), "Ch0 masked value 2 incorrect"
|
| 110 |
+
assert torch.isclose(out[0, 1, 1], torch.tensor(17.0)), "Ch0 background normalization incorrect"
|
| 111 |
|
| 112 |
+
assert torch.isclose(out[1, 0, 0], torch.tensor(-1.0)), "Ch1 masked value 1 incorrect"
|
| 113 |
+
assert torch.isclose(out[1, 1, 1], torch.tensor(47.0)), "Ch1 background normalization incorrect"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_normalize_intensity_constant_area():
|
| 117 |
"""
|
| 118 |
+
Test edge case where the area under the mask has 0 variance (constant value).
|
| 119 |
+
Std should default to 1.0 to avoid division by zero.
|
| 120 |
"""
|
| 121 |
+
img = torch.ones((1, 4, 4)) * 10.0 # All values are 10
|
| 122 |
+
mask = torch.ones((1, 4, 4))
|
| 123 |
+
|
| 124 |
+
normalizer = NormalizeIntensity_custom(channel_wise=True)
|
| 125 |
+
out = normalizer(img, mask)
|
| 126 |
+
assert torch.allclose(out, torch.zeros_like(out))
|
| 127 |
+
|
| 128 |
+
data = torch.rand(1, 10, 10)
|
| 129 |
+
mask = torch.randint(0, 2, (1, 10, 10)).float()
|
| 130 |
+
normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True)
|
| 131 |
+
out = normalizer(data, mask)
|
| 132 |
+
|
| 133 |
+
masked = data[mask != 0]
|
| 134 |
+
mean_val = torch.mean(masked.float())
|
| 135 |
+
std_val = torch.std(masked.float(), unbiased=False)
|
| 136 |
+
|
| 137 |
+
epsilon = 1e-8
|
| 138 |
+
normalized_data = (data - mean_val) / (std_val + epsilon)
|
| 139 |
+
|
| 140 |
+
torch.testing.assert_close(out, normalized_data)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def test_run_models():
|
| 144 |
+
args = argparse.Namespace()
|
| 145 |
+
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 146 |
+
args.epochs = 1
|
| 147 |
+
args.batch_size = 2
|
| 148 |
+
args.tile_size = 10
|
| 149 |
+
args.tile_count = 5
|
| 150 |
+
args.use_heatmap = True
|
| 151 |
+
args.amp = False
|
| 152 |
+
args.num_classes = 4
|
| 153 |
+
args.dry_run = True
|
| 154 |
+
args.depth = 3
|
| 155 |
+
|
| 156 |
+
model = MILModel3D(num_classes=args.num_classes, mil_mode="att_trans")
|
| 157 |
+
model.to(args.device)
|
| 158 |
+
params = model.parameters()
|
| 159 |
+
loader = get_dataloader(args, split="train")
|
| 160 |
+
optimizer = torch.optim.AdamW(params, lr=1e-5, weight_decay=1e-5)
|
| 161 |
+
scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp)
|
| 162 |
+
|
| 163 |
+
_ = train_pirads.train_epoch(model, loader, optimizer, scaler=scaler, epoch=0, args=args)
|
| 164 |
+
_ = train_pirads.val_epoch(model, loader, epoch=0, args=args)
|
| 165 |
+
|
| 166 |
+
cspca_model = CSPCAModel(backbone=model).to(args.device)
|
| 167 |
+
optimizer_cspca = torch.optim.AdamW(cspca_model.parameters(), lr=1e-5)
|
| 168 |
+
_ = train_cspca.train_epoch(cspca_model, loader, optimizer_cspca, epoch=0, args=args)
|
| 169 |
+
_ = train_cspca.val_epoch(cspca_model, loader, epoch=0, args=args)
|