Spaces:
Runtime error
Runtime error
Anirudh Balaraman commited on
Commit ·
a4ef78c
1
Parent(s): 7a1d40a
Add salient patch info
Browse files- run_inference.py +36 -5
- src/utils.py +130 -3
- temp.ipynb +419 -173
run_inference.py
CHANGED
|
@@ -34,7 +34,7 @@ from pathlib import Path
|
|
| 34 |
from src.model.MIL import MILModel_3D
|
| 35 |
from src.model.csPCa_model import csPCa_Model
|
| 36 |
from src.data.data_loader import get_dataloader
|
| 37 |
-
from src.utils import save_cspca_checkpoint, get_metrics, setup_logging, save_pirads_checkpoint
|
| 38 |
from src.train import train_cspca, train_pirads
|
| 39 |
import SimpleITK as sitk
|
| 40 |
|
|
@@ -57,6 +57,7 @@ import argparse
|
|
| 57 |
import yaml
|
| 58 |
from src.data.data_loader import data_transform, list_data_collate
|
| 59 |
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
|
|
|
| 60 |
|
| 61 |
def parse_args():
|
| 62 |
|
|
@@ -120,7 +121,7 @@ if __name__ == "__main__":
|
|
| 120 |
|
| 121 |
transform = data_transform(args)
|
| 122 |
files = os.listdir(args.t2_dir)
|
| 123 |
-
data_list = []
|
| 124 |
for file in files:
|
| 125 |
temp = {}
|
| 126 |
temp['image'] = os.path.join(args.t2_dir, file)
|
|
@@ -129,9 +130,9 @@ if __name__ == "__main__":
|
|
| 129 |
temp['heatmap'] = os.path.join(args.heatmapdir, file)
|
| 130 |
temp['mask'] = os.path.join(args.seg_dir, file)
|
| 131 |
temp['label'] = 0 # dummy label
|
| 132 |
-
data_list.append(temp)
|
| 133 |
|
| 134 |
-
dataset = Dataset(data=data_list, transform=transform)
|
| 135 |
loader = torch.utils.data.DataLoader(
|
| 136 |
dataset,
|
| 137 |
batch_size=1,
|
|
@@ -147,6 +148,7 @@ if __name__ == "__main__":
|
|
| 147 |
pirads_model.eval()
|
| 148 |
cspca_risk_list = []
|
| 149 |
cspca_model.eval()
|
|
|
|
| 150 |
with torch.no_grad():
|
| 151 |
for idx, batch_data in enumerate(loader):
|
| 152 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
|
@@ -158,5 +160,34 @@ if __name__ == "__main__":
|
|
| 158 |
output = output.squeeze(1)
|
| 159 |
cspca_risk_list.append(output.item())
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
for i,j in enumerate(files):
|
| 162 |
-
logging.info(f"File: {j}, PIRADS score: {pirads_list[i]}, csPCa risk score: {cspca_risk_list[i]:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
from src.model.MIL import MILModel_3D
|
| 35 |
from src.model.csPCa_model import csPCa_Model
|
| 36 |
from src.data.data_loader import get_dataloader
|
| 37 |
+
from src.utils import save_cspca_checkpoint, get_metrics, setup_logging, save_pirads_checkpoint, get_parent_image, get_patch_coordinate
|
| 38 |
from src.train import train_cspca, train_pirads
|
| 39 |
import SimpleITK as sitk
|
| 40 |
|
|
|
|
| 57 |
import yaml
|
| 58 |
from src.data.data_loader import data_transform, list_data_collate
|
| 59 |
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
| 60 |
+
import json
|
| 61 |
|
| 62 |
def parse_args():
|
| 63 |
|
|
|
|
| 121 |
|
| 122 |
transform = data_transform(args)
|
| 123 |
files = os.listdir(args.t2_dir)
|
| 124 |
+
args.data_list = []
|
| 125 |
for file in files:
|
| 126 |
temp = {}
|
| 127 |
temp['image'] = os.path.join(args.t2_dir, file)
|
|
|
|
| 130 |
temp['heatmap'] = os.path.join(args.heatmapdir, file)
|
| 131 |
temp['mask'] = os.path.join(args.seg_dir, file)
|
| 132 |
temp['label'] = 0 # dummy label
|
| 133 |
+
args.data_list.append(temp)
|
| 134 |
|
| 135 |
+
dataset = Dataset(data=args.data_list, transform=transform)
|
| 136 |
loader = torch.utils.data.DataLoader(
|
| 137 |
dataset,
|
| 138 |
batch_size=1,
|
|
|
|
| 148 |
pirads_model.eval()
|
| 149 |
cspca_risk_list = []
|
| 150 |
cspca_model.eval()
|
| 151 |
+
top5_patches = []
|
| 152 |
with torch.no_grad():
|
| 153 |
for idx, batch_data in enumerate(loader):
|
| 154 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
|
|
|
| 160 |
output = output.squeeze(1)
|
| 161 |
cspca_risk_list.append(output.item())
|
| 162 |
|
| 163 |
+
sh = data.shape
|
| 164 |
+
x = data.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])
|
| 165 |
+
x = cspca_model.backbone.net(x)
|
| 166 |
+
x = x.reshape(sh[0], sh[1], -1)
|
| 167 |
+
x = x.permute(1, 0, 2)
|
| 168 |
+
x = cspca_model.backbone.transformer(x)
|
| 169 |
+
x = x.permute(1, 0, 2)
|
| 170 |
+
a = cspca_model.backbone.attention(x)
|
| 171 |
+
a = torch.softmax(a, dim=1)
|
| 172 |
+
a = a.view(-1)
|
| 173 |
+
top5_values, top5_indices = torch.topk(a, 5)
|
| 174 |
+
|
| 175 |
+
patches_top_5 = []
|
| 176 |
+
for i in range(5):
|
| 177 |
+
patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()
|
| 178 |
+
patches_top_5.append(patch_temp)
|
| 179 |
+
|
| 180 |
+
parent_image = get_parent_image(args)
|
| 181 |
+
|
| 182 |
+
coords = get_patch_coordinate(patches_top_5, parent_image, args)
|
| 183 |
+
|
| 184 |
for i,j in enumerate(files):
|
| 185 |
+
logging.info(f"File: {j}, PIRADS score: {pirads_list[i]}, csPCa risk score: {cspca_risk_list[i]:.4f}")
|
| 186 |
+
|
| 187 |
+
output_dict = {
|
| 188 |
+
'Predicted PIRAD Score': pirads_list[i] + 2.0,
|
| 189 |
+
'csPCa risk': cspca_risk_list[i],
|
| 190 |
+
'Top left coordinate of top 5 patches(x,y,z)': coords,
|
| 191 |
+
}
|
| 192 |
+
with open(os.path.join(args.output_dir, "results.json"), 'w') as f:
|
| 193 |
+
json.dump(output_dict, f, indent=4)
|
src/utils.py
CHANGED
|
@@ -14,19 +14,52 @@ import torch.nn.functional as F
|
|
| 14 |
from monai.config import KeysCollection
|
| 15 |
from monai.metrics import Cumulative, CumulativeAverage
|
| 16 |
from monai.networks.nets import milmodel, resnet, MILModel
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from sklearn.metrics import cohen_kappa_score
|
| 19 |
from torch.cuda.amp import GradScaler, autocast
|
| 20 |
from torch.utils.data.dataloader import default_collate
|
| 21 |
from torchvision.models.resnet import ResNet50_Weights
|
| 22 |
-
|
| 23 |
from torch.utils.data.distributed import DistributedSampler
|
| 24 |
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
| 25 |
|
| 26 |
import matplotlib.pyplot as plt
|
| 27 |
|
| 28 |
import wandb
|
| 29 |
import math
|
|
|
|
| 30 |
|
| 31 |
from src.model.MIL import MILModel_3D
|
| 32 |
from src.model.csPCa_model import csPCa_Model
|
|
@@ -96,4 +129,98 @@ def validate_steps(steps):
|
|
| 96 |
f"Step '{step}' requires '{req}' to be executed before it. "
|
| 97 |
f"Given order: {steps}"
|
| 98 |
)
|
| 99 |
-
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from monai.config import KeysCollection
|
| 15 |
from monai.metrics import Cumulative, CumulativeAverage
|
| 16 |
from monai.networks.nets import milmodel, resnet, MILModel
|
| 17 |
+
from monai.transforms import (
|
| 18 |
+
Compose,
|
| 19 |
+
GridPatchd,
|
| 20 |
+
LoadImaged,
|
| 21 |
+
MapTransform,
|
| 22 |
+
RandFlipd,
|
| 23 |
+
RandGridPatchd,
|
| 24 |
+
RandRotate90d,
|
| 25 |
+
ScaleIntensityRanged,
|
| 26 |
+
SplitDimd,
|
| 27 |
+
ToTensord,
|
| 28 |
+
ConcatItemsd,
|
| 29 |
+
SelectItemsd,
|
| 30 |
+
EnsureChannelFirstd,
|
| 31 |
+
RepeatChanneld,
|
| 32 |
+
DeleteItemsd,
|
| 33 |
+
EnsureTyped,
|
| 34 |
+
ClipIntensityPercentilesd,
|
| 35 |
+
MaskIntensityd,
|
| 36 |
+
HistogramNormalized,
|
| 37 |
+
RandBiasFieldd,
|
| 38 |
+
RandCropByPosNegLabeld,
|
| 39 |
+
NormalizeIntensityd,
|
| 40 |
+
SqueezeDimd,
|
| 41 |
+
CropForegroundd,
|
| 42 |
+
ScaleIntensityd,
|
| 43 |
+
SpatialPadd,
|
| 44 |
+
CenterSpatialCropd,
|
| 45 |
+
ScaleIntensityd,
|
| 46 |
+
Transposed,
|
| 47 |
+
RandWeightedCropd,
|
| 48 |
+
)
|
| 49 |
from sklearn.metrics import cohen_kappa_score
|
| 50 |
from torch.cuda.amp import GradScaler, autocast
|
| 51 |
from torch.utils.data.dataloader import default_collate
|
| 52 |
from torchvision.models.resnet import ResNet50_Weights
|
| 53 |
+
from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
|
| 54 |
from torch.utils.data.distributed import DistributedSampler
|
| 55 |
from torch.utils.tensorboard import SummaryWriter
|
| 56 |
+
import matplotlib.patches as patches
|
| 57 |
|
| 58 |
import matplotlib.pyplot as plt
|
| 59 |
|
| 60 |
import wandb
|
| 61 |
import math
|
| 62 |
+
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
| 63 |
|
| 64 |
from src.model.MIL import MILModel_3D
|
| 65 |
from src.model.csPCa_model import csPCa_Model
|
|
|
|
| 129 |
f"Step '{step}' requires '{req}' to be executed before it. "
|
| 130 |
f"Given order: {steps}"
|
| 131 |
)
|
| 132 |
+
sys.exit(1)
|
| 133 |
+
|
| 134 |
+
def get_patch_coordinate(patches_top_5, parent_image, args):
|
| 135 |
+
|
| 136 |
+
sample = np.array([i.transpose(1,2,0) for i in patches_top_5])
|
| 137 |
+
coords = []
|
| 138 |
+
rows, h, w, slices = sample.shape
|
| 139 |
+
|
| 140 |
+
for i in range(rows):
|
| 141 |
+
for j in range(slices):
|
| 142 |
+
if j == 0:
|
| 143 |
+
for k in range(parent_image.shape[2]):
|
| 144 |
+
img_temp = parent_image[:, :, k]
|
| 145 |
+
H, W = img_temp.shape
|
| 146 |
+
h, w = sample[i, :, :, j].shape
|
| 147 |
+
a,b = 0, 0 # Initialize a and b
|
| 148 |
+
bool1 = False
|
| 149 |
+
for l in range(H - h + 1):
|
| 150 |
+
for m in range(W - w + 1):
|
| 151 |
+
if np.array_equal(img_temp[l:l+h, m:m+w], sample[i, :, :, j]):
|
| 152 |
+
a,b = l, m # top-left corner
|
| 153 |
+
coords.append((a,b,k))
|
| 154 |
+
bool1 = True
|
| 155 |
+
break
|
| 156 |
+
if bool1:
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
if bool1:
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
return coords
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def get_parent_image(args):
|
| 166 |
+
transform_image = Compose(
|
| 167 |
+
[
|
| 168 |
+
LoadImaged(keys=["image", "mask"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
|
| 169 |
+
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
|
| 170 |
+
NormalizeIntensity_customd(keys=["image"], mask_key="mask", channel_wise=True),
|
| 171 |
+
EnsureTyped(keys=["label"], dtype=torch.float32),
|
| 172 |
+
ToTensord(keys=["image", "label"]),
|
| 173 |
+
]
|
| 174 |
+
)
|
| 175 |
+
dataset_image = Dataset(data=args.data_list, transform=transform_image)
|
| 176 |
+
return dataset_image[0]['image'][0].numpy()
|
| 177 |
+
|
| 178 |
+
'''
|
| 179 |
+
def visualise_patches():
|
| 180 |
+
sample = np.array([i.transpose(1,2,0) for i in patches_top_5])
|
| 181 |
+
rows = len(patches_top_5)
|
| 182 |
+
img = sample[0]
|
| 183 |
+
coords = []
|
| 184 |
+
rows, h, w, slices = sample.shape
|
| 185 |
+
|
| 186 |
+
fig, axes = plt.subplots(nrows=rows, ncols=slices, figsize=(slices * 3, rows * 3))
|
| 187 |
+
|
| 188 |
+
for i in range(rows):
|
| 189 |
+
for j in range(slices):
|
| 190 |
+
ax = axes[i, j]
|
| 191 |
+
|
| 192 |
+
if j == 0:
|
| 193 |
+
|
| 194 |
+
for k in range(parent_image.shape[2]):
|
| 195 |
+
img_temp = parent_image[:, :, k]
|
| 196 |
+
H, W = img_temp.shape
|
| 197 |
+
h, w = sample[i, :, :, j].shape
|
| 198 |
+
a,b = 0, 0 # Initialize a and b
|
| 199 |
+
bool1 = False
|
| 200 |
+
for l in range(H - h + 1):
|
| 201 |
+
for m in range(W - w + 1):
|
| 202 |
+
if np.array_equal(img_temp[l:l+h, m:m+w], sample[i, :, :, j]):
|
| 203 |
+
a,b = l, m # top-left corner
|
| 204 |
+
coords.append((a,b,k))
|
| 205 |
+
bool1 = True
|
| 206 |
+
break
|
| 207 |
+
if bool1:
|
| 208 |
+
break
|
| 209 |
+
|
| 210 |
+
if bool1:
|
| 211 |
+
break
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
ax.imshow(parent_image[:, :, k+j], cmap='gray')
|
| 217 |
+
rect = patches.Rectangle((b, a), args.tile_size, args.tile_size,
|
| 218 |
+
linewidth=2, edgecolor='red', facecolor='none')
|
| 219 |
+
ax.add_patch(rect)
|
| 220 |
+
ax.axis('off')
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
plt.tight_layout()
|
| 224 |
+
plt.show()
|
| 225 |
+
a=1
|
| 226 |
+
'''
|
temp.ipynb
CHANGED
|
@@ -2,10 +2,19 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "cec738fb",
|
| 7 |
"metadata": {},
|
| 8 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
"source": [
|
| 10 |
"import argparse\n",
|
| 11 |
"import os\n",
|
|
@@ -70,7 +79,7 @@
|
|
| 70 |
},
|
| 71 |
{
|
| 72 |
"cell_type": "code",
|
| 73 |
-
"execution_count":
|
| 74 |
"id": "c91a5802",
|
| 75 |
"metadata": {},
|
| 76 |
"outputs": [
|
|
@@ -78,9 +87,9 @@
|
|
| 78 |
"name": "stderr",
|
| 79 |
"output_type": "stream",
|
| 80 |
"text": [
|
| 81 |
-
" 0%| | 0/1 [00:
|
| 82 |
"You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
| 83 |
-
" 0%| | 0/1 [00:
|
| 84 |
]
|
| 85 |
}
|
| 86 |
],
|
|
@@ -104,36 +113,23 @@
|
|
| 104 |
},
|
| 105 |
{
|
| 106 |
"cell_type": "code",
|
| 107 |
-
"execution_count":
|
| 108 |
-
"id": "
|
| 109 |
"metadata": {},
|
| 110 |
-
"outputs": [
|
| 111 |
-
{
|
| 112 |
-
"data": {
|
| 113 |
-
"text/plain": [
|
| 114 |
-
"{'margin': 0.2,\n",
|
| 115 |
-
" 't2_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/t2_histmatched',\n",
|
| 116 |
-
" 'dwi_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/DWI_histmatched',\n",
|
| 117 |
-
" 'adc_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/ADC_histmatched',\n",
|
| 118 |
-
" 'output_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed',\n",
|
| 119 |
-
" 'project_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate',\n",
|
| 120 |
-
" 'seg_dir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/prostate_mask',\n",
|
| 121 |
-
" 'heatmapdir': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/heatmaps/'}"
|
| 122 |
-
]
|
| 123 |
-
},
|
| 124 |
-
"execution_count": 6,
|
| 125 |
-
"metadata": {},
|
| 126 |
-
"output_type": "execute_result"
|
| 127 |
-
}
|
| 128 |
-
],
|
| 129 |
"source": [
|
| 130 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
]
|
| 132 |
},
|
| 133 |
{
|
| 134 |
"cell_type": "code",
|
| 135 |
-
"execution_count":
|
| 136 |
-
"id": "
|
| 137 |
"metadata": {},
|
| 138 |
"outputs": [
|
| 139 |
{
|
|
@@ -141,155 +137,23 @@
|
|
| 141 |
"output_type": "stream",
|
| 142 |
"text": [
|
| 143 |
"enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
|
|
|
|
| 144 |
"You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n"
|
| 145 |
]
|
| 146 |
-
},
|
| 147 |
-
{
|
| 148 |
-
"data": {
|
| 149 |
-
"text/plain": [
|
| 150 |
-
"MILModel_3D(\n",
|
| 151 |
-
" (attention): Sequential(\n",
|
| 152 |
-
" (0): Linear(in_features=512, out_features=2048, bias=True)\n",
|
| 153 |
-
" (1): Tanh()\n",
|
| 154 |
-
" (2): Linear(in_features=2048, out_features=1, bias=True)\n",
|
| 155 |
-
" )\n",
|
| 156 |
-
" (transformer): TransformerEncoder(\n",
|
| 157 |
-
" (layers): ModuleList(\n",
|
| 158 |
-
" (0-3): 4 x TransformerEncoderLayer(\n",
|
| 159 |
-
" (self_attn): MultiheadAttention(\n",
|
| 160 |
-
" (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n",
|
| 161 |
-
" )\n",
|
| 162 |
-
" (linear1): Linear(in_features=512, out_features=2048, bias=True)\n",
|
| 163 |
-
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
| 164 |
-
" (linear2): Linear(in_features=2048, out_features=512, bias=True)\n",
|
| 165 |
-
" (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
|
| 166 |
-
" (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
|
| 167 |
-
" (dropout1): Dropout(p=0.0, inplace=False)\n",
|
| 168 |
-
" (dropout2): Dropout(p=0.0, inplace=False)\n",
|
| 169 |
-
" )\n",
|
| 170 |
-
" )\n",
|
| 171 |
-
" )\n",
|
| 172 |
-
" (myfc): Linear(in_features=512, out_features=4, bias=True)\n",
|
| 173 |
-
" (net): ResNet(\n",
|
| 174 |
-
" (conv1): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)\n",
|
| 175 |
-
" (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 176 |
-
" (act): ReLU(inplace=True)\n",
|
| 177 |
-
" (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
| 178 |
-
" (layer1): Sequential(\n",
|
| 179 |
-
" (0): ResNetBlock(\n",
|
| 180 |
-
" (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 181 |
-
" (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 182 |
-
" (act): ReLU(inplace=True)\n",
|
| 183 |
-
" (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 184 |
-
" (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 185 |
-
" )\n",
|
| 186 |
-
" (1): ResNetBlock(\n",
|
| 187 |
-
" (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 188 |
-
" (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 189 |
-
" (act): ReLU(inplace=True)\n",
|
| 190 |
-
" (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 191 |
-
" (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 192 |
-
" )\n",
|
| 193 |
-
" )\n",
|
| 194 |
-
" (layer2): Sequential(\n",
|
| 195 |
-
" (0): ResNetBlock(\n",
|
| 196 |
-
" (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)\n",
|
| 197 |
-
" (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 198 |
-
" (act): ReLU(inplace=True)\n",
|
| 199 |
-
" (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 200 |
-
" (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 201 |
-
" (downsample): Sequential(\n",
|
| 202 |
-
" (0): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(2, 2, 2))\n",
|
| 203 |
-
" (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 204 |
-
" )\n",
|
| 205 |
-
" )\n",
|
| 206 |
-
" (1): ResNetBlock(\n",
|
| 207 |
-
" (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 208 |
-
" (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 209 |
-
" (act): ReLU(inplace=True)\n",
|
| 210 |
-
" (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 211 |
-
" (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 212 |
-
" )\n",
|
| 213 |
-
" )\n",
|
| 214 |
-
" (layer3): Sequential(\n",
|
| 215 |
-
" (0): ResNetBlock(\n",
|
| 216 |
-
" (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)\n",
|
| 217 |
-
" (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 218 |
-
" (act): ReLU(inplace=True)\n",
|
| 219 |
-
" (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 220 |
-
" (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 221 |
-
" (downsample): Sequential(\n",
|
| 222 |
-
" (0): Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(2, 2, 2))\n",
|
| 223 |
-
" (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 224 |
-
" )\n",
|
| 225 |
-
" )\n",
|
| 226 |
-
" (1): ResNetBlock(\n",
|
| 227 |
-
" (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 228 |
-
" (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 229 |
-
" (act): ReLU(inplace=True)\n",
|
| 230 |
-
" (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 231 |
-
" (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 232 |
-
" )\n",
|
| 233 |
-
" )\n",
|
| 234 |
-
" (layer4): Sequential(\n",
|
| 235 |
-
" (0): ResNetBlock(\n",
|
| 236 |
-
" (conv1): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)\n",
|
| 237 |
-
" (bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 238 |
-
" (act): ReLU(inplace=True)\n",
|
| 239 |
-
" (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 240 |
-
" (bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 241 |
-
" (downsample): Sequential(\n",
|
| 242 |
-
" (0): Conv3d(256, 512, kernel_size=(1, 1, 1), stride=(2, 2, 2))\n",
|
| 243 |
-
" (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 244 |
-
" )\n",
|
| 245 |
-
" )\n",
|
| 246 |
-
" (1): ResNetBlock(\n",
|
| 247 |
-
" (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 248 |
-
" (bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 249 |
-
" (act): ReLU(inplace=True)\n",
|
| 250 |
-
" (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)\n",
|
| 251 |
-
" (bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 252 |
-
" )\n",
|
| 253 |
-
" )\n",
|
| 254 |
-
" (avgpool): AdaptiveAvgPool3d(output_size=(1, 1, 1))\n",
|
| 255 |
-
" (fc): Identity()\n",
|
| 256 |
-
" )\n",
|
| 257 |
-
")"
|
| 258 |
-
]
|
| 259 |
-
},
|
| 260 |
-
"execution_count": 8,
|
| 261 |
-
"metadata": {},
|
| 262 |
-
"output_type": "execute_result"
|
| 263 |
}
|
| 264 |
],
|
| 265 |
"source": [
|
| 266 |
"args.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 267 |
-
"
|
| 268 |
-
"
|
| 269 |
"pirads_model = MILModel_3D(\n",
|
| 270 |
" num_classes=args.num_classes, \n",
|
| 271 |
" mil_mode=args.mil_mode \n",
|
| 272 |
")\n",
|
| 273 |
"pirads_checkpoint = torch.load(os.path.join(args.project_dir, 'models', 'pirads.pt'), map_location=\"cpu\")\n",
|
| 274 |
"pirads_model.load_state_dict(pirads_checkpoint[\"state_dict\"])\n",
|
| 275 |
-
"pirads_model.to(args.device)"
|
| 276 |
-
|
| 277 |
-
},
|
| 278 |
-
{
|
| 279 |
-
"cell_type": "code",
|
| 280 |
-
"execution_count": 31,
|
| 281 |
-
"id": "01467cae",
|
| 282 |
-
"metadata": {},
|
| 283 |
-
"outputs": [
|
| 284 |
-
{
|
| 285 |
-
"name": "stderr",
|
| 286 |
-
"output_type": "stream",
|
| 287 |
-
"text": [
|
| 288 |
-
"You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n"
|
| 289 |
-
]
|
| 290 |
-
}
|
| 291 |
-
],
|
| 292 |
-
"source": [
|
| 293 |
"cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)\n",
|
| 294 |
"checkpt = torch.load(os.path.join(args.project_dir, 'models', 'cspca_model.pth'), map_location=\"cpu\")\n",
|
| 295 |
"cspca_model.load_state_dict(checkpt['state_dict'])\n",
|
|
@@ -298,18 +162,41 @@
|
|
| 298 |
},
|
| 299 |
{
|
| 300 |
"cell_type": "code",
|
| 301 |
-
"execution_count":
|
| 302 |
-
"id": "
|
| 303 |
"metadata": {},
|
| 304 |
"outputs": [],
|
| 305 |
"source": [
|
| 306 |
-
"\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
]
|
| 308 |
},
|
| 309 |
{
|
| 310 |
"cell_type": "code",
|
| 311 |
-
"execution_count":
|
| 312 |
-
"id": "
|
| 313 |
"metadata": {},
|
| 314 |
"outputs": [],
|
| 315 |
"source": [
|
|
@@ -317,6 +204,7 @@
|
|
| 317 |
"pirads_model.eval()\n",
|
| 318 |
"cspca_risk_list = []\n",
|
| 319 |
"cspca_model.eval()\n",
|
|
|
|
| 320 |
"with torch.no_grad():\n",
|
| 321 |
" for idx, batch_data in enumerate(loader):\n",
|
| 322 |
" data = batch_data[\"image\"].as_subclass(torch.Tensor).to(args.device)\n",
|
|
@@ -326,13 +214,371 @@
|
|
| 326 |
"\n",
|
| 327 |
" output = cspca_model(data)\n",
|
| 328 |
" output = output.squeeze(1)\n",
|
| 329 |
-
" cspca_risk_list.append(output.item())"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
]
|
| 331 |
},
|
| 332 |
{
|
| 333 |
"cell_type": "code",
|
| 334 |
"execution_count": null,
|
| 335 |
-
"id": "
|
| 336 |
"metadata": {},
|
| 337 |
"outputs": [],
|
| 338 |
"source": []
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
"id": "cec738fb",
|
| 7 |
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"If you have questions or suggestions, feel free to open an issue at https://github.com/DIAGNijmegen/picai_prep\n",
|
| 14 |
+
"\n"
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
"source": [
|
| 19 |
"import argparse\n",
|
| 20 |
"import os\n",
|
|
|
|
| 79 |
},
|
| 80 |
{
|
| 81 |
"cell_type": "code",
|
| 82 |
+
"execution_count": 2,
|
| 83 |
"id": "c91a5802",
|
| 84 |
"metadata": {},
|
| 85 |
"outputs": [
|
|
|
|
| 87 |
"name": "stderr",
|
| 88 |
"output_type": "stream",
|
| 89 |
"text": [
|
| 90 |
+
" 0%| | 0/1 [00:03<?, ?it/s]\n",
|
| 91 |
"You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
| 92 |
+
" 0%| | 0/1 [00:05<?, ?it/s]\n"
|
| 93 |
]
|
| 94 |
}
|
| 95 |
],
|
|
|
|
| 113 |
},
|
| 114 |
{
|
| 115 |
"cell_type": "code",
|
| 116 |
+
"execution_count": 4,
|
| 117 |
+
"id": "8b5d382e",
|
| 118 |
"metadata": {},
|
| 119 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
"source": [
|
| 121 |
+
"args.num_classes = 4\n",
|
| 122 |
+
"args.mil_mode = \"att_trans\"\n",
|
| 123 |
+
"args.use_heatmap = True\n",
|
| 124 |
+
"args.tile_size = 64\n",
|
| 125 |
+
"args.tile_count = 24\n",
|
| 126 |
+
"args.depth = 3\n"
|
| 127 |
]
|
| 128 |
},
|
| 129 |
{
|
| 130 |
"cell_type": "code",
|
| 131 |
+
"execution_count": 5,
|
| 132 |
+
"id": "4cf061ec",
|
| 133 |
"metadata": {},
|
| 134 |
"outputs": [
|
| 135 |
{
|
|
|
|
| 137 |
"output_type": "stream",
|
| 138 |
"text": [
|
| 139 |
"enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
|
| 140 |
+
"You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
| 141 |
"You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n"
|
| 142 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
}
|
| 144 |
],
|
| 145 |
"source": [
|
| 146 |
"args.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"\n",
|
| 149 |
"pirads_model = MILModel_3D(\n",
|
| 150 |
" num_classes=args.num_classes, \n",
|
| 151 |
" mil_mode=args.mil_mode \n",
|
| 152 |
")\n",
|
| 153 |
"pirads_checkpoint = torch.load(os.path.join(args.project_dir, 'models', 'pirads.pt'), map_location=\"cpu\")\n",
|
| 154 |
"pirads_model.load_state_dict(pirads_checkpoint[\"state_dict\"])\n",
|
| 155 |
+
"pirads_model.to(args.device)\n",
|
| 156 |
+
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
"cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)\n",
|
| 158 |
"checkpt = torch.load(os.path.join(args.project_dir, 'models', 'cspca_model.pth'), map_location=\"cpu\")\n",
|
| 159 |
"cspca_model.load_state_dict(checkpt['state_dict'])\n",
|
|
|
|
| 162 |
},
|
| 163 |
{
|
| 164 |
"cell_type": "code",
|
| 165 |
+
"execution_count": 6,
|
| 166 |
+
"id": "fac15515",
|
| 167 |
"metadata": {},
|
| 168 |
"outputs": [],
|
| 169 |
"source": [
|
| 170 |
+
"transform = data_transform(args)\n",
|
| 171 |
+
"files = os.listdir(args.t2_dir)\n",
|
| 172 |
+
"data_list = []\n",
|
| 173 |
+
"for file in files:\n",
|
| 174 |
+
" temp = {}\n",
|
| 175 |
+
" temp['image'] = os.path.join(args.t2_dir, file)\n",
|
| 176 |
+
" temp['dwi'] = os.path.join(args.dwi_dir, file)\n",
|
| 177 |
+
" temp['adc'] = os.path.join(args.adc_dir, file)\n",
|
| 178 |
+
" temp['heatmap'] = os.path.join(args.heatmapdir, file)\n",
|
| 179 |
+
" temp['mask'] = os.path.join(args.seg_dir, file)\n",
|
| 180 |
+
" temp['label'] = 0 # dummy label\n",
|
| 181 |
+
" data_list.append(temp)\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"dataset = Dataset(data=data_list, transform=transform)\n",
|
| 184 |
+
"loader = torch.utils.data.DataLoader(\n",
|
| 185 |
+
" dataset,\n",
|
| 186 |
+
" batch_size=1,\n",
|
| 187 |
+
" shuffle=False,\n",
|
| 188 |
+
" num_workers=0,\n",
|
| 189 |
+
" pin_memory=True,\n",
|
| 190 |
+
" multiprocessing_context= None,\n",
|
| 191 |
+
" sampler=None,\n",
|
| 192 |
+
" collate_fn=list_data_collate,\n",
|
| 193 |
+
")"
|
| 194 |
]
|
| 195 |
},
|
| 196 |
{
|
| 197 |
"cell_type": "code",
|
| 198 |
+
"execution_count": 7,
|
| 199 |
+
"id": "eb80047b",
|
| 200 |
"metadata": {},
|
| 201 |
"outputs": [],
|
| 202 |
"source": [
|
|
|
|
| 204 |
"pirads_model.eval()\n",
|
| 205 |
"cspca_risk_list = []\n",
|
| 206 |
"cspca_model.eval()\n",
|
| 207 |
+
"top5_patches = []\n",
|
| 208 |
"with torch.no_grad():\n",
|
| 209 |
" for idx, batch_data in enumerate(loader):\n",
|
| 210 |
" data = batch_data[\"image\"].as_subclass(torch.Tensor).to(args.device)\n",
|
|
|
|
| 214 |
"\n",
|
| 215 |
" output = cspca_model(data)\n",
|
| 216 |
" output = output.squeeze(1)\n",
|
| 217 |
+
" cspca_risk_list.append(output.item())\n",
|
| 218 |
+
"\n",
|
| 219 |
+
" sh = data.shape\n",
|
| 220 |
+
" x = data.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])\n",
|
| 221 |
+
" x = cspca_model.backbone.net(x)\n",
|
| 222 |
+
" x = x.reshape(sh[0], sh[1], -1)\n",
|
| 223 |
+
" x = x.permute(1, 0, 2)\n",
|
| 224 |
+
" x = cspca_model.backbone.transformer(x)\n",
|
| 225 |
+
" x = x.permute(1, 0, 2)\n",
|
| 226 |
+
" a = cspca_model.backbone.attention(x)\n",
|
| 227 |
+
" a = torch.softmax(a, dim=1)\n",
|
| 228 |
+
" a = a.view(-1)\n",
|
| 229 |
+
" top5_values, top5_indices = torch.topk(a, 5)\n",
|
| 230 |
+
"\n",
|
| 231 |
+
" patches_top_5 = []\n",
|
| 232 |
+
" for i in range(5):\n",
|
| 233 |
+
" patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()\n",
|
| 234 |
+
" patches_top_5.append(patch_temp)"
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"cell_type": "code",
|
| 239 |
+
"execution_count": 8,
|
| 240 |
+
"id": "dbcfc97f",
|
| 241 |
+
"metadata": {},
|
| 242 |
+
"outputs": [
|
| 243 |
+
{
|
| 244 |
+
"data": {
|
| 245 |
+
"text/plain": [
|
| 246 |
+
"[array([[[-0.43333182, -0.52243334, -1.4245864 , ..., -1.7698549 ,\n",
|
| 247 |
+
" -1.6027895 , -1.257521 ],\n",
|
| 248 |
+
" [-0.7786003 , -1.2797965 , -2.0817103 , ..., -2.0371597 ,\n",
|
| 249 |
+
" -2.0817103 , -2.059435 ],\n",
|
| 250 |
+
" [-1.4802749 , -1.658478 , -1.7809926 , ..., -1.4691372 ,\n",
|
| 251 |
+
" -2.1596742 , -1.9591957 ],\n",
|
| 252 |
+
" ...,\n",
|
| 253 |
+
" [ 0.8697783 , 0.34630668, 0.6358867 , ..., 0.624749 ,\n",
|
| 254 |
+
" 0.24606745, 0.03445129],\n",
|
| 255 |
+
" [ 0.44654593, 0.23492976, 0.7806767 , ..., 1.2373221 ,\n",
|
| 256 |
+
" 0.2906182 , -1.0124918 ],\n",
|
| 257 |
+
" [ 0.3017559 , -0.2551287 , 0.27948052, ..., 1.1036698 ,\n",
|
| 258 |
+
" -0.16602719, -1.2463834 ]],\n",
|
| 259 |
+
" \n",
|
| 260 |
+
" [[-0.84542644, -1.3800358 , -1.4357241 , ..., -2.5383558 ,\n",
|
| 261 |
+
" -2.0817103 , -0.99021643],\n",
|
| 262 |
+
" [-0.4444695 , -1.3466226 , -1.8255434 , ..., -1.335485 ,\n",
|
| 263 |
+
" -1.6362027 , -1.4579996 ],\n",
|
| 264 |
+
" [-1.2129703 , -1.7921304 , -2.0371597 , ..., 0.14582822,\n",
|
| 265 |
+
" -0.04351256, -0.82315105],\n",
|
| 266 |
+
" ...,\n",
|
| 267 |
+
" [-0.4556072 , 1.4378005 , 1.5603153 , ..., 0.2906182 ,\n",
|
| 268 |
+
" -0.85656416, -1.1906949 ],\n",
|
| 269 |
+
" [-1.2352457 , -0.35536796, 1.1816336 , ..., -0.39991874,\n",
|
| 270 |
+
" -1.3800358 , -1.2129703 ],\n",
|
| 271 |
+
" [-1.2909342 , -1.1127311 , 0.5467852 , ..., -0.9568034 ,\n",
|
| 272 |
+
" -1.4023111 , -0.9011149 ]],\n",
|
| 273 |
+
" \n",
|
| 274 |
+
" [[-2.1596742 , -1.914645 , -1.7809926 , ..., -1.0124918 ,\n",
|
| 275 |
+
" 0.03445129, 0.44654593],\n",
|
| 276 |
+
" [-2.0037465 , -1.8923696 , -1.8700942 , ..., -1.2129703 ,\n",
|
| 277 |
+
" -1.1572819 , -0.50015795],\n",
|
| 278 |
+
" [-1.8478189 , -1.9369203 , -1.9814711 , ..., -0.65608567,\n",
|
| 279 |
+
" -1.2352457 , -1.4134488 ],\n",
|
| 280 |
+
" ...,\n",
|
| 281 |
+
" [-0.6338103 , -1.0124918 , -0.16602719, ..., 0.41313285,\n",
|
| 282 |
+
" 0.13469052, -0.80087566],\n",
|
| 283 |
+
" [-0.7451872 , -1.3577603 , -0.2996795 , ..., 0.34630668,\n",
|
| 284 |
+
" 0.68043745, 0.45768362],\n",
|
| 285 |
+
" [ 0.0121759 , -0.7674626 , -0.33309257, ..., 0.09013975,\n",
|
| 286 |
+
" 0.46882132, 1.0034306 ]]], dtype=float32),\n",
|
| 287 |
+
" array([[[-0.64494795, -0.789738 , -0.5447087 , ..., 0.0233136 ,\n",
|
| 288 |
+
" 0.14582822, -0.35536796],\n",
|
| 289 |
+
" [-0.8120134 , -0.7340495 , -0.42219412, ..., -0.31081718,\n",
|
| 290 |
+
" 0.46882132, 0.15696591],\n",
|
| 291 |
+
" [-0.08806333, 0.03445129, 0.12355283, ..., -0.2774041 ,\n",
|
| 292 |
+
" 0.73612595, 0.7249882 ],\n",
|
| 293 |
+
" ...,\n",
|
| 294 |
+
" [ 0.09013975, 0.0233136 , -0.24399103, ..., 0.34630668,\n",
|
| 295 |
+
" 0.914329 , 0.6358867 ],\n",
|
| 296 |
+
" [ 0.2906182 , 0.335169 , 0.624749 , ..., 0.24606745,\n",
|
| 297 |
+
" 0.9254667 , 0.9922929 ],\n",
|
| 298 |
+
" [ 0.44654593, 0.70271283, 1.0479814 , ..., -0.09920102,\n",
|
| 299 |
+
" 0.37971976, 0.70271283]],\n",
|
| 300 |
+
" \n",
|
| 301 |
+
" [[-0.23285334, -0.01009948, -0.1326141 , ..., 1.1593583 ,\n",
|
| 302 |
+
" 1.5603153 , 1.5603153 ],\n",
|
| 303 |
+
" [-0.23285334, 0.13469052, 0.1792413 , ..., 0.98115516,\n",
|
| 304 |
+
" 1.5603153 , 1.5603153 ],\n",
|
| 305 |
+
" [-0.36650565, -0.01009948, 0.190379 , ..., 0.8697783 ,\n",
|
| 306 |
+
" 1.5603153 , 1.5603153 ],\n",
|
| 307 |
+
" ...,\n",
|
| 308 |
+
" [-0.16602719, -0.06578795, 0.190379 , ..., -1.4357241 ,\n",
|
| 309 |
+
" -1.368898 , -1.6027895 ],\n",
|
| 310 |
+
" [-0.2662664 , -0.35536796, 0.190379 , ..., -1.513688 ,\n",
|
| 311 |
+
" -1.3800358 , -1.6362027 ],\n",
|
| 312 |
+
" [-0.789738 , -0.7786003 , -0.17716487, ..., -1.7141665 ,\n",
|
| 313 |
+
" -1.5693765 , -1.9591957 ]],\n",
|
| 314 |
+
" \n",
|
| 315 |
+
" [[-0.12147641, -0.01009948, 0.0233136 , ..., 0.769539 ,\n",
|
| 316 |
+
" 0.8140898 , 0.2906182 ],\n",
|
| 317 |
+
" [-0.35536796, -0.2774041 , -0.16602719, ..., 0.55792284,\n",
|
| 318 |
+
" 0.68043745, 0.335169 ],\n",
|
| 319 |
+
" [-0.18830256, 0.1681036 , 0.7918144 , ..., 0.70271283,\n",
|
| 320 |
+
" 0.4910967 , 0.10127745],\n",
|
| 321 |
+
" ...,\n",
|
| 322 |
+
" [-0.97907877, -0.97907877, -0.9011149 , ..., -0.43333182,\n",
|
| 323 |
+
" -0.5447087 , -0.1437518 ],\n",
|
| 324 |
+
" [-1.1238688 , -1.1461442 , -1.0347673 , ..., -0.7117741 ,\n",
|
| 325 |
+
" -0.12147641, 0.479959 ],\n",
|
| 326 |
+
" [-0.8788395 , -0.62267256, -0.9122526 , ..., -0.92339027,\n",
|
| 327 |
+
" -0.42219412, 0.22379206]]], dtype=float32),\n",
|
| 328 |
+
" array([[[ 2.7948052e-01, 1.0368437e+00, 1.5603153e+00, ...,\n",
|
| 329 |
+
" -1.2797965e+00, -8.3428878e-01, -2.4399103e-01],\n",
|
| 330 |
+
" [ 6.7864366e-02, 3.4630668e-01, 1.1704960e+00, ...,\n",
|
| 331 |
+
" -7.7860028e-01, -8.3428878e-01, -3.1081718e-01],\n",
|
| 332 |
+
" [ 2.6834285e-01, 9.0139754e-02, 2.3492976e-01, ...,\n",
|
| 333 |
+
" -6.3381028e-01, -9.5680338e-01, -4.4446951e-01],\n",
|
| 334 |
+
" ...,\n",
|
| 335 |
+
" [-1.3688980e+00, -1.2241080e+00, -9.4566566e-01, ...,\n",
|
| 336 |
+
" -1.0570426e+00, -2.6626641e-01, 1.2707351e+00],\n",
|
| 337 |
+
" [-1.0793180e+00, -8.8997722e-01, -1.1015934e+00, ...,\n",
|
| 338 |
+
" -7.6746261e-01, -5.4650255e-02, 1.0257059e+00],\n",
|
| 339 |
+
" [-7.3404950e-01, -6.5608567e-01, -1.0124918e+00, ...,\n",
|
| 340 |
+
" -5.3357106e-01, 2.6834285e-01, 8.5864055e-01]],\n",
|
| 341 |
+
" \n",
|
| 342 |
+
" [[ 1.1593583e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
|
| 343 |
+
" -4.5560721e-01, -1.1033872e-01, 8.6977828e-01],\n",
|
| 344 |
+
" [ 8.4750289e-01, 1.5603153e+00, 1.5603153e+00, ...,\n",
|
| 345 |
+
" -1.3020718e+00, -1.0459049e+00, 6.3588673e-01],\n",
|
| 346 |
+
" [ 7.9002060e-02, 1.3469052e-01, -3.1081718e-01, ...,\n",
|
| 347 |
+
" -1.3688980e+00, -1.0570426e+00, 4.9109671e-01],\n",
|
| 348 |
+
" ...,\n",
|
| 349 |
+
" [ 1.5603153e+00, 6.4702439e-01, -4.3512560e-02, ...,\n",
|
| 350 |
+
" 5.2450979e-01, 8.9205366e-01, -2.9967949e-01],\n",
|
| 351 |
+
" [ 1.2930106e+00, 6.1361134e-01, -4.6674490e-01, ...,\n",
|
| 352 |
+
" 2.7948052e-01, 8.4750289e-01, 7.8067672e-01],\n",
|
| 353 |
+
" [ 1.0382106e-03, 2.5720516e-01, -4.1105643e-01, ...,\n",
|
| 354 |
+
" 6.2474900e-01, 1.5603153e+00, 1.5603153e+00]],\n",
|
| 355 |
+
" \n",
|
| 356 |
+
" [[-9.2339027e-01, -1.0236295e+00, -1.0347673e+00, ...,\n",
|
| 357 |
+
" 2.6834285e-01, -1.9944026e-01, -1.1033872e-01],\n",
|
| 358 |
+
" [-8.2315105e-01, -1.1127311e+00, -9.7907877e-01, ...,\n",
|
| 359 |
+
" 1.2818729e+00, 3.2403129e-01, 6.7864366e-02],\n",
|
| 360 |
+
" [-8.0087566e-01, -9.7907877e-01, -8.0087566e-01, ...,\n",
|
| 361 |
+
" 1.5046268e+00, 6.1361134e-01, -1.7716487e-01],\n",
|
| 362 |
+
" ...,\n",
|
| 363 |
+
" [ 1.3932499e+00, 6.8043745e-01, 2.3492976e-01, ...,\n",
|
| 364 |
+
" -1.8144057e+00, -1.2129703e+00, 2.3313597e-02],\n",
|
| 365 |
+
" [ 1.5603153e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
|
| 366 |
+
" -1.7587173e+00, -1.1572819e+00, 5.6906056e-01],\n",
|
| 367 |
+
" [ 1.5603153e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
|
| 368 |
+
" -1.8366811e+00, -8.7883949e-01, 1.3598367e+00]]], dtype=float32),\n",
|
| 369 |
+
" array([[[-1.4245864 , -1.3466226 , -1.079318 , ..., -0.36650565,\n",
|
| 370 |
+
" -0.66722333, -1.079318 ],\n",
|
| 371 |
+
" [-1.1350064 , -1.2352457 , -1.4357241 , ..., -0.01009948,\n",
|
| 372 |
+
" -0.41105643, -0.7563249 ],\n",
|
| 373 |
+
" [-1.0459049 , -0.84542644, -1.368898 , ..., -0.22171564,\n",
|
| 374 |
+
" -0.4778826 , -0.82315105],\n",
|
| 375 |
+
" ...,\n",
|
| 376 |
+
" [-2.2599134 , -2.1151235 , -1.6139272 , ..., 0.5245098 ,\n",
|
| 377 |
+
" 0.41313285, 0.37971976],\n",
|
| 378 |
+
" [-2.560631 , -2.204225 , -1.7921304 , ..., 0.10127745,\n",
|
| 379 |
+
" 0.3128936 , 0.37971976],\n",
|
| 380 |
+
" [-2.2933266 , -2.1596742 , -2.0705726 , ..., 0.0233136 ,\n",
|
| 381 |
+
" 0.3017559 , 0.4910967 ]],\n",
|
| 382 |
+
" \n",
|
| 383 |
+
" [[-1.0570426 , -1.3911734 , -1.658478 , ..., 0.769539 ,\n",
|
| 384 |
+
" 0.7806767 , 0.95887977],\n",
|
| 385 |
+
" [-1.6362027 , -1.4134488 , -1.5693765 , ..., 1.0145682 ,\n",
|
| 386 |
+
" 1.1816336 , 1.1370829 ],\n",
|
| 387 |
+
" [-1.914645 , -1.1127311 , -1.224108 , ..., 0.68043745,\n",
|
| 388 |
+
" 0.9922929 , 0.85864055],\n",
|
| 389 |
+
" ...,\n",
|
| 390 |
+
" [-2.1262612 , -2.026022 , -1.7921304 , ..., 0.4910967 ,\n",
|
| 391 |
+
" 0.9031913 , 1.0702567 ],\n",
|
| 392 |
+
" [-2.3712904 , -2.5494936 , -2.304464 , ..., 0.70271283,\n",
|
| 393 |
+
" 0.914329 , 1.0479814 ],\n",
|
| 394 |
+
" [-2.2710512 , -2.3267395 , -2.4715295 , ..., 1.0145682 ,\n",
|
| 395 |
+
" 0.7806767 , 0.7472636 ]],\n",
|
| 396 |
+
" \n",
|
| 397 |
+
" [[-0.94566566, -0.6003972 , -0.85656416, ..., 1.2373221 ,\n",
|
| 398 |
+
" 1.5603153 , 1.5603153 ],\n",
|
| 399 |
+
" [-1.1795572 , -0.38878104, -0.37764335, ..., 1.5603153 ,\n",
|
| 400 |
+
" 1.5603153 , 1.5603153 ],\n",
|
| 401 |
+
" [-1.335485 , -1.3577603 , -0.32195488, ..., 1.5603153 ,\n",
|
| 402 |
+
" 1.5603153 , 1.5603153 ],\n",
|
| 403 |
+
" ...,\n",
|
| 404 |
+
" [-1.224108 , -1.1795572 , -1.4914126 , ..., -0.04351256,\n",
|
| 405 |
+
" -0.5112957 , -0.64494795],\n",
|
| 406 |
+
" [-1.3132095 , -1.8589565 , -1.7587173 , ..., -0.03237487,\n",
|
| 407 |
+
" -0.42219412, -0.1437518 ],\n",
|
| 408 |
+
" [-1.6473403 , -2.4381166 , -2.1262612 , ..., 0.37971976,\n",
|
| 409 |
+
" 0.34630668, 0.46882132]]], dtype=float32),\n",
|
| 410 |
+
" array([[[-0.84542644, -0.8120134 , -0.6894987 , ..., 0.46882132,\n",
|
| 411 |
+
" -0.01009948, -0.39991874],\n",
|
| 412 |
+
" [-0.65608567, -0.8120134 , -0.6338103 , ..., 0.59133595,\n",
|
| 413 |
+
" 0.04558898, -0.5112957 ],\n",
|
| 414 |
+
" [-0.96794105, -0.8342888 , -0.35536796, ..., 0.5801982 ,\n",
|
| 415 |
+
" 0.13469052, -0.37764335],\n",
|
| 416 |
+
" ...,\n",
|
| 417 |
+
" [-0.09920102, -0.11033872, -1.224108 , ..., -0.92339027,\n",
|
| 418 |
+
" -1.2686588 , -0.85656416],\n",
|
| 419 |
+
" [ 0.0121759 , -0.01009948, -1.0236295 , ..., -1.2129703 ,\n",
|
| 420 |
+
" -1.4023111 , -0.7340495 ],\n",
|
| 421 |
+
" [ 0.23492976, -0.07692564, -0.70063645, ..., -0.70063645,\n",
|
| 422 |
+
" -0.934528 , -1.0124918 ]],\n",
|
| 423 |
+
" \n",
|
| 424 |
+
" [[-1.6139272 , -0.97907877, -0.7340495 , ..., -0.8342888 ,\n",
|
| 425 |
+
" -0.6894987 , -1.0347673 ],\n",
|
| 426 |
+
" [-1.0459049 , -0.5892595 , -0.789738 , ..., -1.257521 ,\n",
|
| 427 |
+
" -1.0124918 , -1.4357241 ],\n",
|
| 428 |
+
" [-0.42219412, -0.5892595 , -1.4134488 , ..., -1.5025504 ,\n",
|
| 429 |
+
" -1.2463834 , -1.4802749 ],\n",
|
| 430 |
+
" ...,\n",
|
| 431 |
+
" [ 0.10127745, 0.40199515, 0.13469052, ..., 0.25720516,\n",
|
| 432 |
+
" 0.55792284, 0.12355283],\n",
|
| 433 |
+
" [-0.2885418 , -0.2551287 , 0.41313285, ..., 0.46882132,\n",
|
| 434 |
+
" -0.33309257, -1.2797965 ],\n",
|
| 435 |
+
" [-0.2996795 , -0.64494795, 0.9366044 , ..., 0.56906056,\n",
|
| 436 |
+
" -0.41105643, -1.3466226 ]],\n",
|
| 437 |
+
" \n",
|
| 438 |
+
" [[-1.1238688 , -1.4134488 , -1.7364419 , ..., -1.1350064 ,\n",
|
| 439 |
+
" -1.6918911 , -2.3156018 ],\n",
|
| 440 |
+
" [-1.4914126 , -1.3020718 , -0.99021643, ..., -1.658478 ,\n",
|
| 441 |
+
" -1.7698549 , -1.8812319 ],\n",
|
| 442 |
+
" [-1.5582387 , -0.85656416, -0.43333182, ..., -1.6027895 ,\n",
|
| 443 |
+
" -1.914645 , -1.7698549 ],\n",
|
| 444 |
+
" ...,\n",
|
| 445 |
+
" [ 0.6358867 , 0.9031913 , 1.0702567 , ..., 0.11241514,\n",
|
| 446 |
+
" 0.07900206, 0.34630668],\n",
|
| 447 |
+
" [ 0.70271283, 1.1593583 , 0.9254667 , ..., 0.335169 ,\n",
|
| 448 |
+
" 0.41313285, 0.23492976],\n",
|
| 449 |
+
" [ 0.35744438, 1.1482205 , 0.8697783 , ..., -0.2662664 ,\n",
|
| 450 |
+
" 0.56906056, 0.624749 ]]], dtype=float32)]"
|
| 451 |
+
]
|
| 452 |
+
},
|
| 453 |
+
"execution_count": 8,
|
| 454 |
+
"metadata": {},
|
| 455 |
+
"output_type": "execute_result"
|
| 456 |
+
}
|
| 457 |
+
],
|
| 458 |
+
"source": [
|
| 459 |
+
"patches_top_5"
|
| 460 |
+
]
|
| 461 |
+
},
|
| 462 |
+
{
|
| 463 |
+
"cell_type": "code",
|
| 464 |
+
"execution_count": 11,
|
| 465 |
+
"id": "4edb20e7",
|
| 466 |
+
"metadata": {},
|
| 467 |
+
"outputs": [],
|
| 468 |
+
"source": [
|
| 469 |
+
"import argparse\n",
|
| 470 |
+
"import os\n",
|
| 471 |
+
"import shutil\n",
|
| 472 |
+
"import time\n",
|
| 473 |
+
"import yaml\n",
|
| 474 |
+
"import sys\n",
|
| 475 |
+
"import gdown\n",
|
| 476 |
+
"import numpy as np\n",
|
| 477 |
+
"import torch\n",
|
| 478 |
+
"import torch.distributed as dist\n",
|
| 479 |
+
"import torch.multiprocessing as mp\n",
|
| 480 |
+
"import torch.nn as nn\n",
|
| 481 |
+
"import torch.nn.functional as F\n",
|
| 482 |
+
"from monai.config import KeysCollection\n",
|
| 483 |
+
"from monai.metrics import Cumulative, CumulativeAverage\n",
|
| 484 |
+
"from monai.networks.nets import milmodel, resnet, MILModel\n",
|
| 485 |
+
"from monai.transforms import (\n",
|
| 486 |
+
" Compose,\n",
|
| 487 |
+
" GridPatchd,\n",
|
| 488 |
+
" LoadImaged,\n",
|
| 489 |
+
" MapTransform,\n",
|
| 490 |
+
" RandFlipd,\n",
|
| 491 |
+
" RandGridPatchd,\n",
|
| 492 |
+
" RandRotate90d,\n",
|
| 493 |
+
" ScaleIntensityRanged,\n",
|
| 494 |
+
" SplitDimd,\n",
|
| 495 |
+
" ToTensord,\n",
|
| 496 |
+
" ConcatItemsd, \n",
|
| 497 |
+
" SelectItemsd,\n",
|
| 498 |
+
" EnsureChannelFirstd,\n",
|
| 499 |
+
" RepeatChanneld,\n",
|
| 500 |
+
" DeleteItemsd,\n",
|
| 501 |
+
" EnsureTyped,\n",
|
| 502 |
+
" ClipIntensityPercentilesd,\n",
|
| 503 |
+
" MaskIntensityd,\n",
|
| 504 |
+
" HistogramNormalized,\n",
|
| 505 |
+
" RandBiasFieldd,\n",
|
| 506 |
+
" RandCropByPosNegLabeld,\n",
|
| 507 |
+
" NormalizeIntensityd,\n",
|
| 508 |
+
" SqueezeDimd,\n",
|
| 509 |
+
" CropForegroundd,\n",
|
| 510 |
+
" ScaleIntensityd,\n",
|
| 511 |
+
" SpatialPadd,\n",
|
| 512 |
+
" CenterSpatialCropd,\n",
|
| 513 |
+
" ScaleIntensityd,\n",
|
| 514 |
+
" Transposed,\n",
|
| 515 |
+
" RandWeightedCropd,\n",
|
| 516 |
+
")\n",
|
| 517 |
+
"from sklearn.metrics import cohen_kappa_score\n",
|
| 518 |
+
"from torch.cuda.amp import GradScaler, autocast\n",
|
| 519 |
+
"from torch.utils.data.dataloader import default_collate\n",
|
| 520 |
+
"from torchvision.models.resnet import ResNet50_Weights\n",
|
| 521 |
+
"from src.data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd\n",
|
| 522 |
+
"from torch.utils.data.distributed import DistributedSampler\n",
|
| 523 |
+
"from torch.utils.tensorboard import SummaryWriter\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"import matplotlib.pyplot as plt\n",
|
| 526 |
+
"\n",
|
| 527 |
+
"import wandb\n",
|
| 528 |
+
"import math\n",
|
| 529 |
+
"from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset\n",
|
| 530 |
+
"\n",
|
| 531 |
+
"from src.model.MIL import MILModel_3D\n",
|
| 532 |
+
"from src.model.csPCa_model import csPCa_Model\n",
|
| 533 |
+
"\n",
|
| 534 |
+
"import logging\n",
|
| 535 |
+
"from pathlib import Path"
|
| 536 |
+
]
|
| 537 |
+
},
|
| 538 |
+
{
|
| 539 |
+
"cell_type": "code",
|
| 540 |
+
"execution_count": 13,
|
| 541 |
+
"id": "e42cc132",
|
| 542 |
+
"metadata": {},
|
| 543 |
+
"outputs": [],
|
| 544 |
+
"source": [
|
| 545 |
+
"transform_image = Compose(\n",
|
| 546 |
+
" [\n",
|
| 547 |
+
" LoadImaged(keys=[\"image\", \"mask\"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),\n",
|
| 548 |
+
" ClipMaskIntensityPercentilesd(keys=[\"image\"], lower=0, upper=99.5, mask_key=\"mask\"),\n",
|
| 549 |
+
" NormalizeIntensity_customd(keys=[\"image\"], mask_key=\"mask\", channel_wise=True),\n",
|
| 550 |
+
" EnsureTyped(keys=[\"label\"], dtype=torch.float32),\n",
|
| 551 |
+
" ToTensord(keys=[\"image\", \"label\"]),\n",
|
| 552 |
+
" ]\n",
|
| 553 |
+
")\n",
|
| 554 |
+
"dataset_image = Dataset(data=data_list, transform=transform_image)\n"
|
| 555 |
+
]
|
| 556 |
+
},
|
| 557 |
+
{
|
| 558 |
+
"cell_type": "code",
|
| 559 |
+
"execution_count": 19,
|
| 560 |
+
"id": "bcdddd9e",
|
| 561 |
+
"metadata": {},
|
| 562 |
+
"outputs": [
|
| 563 |
+
{
|
| 564 |
+
"data": {
|
| 565 |
+
"text/plain": [
|
| 566 |
+
"(270, 270, 28)"
|
| 567 |
+
]
|
| 568 |
+
},
|
| 569 |
+
"execution_count": 19,
|
| 570 |
+
"metadata": {},
|
| 571 |
+
"output_type": "execute_result"
|
| 572 |
+
}
|
| 573 |
+
],
|
| 574 |
+
"source": [
|
| 575 |
+
"dataset_image[0]['image'][0].numpy().shape"
|
| 576 |
]
|
| 577 |
},
|
| 578 |
{
|
| 579 |
"cell_type": "code",
|
| 580 |
"execution_count": null,
|
| 581 |
+
"id": "56072a2b",
|
| 582 |
"metadata": {},
|
| 583 |
"outputs": [],
|
| 584 |
"source": []
|