Anirudh Balaraman commited on
Commit
8c2b158
·
1 Parent(s): a4ef78c

update json dump

Browse files
Files changed (3) hide show
  1. run_inference.py +13 -9
  2. src/utils.py +2 -2
  3. temp.ipynb +27 -1
run_inference.py CHANGED
@@ -148,7 +148,7 @@ if __name__ == "__main__":
148
  pirads_model.eval()
149
  cspca_risk_list = []
150
  cspca_model.eval()
151
- top5_patches = []
152
  with torch.no_grad():
153
  for idx, batch_data in enumerate(loader):
154
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
@@ -176,18 +176,22 @@ if __name__ == "__main__":
176
  for i in range(5):
177
  patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()
178
  patches_top_5.append(patch_temp)
 
 
 
 
179
 
180
- parent_image = get_parent_image(args)
181
-
182
- coords = get_patch_coordinate(patches_top_5, parent_image, args)
183
 
184
  for i,j in enumerate(files):
185
  logging.info(f"File: {j}, PIRADS score: {pirads_list[i]}, csPCa risk score: {cspca_risk_list[i]:.4f}")
186
 
187
- output_dict = {
188
- 'Predicted PIRAD Score': pirads_list[i] + 2.0,
189
- 'csPCa risk': cspca_risk_list[i],
190
- 'Top left coordinate of top 5 patches(x,y,z)': coords,
191
- }
192
  with open(os.path.join(args.output_dir, "results.json"), 'w') as f:
193
  json.dump(output_dict, f, indent=4)
 
148
  pirads_model.eval()
149
  cspca_risk_list = []
150
  cspca_model.eval()
151
+ patches_top_5_list = []
152
  with torch.no_grad():
153
  for idx, batch_data in enumerate(loader):
154
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
 
176
  for i in range(5):
177
  patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()
178
  patches_top_5.append(patch_temp)
179
+ patches_top_5_list.append(patches_top_5)
180
+ coords_list = []
181
+ for i in args.data_list:
182
+ parent_image = get_parent_image([i], args)
183
 
184
+ coords = get_patch_coordinate(patches_top_5, parent_image, args)
185
+ coords_list.append(coords)
186
+ output_dict = {}
187
 
188
  for i,j in enumerate(files):
189
  logging.info(f"File: {j}, PIRADS score: {pirads_list[i]}, csPCa risk score: {cspca_risk_list[i]:.4f}")
190
 
191
+ output_dict[j] = {
192
+ 'Predicted PIRAD Score': pirads_list[i] + 2.0,
193
+ 'csPCa risk': cspca_risk_list[i],
194
+ 'Top left coordinate of top 5 patches(x,y,z)': coords_list[i],
195
+ }
196
  with open(os.path.join(args.output_dir, "results.json"), 'w') as f:
197
  json.dump(output_dict, f, indent=4)
src/utils.py CHANGED
@@ -162,7 +162,7 @@ def get_patch_coordinate(patches_top_5, parent_image, args):
162
  return coords
163
 
164
 
165
- def get_parent_image(args):
166
  transform_image = Compose(
167
  [
168
  LoadImaged(keys=["image", "mask"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
@@ -172,7 +172,7 @@ def get_parent_image(args):
172
  ToTensord(keys=["image", "label"]),
173
  ]
174
  )
175
- dataset_image = Dataset(data=args.data_list, transform=transform_image)
176
  return dataset_image[0]['image'][0].numpy()
177
 
178
  '''
 
162
  return coords
163
 
164
 
165
+ def get_parent_image(temp_data_list, args):
166
  transform_image = Compose(
167
  [
168
  LoadImaged(keys=["image", "mask"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
 
172
  ToTensord(keys=["image", "label"]),
173
  ]
174
  )
175
+ dataset_image = Dataset(data=temp_data_list, transform=transform_image)
176
  return dataset_image[0]['image'][0].numpy()
177
 
178
  '''
temp.ipynb CHANGED
@@ -577,9 +577,35 @@
577
  },
578
  {
579
  "cell_type": "code",
580
- "execution_count": null,
581
  "id": "56072a2b",
582
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  "outputs": [],
584
  "source": []
585
  }
 
577
  },
578
  {
579
  "cell_type": "code",
580
+ "execution_count": 20,
581
  "id": "56072a2b",
582
  "metadata": {},
583
+ "outputs": [
584
+ {
585
+ "data": {
586
+ "text/plain": [
587
+ "[{'image': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/t2_histmatched/1009449_11049598.nrrd',\n",
588
+ " 'dwi': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/DWI_histmatched/1009449_11049598.nrrd',\n",
589
+ " 'adc': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/ADC_histmatched/1009449_11049598.nrrd',\n",
590
+ " 'heatmap': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/heatmaps/1009449_11049598.nrrd',\n",
591
+ " 'mask': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/prostate_mask/1009449_11049598.nrrd',\n",
592
+ " 'label': 0}]"
593
+ ]
594
+ },
595
+ "execution_count": 20,
596
+ "metadata": {},
597
+ "output_type": "execute_result"
598
+ }
599
+ ],
600
+ "source": [
601
+ "data_list"
602
+ ]
603
+ },
604
+ {
605
+ "cell_type": "code",
606
+ "execution_count": null,
607
+ "id": "db1163d2",
608
+ "metadata": {},
609
  "outputs": [],
610
  "source": []
611
  }