Spaces:
Sleeping
Sleeping
Anirudh Balaraman commited on
Commit ·
8c2b158
1
Parent(s): a4ef78c
update json dump
Browse files- run_inference.py +13 -9
- src/utils.py +2 -2
- temp.ipynb +27 -1
run_inference.py
CHANGED
|
@@ -148,7 +148,7 @@ if __name__ == "__main__":
|
|
| 148 |
pirads_model.eval()
|
| 149 |
cspca_risk_list = []
|
| 150 |
cspca_model.eval()
|
| 151 |
-
|
| 152 |
with torch.no_grad():
|
| 153 |
for idx, batch_data in enumerate(loader):
|
| 154 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
|
@@ -176,18 +176,22 @@ if __name__ == "__main__":
|
|
| 176 |
for i in range(5):
|
| 177 |
patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()
|
| 178 |
patches_top_5.append(patch_temp)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
|
| 184 |
for i,j in enumerate(files):
|
| 185 |
logging.info(f"File: {j}, PIRADS score: {pirads_list[i]}, csPCa risk score: {cspca_risk_list[i]:.4f}")
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
with open(os.path.join(args.output_dir, "results.json"), 'w') as f:
|
| 193 |
json.dump(output_dict, f, indent=4)
|
|
|
|
| 148 |
pirads_model.eval()
|
| 149 |
cspca_risk_list = []
|
| 150 |
cspca_model.eval()
|
| 151 |
+
patches_top_5_list = []
|
| 152 |
with torch.no_grad():
|
| 153 |
for idx, batch_data in enumerate(loader):
|
| 154 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
|
|
|
| 176 |
for i in range(5):
|
| 177 |
patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()
|
| 178 |
patches_top_5.append(patch_temp)
|
| 179 |
+
patches_top_5_list.append(patches_top_5)
|
| 180 |
+
coords_list = []
|
| 181 |
+
for i in args.data_list:
|
| 182 |
+
parent_image = get_parent_image([i], args)
|
| 183 |
|
| 184 |
+
coords = get_patch_coordinate(patches_top_5, parent_image, args)
|
| 185 |
+
coords_list.append(coords)
|
| 186 |
+
output_dict = {}
|
| 187 |
|
| 188 |
for i,j in enumerate(files):
|
| 189 |
logging.info(f"File: {j}, PIRADS score: {pirads_list[i]}, csPCa risk score: {cspca_risk_list[i]:.4f}")
|
| 190 |
|
| 191 |
+
output_dict[j] = {
|
| 192 |
+
'Predicted PIRAD Score': pirads_list[i] + 2.0,
|
| 193 |
+
'csPCa risk': cspca_risk_list[i],
|
| 194 |
+
'Top left coordinate of top 5 patches(x,y,z)': coords_list[i],
|
| 195 |
+
}
|
| 196 |
with open(os.path.join(args.output_dir, "results.json"), 'w') as f:
|
| 197 |
json.dump(output_dict, f, indent=4)
|
src/utils.py
CHANGED
|
@@ -162,7 +162,7 @@ def get_patch_coordinate(patches_top_5, parent_image, args):
|
|
| 162 |
return coords
|
| 163 |
|
| 164 |
|
| 165 |
-
def get_parent_image(args):
|
| 166 |
transform_image = Compose(
|
| 167 |
[
|
| 168 |
LoadImaged(keys=["image", "mask"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
|
|
@@ -172,7 +172,7 @@ def get_parent_image(args):
|
|
| 172 |
ToTensord(keys=["image", "label"]),
|
| 173 |
]
|
| 174 |
)
|
| 175 |
-
dataset_image = Dataset(data=
|
| 176 |
return dataset_image[0]['image'][0].numpy()
|
| 177 |
|
| 178 |
'''
|
|
|
|
| 162 |
return coords
|
| 163 |
|
| 164 |
|
| 165 |
+
def get_parent_image(temp_data_list, args):
|
| 166 |
transform_image = Compose(
|
| 167 |
[
|
| 168 |
LoadImaged(keys=["image", "mask"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
|
|
|
|
| 172 |
ToTensord(keys=["image", "label"]),
|
| 173 |
]
|
| 174 |
)
|
| 175 |
+
dataset_image = Dataset(data=temp_data_list, transform=transform_image)
|
| 176 |
return dataset_image[0]['image'][0].numpy()
|
| 177 |
|
| 178 |
'''
|
temp.ipynb
CHANGED
|
@@ -577,9 +577,35 @@
|
|
| 577 |
},
|
| 578 |
{
|
| 579 |
"cell_type": "code",
|
| 580 |
-
"execution_count":
|
| 581 |
"id": "56072a2b",
|
| 582 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
"outputs": [],
|
| 584 |
"source": []
|
| 585 |
}
|
|
|
|
| 577 |
},
|
| 578 |
{
|
| 579 |
"cell_type": "code",
|
| 580 |
+
"execution_count": 20,
|
| 581 |
"id": "56072a2b",
|
| 582 |
"metadata": {},
|
| 583 |
+
"outputs": [
|
| 584 |
+
{
|
| 585 |
+
"data": {
|
| 586 |
+
"text/plain": [
|
| 587 |
+
"[{'image': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/t2_histmatched/1009449_11049598.nrrd',\n",
|
| 588 |
+
" 'dwi': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/DWI_histmatched/1009449_11049598.nrrd',\n",
|
| 589 |
+
" 'adc': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/ADC_histmatched/1009449_11049598.nrrd',\n",
|
| 590 |
+
" 'heatmap': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/heatmaps/1009449_11049598.nrrd',\n",
|
| 591 |
+
" 'mask': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/prostate_mask/1009449_11049598.nrrd',\n",
|
| 592 |
+
" 'label': 0}]"
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
"execution_count": 20,
|
| 596 |
+
"metadata": {},
|
| 597 |
+
"output_type": "execute_result"
|
| 598 |
+
}
|
| 599 |
+
],
|
| 600 |
+
"source": [
|
| 601 |
+
"data_list"
|
| 602 |
+
]
|
| 603 |
+
},
|
| 604 |
+
{
|
| 605 |
+
"cell_type": "code",
|
| 606 |
+
"execution_count": null,
|
| 607 |
+
"id": "db1163d2",
|
| 608 |
+
"metadata": {},
|
| 609 |
"outputs": [],
|
| 610 |
"source": []
|
| 611 |
}
|