FelixzeroSun commited on
Commit
ef42723
·
verified ·
1 Parent(s): 46a770f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -44
app.py CHANGED
@@ -45,9 +45,6 @@ from process import SynthradAlgorithm2
45
 
46
  from process_1 import SynthradAlgorithm1
47
 
48
- # =========================
49
- # Streamlit UI
50
- # =========================
51
  st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide")
52
  st.title("SynthRad — MRI/CBCT + Mask → synthetic CT")
53
  st.image("./workflow.png",width=800)
@@ -94,32 +91,26 @@ src = st.radio("Source", ["Sample", "Upload"], index=0, horizontal=True)
94
  def build_sample_map(task_name: str):
95
  repo_dir = REPO_DIRS[task_name]
96
  if task_name == "Task 1 (MR → CT)":
97
- vol_key = "mri"
98
- vol_fname = "mr.mha"
99
- mask_fname = "mask1.mha"
100
  else:
101
- vol_key = "cbct"
102
- vol_fname = "cbct.mha"
103
- mask_fname = "mask2.mha"
 
 
 
 
 
 
104
  sample_map = {
105
- "Abdomen (sample)": {
106
- "region": "Abdomen",
107
- "vol": os.path.join(repo_dir, "Abdomen", vol_fname),
108
- "mask": os.path.join(repo_dir, "Abdomen", mask_fname),
109
- },
110
- "Head and Neck (sample)": {
111
- "region": "Head and Neck",
112
- "vol": os.path.join(repo_dir, "Head and Neck", vol_fname),
113
- "mask": os.path.join(repo_dir, "Head and Neck", mask_fname),
114
- },
115
- "Thorax (sample)": {
116
- "region": "Thorax",
117
- "vol": os.path.join(repo_dir, "Thorax", vol_fname),
118
- "mask": os.path.join(repo_dir, "Thorax", mask_fname),
119
- },
120
  }
121
  return sample_map
122
 
 
123
  SAMPLE_MAP = build_sample_map(task)
124
 
125
 
@@ -214,40 +205,116 @@ if run_btn:
214
  if st.session_state.vol_np is None:
215
  st.info("Select Upload or Sample, then click Run")
216
  else:
217
-
218
  out_lps = sitk.DICOMOrient(st.session_state.synth_ct, "LPS")
219
- vol = sitk.GetArrayFromImage(out_lps).astype(np.float32)
220
- D, H, W = vol.shape
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  col_d1, col_d2, col_d3 = st.columns(3)
223
 
224
-
225
  with col_d3:
226
- _download_sitk_image(
227
- st.session_state.synth_ct,
228
- file_name="synth_ct.nii.gz",
229
- label="Download synthetic CT",
230
- )
231
-
232
 
233
  with col_d1:
234
  if st.session_state.input_vol is not None:
235
  in_name = "input_mr.nii.gz" if task == "Task 1 (MR → CT)" else "input_cbct.nii.gz"
236
  in_label = "Download input MRI" if task == "Task 1 (MR → CT)" else "Download input CBCT"
237
- _download_sitk_image(
238
- st.session_state.input_vol,
239
- file_name=in_name,
240
- label=in_label,
241
- )
242
  else:
243
  st.button("Download input", disabled=True)
244
 
245
  with col_d2:
246
  if st.session_state.input_mask is not None:
247
- _download_sitk_image(
248
- st.session_state.input_mask,
249
- file_name="input_mask.nii.gz",
250
- label="Download input Mask",
251
- )
252
  else:
253
  st.button("Download input Mask", disabled=True)
 
45
 
46
  from process_1 import SynthradAlgorithm1
47
 
 
 
 
48
  st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide")
49
  st.title("SynthRad — MRI/CBCT + Mask → synthetic CT")
50
  st.image("./workflow.png",width=800)
 
91
  def build_sample_map(task_name: str):
92
  repo_dir = REPO_DIRS[task_name]
93
  if task_name == "Task 1 (MR → CT)":
94
+ vol_fname = "mr.mha"
95
+ mask_fname = "mask1.mha"
 
96
  else:
97
+ vol_fname = "cbct.mha"
98
+ mask_fname = "mask2.mha"
99
+
100
+ def pack(region_dir):
101
+ vol_path = os.path.join(repo_dir, region_dir, vol_fname)
102
+ mask_path = os.path.join(repo_dir, region_dir, mask_fname)
103
+ gt_path = os.path.join(repo_dir, region_dir, "ct.mha") # 约定:GT=ct.mha
104
+ return {"vol": vol_path, "mask": mask_path, "gt": gt_path}
105
+
106
  sample_map = {
107
+ "Abdomen (sample)": {"region": "Abdomen", **pack("Abdomen")},
108
+ "Head and Neck (sample)": {"region": "Head and Neck", **pack("Head and Neck")},
109
+ "Thorax (sample)": {"region": "Thorax", **pack("Thorax")},
 
 
 
 
 
 
 
 
 
 
 
 
110
  }
111
  return sample_map
112
 
113
+
114
  SAMPLE_MAP = build_sample_map(task)
115
 
116
 
 
205
  if st.session_state.vol_np is None:
206
  st.info("Select Upload or Sample, then click Run")
207
  else:
208
+ in_lps = sitk.DICOMOrient(st.session_state.input_vol, "LPS")
209
  out_lps = sitk.DICOMOrient(st.session_state.synth_ct, "LPS")
 
 
210
 
211
+ res = sitk.ResampleImageFilter()
212
+ res.SetReferenceImage(in_lps)
213
+ res.SetInterpolator(sitk.sitkLinear)
214
+ res.SetOutputPixelType(out_lps.GetPixelID())
215
+ out_on_input = res.Execute(out_lps)
216
+
217
+ gt_on_input = None
218
+ if src == "Sample":
219
+ gt_path = SAMPLE_MAP[sample_key].get("gt", None)
220
+ if gt_path and os.path.exists(gt_path):
221
+ gt_img = sitk.DICOMOrient(sitk.ReadImage(gt_path), "LPS")
222
+ res.SetReferenceImage(in_lps)
223
+ res.SetInterpolator(sitk.sitkLinear)
224
+ res.SetOutputPixelType(gt_img.GetPixelID())
225
+ gt_on_input = res.Execute(gt_img)
226
+
227
+ # numpy
228
+ in_vol = sitk.GetArrayFromImage(in_lps).astype(np.float32)
229
+ syn_vol = sitk.GetArrayFromImage(out_on_input).astype(np.float32)
230
+ gt_vol = sitk.GetArrayFromImage(gt_on_input).astype(np.float32) if gt_on_input is not None else None
231
+
232
+ st.subheader("Input vs Synthetic CT Viewer (Axial only)")
233
+ n_slices = in_vol.shape[0]
234
+ idx = st.slider("Slice index (Axial/Z)", 0, n_slices - 1, n_slices // 2)
235
+
236
+ def get_axial(arr, k):
237
+ return arr[k, :, :]
238
+
239
+ sl_in = get_axial(in_vol, idx)
240
+ sl_syn = get_axial(syn_vol, idx)
241
+ img_in = _norm2u8(sl_in)
242
+ img_syn = _norm2u8(sl_syn)
243
+ img_gt = _norm2u8(get_axial(gt_vol, idx)) if gt_vol is not None else None
244
+
245
+ overlay_mask = st.checkbox("Overlay mask (red)")
246
+ alpha = st.slider("Mask opacity", 0.0, 1.0, 0.35, 0.05, disabled=not overlay_mask)
247
+ mask_slice = None
248
+ if overlay_mask and st.session_state.input_mask is not None:
249
+ mask_lps = sitk.DICOMOrient(st.session_state.input_mask, "LPS")
250
+ res_nn = sitk.ResampleImageFilter()
251
+ res_nn.SetReferenceImage(in_lps)
252
+ res_nn.SetInterpolator(sitk.sitkNearestNeighbor)
253
+ mask_on_input = res_nn.Execute(mask_lps)
254
+ mask_np = sitk.GetArrayFromImage(mask_on_input)
255
+ mask_slice = get_axial(mask_np, min(idx, mask_np.shape[0]-1))
256
+ mask_plot = np.where(mask_slice > 0, 1.0, np.nan)
257
+ else:
258
+ mask_plot = None
259
+
260
+ import plotly.graph_objects as go
261
+ from plotly.subplots import make_subplots
262
+ sx, sy, _ = in_lps.GetSpacing()
263
+ xs = np.arange(img_in.shape[1]) * sx
264
+ ys = np.arange(img_in.shape[0]) * sy
265
+
266
+ cols = 3 if (src == "Sample" and img_gt is not None) else 2
267
+ titles = ["Input (MRI/CBCT)", "Synthetic CT"] + (["Ground-Truth CT"] if cols == 3 else [])
268
+ fig = make_subplots(rows=1, cols=cols, subplot_titles=tuple(titles))
269
+
270
+ fig.add_trace(go.Heatmap(z=img_in, x=xs, y=ys, colorscale="gray",
271
+ zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=1)
272
+ # synCT
273
+ fig.add_trace(go.Heatmap(z=img_syn, x=xs, y=ys, colorscale="gray",
274
+ zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=2)
275
+ # GT
276
+ if cols == 3:
277
+ fig.add_trace(go.Heatmap(z=img_gt, x=xs, y=ys, colorscale="gray",
278
+ zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=3)
279
+
280
+ # mask overlay
281
+ if mask_plot is not None:
282
+ red_scale = [[0.0, "rgba(255,0,0,1.0)"], [1.0, "rgba(255,0,0,1.0)"]]
283
+ for c in range(1, cols+1):
284
+ fig.add_trace(go.Heatmap(z=mask_plot, x=xs, y=ys, colorscale=red_scale,
285
+ showscale=False, opacity=alpha, hoverinfo="skip"), row=1, col=c)
286
+
287
+ for c in range(1, cols+1):
288
+ fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=c)
289
+ fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=c)
290
+
291
+ fig.update_layout(height=600, margin=dict(l=10, r=10, t=40, b=10))
292
+ st.plotly_chart(fig, use_container_width=True)
293
+
294
+ # Caption
295
+ if cols == 3:
296
+ st.caption(f"Axial (Z) slice {idx+1}/{n_slices} — All aligned to input geometry; GT only for samples.")
297
+ else:
298
+ st.caption(f"Axial (Z) slice {idx+1}/{n_slices} — Aligned to input geometry.")
299
  col_d1, col_d2, col_d3 = st.columns(3)
300
 
 
301
  with col_d3:
302
+ _download_sitk_image(st.session_state.synth_ct,
303
+ file_name="synth_ct.nii.gz",
304
+ label="Download synthetic CT")
 
 
 
305
 
306
  with col_d1:
307
  if st.session_state.input_vol is not None:
308
  in_name = "input_mr.nii.gz" if task == "Task 1 (MR → CT)" else "input_cbct.nii.gz"
309
  in_label = "Download input MRI" if task == "Task 1 (MR → CT)" else "Download input CBCT"
310
+ _download_sitk_image(st.session_state.input_vol, file_name=in_name, label=in_label)
 
 
 
 
311
  else:
312
  st.button("Download input", disabled=True)
313
 
314
  with col_d2:
315
  if st.session_state.input_mask is not None:
316
+ _download_sitk_image(st.session_state.input_mask,
317
+ file_name="input_mask.nii.gz",
318
+ label="Download input Mask")
 
 
319
  else:
320
  st.button("Download input Mask", disabled=True)