Anirudh Balaraman commited on
Commit
95dc457
·
1 Parent(s): 6f43d62

fix pytest

Browse files
Files changed (6) hide show
  1. Makefile +2 -5
  2. run_cspca.py +0 -3
  3. src/data/data_loader.py +8 -6
  4. temp.ipynb +243 -0
  5. tests/__init__.py +0 -0
  6. tests/test_run.py +142 -112
Makefile CHANGED
@@ -20,11 +20,8 @@ clean:
20
  @python3 -Bc "import shutil, pathlib; \
21
  [shutil.rmtree(p) for p in pathlib.Path('.').rglob('__pycache__')]; \
22
  [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.ipynb_checkpoints')]; \
23
- [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.monai-cache')]; \
24
- [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.mypy_cache')]; \
25
- [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.ruff_cache')]; \
26
- [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.pytest_cache')]"
27
 
28
  # Updated 'check' to clean before running (optional)
29
  # This ensures you are testing from a "blank slate"
30
- check: format lint typecheck clean
 
20
  @python3 -Bc "import shutil, pathlib; \
21
  [shutil.rmtree(p) for p in pathlib.Path('.').rglob('__pycache__')]; \
22
  [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.ipynb_checkpoints')]; \
23
+ [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.monai-cache')];"
 
 
 
24
 
25
  # Updated 'check' to clean before running (optional)
26
  # This ensures you are testing from a "blank slate"
27
+ check: format lint typecheck test clean
run_cspca.py CHANGED
@@ -21,7 +21,6 @@ def main_worker(args):
21
  cache_dir_path = Path(os.path.join(args.logdir, "cache"))
22
 
23
  if args.mode == "train":
24
-
25
  checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
26
  mil_model.load_state_dict(checkpoint["state_dict"])
27
  mil_model = mil_model.to(args.device)
@@ -64,7 +63,6 @@ def main_worker(args):
64
  if cache_dir_path.exists() and cache_dir_path.is_dir():
65
  shutil.rmtree(cache_dir_path)
66
 
67
-
68
  cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
69
  checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
70
  cspca_model.load_state_dict(checkpt["state_dict"])
@@ -92,7 +90,6 @@ def main_worker(args):
92
  get_metrics(metrics_dict)
93
 
94
 
95
-
96
  def parse_args():
97
  parser = argparse.ArgumentParser(
98
  description="Multiple Instance Learning (MIL) for csPCa risk prediction."
 
21
  cache_dir_path = Path(os.path.join(args.logdir, "cache"))
22
 
23
  if args.mode == "train":
 
24
  checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
25
  mil_model.load_state_dict(checkpoint["state_dict"])
26
  mil_model = mil_model.to(args.device)
 
63
  if cache_dir_path.exists() and cache_dir_path.is_dir():
64
  shutil.rmtree(cache_dir_path)
65
 
 
66
  cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
67
  checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
68
  cspca_model.load_state_dict(checkpt["state_dict"])
 
90
  get_metrics(metrics_dict)
91
 
92
 
 
93
  def parse_args():
94
  parser = argparse.ArgumentParser(
95
  description="Multiple Instance Learning (MIL) for csPCa risk prediction."
src/data/data_loader.py CHANGED
@@ -26,6 +26,7 @@ from .custom_transforms import (
26
  NormalizeIntensity_customd,
27
  )
28
 
 
29
  class DummyMILDataset(torch.utils.data.Dataset):
30
  def __init__(self, args, num_samples=8):
31
  self.num_samples = num_samples
@@ -43,13 +44,16 @@ class DummyMILDataset(torch.utils.data.Dataset):
43
  item = {
44
  # Shape: (Channels=3, Depth, H, W) based on your Transposed(indices=(0, 3, 1, 2))
45
  "image": torch.randn(3, self.args.depth, self.args.tile_size, self.args.tile_size),
46
- "label": torch.tensor(label_value, dtype=torch.float32)
47
  }
48
  if self.args.use_heatmap:
49
- item["final_heatmap"] = torch.randn(1, self.args.depth, self.args.tile_size, self.args.tile_size)
 
 
50
  bag.append(item)
51
  return bag
52
 
 
53
  def list_data_collate(batch: list):
54
  """
55
  Combine instances from a list of dicts into a single dict, by stacking them along first dim
@@ -130,18 +134,16 @@ def data_transform(args: argparse.Namespace) -> Transform:
130
  def get_dataloader(
131
  args: argparse.Namespace, split: Literal["train", "test"]
132
  ) -> torch.utils.data.DataLoader:
133
-
134
  if args.dry_run:
135
  print(f"🛠️ DRY RUN: Creating synthetic {split} dataloader...")
136
  dummy_ds = DummyMILDataset(args, num_samples=args.batch_size * 2)
137
  return torch.utils.data.DataLoader(
138
  dummy_ds,
139
  batch_size=args.batch_size,
140
- collate_fn=list_data_collate, # Uses your custom stacking logic
141
- num_workers=0 # Keep it simple for dry run
142
  )
143
 
144
-
145
  data_list = load_decathlon_datalist(
146
  data_list_file_path=args.dataset_json,
147
  data_list_key=split,
 
26
  NormalizeIntensity_customd,
27
  )
28
 
29
+
30
  class DummyMILDataset(torch.utils.data.Dataset):
31
  def __init__(self, args, num_samples=8):
32
  self.num_samples = num_samples
 
44
  item = {
45
  # Shape: (Channels=3, Depth, H, W) based on your Transposed(indices=(0, 3, 1, 2))
46
  "image": torch.randn(3, self.args.depth, self.args.tile_size, self.args.tile_size),
47
+ "label": torch.tensor(label_value, dtype=torch.float32),
48
  }
49
  if self.args.use_heatmap:
50
+ item["final_heatmap"] = torch.randn(
51
+ 1, self.args.depth, self.args.tile_size, self.args.tile_size
52
+ )
53
  bag.append(item)
54
  return bag
55
 
56
+
57
  def list_data_collate(batch: list):
58
  """
59
  Combine instances from a list of dicts into a single dict, by stacking them along first dim
 
134
  def get_dataloader(
135
  args: argparse.Namespace, split: Literal["train", "test"]
136
  ) -> torch.utils.data.DataLoader:
 
137
  if args.dry_run:
138
  print(f"🛠️ DRY RUN: Creating synthetic {split} dataloader...")
139
  dummy_ds = DummyMILDataset(args, num_samples=args.batch_size * 2)
140
  return torch.utils.data.DataLoader(
141
  dummy_ds,
142
  batch_size=args.batch_size,
143
+ collate_fn=list_data_collate, # Uses your custom stacking logic
144
+ num_workers=0, # Keep it simple for dry run
145
  )
146
 
 
147
  data_list = load_decathlon_datalist(
148
  data_list_file_path=args.dataset_json,
149
  data_list_key=split,
temp.ipynb CHANGED
@@ -77,6 +77,249 @@
77
  "from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset"
78
  ]
79
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  {
81
  "cell_type": "code",
82
  "execution_count": 2,
 
77
  "from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset"
78
  ]
79
  },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 1,
83
+ "id": "bc433898",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "import subprocess\n",
88
+ "import sys\n",
89
+ "from pathlib import Path\n",
90
+ "import torch\n",
91
+ "import pytest\n",
92
+ "import argparse\n",
93
+ "from src.train.train_pirads import get_attention_scores\n"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 8,
99
+ "id": "f1c90aff",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "batch_size = 2\n",
104
+ "num_patches = 4\n",
105
+ "\n",
106
+ "# Sample 0: Target = 3 (Cancer), Sample 1: Target = 0 (PI-RADS 2)\n",
107
+ "data = torch.randn(batch_size, num_patches, 1, 8, 8)\n",
108
+ "target = torch.tensor([3.0, 0.0])\n",
109
+ "\n",
110
+ "# Create heatmaps: Sample 0 has one \"hot\" patch\n",
111
+ "heatmap = torch.zeros(batch_size, num_patches, 1, 8, 8)\n",
112
+ "heatmap[0, 0] = 10.0 # High attention on patch 0 for the first sample\n",
113
+ "heatmap[0, 3] = 2.0 \n",
114
+ "heatmap[1, 2] = 5.0 # Should be overridden by PI-RADS 2 logic anyway\n"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 25,
120
+ "id": "80cb444f",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "def mock_args():\n",
125
+ " # Mocking argparse for the device\n",
126
+ " args = argparse.Namespace()\n",
127
+ " args.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
128
+ " return args"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 41,
134
+ "id": "6528fd4d",
135
+ "metadata": {},
136
+ "outputs": [
137
+ {
138
+ "ename": "AssertionError",
139
+ "evalue": "",
140
+ "output_type": "error",
141
+ "traceback": [
142
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
143
+ "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
144
+ "Cell \u001b[0;32mIn[41], line 23\u001b[0m\n\u001b[1;32m 21\u001b[0m idx \u001b[38;5;241m=\u001b[39m (shuffled_images[\u001b[38;5;241m0\u001b[39m, :, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m5.0\u001b[39m)\u001b[38;5;241m.\u001b[39mnonzero(as_tuple\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# The attention score at that same index should be the maximum\u001b[39;00m\n\u001b[0;32m---> 23\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m att_labels[\u001b[38;5;241m0\u001b[39m, idx] \u001b[38;5;241m==\u001b[39m att_labels[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mmedian()\n",
145
+ "\u001b[0;31mAssertionError\u001b[0m: "
146
+ ]
147
+ }
148
+ ],
149
+ "source": [
150
+ "num_patches = 10\n",
151
+ "\n",
152
+ "# Distinct data per patch: [0, 1, 2, 3...]\n",
153
+ "data = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()\n",
154
+ "target = torch.tensor([3.0])\n",
155
+ "\n",
156
+ "# Heatmap matches the data indices so we can track the \"label\"\n",
157
+ "heatmap = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()\n",
158
+ "\n",
159
+ "att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args)\n",
160
+ "\n",
161
+ "\n",
162
+ "idx= (shuffled_images[0, :, 0, 0, 0] == 9.0).nonzero(as_tuple=True)[0]\n",
163
+ "# The attention score at that same index should be the maximum\n",
164
+ "assert att_labels[0, idx] == att_labels[0].max()\n",
165
+ "\n",
166
+ "idx = (shuffled_images[0, :, 0, 0, 0] == 0.0).nonzero(as_tuple=True)[0]\n",
167
+ "# The attention score at that same index should be the maximum\n",
168
+ "assert att_labels[0, idx] == att_labels[0].min()\n",
169
+ "\n",
170
+ "idx = (shuffled_images[0, :, 0, 0, 0] == 5.0).nonzero(as_tuple=True)[0]\n",
171
+ "# The attention score at that same index should be the maximum\n",
172
+ "assert att_labels[0, idx] == att_labels[0].median()"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 2,
178
+ "id": "90f5acab",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "import subprocess\n",
183
+ "import sys\n",
184
+ "from pathlib import Path\n",
185
+ "import torch\n",
186
+ "import pytest\n",
187
+ "import argparse\n",
188
+ "from src.train.train_pirads import get_attention_scores\n",
189
+ "import monai\n",
190
+ "from monai.transforms import Transform\n",
191
+ "from src.data.custom_transforms import NormalizeIntensity_custom"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 3,
197
+ "id": "e3a2dc6c",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "img = torch.zeros((2, 4, 4), dtype=torch.float32)\n",
202
+ "mask = torch.zeros((1, 4, 4), dtype=torch.float32)"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": 4,
208
+ "id": "98a500df",
209
+ "metadata": {},
210
+ "outputs": [
211
+ {
212
+ "data": {
213
+ "text/plain": [
214
+ "tensor([[[0., 0., 0., 0.],\n",
215
+ " [0., 0., 0., 0.],\n",
216
+ " [0., 0., 0., 0.],\n",
217
+ " [0., 0., 0., 0.]],\n",
218
+ "\n",
219
+ " [[0., 0., 0., 0.],\n",
220
+ " [0., 0., 0., 0.],\n",
221
+ " [0., 0., 0., 0.],\n",
222
+ " [0., 0., 0., 0.]]])"
223
+ ]
224
+ },
225
+ "execution_count": 4,
226
+ "metadata": {},
227
+ "output_type": "execute_result"
228
+ }
229
+ ],
230
+ "source": [
231
+ "img"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": 5,
237
+ "id": "c9974f43",
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": [
241
+ "img[0, :, :] = 100.0 # Background\n",
242
+ "img[0, 0, 0] = 10.0 # Masked pixel 1\n",
243
+ "img[0, 0, 1] = 20.0 # Masked pixel 2\n",
244
+ "\n",
245
+ "# --- Channel 1 Setup ---\n",
246
+ "# Inside mask: Values [2, 4]\n",
247
+ "# Outside mask: Value 50\n",
248
+ "img[1, :, :] = 50.0 # Background\n",
249
+ "img[1, 0, 0] = 2.0 # Masked pixel 1\n",
250
+ "img[1, 0, 1] = 4.0 # Masked pixel 2\n",
251
+ "\n",
252
+ "# --- Mask Setup ---\n",
253
+ "# Selects only the top-left two pixels (0,0) and (0,1)\n",
254
+ "mask[0, 0, 0] = 1\n",
255
+ "mask[0, 0, 1] = 1\n"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "id": "eb910fda",
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "data = torch.rand(1, 10, 10)\n",
266
+ "mask = torch.randint(0, 2, (1, 10, 10)).float()\n",
267
+ "normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True)\n",
268
+ "out = normalizer(data, mask)"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": 25,
274
+ "id": "923341a3",
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": [
278
+ "masked = data[mask != 0]\n",
279
+ "mean_ = torch.mean(masked.float())\n",
280
+ "std_ = torch.std(masked.float(), unbiased=False)\n",
281
+ "\n",
282
+ "epsilon = 1e-8\n",
283
+ "normalized_data = (data - mean_) / (std_ + epsilon)"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": 23,
289
+ "id": "e844cde1",
290
+ "metadata": {},
291
+ "outputs": [
292
+ {
293
+ "data": {
294
+ "text/plain": [
295
+ "tensor([ 1.4106, -0.1975, 0.3907, 1.2870, -0.7974, -1.2061, 0.7028, 1.2778,\n",
296
+ " 0.4667, -0.3361, -0.7842, -1.6296, -1.2037, 1.3582, -0.5648, -0.3055,\n",
297
+ " -0.3313, 0.0328, -1.0675, 0.6328, -0.2215, -1.3372, 0.5165, 1.9302,\n",
298
+ " 0.8875, 0.6793, 0.5553, 0.4335, 0.6390, -1.3707, 1.6053, 1.8626,\n",
299
+ " -0.3923, 0.2319, 0.3911, -0.4683, -1.1255, -1.6464, -0.2123, -0.5415,\n",
300
+ " 0.1401, -0.2822, 1.5019, -0.5117, -1.6047, -0.2322, -1.3080, 0.0130,\n",
301
+ " 1.8028, 0.5602, -1.6317])"
302
+ ]
303
+ },
304
+ "execution_count": 23,
305
+ "metadata": {},
306
+ "output_type": "execute_result"
307
+ }
308
+ ],
309
+ "source": [
310
+ "masked"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": 26,
316
+ "id": "a9a20f58",
317
+ "metadata": {},
318
+ "outputs": [],
319
+ "source": [
320
+ "torch.testing.assert_close(out, normalized_data)"
321
+ ]
322
+ },
323
  {
324
  "cell_type": "code",
325
  "execution_count": 2,
tests/__init__.py ADDED
File without changes
tests/test_run.py CHANGED
@@ -1,139 +1,169 @@
1
- import subprocess
2
- import sys
3
- from pathlib import Path
4
 
 
 
5
 
6
- def test_run_pirads_training():
7
- """
8
- Test that run_cspca.py runs without crashing using an existing YAML config.
9
- """
 
 
10
 
11
- # Path to your run_pirads.py script
12
- repo_root = Path(__file__).parent.parent
13
- script_path = repo_root / "run_pirads.py"
14
 
15
- # Path to your existing config.yaml
16
- config_path = repo_root / "config" / "config_pirads_train.yaml" # adjust this path
 
 
 
 
17
 
18
- # Make sure the file exists
19
- assert config_path.exists(), f"Config file not found: {config_path}"
20
 
21
- # Run the script with the config
22
- result = subprocess.run(
23
- [
24
- sys.executable,
25
- str(script_path),
26
- "--mode",
27
- "train",
28
- "--config",
29
- str(config_path),
30
- "--dry_run",
31
- ],
32
- capture_output=True,
33
- text=True,
34
- )
35
 
36
- # Check that it ran without errors
37
- assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def test_run_pirads_inference():
41
- """
42
- Test that run_cspca.py runs without crashing using an existing YAML config.
43
- """
44
 
45
- # Path to your run_pirads.py script
46
- repo_root = Path(__file__).parent.parent
47
- script_path = repo_root / "run_pirads.py"
48
 
49
- # Path to your existing config.yaml
50
- config_path = repo_root / "config" / "config_pirads_test.yaml" # adjust this path
 
51
 
52
- # Make sure the file exists
53
- assert config_path.exists(), f"Config file not found: {config_path}"
54
 
55
- # Run the script with the config
56
- result = subprocess.run(
57
- [
58
- sys.executable,
59
- str(script_path),
60
- "--mode",
61
- "test",
62
- "--config",
63
- str(config_path),
64
- "--dry_run",
65
- ],
66
- capture_output=True,
67
- text=True,
68
- )
69
 
70
- # Check that it ran without errors
71
- assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
 
72
 
 
 
 
73
 
74
- def test_run_cspca_training():
 
 
 
 
 
 
 
 
 
 
 
 
75
  """
76
- Test that run_cspca.py runs without crashing using an existing YAML config.
 
77
  """
78
 
79
- # Path to your run_cspca.py script
80
- repo_root = Path(__file__).parent.parent
81
- script_path = repo_root / "run_cspca.py"
82
 
83
- # Path to your existing config.yaml
84
- config_path = repo_root / "config" / "config_cspca_train.yaml" # adjust this path
 
85
 
86
- # Make sure the file exists
87
- assert config_path.exists(), f"Config file not found: {config_path}"
 
88
 
89
- # Run the script with the config
90
- result = subprocess.run(
91
- [
92
- sys.executable,
93
- str(script_path),
94
- "--mode",
95
- "train",
96
- "--config",
97
- str(config_path),
98
- "--dry_run",
99
- ],
100
- capture_output=True,
101
- text=True,
102
- )
103
 
104
- # Check that it ran without errors
105
- assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
106
 
 
 
 
107
 
108
- def test_run_cspca_inference():
 
 
 
 
109
  """
110
- Test that run_cspca.py runs without crashing using an existing YAML config.
 
111
  """
112
-
113
- # Path to your run_cspca.py script
114
- repo_root = Path(__file__).parent.parent
115
- script_path = repo_root / "run_cspca.py"
116
-
117
- # Path to your existing config.yaml
118
- config_path = repo_root / "config" / "config_cspca_test.yaml" # adjust this path
119
-
120
- # Make sure the file exists
121
- assert config_path.exists(), f"Config file not found: {config_path}"
122
-
123
- # Run the script with the config
124
- result = subprocess.run(
125
- [
126
- sys.executable,
127
- str(script_path),
128
- "--mode",
129
- "test",
130
- "--config",
131
- str(config_path),
132
- "--dry_run",
133
- ],
134
- capture_output=True,
135
- text=True,
136
- )
137
-
138
- # Check that it ran without errors
139
- assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
 
 
2
 
3
+ import pytest
4
+ import torch
5
 
6
+ from src.data.custom_transforms import NormalizeIntensity_custom
7
+ from src.data.data_loader import get_dataloader
8
+ from src.model.cspca_model import CSPCAModel
9
+ from src.model.mil import MILModel3D
10
+ from src.train import train_cspca, train_pirads
11
+ from src.train.train_pirads import get_attention_scores
12
 
 
 
 
13
 
14
+ @pytest.fixture
15
+ def mock_args():
16
+ # Mocking argparse for the device
17
+ args = argparse.Namespace()
18
+ args.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ return args
20
 
 
 
21
 
22
+ def test_get_attention_scores_logic(mock_args):
23
+ # Setup: 2 samples, 4 patches, images of size 8x8
24
+ batch_size = 2
25
+ num_patches = 4
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Sample 0: Target = 3 (Cancer), Sample 1: Target = 0 (PI-RADS 2)
28
+ data = torch.randn(batch_size, num_patches, 1, 8, 8)
29
+ target = torch.tensor([3.0, 0.0])
30
 
31
+ # Create heatmaps: Sample 0 has one "hot" patch
32
+ heatmap = torch.zeros(batch_size, num_patches, 1, 8, 8)
33
+ heatmap[0, 0] = 10.0 # High attention on patch 0 for the first sample
34
+ heatmap[1, :] = 5.0 # Should be overridden by PI-RADS 2 logic anyway
35
+
36
+ att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args)
37
+
38
+ # --- TEST 1: Normalization ---
39
+ sums = att_labels.sum(dim=1)
40
+ torch.testing.assert_close(sums, torch.ones(batch_size).to(mock_args.device))
41
+
42
+ # --- TEST 2: PI-RADS 2 Uniformity ---
43
+ pirads_2_scores = att_labels[1]
44
+ expected_uniform = torch.ones(num_patches).to(mock_args.device) / num_patches
45
+ torch.testing.assert_close(pirads_2_scores, expected_uniform)
46
+
47
+ # --- TEST 4: Output Shapes ---
48
+ assert att_labels.shape == (batch_size, num_patches)
49
+ assert shuffled_images.shape == data.shape
50
 
 
 
 
 
51
 
52
+ def test_shuffling_consistency(mock_args):
53
+ # Verify that the image and label are shuffled with the SAME permutation
54
+ num_patches = 10
55
 
56
+ # Distinct data per patch: [0, 1, 2, 3...]
57
+ data = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()
58
+ target = torch.tensor([3.0])
59
 
60
+ # Heatmap matches the data indices so we can track the "label"
61
+ heatmap = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()
62
 
63
+ att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ idx = (shuffled_images[0, :, 0, 0, 0] == 9.0).nonzero(as_tuple=True)[0]
66
+ # The attention score at that same index should be the maximum
67
+ assert att_labels[0, idx] == att_labels[0].max()
68
 
69
+ idx = (shuffled_images[0, :, 0, 0, 0] == 0.0).nonzero(as_tuple=True)[0]
70
+ # The attention score at that same index should be the minimum
71
+ assert att_labels[0, idx] == att_labels[0].min()
72
 
73
+ shuffled_images = shuffled_images.cpu().squeeze() # Shape [10]
74
+ att_labels = att_labels.cpu().squeeze() # Shape [10]
75
+
76
+ sorted_vals, original_indices = torch.sort(shuffled_images)
77
+ sorted_labels = att_labels[original_indices]
78
+
79
+ for i in range(len(sorted_labels) - 1):
80
+ assert sorted_labels[i] <= sorted_labels[i + 1], (
81
+ f"Alignment broken at index {i}: Image val {sorted_vals[i]} has higher label than {sorted_vals[i + 1]}"
82
+ )
83
+
84
+
85
+ def test_normalize_intensity_custom_masked_stats():
86
  """
87
+ Test that statistics (mean/std) are calculated ONLY from the masked region,
88
+ but applied to the whole image.
89
  """
90
 
91
+ img = torch.zeros((2, 4, 4), dtype=torch.float32)
92
+ mask = torch.zeros((1, 4, 4), dtype=torch.float32)
 
93
 
94
+ img[0, :, :] = 100.0
95
+ img[0, 0, 0] = 10.0
96
+ img[0, 0, 1] = 20.0
97
 
98
+ img[1, :, :] = 50.0
99
+ img[1, 0, 0] = 2.0
100
+ img[1, 0, 1] = 4.0
101
 
102
+ mask[0, 0, 0] = 1
103
+ mask[0, 0, 1] = 1
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True)
106
+ out = normalizer(img, mask)
107
 
108
+ assert torch.isclose(out[0, 0, 0], torch.tensor(-1.0)), "Ch0 masked value 1 incorrect"
109
+ assert torch.isclose(out[0, 0, 1], torch.tensor(1.0)), "Ch0 masked value 2 incorrect"
110
+ assert torch.isclose(out[0, 1, 1], torch.tensor(17.0)), "Ch0 background normalization incorrect"
111
 
112
+ assert torch.isclose(out[1, 0, 0], torch.tensor(-1.0)), "Ch1 masked value 1 incorrect"
113
+ assert torch.isclose(out[1, 1, 1], torch.tensor(47.0)), "Ch1 background normalization incorrect"
114
+
115
+
116
+ def test_normalize_intensity_constant_area():
117
  """
118
+ Test edge case where the area under the mask has 0 variance (constant value).
119
+ Std should default to 1.0 to avoid division by zero.
120
  """
121
+ img = torch.ones((1, 4, 4)) * 10.0 # All values are 10
122
+ mask = torch.ones((1, 4, 4))
123
+
124
+ normalizer = NormalizeIntensity_custom(channel_wise=True)
125
+ out = normalizer(img, mask)
126
+ assert torch.allclose(out, torch.zeros_like(out))
127
+
128
+ data = torch.rand(1, 10, 10)
129
+ mask = torch.randint(0, 2, (1, 10, 10)).float()
130
+ normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True)
131
+ out = normalizer(data, mask)
132
+
133
+ masked = data[mask != 0]
134
+ mean_val = torch.mean(masked.float())
135
+ std_val = torch.std(masked.float(), unbiased=False)
136
+
137
+ epsilon = 1e-8
138
+ normalized_data = (data - mean_val) / (std_val + epsilon)
139
+
140
+ torch.testing.assert_close(out, normalized_data)
141
+
142
+
143
+ def test_run_models():
144
+ args = argparse.Namespace()
145
+ args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
+ args.epochs = 1
147
+ args.batch_size = 2
148
+ args.tile_size = 10
149
+ args.tile_count = 5
150
+ args.use_heatmap = True
151
+ args.amp = False
152
+ args.num_classes = 4
153
+ args.dry_run = True
154
+ args.depth = 3
155
+
156
+ model = MILModel3D(num_classes=args.num_classes, mil_mode="att_trans")
157
+ model.to(args.device)
158
+ params = model.parameters()
159
+ loader = get_dataloader(args, split="train")
160
+ optimizer = torch.optim.AdamW(params, lr=1e-5, weight_decay=1e-5)
161
+ scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp)
162
+
163
+ _ = train_pirads.train_epoch(model, loader, optimizer, scaler=scaler, epoch=0, args=args)
164
+ _ = train_pirads.val_epoch(model, loader, epoch=0, args=args)
165
+
166
+ cspca_model = CSPCAModel(backbone=model).to(args.device)
167
+ optimizer_cspca = torch.optim.AdamW(cspca_model.parameters(), lr=1e-5)
168
+ _ = train_cspca.train_epoch(cspca_model, loader, optimizer_cspca, epoch=0, args=args)
169
+ _ = train_cspca.val_epoch(cspca_model, loader, epoch=0, args=args)