Anirudh Balaraman commited on
Commit
a4ef78c
·
1 Parent(s): 7a1d40a

Add salient patch info

Browse files
Files changed (3) hide show
  1. run_inference.py +36 -5
  2. src/utils.py +130 -3
  3. temp.ipynb +419 -173
run_inference.py CHANGED
@@ -34,7 +34,7 @@ from pathlib import Path
34
  from src.model.MIL import MILModel_3D
35
  from src.model.csPCa_model import csPCa_Model
36
  from src.data.data_loader import get_dataloader
37
- from src.utils import save_cspca_checkpoint, get_metrics, setup_logging, save_pirads_checkpoint
38
  from src.train import train_cspca, train_pirads
39
  import SimpleITK as sitk
40
 
@@ -57,6 +57,7 @@ import argparse
57
  import yaml
58
  from src.data.data_loader import data_transform, list_data_collate
59
  from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
 
60
 
61
  def parse_args():
62
 
@@ -120,7 +121,7 @@ if __name__ == "__main__":
120
 
121
  transform = data_transform(args)
122
  files = os.listdir(args.t2_dir)
123
- data_list = []
124
  for file in files:
125
  temp = {}
126
  temp['image'] = os.path.join(args.t2_dir, file)
@@ -129,9 +130,9 @@ if __name__ == "__main__":
129
  temp['heatmap'] = os.path.join(args.heatmapdir, file)
130
  temp['mask'] = os.path.join(args.seg_dir, file)
131
  temp['label'] = 0 # dummy label
132
- data_list.append(temp)
133
 
134
- dataset = Dataset(data=data_list, transform=transform)
135
  loader = torch.utils.data.DataLoader(
136
  dataset,
137
  batch_size=1,
@@ -147,6 +148,7 @@ if __name__ == "__main__":
147
  pirads_model.eval()
148
  cspca_risk_list = []
149
  cspca_model.eval()
 
150
  with torch.no_grad():
151
  for idx, batch_data in enumerate(loader):
152
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
@@ -158,5 +160,34 @@ if __name__ == "__main__":
158
  output = output.squeeze(1)
159
  cspca_risk_list.append(output.item())
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  for i,j in enumerate(files):
162
- logging.info(f"File: {j}, PIRADS score: {pirads_list[i]}, csPCa risk score: {cspca_risk_list[i]:.4f}")
 
 
 
 
 
 
 
 
 
34
  from src.model.MIL import MILModel_3D
35
  from src.model.csPCa_model import csPCa_Model
36
  from src.data.data_loader import get_dataloader
37
+ from src.utils import save_cspca_checkpoint, get_metrics, setup_logging, save_pirads_checkpoint, get_parent_image, get_patch_coordinate
38
  from src.train import train_cspca, train_pirads
39
  import SimpleITK as sitk
40
 
 
57
  import yaml
58
  from src.data.data_loader import data_transform, list_data_collate
59
  from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
60
+ import json
61
 
62
  def parse_args():
63
 
 
121
 
122
  transform = data_transform(args)
123
  files = os.listdir(args.t2_dir)
124
+ args.data_list = []
125
  for file in files:
126
  temp = {}
127
  temp['image'] = os.path.join(args.t2_dir, file)
 
130
  temp['heatmap'] = os.path.join(args.heatmapdir, file)
131
  temp['mask'] = os.path.join(args.seg_dir, file)
132
  temp['label'] = 0 # dummy label
133
+ args.data_list.append(temp)
134
 
135
+ dataset = Dataset(data=args.data_list, transform=transform)
136
  loader = torch.utils.data.DataLoader(
137
  dataset,
138
  batch_size=1,
 
148
  pirads_model.eval()
149
  cspca_risk_list = []
150
  cspca_model.eval()
151
+ top5_patches = []
152
  with torch.no_grad():
153
  for idx, batch_data in enumerate(loader):
154
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
 
160
  output = output.squeeze(1)
161
  cspca_risk_list.append(output.item())
162
 
163
+ sh = data.shape
164
+ x = data.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])
165
+ x = cspca_model.backbone.net(x)
166
+ x = x.reshape(sh[0], sh[1], -1)
167
+ x = x.permute(1, 0, 2)
168
+ x = cspca_model.backbone.transformer(x)
169
+ x = x.permute(1, 0, 2)
170
+ a = cspca_model.backbone.attention(x)
171
+ a = torch.softmax(a, dim=1)
172
+ a = a.view(-1)
173
+ top5_values, top5_indices = torch.topk(a, 5)
174
+
175
+ patches_top_5 = []
176
+ for i in range(5):
177
+ patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()
178
+ patches_top_5.append(patch_temp)
179
+
180
+ parent_image = get_parent_image(args)
181
+
182
+ coords = get_patch_coordinate(patches_top_5, parent_image, args)
183
+
184
  for i,j in enumerate(files):
185
+ logging.info(f"File: {j}, PIRADS score: {pirads_list[i]}, csPCa risk score: {cspca_risk_list[i]:.4f}")
186
+
187
+ output_dict = {
188
+ 'Predicted PIRAD Score': pirads_list[i] + 2.0,
189
+ 'csPCa risk': cspca_risk_list[i],
190
+ 'Top left coordinate of top 5 patches(x,y,z)': coords,
191
+ }
192
+ with open(os.path.join(args.output_dir, "results.json"), 'w') as f:
193
+ json.dump(output_dict, f, indent=4)
src/utils.py CHANGED
@@ -14,19 +14,52 @@ import torch.nn.functional as F
14
  from monai.config import KeysCollection
15
  from monai.metrics import Cumulative, CumulativeAverage
16
  from monai.networks.nets import milmodel, resnet, MILModel
17
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from sklearn.metrics import cohen_kappa_score
19
  from torch.cuda.amp import GradScaler, autocast
20
  from torch.utils.data.dataloader import default_collate
21
  from torchvision.models.resnet import ResNet50_Weights
22
-
23
  from torch.utils.data.distributed import DistributedSampler
24
  from torch.utils.tensorboard import SummaryWriter
 
25
 
26
  import matplotlib.pyplot as plt
27
 
28
  import wandb
29
  import math
 
30
 
31
  from src.model.MIL import MILModel_3D
32
  from src.model.csPCa_model import csPCa_Model
@@ -96,4 +129,98 @@ def validate_steps(steps):
96
  f"Step '{step}' requires '{req}' to be executed before it. "
97
  f"Given order: {steps}"
98
  )
99
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from monai.config import KeysCollection
15
  from monai.metrics import Cumulative, CumulativeAverage
16
  from monai.networks.nets import milmodel, resnet, MILModel
17
+ from monai.transforms import (
18
+ Compose,
19
+ GridPatchd,
20
+ LoadImaged,
21
+ MapTransform,
22
+ RandFlipd,
23
+ RandGridPatchd,
24
+ RandRotate90d,
25
+ ScaleIntensityRanged,
26
+ SplitDimd,
27
+ ToTensord,
28
+ ConcatItemsd,
29
+ SelectItemsd,
30
+ EnsureChannelFirstd,
31
+ RepeatChanneld,
32
+ DeleteItemsd,
33
+ EnsureTyped,
34
+ ClipIntensityPercentilesd,
35
+ MaskIntensityd,
36
+ HistogramNormalized,
37
+ RandBiasFieldd,
38
+ RandCropByPosNegLabeld,
39
+ NormalizeIntensityd,
40
+ SqueezeDimd,
41
+ CropForegroundd,
42
+ ScaleIntensityd,
43
+ SpatialPadd,
44
+ CenterSpatialCropd,
45
+ ScaleIntensityd,
46
+ Transposed,
47
+ RandWeightedCropd,
48
+ )
49
  from sklearn.metrics import cohen_kappa_score
50
  from torch.cuda.amp import GradScaler, autocast
51
  from torch.utils.data.dataloader import default_collate
52
  from torchvision.models.resnet import ResNet50_Weights
53
+ from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
54
  from torch.utils.data.distributed import DistributedSampler
55
  from torch.utils.tensorboard import SummaryWriter
56
+ import matplotlib.patches as patches
57
 
58
  import matplotlib.pyplot as plt
59
 
60
  import wandb
61
  import math
62
+ from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
63
 
64
  from src.model.MIL import MILModel_3D
65
  from src.model.csPCa_model import csPCa_Model
 
129
  f"Step '{step}' requires '{req}' to be executed before it. "
130
  f"Given order: {steps}"
131
  )
132
+ sys.exit(1)
133
+
134
+ def get_patch_coordinate(patches_top_5, parent_image, args):
135
+
136
+ sample = np.array([i.transpose(1,2,0) for i in patches_top_5])
137
+ coords = []
138
+ rows, h, w, slices = sample.shape
139
+
140
+ for i in range(rows):
141
+ for j in range(slices):
142
+ if j == 0:
143
+ for k in range(parent_image.shape[2]):
144
+ img_temp = parent_image[:, :, k]
145
+ H, W = img_temp.shape
146
+ h, w = sample[i, :, :, j].shape
147
+ a,b = 0, 0 # Initialize a and b
148
+ bool1 = False
149
+ for l in range(H - h + 1):
150
+ for m in range(W - w + 1):
151
+ if np.array_equal(img_temp[l:l+h, m:m+w], sample[i, :, :, j]):
152
+ a,b = l, m # top-left corner
153
+ coords.append((a,b,k))
154
+ bool1 = True
155
+ break
156
+ if bool1:
157
+ break
158
+
159
+ if bool1:
160
+ break
161
+
162
+ return coords
163
+
164
+
165
+ def get_parent_image(args):
166
+ transform_image = Compose(
167
+ [
168
+ LoadImaged(keys=["image", "mask"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
169
+ ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
170
+ NormalizeIntensity_customd(keys=["image"], mask_key="mask", channel_wise=True),
171
+ EnsureTyped(keys=["label"], dtype=torch.float32),
172
+ ToTensord(keys=["image", "label"]),
173
+ ]
174
+ )
175
+ dataset_image = Dataset(data=args.data_list, transform=transform_image)
176
+ return dataset_image[0]['image'][0].numpy()
177
+
178
+ '''
179
+ def visualise_patches():
180
+ sample = np.array([i.transpose(1,2,0) for i in patches_top_5])
181
+ rows = len(patches_top_5)
182
+ img = sample[0]
183
+ coords = []
184
+ rows, h, w, slices = sample.shape
185
+
186
+ fig, axes = plt.subplots(nrows=rows, ncols=slices, figsize=(slices * 3, rows * 3))
187
+
188
+ for i in range(rows):
189
+ for j in range(slices):
190
+ ax = axes[i, j]
191
+
192
+ if j == 0:
193
+
194
+ for k in range(parent_image.shape[2]):
195
+ img_temp = parent_image[:, :, k]
196
+ H, W = img_temp.shape
197
+ h, w = sample[i, :, :, j].shape
198
+ a,b = 0, 0 # Initialize a and b
199
+ bool1 = False
200
+ for l in range(H - h + 1):
201
+ for m in range(W - w + 1):
202
+ if np.array_equal(img_temp[l:l+h, m:m+w], sample[i, :, :, j]):
203
+ a,b = l, m # top-left corner
204
+ coords.append((a,b,k))
205
+ bool1 = True
206
+ break
207
+ if bool1:
208
+ break
209
+
210
+ if bool1:
211
+ break
212
+
213
+
214
+
215
+
216
+ ax.imshow(parent_image[:, :, k+j], cmap='gray')
217
+ rect = patches.Rectangle((b, a), args.tile_size, args.tile_size,
218
+ linewidth=2, edgecolor='red', facecolor='none')
219
+ ax.add_patch(rect)
220
+ ax.axis('off')
221
+
222
+
223
+ plt.tight_layout()
224
+ plt.show()
225
+ a=1
226
+ '''
temp.ipynb CHANGED
@@ -2,10 +2,19 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 18,
6
  "id": "cec738fb",
7
  "metadata": {},
8
- "outputs": [],
 
 
 
 
 
 
 
 
 
9
  "source": [
10
  "import argparse\n",
11
  "import os\n",
@@ -70,7 +79,7 @@
70
  },
71
  {
72
  "cell_type": "code",
73
- "execution_count": null,
74
  "id": "c91a5802",
75
  "metadata": {},
76
  "outputs": [
@@ -78,9 +87,9 @@
78
  "name": "stderr",
79
  "output_type": "stream",
80
  "text": [
81
- " 0%| | 0/1 [00:02<?, ?it/s]\n",
82
  "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
83
- " 0%| | 0/1 [00:07<?, ?it/s]\n"
84
  ]
85
  }
86
  ],
@@ -104,36 +113,23 @@
104
  },
105
  {
106
  "cell_type": "code",
107
- "execution_count": 6,
108
- "id": "f32be6e9",
109
  "metadata": {},
110
- "outputs": [
111
- {
112
- "data": {
113
- "text/plain": [
114
- "{'margin': 0.2,\n",
115
- " 't2_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/t2_histmatched',\n",
116
- " 'dwi_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/DWI_histmatched',\n",
117
- " 'adc_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/ADC_histmatched',\n",
118
- " 'output_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed',\n",
119
- " 'project_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate',\n",
120
- " 'seg_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/prostate_mask',\n",
121
- " 'heatmapdir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/heatmaps/'}"
122
- ]
123
- },
124
- "execution_count": 6,
125
- "metadata": {},
126
- "output_type": "execute_result"
127
- }
128
- ],
129
  "source": [
130
- "vars(args)"
 
 
 
 
 
131
  ]
132
  },
133
  {
134
  "cell_type": "code",
135
- "execution_count": 8,
136
- "id": "118c2549",
137
  "metadata": {},
138
  "outputs": [
139
  {
@@ -141,155 +137,23 @@
141
  "output_type": "stream",
142
  "text": [
143
  "enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
 
144
  "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n"
145
  ]
146
- },
147
- {
148
- "data": {
149
- "text/plain": [
150
- "MILModel_3D(\n",
151
- " (attention): Sequential(\n",
152
- " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
153
- " (1): Tanh()\n",
154
- " (2): Linear(in_features=2048, out_features=1, bias=True)\n",
155
- " )\n",
156
- " (transformer): TransformerEncoder(\n",
157
- " (layers): ModuleList(\n",
158
- " (0-3): 4 x TransformerEncoderLayer(\n",
159
- " (self_attn): MultiheadAttention(\n",
160
- " (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n",
161
- " )\n",
162
- " (linear1): Linear(in_features=512, out_features=2048, bias=True)\n",
163
- " (dropout): Dropout(p=0.0, inplace=False)\n",
164
- " (linear2): Linear(in_features=2048, out_features=512, bias=True)\n",
165
- " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
166
- " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
167
- " (dropout1): Dropout(p=0.0, inplace=False)\n",
168
- " (dropout2): Dropout(p=0.0, inplace=False)\n",
169
- " )\n",
170
- " )\n",
171
- " )\n",
172
- " (myfc): Linear(in_features=512, out_features=4, bias=True)\n",
173
- " (net): ResNet(\n",
174
- " (conv1): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)\n",
175
- " (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
176
- " (act): ReLU(inplace=True)\n",
177
- " (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
178
- " (layer1): Sequential(\n",
179
- " (0): ResNetBlock(\n",
180
- " (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
181
- " (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
182
- " (act): ReLU(inplace=True)\n",
183
- " (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
184
- " (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
185
- " )\n",
186
- " (1): ResNetBlock(\n",
187
- " (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
188
- " (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
189
- " (act): ReLU(inplace=True)\n",
190
- " (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
191
- " (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
192
- " )\n",
193
- " )\n",
194
- " (layer2): Sequential(\n",
195
- " (0): ResNetBlock(\n",
196
- " (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)\n",
197
- " (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
198
- " (act): ReLU(inplace=True)\n",
199
- " (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
200
- " (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
201
- " (downsample): Sequential(\n",
202
- " (0): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(2, 2, 2))\n",
203
- " (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
204
- " )\n",
205
- " )\n",
206
- " (1): ResNetBlock(\n",
207
- " (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
208
- " (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
209
- " (act): ReLU(inplace=True)\n",
210
- " (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
211
- " (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
212
- " )\n",
213
- " )\n",
214
- " (layer3): Sequential(\n",
215
- " (0): ResNetBlock(\n",
216
- " (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)\n",
217
- " (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
218
- " (act): ReLU(inplace=True)\n",
219
- " (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
220
- " (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
221
- " (downsample): Sequential(\n",
222
- " (0): Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(2, 2, 2))\n",
223
- " (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
224
- " )\n",
225
- " )\n",
226
- " (1): ResNetBlock(\n",
227
- " (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
228
- " (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
229
- " (act): ReLU(inplace=True)\n",
230
- " (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
231
- " (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
232
- " )\n",
233
- " )\n",
234
- " (layer4): Sequential(\n",
235
- " (0): ResNetBlock(\n",
236
- " (conv1): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)\n",
237
- " (bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
238
- " (act): ReLU(inplace=True)\n",
239
- " (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
240
- " (bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
241
- " (downsample): Sequential(\n",
242
- " (0): Conv3d(256, 512, kernel_size=(1, 1, 1), stride=(2, 2, 2))\n",
243
- " (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
244
- " )\n",
245
- " )\n",
246
- " (1): ResNetBlock(\n",
247
- " (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
248
- " (bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
249
- " (act): ReLU(inplace=True)\n",
250
- " (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
251
- " (bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
252
- " )\n",
253
- " )\n",
254
- " (avgpool): AdaptiveAvgPool3d(output_size=(1, 1, 1))\n",
255
- " (fc): Identity()\n",
256
- " )\n",
257
- ")"
258
- ]
259
- },
260
- "execution_count": 8,
261
- "metadata": {},
262
- "output_type": "execute_result"
263
  }
264
  ],
265
  "source": [
266
  "args.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
267
- "args.num_classes = 4\n",
268
- "args.mil_mode = 'att_trans'\n",
269
  "pirads_model = MILModel_3D(\n",
270
  " num_classes=args.num_classes, \n",
271
  " mil_mode=args.mil_mode \n",
272
  ")\n",
273
  "pirads_checkpoint = torch.load(os.path.join(args.project_dir, 'models', 'pirads.pt'), map_location=\"cpu\")\n",
274
  "pirads_model.load_state_dict(pirads_checkpoint[\"state_dict\"])\n",
275
- "pirads_model.to(args.device)"
276
- ]
277
- },
278
- {
279
- "cell_type": "code",
280
- "execution_count": 31,
281
- "id": "01467cae",
282
- "metadata": {},
283
- "outputs": [
284
- {
285
- "name": "stderr",
286
- "output_type": "stream",
287
- "text": [
288
- "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n"
289
- ]
290
- }
291
- ],
292
- "source": [
293
  "cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)\n",
294
  "checkpt = torch.load(os.path.join(args.project_dir, 'models', 'cspca_model.pth'), map_location=\"cpu\")\n",
295
  "cspca_model.load_state_dict(checkpt['state_dict'])\n",
@@ -298,18 +162,41 @@
298
  },
299
  {
300
  "cell_type": "code",
301
- "execution_count": null,
302
- "id": "aa7850f3",
303
  "metadata": {},
304
  "outputs": [],
305
  "source": [
306
- "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  ]
308
  },
309
  {
310
  "cell_type": "code",
311
- "execution_count": 33,
312
- "id": "bd8884f8",
313
  "metadata": {},
314
  "outputs": [],
315
  "source": [
@@ -317,6 +204,7 @@
317
  "pirads_model.eval()\n",
318
  "cspca_risk_list = []\n",
319
  "cspca_model.eval()\n",
 
320
  "with torch.no_grad():\n",
321
  " for idx, batch_data in enumerate(loader):\n",
322
  " data = batch_data[\"image\"].as_subclass(torch.Tensor).to(args.device)\n",
@@ -326,13 +214,371 @@
326
  "\n",
327
  " output = cspca_model(data)\n",
328
  " output = output.squeeze(1)\n",
329
- " cspca_risk_list.append(output.item())"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  ]
331
  },
332
  {
333
  "cell_type": "code",
334
  "execution_count": null,
335
- "id": "4cf061ec",
336
  "metadata": {},
337
  "outputs": [],
338
  "source": []
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "cec738fb",
7
  "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "If you have questions or suggestions, feel free to open an issue at https://github.com/DIAGNijmegen/picai_prep\n",
14
+ "\n"
15
+ ]
16
+ }
17
+ ],
18
  "source": [
19
  "import argparse\n",
20
  "import os\n",
 
79
  },
80
  {
81
  "cell_type": "code",
82
+ "execution_count": 2,
83
  "id": "c91a5802",
84
  "metadata": {},
85
  "outputs": [
 
87
  "name": "stderr",
88
  "output_type": "stream",
89
  "text": [
90
+ " 0%| | 0/1 [00:03<?, ?it/s]\n",
91
  "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
92
+ " 0%| | 0/1 [00:05<?, ?it/s]\n"
93
  ]
94
  }
95
  ],
 
113
  },
114
  {
115
  "cell_type": "code",
116
+ "execution_count": 4,
117
+ "id": "8b5d382e",
118
  "metadata": {},
119
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  "source": [
121
+ "args.num_classes = 4\n",
122
+ "args.mil_mode = \"att_trans\"\n",
123
+ "args.use_heatmap = True\n",
124
+ "args.tile_size = 64\n",
125
+ "args.tile_count = 24\n",
126
+ "args.depth = 3\n"
127
  ]
128
  },
129
  {
130
  "cell_type": "code",
131
+ "execution_count": 5,
132
+ "id": "4cf061ec",
133
  "metadata": {},
134
  "outputs": [
135
  {
 
137
  "output_type": "stream",
138
  "text": [
139
  "enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
140
+ "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
141
  "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n"
142
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  }
144
  ],
145
  "source": [
146
  "args.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
147
+ "\n",
148
+ "\n",
149
  "pirads_model = MILModel_3D(\n",
150
  " num_classes=args.num_classes, \n",
151
  " mil_mode=args.mil_mode \n",
152
  ")\n",
153
  "pirads_checkpoint = torch.load(os.path.join(args.project_dir, 'models', 'pirads.pt'), map_location=\"cpu\")\n",
154
  "pirads_model.load_state_dict(pirads_checkpoint[\"state_dict\"])\n",
155
+ "pirads_model.to(args.device)\n",
156
+ "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  "cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)\n",
158
  "checkpt = torch.load(os.path.join(args.project_dir, 'models', 'cspca_model.pth'), map_location=\"cpu\")\n",
159
  "cspca_model.load_state_dict(checkpt['state_dict'])\n",
 
162
  },
163
  {
164
  "cell_type": "code",
165
+ "execution_count": 6,
166
+ "id": "fac15515",
167
  "metadata": {},
168
  "outputs": [],
169
  "source": [
170
+ "transform = data_transform(args)\n",
171
+ "files = os.listdir(args.t2_dir)\n",
172
+ "data_list = []\n",
173
+ "for file in files:\n",
174
+ " temp = {}\n",
175
+ " temp['image'] = os.path.join(args.t2_dir, file)\n",
176
+ " temp['dwi'] = os.path.join(args.dwi_dir, file)\n",
177
+ " temp['adc'] = os.path.join(args.adc_dir, file)\n",
178
+ " temp['heatmap'] = os.path.join(args.heatmapdir, file)\n",
179
+ " temp['mask'] = os.path.join(args.seg_dir, file)\n",
180
+ " temp['label'] = 0 # dummy label\n",
181
+ " data_list.append(temp)\n",
182
+ "\n",
183
+ "dataset = Dataset(data=data_list, transform=transform)\n",
184
+ "loader = torch.utils.data.DataLoader(\n",
185
+ " dataset,\n",
186
+ " batch_size=1,\n",
187
+ " shuffle=False,\n",
188
+ " num_workers=0,\n",
189
+ " pin_memory=True,\n",
190
+ " multiprocessing_context= None,\n",
191
+ " sampler=None,\n",
192
+ " collate_fn=list_data_collate,\n",
193
+ ")"
194
  ]
195
  },
196
  {
197
  "cell_type": "code",
198
+ "execution_count": 7,
199
+ "id": "eb80047b",
200
  "metadata": {},
201
  "outputs": [],
202
  "source": [
 
204
  "pirads_model.eval()\n",
205
  "cspca_risk_list = []\n",
206
  "cspca_model.eval()\n",
207
+ "top5_patches = []\n",
208
  "with torch.no_grad():\n",
209
  " for idx, batch_data in enumerate(loader):\n",
210
  " data = batch_data[\"image\"].as_subclass(torch.Tensor).to(args.device)\n",
 
214
  "\n",
215
  " output = cspca_model(data)\n",
216
  " output = output.squeeze(1)\n",
217
+ " cspca_risk_list.append(output.item())\n",
218
+ "\n",
219
+ " sh = data.shape\n",
220
+ " x = data.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])\n",
221
+ " x = cspca_model.backbone.net(x)\n",
222
+ " x = x.reshape(sh[0], sh[1], -1)\n",
223
+ " x = x.permute(1, 0, 2)\n",
224
+ " x = cspca_model.backbone.transformer(x)\n",
225
+ " x = x.permute(1, 0, 2)\n",
226
+ " a = cspca_model.backbone.attention(x)\n",
227
+ " a = torch.softmax(a, dim=1)\n",
228
+ " a = a.view(-1)\n",
229
+ " top5_values, top5_indices = torch.topk(a, 5)\n",
230
+ "\n",
231
+ " patches_top_5 = []\n",
232
+ " for i in range(5):\n",
233
+ " patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()\n",
234
+ " patches_top_5.append(patch_temp)"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": 8,
240
+ "id": "dbcfc97f",
241
+ "metadata": {},
242
+ "outputs": [
243
+ {
244
+ "data": {
245
+ "text/plain": [
246
+ "[array([[[-0.43333182, -0.52243334, -1.4245864 , ..., -1.7698549 ,\n",
247
+ " -1.6027895 , -1.257521 ],\n",
248
+ " [-0.7786003 , -1.2797965 , -2.0817103 , ..., -2.0371597 ,\n",
249
+ " -2.0817103 , -2.059435 ],\n",
250
+ " [-1.4802749 , -1.658478 , -1.7809926 , ..., -1.4691372 ,\n",
251
+ " -2.1596742 , -1.9591957 ],\n",
252
+ " ...,\n",
253
+ " [ 0.8697783 , 0.34630668, 0.6358867 , ..., 0.624749 ,\n",
254
+ " 0.24606745, 0.03445129],\n",
255
+ " [ 0.44654593, 0.23492976, 0.7806767 , ..., 1.2373221 ,\n",
256
+ " 0.2906182 , -1.0124918 ],\n",
257
+ " [ 0.3017559 , -0.2551287 , 0.27948052, ..., 1.1036698 ,\n",
258
+ " -0.16602719, -1.2463834 ]],\n",
259
+ " \n",
260
+ " [[-0.84542644, -1.3800358 , -1.4357241 , ..., -2.5383558 ,\n",
261
+ " -2.0817103 , -0.99021643],\n",
262
+ " [-0.4444695 , -1.3466226 , -1.8255434 , ..., -1.335485 ,\n",
263
+ " -1.6362027 , -1.4579996 ],\n",
264
+ " [-1.2129703 , -1.7921304 , -2.0371597 , ..., 0.14582822,\n",
265
+ " -0.04351256, -0.82315105],\n",
266
+ " ...,\n",
267
+ " [-0.4556072 , 1.4378005 , 1.5603153 , ..., 0.2906182 ,\n",
268
+ " -0.85656416, -1.1906949 ],\n",
269
+ " [-1.2352457 , -0.35536796, 1.1816336 , ..., -0.39991874,\n",
270
+ " -1.3800358 , -1.2129703 ],\n",
271
+ " [-1.2909342 , -1.1127311 , 0.5467852 , ..., -0.9568034 ,\n",
272
+ " -1.4023111 , -0.9011149 ]],\n",
273
+ " \n",
274
+ " [[-2.1596742 , -1.914645 , -1.7809926 , ..., -1.0124918 ,\n",
275
+ " 0.03445129, 0.44654593],\n",
276
+ " [-2.0037465 , -1.8923696 , -1.8700942 , ..., -1.2129703 ,\n",
277
+ " -1.1572819 , -0.50015795],\n",
278
+ " [-1.8478189 , -1.9369203 , -1.9814711 , ..., -0.65608567,\n",
279
+ " -1.2352457 , -1.4134488 ],\n",
280
+ " ...,\n",
281
+ " [-0.6338103 , -1.0124918 , -0.16602719, ..., 0.41313285,\n",
282
+ " 0.13469052, -0.80087566],\n",
283
+ " [-0.7451872 , -1.3577603 , -0.2996795 , ..., 0.34630668,\n",
284
+ " 0.68043745, 0.45768362],\n",
285
+ " [ 0.0121759 , -0.7674626 , -0.33309257, ..., 0.09013975,\n",
286
+ " 0.46882132, 1.0034306 ]]], dtype=float32),\n",
287
+ " array([[[-0.64494795, -0.789738 , -0.5447087 , ..., 0.0233136 ,\n",
288
+ " 0.14582822, -0.35536796],\n",
289
+ " [-0.8120134 , -0.7340495 , -0.42219412, ..., -0.31081718,\n",
290
+ " 0.46882132, 0.15696591],\n",
291
+ " [-0.08806333, 0.03445129, 0.12355283, ..., -0.2774041 ,\n",
292
+ " 0.73612595, 0.7249882 ],\n",
293
+ " ...,\n",
294
+ " [ 0.09013975, 0.0233136 , -0.24399103, ..., 0.34630668,\n",
295
+ " 0.914329 , 0.6358867 ],\n",
296
+ " [ 0.2906182 , 0.335169 , 0.624749 , ..., 0.24606745,\n",
297
+ " 0.9254667 , 0.9922929 ],\n",
298
+ " [ 0.44654593, 0.70271283, 1.0479814 , ..., -0.09920102,\n",
299
+ " 0.37971976, 0.70271283]],\n",
300
+ " \n",
301
+ " [[-0.23285334, -0.01009948, -0.1326141 , ..., 1.1593583 ,\n",
302
+ " 1.5603153 , 1.5603153 ],\n",
303
+ " [-0.23285334, 0.13469052, 0.1792413 , ..., 0.98115516,\n",
304
+ " 1.5603153 , 1.5603153 ],\n",
305
+ " [-0.36650565, -0.01009948, 0.190379 , ..., 0.8697783 ,\n",
306
+ " 1.5603153 , 1.5603153 ],\n",
307
+ " ...,\n",
308
+ " [-0.16602719, -0.06578795, 0.190379 , ..., -1.4357241 ,\n",
309
+ " -1.368898 , -1.6027895 ],\n",
310
+ " [-0.2662664 , -0.35536796, 0.190379 , ..., -1.513688 ,\n",
311
+ " -1.3800358 , -1.6362027 ],\n",
312
+ " [-0.789738 , -0.7786003 , -0.17716487, ..., -1.7141665 ,\n",
313
+ " -1.5693765 , -1.9591957 ]],\n",
314
+ " \n",
315
+ " [[-0.12147641, -0.01009948, 0.0233136 , ..., 0.769539 ,\n",
316
+ " 0.8140898 , 0.2906182 ],\n",
317
+ " [-0.35536796, -0.2774041 , -0.16602719, ..., 0.55792284,\n",
318
+ " 0.68043745, 0.335169 ],\n",
319
+ " [-0.18830256, 0.1681036 , 0.7918144 , ..., 0.70271283,\n",
320
+ " 0.4910967 , 0.10127745],\n",
321
+ " ...,\n",
322
+ " [-0.97907877, -0.97907877, -0.9011149 , ..., -0.43333182,\n",
323
+ " -0.5447087 , -0.1437518 ],\n",
324
+ " [-1.1238688 , -1.1461442 , -1.0347673 , ..., -0.7117741 ,\n",
325
+ " -0.12147641, 0.479959 ],\n",
326
+ " [-0.8788395 , -0.62267256, -0.9122526 , ..., -0.92339027,\n",
327
+ " -0.42219412, 0.22379206]]], dtype=float32),\n",
328
+ " array([[[ 2.7948052e-01, 1.0368437e+00, 1.5603153e+00, ...,\n",
329
+ " -1.2797965e+00, -8.3428878e-01, -2.4399103e-01],\n",
330
+ " [ 6.7864366e-02, 3.4630668e-01, 1.1704960e+00, ...,\n",
331
+ " -7.7860028e-01, -8.3428878e-01, -3.1081718e-01],\n",
332
+ " [ 2.6834285e-01, 9.0139754e-02, 2.3492976e-01, ...,\n",
333
+ " -6.3381028e-01, -9.5680338e-01, -4.4446951e-01],\n",
334
+ " ...,\n",
335
+ " [-1.3688980e+00, -1.2241080e+00, -9.4566566e-01, ...,\n",
336
+ " -1.0570426e+00, -2.6626641e-01, 1.2707351e+00],\n",
337
+ " [-1.0793180e+00, -8.8997722e-01, -1.1015934e+00, ...,\n",
338
+ " -7.6746261e-01, -5.4650255e-02, 1.0257059e+00],\n",
339
+ " [-7.3404950e-01, -6.5608567e-01, -1.0124918e+00, ...,\n",
340
+ " -5.3357106e-01, 2.6834285e-01, 8.5864055e-01]],\n",
341
+ " \n",
342
+ " [[ 1.1593583e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
343
+ " -4.5560721e-01, -1.1033872e-01, 8.6977828e-01],\n",
344
+ " [ 8.4750289e-01, 1.5603153e+00, 1.5603153e+00, ...,\n",
345
+ " -1.3020718e+00, -1.0459049e+00, 6.3588673e-01],\n",
346
+ " [ 7.9002060e-02, 1.3469052e-01, -3.1081718e-01, ...,\n",
347
+ " -1.3688980e+00, -1.0570426e+00, 4.9109671e-01],\n",
348
+ " ...,\n",
349
+ " [ 1.5603153e+00, 6.4702439e-01, -4.3512560e-02, ...,\n",
350
+ " 5.2450979e-01, 8.9205366e-01, -2.9967949e-01],\n",
351
+ " [ 1.2930106e+00, 6.1361134e-01, -4.6674490e-01, ...,\n",
352
+ " 2.7948052e-01, 8.4750289e-01, 7.8067672e-01],\n",
353
+ " [ 1.0382106e-03, 2.5720516e-01, -4.1105643e-01, ...,\n",
354
+ " 6.2474900e-01, 1.5603153e+00, 1.5603153e+00]],\n",
355
+ " \n",
356
+ " [[-9.2339027e-01, -1.0236295e+00, -1.0347673e+00, ...,\n",
357
+ " 2.6834285e-01, -1.9944026e-01, -1.1033872e-01],\n",
358
+ " [-8.2315105e-01, -1.1127311e+00, -9.7907877e-01, ...,\n",
359
+ " 1.2818729e+00, 3.2403129e-01, 6.7864366e-02],\n",
360
+ " [-8.0087566e-01, -9.7907877e-01, -8.0087566e-01, ...,\n",
361
+ " 1.5046268e+00, 6.1361134e-01, -1.7716487e-01],\n",
362
+ " ...,\n",
363
+ " [ 1.3932499e+00, 6.8043745e-01, 2.3492976e-01, ...,\n",
364
+ " -1.8144057e+00, -1.2129703e+00, 2.3313597e-02],\n",
365
+ " [ 1.5603153e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
366
+ " -1.7587173e+00, -1.1572819e+00, 5.6906056e-01],\n",
367
+ " [ 1.5603153e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
368
+ " -1.8366811e+00, -8.7883949e-01, 1.3598367e+00]]], dtype=float32),\n",
369
+ " array([[[-1.4245864 , -1.3466226 , -1.079318 , ..., -0.36650565,\n",
370
+ " -0.66722333, -1.079318 ],\n",
371
+ " [-1.1350064 , -1.2352457 , -1.4357241 , ..., -0.01009948,\n",
372
+ " -0.41105643, -0.7563249 ],\n",
373
+ " [-1.0459049 , -0.84542644, -1.368898 , ..., -0.22171564,\n",
374
+ " -0.4778826 , -0.82315105],\n",
375
+ " ...,\n",
376
+ " [-2.2599134 , -2.1151235 , -1.6139272 , ..., 0.5245098 ,\n",
377
+ " 0.41313285, 0.37971976],\n",
378
+ " [-2.560631 , -2.204225 , -1.7921304 , ..., 0.10127745,\n",
379
+ " 0.3128936 , 0.37971976],\n",
380
+ " [-2.2933266 , -2.1596742 , -2.0705726 , ..., 0.0233136 ,\n",
381
+ " 0.3017559 , 0.4910967 ]],\n",
382
+ " \n",
383
+ " [[-1.0570426 , -1.3911734 , -1.658478 , ..., 0.769539 ,\n",
384
+ " 0.7806767 , 0.95887977],\n",
385
+ " [-1.6362027 , -1.4134488 , -1.5693765 , ..., 1.0145682 ,\n",
386
+ " 1.1816336 , 1.1370829 ],\n",
387
+ " [-1.914645 , -1.1127311 , -1.224108 , ..., 0.68043745,\n",
388
+ " 0.9922929 , 0.85864055],\n",
389
+ " ...,\n",
390
+ " [-2.1262612 , -2.026022 , -1.7921304 , ..., 0.4910967 ,\n",
391
+ " 0.9031913 , 1.0702567 ],\n",
392
+ " [-2.3712904 , -2.5494936 , -2.304464 , ..., 0.70271283,\n",
393
+ " 0.914329 , 1.0479814 ],\n",
394
+ " [-2.2710512 , -2.3267395 , -2.4715295 , ..., 1.0145682 ,\n",
395
+ " 0.7806767 , 0.7472636 ]],\n",
396
+ " \n",
397
+ " [[-0.94566566, -0.6003972 , -0.85656416, ..., 1.2373221 ,\n",
398
+ " 1.5603153 , 1.5603153 ],\n",
399
+ " [-1.1795572 , -0.38878104, -0.37764335, ..., 1.5603153 ,\n",
400
+ " 1.5603153 , 1.5603153 ],\n",
401
+ " [-1.335485 , -1.3577603 , -0.32195488, ..., 1.5603153 ,\n",
402
+ " 1.5603153 , 1.5603153 ],\n",
403
+ " ...,\n",
404
+ " [-1.224108 , -1.1795572 , -1.4914126 , ..., -0.04351256,\n",
405
+ " -0.5112957 , -0.64494795],\n",
406
+ " [-1.3132095 , -1.8589565 , -1.7587173 , ..., -0.03237487,\n",
407
+ " -0.42219412, -0.1437518 ],\n",
408
+ " [-1.6473403 , -2.4381166 , -2.1262612 , ..., 0.37971976,\n",
409
+ " 0.34630668, 0.46882132]]], dtype=float32),\n",
410
+ " array([[[-0.84542644, -0.8120134 , -0.6894987 , ..., 0.46882132,\n",
411
+ " -0.01009948, -0.39991874],\n",
412
+ " [-0.65608567, -0.8120134 , -0.6338103 , ..., 0.59133595,\n",
413
+ " 0.04558898, -0.5112957 ],\n",
414
+ " [-0.96794105, -0.8342888 , -0.35536796, ..., 0.5801982 ,\n",
415
+ " 0.13469052, -0.37764335],\n",
416
+ " ...,\n",
417
+ " [-0.09920102, -0.11033872, -1.224108 , ..., -0.92339027,\n",
418
+ " -1.2686588 , -0.85656416],\n",
419
+ " [ 0.0121759 , -0.01009948, -1.0236295 , ..., -1.2129703 ,\n",
420
+ " -1.4023111 , -0.7340495 ],\n",
421
+ " [ 0.23492976, -0.07692564, -0.70063645, ..., -0.70063645,\n",
422
+ " -0.934528 , -1.0124918 ]],\n",
423
+ " \n",
424
+ " [[-1.6139272 , -0.97907877, -0.7340495 , ..., -0.8342888 ,\n",
425
+ " -0.6894987 , -1.0347673 ],\n",
426
+ " [-1.0459049 , -0.5892595 , -0.789738 , ..., -1.257521 ,\n",
427
+ " -1.0124918 , -1.4357241 ],\n",
428
+ " [-0.42219412, -0.5892595 , -1.4134488 , ..., -1.5025504 ,\n",
429
+ " -1.2463834 , -1.4802749 ],\n",
430
+ " ...,\n",
431
+ " [ 0.10127745, 0.40199515, 0.13469052, ..., 0.25720516,\n",
432
+ " 0.55792284, 0.12355283],\n",
433
+ " [-0.2885418 , -0.2551287 , 0.41313285, ..., 0.46882132,\n",
434
+ " -0.33309257, -1.2797965 ],\n",
435
+ " [-0.2996795 , -0.64494795, 0.9366044 , ..., 0.56906056,\n",
436
+ " -0.41105643, -1.3466226 ]],\n",
437
+ " \n",
438
+ " [[-1.1238688 , -1.4134488 , -1.7364419 , ..., -1.1350064 ,\n",
439
+ " -1.6918911 , -2.3156018 ],\n",
440
+ " [-1.4914126 , -1.3020718 , -0.99021643, ..., -1.658478 ,\n",
441
+ " -1.7698549 , -1.8812319 ],\n",
442
+ " [-1.5582387 , -0.85656416, -0.43333182, ..., -1.6027895 ,\n",
443
+ " -1.914645 , -1.7698549 ],\n",
444
+ " ...,\n",
445
+ " [ 0.6358867 , 0.9031913 , 1.0702567 , ..., 0.11241514,\n",
446
+ " 0.07900206, 0.34630668],\n",
447
+ " [ 0.70271283, 1.1593583 , 0.9254667 , ..., 0.335169 ,\n",
448
+ " 0.41313285, 0.23492976],\n",
449
+ " [ 0.35744438, 1.1482205 , 0.8697783 , ..., -0.2662664 ,\n",
450
+ " 0.56906056, 0.624749 ]]], dtype=float32)]"
451
+ ]
452
+ },
453
+ "execution_count": 8,
454
+ "metadata": {},
455
+ "output_type": "execute_result"
456
+ }
457
+ ],
458
+ "source": [
459
+ "patches_top_5"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "execution_count": 11,
465
+ "id": "4edb20e7",
466
+ "metadata": {},
467
+ "outputs": [],
468
+ "source": [
469
+ "import argparse\n",
470
+ "import os\n",
471
+ "import shutil\n",
472
+ "import time\n",
473
+ "import yaml\n",
474
+ "import sys\n",
475
+ "import gdown\n",
476
+ "import numpy as np\n",
477
+ "import torch\n",
478
+ "import torch.distributed as dist\n",
479
+ "import torch.multiprocessing as mp\n",
480
+ "import torch.nn as nn\n",
481
+ "import torch.nn.functional as F\n",
482
+ "from monai.config import KeysCollection\n",
483
+ "from monai.metrics import Cumulative, CumulativeAverage\n",
484
+ "from monai.networks.nets import milmodel, resnet, MILModel\n",
485
+ "from monai.transforms import (\n",
486
+ " Compose,\n",
487
+ " GridPatchd,\n",
488
+ " LoadImaged,\n",
489
+ " MapTransform,\n",
490
+ " RandFlipd,\n",
491
+ " RandGridPatchd,\n",
492
+ " RandRotate90d,\n",
493
+ " ScaleIntensityRanged,\n",
494
+ " SplitDimd,\n",
495
+ " ToTensord,\n",
496
+ " ConcatItemsd, \n",
497
+ " SelectItemsd,\n",
498
+ " EnsureChannelFirstd,\n",
499
+ " RepeatChanneld,\n",
500
+ " DeleteItemsd,\n",
501
+ " EnsureTyped,\n",
502
+ " ClipIntensityPercentilesd,\n",
503
+ " MaskIntensityd,\n",
504
+ " HistogramNormalized,\n",
505
+ " RandBiasFieldd,\n",
506
+ " RandCropByPosNegLabeld,\n",
507
+ " NormalizeIntensityd,\n",
508
+ " SqueezeDimd,\n",
509
+ " CropForegroundd,\n",
510
+ " ScaleIntensityd,\n",
511
+ " SpatialPadd,\n",
512
+ " CenterSpatialCropd,\n",
513
+ " ScaleIntensityd,\n",
514
+ " Transposed,\n",
515
+ " RandWeightedCropd,\n",
516
+ ")\n",
517
+ "from sklearn.metrics import cohen_kappa_score\n",
518
+ "from torch.cuda.amp import GradScaler, autocast\n",
519
+ "from torch.utils.data.dataloader import default_collate\n",
520
+ "from torchvision.models.resnet import ResNet50_Weights\n",
521
+ "from src.data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd\n",
522
+ "from torch.utils.data.distributed import DistributedSampler\n",
523
+ "from torch.utils.tensorboard import SummaryWriter\n",
524
+ "\n",
525
+ "import matplotlib.pyplot as plt\n",
526
+ "\n",
527
+ "import wandb\n",
528
+ "import math\n",
529
+ "from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset\n",
530
+ "\n",
531
+ "from src.model.MIL import MILModel_3D\n",
532
+ "from src.model.csPCa_model import csPCa_Model\n",
533
+ "\n",
534
+ "import logging\n",
535
+ "from pathlib import Path"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": 13,
541
+ "id": "e42cc132",
542
+ "metadata": {},
543
+ "outputs": [],
544
+ "source": [
545
+ "transform_image = Compose(\n",
546
+ " [\n",
547
+ " LoadImaged(keys=[\"image\", \"mask\"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),\n",
548
+ " ClipMaskIntensityPercentilesd(keys=[\"image\"], lower=0, upper=99.5, mask_key=\"mask\"),\n",
549
+ " NormalizeIntensity_customd(keys=[\"image\"], mask_key=\"mask\", channel_wise=True),\n",
550
+ " EnsureTyped(keys=[\"label\"], dtype=torch.float32),\n",
551
+ " ToTensord(keys=[\"image\", \"label\"]),\n",
552
+ " ]\n",
553
+ ")\n",
554
+ "dataset_image = Dataset(data=data_list, transform=transform_image)\n"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "code",
559
+ "execution_count": 19,
560
+ "id": "bcdddd9e",
561
+ "metadata": {},
562
+ "outputs": [
563
+ {
564
+ "data": {
565
+ "text/plain": [
566
+ "(270, 270, 28)"
567
+ ]
568
+ },
569
+ "execution_count": 19,
570
+ "metadata": {},
571
+ "output_type": "execute_result"
572
+ }
573
+ ],
574
+ "source": [
575
+ "dataset_image[0]['image'][0].numpy().shape"
576
  ]
577
  },
578
  {
579
  "cell_type": "code",
580
  "execution_count": null,
581
+ "id": "56072a2b",
582
  "metadata": {},
583
  "outputs": [],
584
  "source": []