Anirudh Balaraman commited on
Commit
37d4614
·
1 Parent(s): 30f1102
Files changed (5) hide show
  1. app.py +150 -159
  2. pyproject.toml +1 -0
  3. run_inference.py +0 -2
  4. src/utils.py +19 -26
  5. visualisation.ipynb +13 -19
app.py CHANGED
@@ -1,15 +1,17 @@
1
- import streamlit as st
2
- import subprocess
3
  import os
4
  import shutil
5
- from huggingface_hub import hf_hub_download
6
- import nrrd
 
7
  import matplotlib.pyplot as plt
 
8
  import numpy as np
9
- import matplotlib.patches as patches
10
- import json
11
  import plotly.graph_objects as go
12
- import base64
 
 
13
 
14
  def render_clickable_image(image_path, link_url, width=100):
15
  """
@@ -18,7 +20,7 @@ def render_clickable_image(image_path, link_url, width=100):
18
  # 1. Read the image file and encode it to base64
19
  with open(image_path, "rb") as f:
20
  data = base64.b64encode(f.read()).decode("utf-8")
21
-
22
  # 2. Create the HTML string
23
  # target="_blank" opens the link in a new tab
24
  html_code = f"""
@@ -26,16 +28,13 @@ def render_clickable_image(image_path, link_url, width=100):
26
  <img src="data:image/png;base64,{data}" width="{width}" style="border-radius: 5px;">
27
  </a>
28
  """
29
-
30
  # 3. Render it
31
  st.markdown(html_code, unsafe_allow_html=True)
32
 
33
 
34
  st.set_page_config(
35
- page_title="Prostate Scoring",
36
- page_icon="🩺",
37
- layout="wide",
38
- initial_sidebar_state="expanded"
39
  )
40
 
41
 
@@ -45,10 +44,11 @@ def load_nrrd(file_path):
45
  data, header = nrrd.read(file_path)
46
  return data, header
47
 
 
48
  def display_slicer(scan_paths, mask_path=None, bboxes=None, title="Scan Viewer", key_suffix=""):
49
  """
50
  Displays slicer with Multi-Background Support, Mask Overlay, and Bounding Box Multiselect.
51
-
52
  Args:
53
  scan_paths: Dict of {Label: FilePath}. Example: {"T2W": "path/to/t2.nrrd", "ADC": "..."}
54
  """
@@ -58,24 +58,23 @@ def display_slicer(scan_paths, mask_path=None, bboxes=None, title="Scan Viewer",
58
  # --- CONTROLS SECTION (Right Column) ---
59
  with c_controls:
60
  st.write(f"**{title} Controls**")
61
-
62
  # A. Background Selection
63
  # We assume the first key in the dict is the default
64
  available_scans = list(scan_paths.keys())
65
- selected_scan_name = st.radio("Background Image", available_scans, index=0, key=f"bg_{key_suffix}")
 
 
66
  current_file_path = scan_paths[selected_scan_name]
67
 
68
  # B. Lesion Selection (Multiselect)
69
  box_labels = []
70
  selected_labels = []
71
  if bboxes:
72
- box_labels = [f"Lesion {i+1}" for i in range(len(bboxes))]
73
- st.write("---") # Divider
74
  selected_labels = st.multiselect(
75
- "Select Lesions",
76
- options=box_labels,
77
- default=box_labels,
78
- key=f"multi_{key_suffix}"
79
  )
80
 
81
  # C. Toggles
@@ -92,7 +91,7 @@ def display_slicer(scan_paths, mask_path=None, bboxes=None, title="Scan Viewer",
92
 
93
  # Load the selected background image
94
  data, _ = load_nrrd(current_file_path)
95
-
96
  if len(data.shape) != 3:
97
  st.warning("Data is not 3D.")
98
  return
@@ -109,17 +108,19 @@ def display_slicer(scan_paths, mask_path=None, bboxes=None, title="Scan Viewer",
109
  start_slice = int(b[2] + (b[5] // 2))
110
  start_slice = max(0, min(start_slice, total_slices - 1))
111
 
112
- slice_idx = st.slider("Select Slice (Z-Axis)", 0, total_slices - 1, start_slice, key=f"sl_{key_suffix}")
 
 
113
 
114
  # E. Plotting
115
  img_slice = data[:, :, slice_idx]
116
-
117
  # Normalize Image (0-1)
118
  img_slice = img_slice.astype(float)
119
 
120
  fig, ax = plt.subplots(figsize=(5, 5))
121
  ax.imshow(img_slice, cmap="gray", origin="upper")
122
-
123
  # 1. Overlay Mask
124
  if show_mask:
125
  # Load mask on the fly (or cache it if slow)
@@ -133,72 +134,78 @@ def display_slicer(scan_paths, mask_path=None, bboxes=None, title="Scan Viewer",
133
  else:
134
  # Fallback warning if mask dims don't match selected background
135
  # (Common if ADC resolution != T2 resolution)
136
- ax.text(5, 5, "Mask shape mismatch", color='red', fontsize=8)
137
-
138
 
139
  # 2. Overlay Bounding Boxes
140
  if bboxes:
141
  for i, box in enumerate(bboxes):
142
- label = f"Lesion {i+1}"
143
  if label not in selected_labels:
144
- continue
145
 
146
  bx, by, bz, bw, bh, bd = box
147
 
148
  # Visibility check
149
  if bz <= slice_idx < (bz + bd):
150
  rect = patches.Rectangle(
151
- (bx, by), bw, bh,
152
- linewidth=2, edgecolor='yellow', facecolor='none'
153
  )
154
  ax.add_patch(rect)
155
- ax.text(bx, by-5, f"L{i+1}", color='yellow', fontsize=9, fontweight='bold')
156
 
157
  ax.axis("off")
158
  st.pyplot(fig, use_container_width=False)
159
 
 
160
  @st.cache_resource
161
  def download_all_models():
162
  # 1. Ensure the 'models' directory exists
163
- models_dir = os.path.join(os.getcwd(), 'models')
164
- os.makedirs(models_dir, exist_ok=True)
165
 
166
  for filename in FILENAMES:
167
  try:
168
  # 2. Download from Hugging Face (to cache)
169
  cached_path = hf_hub_download(repo_id=REPO_ID, filename=filename)
170
-
171
  # 3. Define where we want it to live locally
172
  destination_path = os.path.join(models_dir, filename)
173
-
174
  # 4. Copy only if it's not already there
175
  if not os.path.exists(destination_path):
176
  shutil.copy(cached_path, destination_path)
177
-
178
  except Exception as e:
179
  st.error(f"Failed to download {filename}: {e}")
180
  st.stop()
181
 
 
182
  with st.container():
183
  col1, col2, col3, col4 = st.columns(4)
184
 
185
  with col1:
186
- render_clickable_image("deployment_images/logo1.png", "https://www.comfort-ai.eu/", width=220)
 
 
187
  with col2:
188
  render_clickable_image("deployment_images/logo2.png", "https://www.charite.de/", width=220)
189
  with col3:
190
  render_clickable_image("deployment_images/logo3.png", "https://mri.tum.de/de", width=220)
191
  with col4:
192
- render_clickable_image("deployment_images/logo4.png", "https://ai-assisted-healthcare.com/", width=220)
 
 
193
 
194
- st.write("")
195
  st.write("")
196
  st.title("PI-RADS and csPCa Risk Prediction from bpMRI")
197
  # --- TRIGGER THE DOWNLOAD STARTUP ---
198
- st.markdown("💡 This application utilizes a weakly supervised, attention-based multiple-instance learning (MIL) model to predict scan-level PI-RADS scores and clinically significant prostate cancer (csPCa) risk from axial biparametric MRI (bpMRI) sequences (T2W, ADC, and DWI). Users may upload their own bpMRI scans as NRRD or select a provided sample case to evaluate the tool. Following inference, outcomes are detailed in the Results & Downloads section. The Visualization module allows users to inspect the prostate mask and the top five salient patches overlaid on the bpMRI sequences. The salient patches are displayed only when the predicted PI-RADS score is 3 or more. For execution details, refer to the log file; for methodology, please visit our [Project Page](https://anirudhbalaraman.github.io/WSAttention-Prostate/) or read the [Paper]. For more implementation details, check our [Github](https://github.com/anirudhbalaraman/WSAttention-Prostate/tree/main)")
 
 
199
  st.markdown("***NOTE*** Required NRRD dimension format: Height x Width x Depth. ")
200
 
201
- #--- CONSTANTS ---
202
  REPO_ID = "anirudh0410/WSAttention-Prostate"
203
  FILENAMES = ["pirads.pt", "prostate_segmentation_model.pt", "cspca_model.pth"]
204
  with st.spinner("Initializing..."):
@@ -208,22 +215,22 @@ with st.spinner("Initializing..."):
208
  # --- CONFIGURATION ---
209
  # Base paths
210
  BASE_DIR = os.getcwd()
211
- INPUT_BASE = os.path.join(BASE_DIR, "temp_data" )
212
  OUTPUT_DIR = os.path.join(BASE_DIR, "temp_data", "processed")
213
- SAMPLES_BASE_DIR = os.path.join(BASE_DIR, "dataset","samples")
214
  SAMPLE_CASES = {
215
  "Sample 1": {
216
  "path": os.path.join(SAMPLES_BASE_DIR, "sample1"),
217
- "files": {"t2": "t2.nrrd", "adc": "adc.nrrd", "dwi": "dwi.nrrd"}
218
  },
219
  "Sample 2": {
220
  "path": os.path.join(SAMPLES_BASE_DIR, "sample2"),
221
- "files": {"t2": "t2.nrrd", "adc": "adc.nrrd", "dwi": "dwi.nrrd"}
222
  },
223
  "Sample 3": {
224
  "path": os.path.join(SAMPLES_BASE_DIR, "sample3"),
225
- "files": {"t2": "t2.nrrd", "adc": "adc.nrrd", "dwi": "dwi.nrrd"}
226
- }
227
  }
228
 
229
  # Create specific sub-directories for each input type
@@ -242,9 +249,7 @@ with st.sidebar:
242
  st.header("Data Selection")
243
  # Dropdown to choose mode
244
  data_source = st.radio(
245
- "Choose Data Source:",
246
- ["Upload My Own Files", "Sample 1", "Sample 2", "Sample 3"],
247
- index=0
248
  )
249
 
250
  # --- 2. INPUT HANDLING ---
@@ -260,15 +265,15 @@ if is_demo_mode:
260
  # Verify files exist
261
  base_path = selected_sample["path"]
262
  f_names = selected_sample["files"]
263
-
264
  missing = []
265
- for key, fname in f_names.items():
266
  if not os.path.exists(os.path.join(base_path, fname)):
267
  missing.append(os.path.join(base_path, fname))
268
 
269
  if missing:
270
  st.error(f"Error: The following sample files are missing in the repo:\n{missing}")
271
-
272
  else:
273
  # Visual feedback
274
  c1, c2, c3 = st.columns(3)
@@ -297,9 +302,8 @@ if "inference_done" not in st.session_state:
297
  if "logs" not in st.session_state:
298
  st.session_state.logs = ""
299
  ready_to_run = (not is_demo_mode and t2_file and adc_file and dwi_file) or is_demo_mode
300
- if ready_to_run:
301
  if st.button("Run Inference", type="primary"):
302
-
303
  st.session_state.inference_done = False
304
  st.session_state.logs = ""
305
  # --- A. CLEANUP & SAVE ---
@@ -307,35 +311,35 @@ if ready_to_run:
307
  # (Optional but recommended for a clean state)
308
  for folder in [T2_DIR, ADC_DIR, DWI_DIR, OUTPUT_DIR]:
309
  for f in os.listdir(folder):
310
- if os.path.isfile(os.path.join(folder,f)):
311
  os.remove(os.path.join(folder, f))
312
- elif os.path.isdir(os.path.join(folder,f)):
313
- shutil.rmtree(os.path.join(folder,f))
314
-
315
-
316
 
317
  if is_demo_mode:
318
-
319
-
320
-
321
  # Copy from the specific sample folder
322
  src = SAMPLE_CASES[data_source]
323
- shutil.copy(os.path.join(src["path"], src["files"]["t2"]), os.path.join(T2_DIR, "sample.nrrd"))
324
- shutil.copy(os.path.join(src["path"], src["files"]["adc"]), os.path.join(ADC_DIR, "sample.nrrd"))
325
- shutil.copy(os.path.join(src["path"], src["files"]["dwi"]), os.path.join(DWI_DIR, "sample.nrrd"))
 
 
 
 
 
 
326
  st.write(f"Loaded data from {data_source}...")
327
 
328
  else:
329
-
330
  # Save T2
331
  # We save it inside the T2_DIR folder
332
  with open(os.path.join(T2_DIR, t2_file.name), "wb") as f:
333
  shutil.copyfileobj(t2_file, f)
334
-
335
  # Save ADC
336
  with open(os.path.join(ADC_DIR, t2_file.name), "wb") as f:
337
  shutil.copyfileobj(adc_file, f)
338
-
339
  # Save DWI
340
  with open(os.path.join(DWI_DIR, t2_file.name), "wb") as f:
341
  shutil.copyfileobj(dwi_file, f)
@@ -346,14 +350,20 @@ if ready_to_run:
346
  # --- B. CONSTRUCT COMMAND ---
347
  # We pass the FOLDER paths, not file paths, matching your argument names
348
  command = [
349
- "python", "run_inference.py",
350
- "--t2_dir", T2_DIR,
351
- "--dwi_dir", DWI_DIR,
352
- "--adc_dir", ADC_DIR,
353
- "--output_dir", OUTPUT_DIR,
354
- "--project_dir", BASE_DIR
 
 
 
 
 
 
355
  ]
356
-
357
  # DEBUG: Show the exact command being run (helpful for troubleshooting)
358
  st.code(" ".join(command), language="bash")
359
 
@@ -361,31 +371,23 @@ if ready_to_run:
361
  with st.spinner("Running Inference... (This may take a moment)"):
362
  try:
363
  # Run the script and capture output
364
- result = subprocess.run(
365
- command,
366
- capture_output=True,
367
- text=True,
368
- check=True
369
- )
370
-
371
 
372
  st.session_state.inference_done = True
373
  st.session_state.logs = result.stdout
374
-
375
  except subprocess.CalledProcessError as e:
376
  st.error("Script Execution Failed.")
377
  st.error("Error Output:")
378
  st.code(e.stderr)
379
-
380
  # --- D. SHOW OUTPUT FILES ---
381
  if st.session_state.inference_done:
382
  st.success("Pipeline Execution Successful!")
383
-
384
-
385
 
386
  st.divider()
387
  with st.expander("📊 Results & Downloads", expanded=True):
388
- if st.session_state.get("logs"): # Show Logs
389
  with st.expander("View Execution Logs"):
390
  st.code(st.session_state.logs)
391
  # List everything in the output directory
@@ -396,53 +398,55 @@ if st.session_state.inference_done:
396
  file_path = os.path.join(OUTPUT_DIR, file_name)
397
  if not os.path.isfile(file_path):
398
  continue
399
-
400
 
401
  with open(file_path, "rb") as f:
402
  st.download_button(
403
- label=f"⬇️ Download {file_name}",
404
- data=f.read(),
405
- file_name=file_name
406
  )
407
  if file_name == "results.json":
408
- with open(file_path, "r") as f:
409
  temp_data = json.load(f)
410
  first_case = next(iter(temp_data.values()))
411
  st.session_state.pirads = first_case.get("Predicted PIRAD Score")
412
  st.session_state.risk = first_case.get("csPCa risk")
413
- st.session_state.coords = first_case.get("Top left coordinate of top 5 patches(x,y,z)")
414
-
 
415
 
416
  else:
417
  st.warning("Script finished but no files were found in output_dir.")
418
 
419
  with st.expander("🩺 Results", expanded=True):
420
  if "risk" in st.session_state and "pirads" in st.session_state:
421
- #st.metric("csPCa Risk Score", f"{st.session_state.risk:.2f}")
422
  risk = st.session_state.get("risk")
423
  z = np.linspace(0, 1, 100).reshape(1, -1) # 1 row, 100 columns
424
  col_chart, col_spacer = st.columns([1, 1])
425
  with col_chart:
426
  fig = go.Figure()
427
- fig.add_trace(go.Heatmap(
428
- z=z, # one row, two columns
429
- x=np.linspace(0, 1, 100), # 0 to 1 scale
430
- y=[0, 1],
431
- showscale=False,
432
- colorscale="RdYlGn_r",
433
- hoverinfo='none'
434
- ))
435
- fig.add_trace(go.Scatter(
436
- x=[risk],
437
- y=[0.1],
438
- mode="markers+text",
439
- marker=dict(symbol="triangle-down", size=16, color="black"),
440
- text=[f"csPCa Risk: {risk:.2f}"],
441
- textposition="top center",
442
- textfont=dict(color="black", size=16),
443
- showlegend=False,
444
- cliponaxis=False
445
- ))
 
 
 
 
446
 
447
  # Layout adjustments
448
  fig.update_layout(
@@ -456,23 +460,17 @@ if st.session_state.inference_done:
456
  showgrid=False,
457
  ticks="outside",
458
  ticklen=4,
459
- tickfont=dict(
460
- size=16,
461
- color="black"
462
- ),
463
  ticklabelposition="inside bottom",
464
  showline=False,
465
  zeroline=False,
466
  mirror=False,
467
- side="bottom"
468
  ),
469
  yaxis=dict(
470
- range=[0, 1],
471
- showticklabels=False,
472
- showgrid=False,
473
- showline=False
474
  ),
475
- plot_bgcolor="white"
476
  )
477
 
478
  st.plotly_chart(fig, use_container_width=False)
@@ -489,7 +487,7 @@ if st.session_state.inference_done:
489
 
490
  for s in range(2, 6):
491
  config = score_config[s]
492
-
493
  # Define styles cleanly without newlines/indentation to prevent HTML errors
494
  if s == int(pirads):
495
  # Selected: Thick border, full opacity
@@ -508,21 +506,21 @@ if st.session_state.inference_done:
508
  # distinct styling properties are joined by semicolons
509
  html_circles += f"""
510
  <div style="
511
- width: 60px;
512
- height: 60px;
513
- background-color: {config['bg']};
514
- color: {config['text']};
515
- border-radius: 50%;
516
- display: flex;
517
- align-items: center;
518
- justify-content: center;
519
- font-size: 24px;
520
- font-weight: bold;
521
- font-family: Arial, sans-serif;
522
  margin-right: 15px;
523
- border: {border};
524
- opacity: {opacity};
525
- transform: {transform};
526
  box-shadow: {box_shadow};">
527
  {s}
528
  </div>
@@ -536,7 +534,7 @@ if st.session_state.inference_done:
536
  {html_circles}
537
  </div>
538
  """,
539
- unsafe_allow_html=True
540
  )
541
  else:
542
  st.info("Results not available.")
@@ -556,25 +554,25 @@ if st.session_state.inference_done:
556
  if os.path.exists(adc_vis_dir) and len(os.listdir(adc_vis_dir)) > 0:
557
  files_in_dir = os.listdir(adc_vis_dir)[0]
558
  adc_vis_path = os.path.join(adc_vis_dir, files_in_dir)
559
-
560
  dwi_vis_dir = os.path.join(OUTPUT_DIR, "DWI_registered")
561
  if os.path.exists(dwi_vis_dir) and len(os.listdir(dwi_vis_dir)) > 0:
562
  files_in_dir = os.listdir(dwi_vis_dir)[0]
563
  dwi_vis_path = os.path.join(dwi_vis_dir, files_in_dir)
564
-
565
  mask_vis_dir = os.path.join(OUTPUT_DIR, "prostate_mask")
566
  if os.path.exists(mask_vis_dir) and len(os.listdir(mask_vis_dir)) > 0:
567
  files_in_maskdir = os.listdir(mask_vis_dir)[0]
568
  mask_vis_path = os.path.join(mask_vis_dir, files_in_maskdir)
569
- print('mask_vis_path')
570
  else:
571
- print('No mask dir')
572
-
573
  roi_bbox = None
574
  if "coords" in st.session_state:
575
  detected_boxes = []
576
  for i in st.session_state.coords:
577
- indi_box = [i[1],i[0],i[2],64,64,3]
578
  detected_boxes.append(indi_box)
579
 
580
  scan_dict = {}
@@ -591,7 +589,7 @@ if st.session_state.inference_done:
591
  mask_path=mask_vis_path,
592
  bboxes=detected_boxes,
593
  title="Salient Patch Viewer",
594
- key_suffix="main_viz"
595
  )
596
  elif scan_dict:
597
  display_slicer(
@@ -599,12 +597,5 @@ if st.session_state.inference_done:
599
  mask_path=mask_vis_path,
600
  bboxes=None,
601
  title="Salient Patch Viewer",
602
- key_suffix="main_viz"
603
- )
604
-
605
-
606
-
607
-
608
-
609
-
610
-
 
1
+ import base64
2
+ import json
3
  import os
4
  import shutil
5
+ import subprocess
6
+
7
+ import matplotlib.patches as patches
8
  import matplotlib.pyplot as plt
9
+ import nrrd
10
  import numpy as np
 
 
11
  import plotly.graph_objects as go
12
+ import streamlit as st
13
+ from huggingface_hub import hf_hub_download
14
+
15
 
16
  def render_clickable_image(image_path, link_url, width=100):
17
  """
 
20
  # 1. Read the image file and encode it to base64
21
  with open(image_path, "rb") as f:
22
  data = base64.b64encode(f.read()).decode("utf-8")
23
+
24
  # 2. Create the HTML string
25
  # target="_blank" opens the link in a new tab
26
  html_code = f"""
 
28
  <img src="data:image/png;base64,{data}" width="{width}" style="border-radius: 5px;">
29
  </a>
30
  """
31
+
32
  # 3. Render it
33
  st.markdown(html_code, unsafe_allow_html=True)
34
 
35
 
36
  st.set_page_config(
37
+ page_title="Prostate Scoring", page_icon="🩺", layout="wide", initial_sidebar_state="expanded"
 
 
 
38
  )
39
 
40
 
 
44
  data, header = nrrd.read(file_path)
45
  return data, header
46
 
47
+
48
  def display_slicer(scan_paths, mask_path=None, bboxes=None, title="Scan Viewer", key_suffix=""):
49
  """
50
  Displays slicer with Multi-Background Support, Mask Overlay, and Bounding Box Multiselect.
51
+
52
  Args:
53
  scan_paths: Dict of {Label: FilePath}. Example: {"T2W": "path/to/t2.nrrd", "ADC": "..."}
54
  """
 
58
  # --- CONTROLS SECTION (Right Column) ---
59
  with c_controls:
60
  st.write(f"**{title} Controls**")
61
+
62
  # A. Background Selection
63
  # We assume the first key in the dict is the default
64
  available_scans = list(scan_paths.keys())
65
+ selected_scan_name = st.radio(
66
+ "Background Image", available_scans, index=0, key=f"bg_{key_suffix}"
67
+ )
68
  current_file_path = scan_paths[selected_scan_name]
69
 
70
  # B. Lesion Selection (Multiselect)
71
  box_labels = []
72
  selected_labels = []
73
  if bboxes:
74
+ box_labels = [f"Lesion {i + 1}" for i in range(len(bboxes))]
75
+ st.write("---") # Divider
76
  selected_labels = st.multiselect(
77
+ "Select Lesions", options=box_labels, default=box_labels, key=f"multi_{key_suffix}"
 
 
 
78
  )
79
 
80
  # C. Toggles
 
91
 
92
  # Load the selected background image
93
  data, _ = load_nrrd(current_file_path)
94
+
95
  if len(data.shape) != 3:
96
  st.warning("Data is not 3D.")
97
  return
 
108
  start_slice = int(b[2] + (b[5] // 2))
109
  start_slice = max(0, min(start_slice, total_slices - 1))
110
 
111
+ slice_idx = st.slider(
112
+ "Select Slice (Z-Axis)", 0, total_slices - 1, start_slice, key=f"sl_{key_suffix}"
113
+ )
114
 
115
  # E. Plotting
116
  img_slice = data[:, :, slice_idx]
117
+
118
  # Normalize Image (0-1)
119
  img_slice = img_slice.astype(float)
120
 
121
  fig, ax = plt.subplots(figsize=(5, 5))
122
  ax.imshow(img_slice, cmap="gray", origin="upper")
123
+
124
  # 1. Overlay Mask
125
  if show_mask:
126
  # Load mask on the fly (or cache it if slow)
 
134
  else:
135
  # Fallback warning if mask dims don't match selected background
136
  # (Common if ADC resolution != T2 resolution)
137
+ ax.text(5, 5, "Mask shape mismatch", color="red", fontsize=8)
 
138
 
139
  # 2. Overlay Bounding Boxes
140
  if bboxes:
141
  for i, box in enumerate(bboxes):
142
+ label = f"Lesion {i + 1}"
143
  if label not in selected_labels:
144
+ continue
145
 
146
  bx, by, bz, bw, bh, bd = box
147
 
148
  # Visibility check
149
  if bz <= slice_idx < (bz + bd):
150
  rect = patches.Rectangle(
151
+ (bx, by), bw, bh, linewidth=2, edgecolor="yellow", facecolor="none"
 
152
  )
153
  ax.add_patch(rect)
154
+ ax.text(bx, by - 5, f"L{i + 1}", color="yellow", fontsize=9, fontweight="bold")
155
 
156
  ax.axis("off")
157
  st.pyplot(fig, use_container_width=False)
158
 
159
+
160
  @st.cache_resource
161
  def download_all_models():
162
  # 1. Ensure the 'models' directory exists
163
+ models_dir = os.path.join(os.getcwd(), "models")
164
+ os.makedirs(models_dir, exist_ok=True)
165
 
166
  for filename in FILENAMES:
167
  try:
168
  # 2. Download from Hugging Face (to cache)
169
  cached_path = hf_hub_download(repo_id=REPO_ID, filename=filename)
170
+
171
  # 3. Define where we want it to live locally
172
  destination_path = os.path.join(models_dir, filename)
173
+
174
  # 4. Copy only if it's not already there
175
  if not os.path.exists(destination_path):
176
  shutil.copy(cached_path, destination_path)
177
+
178
  except Exception as e:
179
  st.error(f"Failed to download {filename}: {e}")
180
  st.stop()
181
 
182
+
183
  with st.container():
184
  col1, col2, col3, col4 = st.columns(4)
185
 
186
  with col1:
187
+ render_clickable_image(
188
+ "deployment_images/logo1.png", "https://www.comfort-ai.eu/", width=220
189
+ )
190
  with col2:
191
  render_clickable_image("deployment_images/logo2.png", "https://www.charite.de/", width=220)
192
  with col3:
193
  render_clickable_image("deployment_images/logo3.png", "https://mri.tum.de/de", width=220)
194
  with col4:
195
+ render_clickable_image(
196
+ "deployment_images/logo4.png", "https://ai-assisted-healthcare.com/", width=220
197
+ )
198
 
199
+ st.write("")
200
  st.write("")
201
  st.title("PI-RADS and csPCa Risk Prediction from bpMRI")
202
  # --- TRIGGER THE DOWNLOAD STARTUP ---
203
+ st.markdown(
204
+ "💡 This application utilizes a weakly supervised, attention-based multiple-instance learning (MIL) model to predict scan-level PI-RADS scores and clinically significant prostate cancer (csPCa) risk from axial biparametric MRI (bpMRI) sequences (T2W, ADC, and DWI). Users may upload their own bpMRI scans as NRRD or select a provided sample case to evaluate the tool. Following inference, outcomes are detailed in the Results & Downloads section. The Visualization module allows users to inspect the prostate mask and the top five salient patches overlaid on the bpMRI sequences. The salient patches are displayed only when the predicted PI-RADS score is 3 or more. For execution details, refer to the log file; for methodology, please visit our [Project Page](https://anirudhbalaraman.github.io/WSAttention-Prostate/) or read the [Paper]. For more implementation details, check our [Github](https://github.com/anirudhbalaraman/WSAttention-Prostate/tree/main)"
205
+ )
206
  st.markdown("***NOTE*** Required NRRD dimension format: Height x Width x Depth. ")
207
 
208
+ # --- CONSTANTS ---
209
  REPO_ID = "anirudh0410/WSAttention-Prostate"
210
  FILENAMES = ["pirads.pt", "prostate_segmentation_model.pt", "cspca_model.pth"]
211
  with st.spinner("Initializing..."):
 
215
  # --- CONFIGURATION ---
216
  # Base paths
217
  BASE_DIR = os.getcwd()
218
+ INPUT_BASE = os.path.join(BASE_DIR, "temp_data")
219
  OUTPUT_DIR = os.path.join(BASE_DIR, "temp_data", "processed")
220
+ SAMPLES_BASE_DIR = os.path.join(BASE_DIR, "dataset", "samples")
221
  SAMPLE_CASES = {
222
  "Sample 1": {
223
  "path": os.path.join(SAMPLES_BASE_DIR, "sample1"),
224
+ "files": {"t2": "t2.nrrd", "adc": "adc.nrrd", "dwi": "dwi.nrrd"},
225
  },
226
  "Sample 2": {
227
  "path": os.path.join(SAMPLES_BASE_DIR, "sample2"),
228
+ "files": {"t2": "t2.nrrd", "adc": "adc.nrrd", "dwi": "dwi.nrrd"},
229
  },
230
  "Sample 3": {
231
  "path": os.path.join(SAMPLES_BASE_DIR, "sample3"),
232
+ "files": {"t2": "t2.nrrd", "adc": "adc.nrrd", "dwi": "dwi.nrrd"},
233
+ },
234
  }
235
 
236
  # Create specific sub-directories for each input type
 
249
  st.header("Data Selection")
250
  # Dropdown to choose mode
251
  data_source = st.radio(
252
+ "Choose Data Source:", ["Upload My Own Files", "Sample 1", "Sample 2", "Sample 3"], index=0
 
 
253
  )
254
 
255
  # --- 2. INPUT HANDLING ---
 
265
  # Verify files exist
266
  base_path = selected_sample["path"]
267
  f_names = selected_sample["files"]
268
+
269
  missing = []
270
+ for _, fname in f_names.items():
271
  if not os.path.exists(os.path.join(base_path, fname)):
272
  missing.append(os.path.join(base_path, fname))
273
 
274
  if missing:
275
  st.error(f"Error: The following sample files are missing in the repo:\n{missing}")
276
+
277
  else:
278
  # Visual feedback
279
  c1, c2, c3 = st.columns(3)
 
302
  if "logs" not in st.session_state:
303
  st.session_state.logs = ""
304
  ready_to_run = (not is_demo_mode and t2_file and adc_file and dwi_file) or is_demo_mode
305
+ if ready_to_run:
306
  if st.button("Run Inference", type="primary"):
 
307
  st.session_state.inference_done = False
308
  st.session_state.logs = ""
309
  # --- A. CLEANUP & SAVE ---
 
311
  # (Optional but recommended for a clean state)
312
  for folder in [T2_DIR, ADC_DIR, DWI_DIR, OUTPUT_DIR]:
313
  for f in os.listdir(folder):
314
+ if os.path.isfile(os.path.join(folder, f)):
315
  os.remove(os.path.join(folder, f))
316
+ elif os.path.isdir(os.path.join(folder, f)):
317
+ shutil.rmtree(os.path.join(folder, f))
 
 
318
 
319
  if is_demo_mode:
 
 
 
320
  # Copy from the specific sample folder
321
  src = SAMPLE_CASES[data_source]
322
+ shutil.copy(
323
+ os.path.join(src["path"], src["files"]["t2"]), os.path.join(T2_DIR, "sample.nrrd")
324
+ )
325
+ shutil.copy(
326
+ os.path.join(src["path"], src["files"]["adc"]), os.path.join(ADC_DIR, "sample.nrrd")
327
+ )
328
+ shutil.copy(
329
+ os.path.join(src["path"], src["files"]["dwi"]), os.path.join(DWI_DIR, "sample.nrrd")
330
+ )
331
  st.write(f"Loaded data from {data_source}...")
332
 
333
  else:
 
334
  # Save T2
335
  # We save it inside the T2_DIR folder
336
  with open(os.path.join(T2_DIR, t2_file.name), "wb") as f:
337
  shutil.copyfileobj(t2_file, f)
338
+
339
  # Save ADC
340
  with open(os.path.join(ADC_DIR, t2_file.name), "wb") as f:
341
  shutil.copyfileobj(adc_file, f)
342
+
343
  # Save DWI
344
  with open(os.path.join(DWI_DIR, t2_file.name), "wb") as f:
345
  shutil.copyfileobj(dwi_file, f)
 
350
  # --- B. CONSTRUCT COMMAND ---
351
  # We pass the FOLDER paths, not file paths, matching your argument names
352
  command = [
353
+ "python",
354
+ "run_inference.py",
355
+ "--t2_dir",
356
+ T2_DIR,
357
+ "--dwi_dir",
358
+ DWI_DIR,
359
+ "--adc_dir",
360
+ ADC_DIR,
361
+ "--output_dir",
362
+ OUTPUT_DIR,
363
+ "--project_dir",
364
+ BASE_DIR,
365
  ]
366
+
367
  # DEBUG: Show the exact command being run (helpful for troubleshooting)
368
  st.code(" ".join(command), language="bash")
369
 
 
371
  with st.spinner("Running Inference... (This may take a moment)"):
372
  try:
373
  # Run the script and capture output
374
+ result = subprocess.run(command, capture_output=True, text=True, check=True)
 
 
 
 
 
 
375
 
376
  st.session_state.inference_done = True
377
  st.session_state.logs = result.stdout
378
+
379
  except subprocess.CalledProcessError as e:
380
  st.error("Script Execution Failed.")
381
  st.error("Error Output:")
382
  st.code(e.stderr)
383
+
384
  # --- D. SHOW OUTPUT FILES ---
385
  if st.session_state.inference_done:
386
  st.success("Pipeline Execution Successful!")
 
 
387
 
388
  st.divider()
389
  with st.expander("📊 Results & Downloads", expanded=True):
390
+ if st.session_state.get("logs"): # Show Logs
391
  with st.expander("View Execution Logs"):
392
  st.code(st.session_state.logs)
393
  # List everything in the output directory
 
398
  file_path = os.path.join(OUTPUT_DIR, file_name)
399
  if not os.path.isfile(file_path):
400
  continue
 
401
 
402
  with open(file_path, "rb") as f:
403
  st.download_button(
404
+ label=f"⬇️ Download {file_name}", data=f.read(), file_name=file_name
 
 
405
  )
406
  if file_name == "results.json":
407
+ with open(file_path) as f:
408
  temp_data = json.load(f)
409
  first_case = next(iter(temp_data.values()))
410
  st.session_state.pirads = first_case.get("Predicted PIRAD Score")
411
  st.session_state.risk = first_case.get("csPCa risk")
412
+ st.session_state.coords = first_case.get(
413
+ "Top left coordinate of top 5 patches(x,y,z)"
414
+ )
415
 
416
  else:
417
  st.warning("Script finished but no files were found in output_dir.")
418
 
419
  with st.expander("🩺 Results", expanded=True):
420
  if "risk" in st.session_state and "pirads" in st.session_state:
421
+ # st.metric("csPCa Risk Score", f"{st.session_state.risk:.2f}")
422
  risk = st.session_state.get("risk")
423
  z = np.linspace(0, 1, 100).reshape(1, -1) # 1 row, 100 columns
424
  col_chart, col_spacer = st.columns([1, 1])
425
  with col_chart:
426
  fig = go.Figure()
427
+ fig.add_trace(
428
+ go.Heatmap(
429
+ z=z, # one row, two columns
430
+ x=np.linspace(0, 1, 100), # 0 to 1 scale
431
+ y=[0, 1],
432
+ showscale=False,
433
+ colorscale="RdYlGn_r",
434
+ hoverinfo="none",
435
+ )
436
+ )
437
+ fig.add_trace(
438
+ go.Scatter(
439
+ x=[risk],
440
+ y=[0.1],
441
+ mode="markers+text",
442
+ marker=dict(symbol="triangle-down", size=16, color="black"),
443
+ text=[f"csPCa Risk: {risk:.2f}"],
444
+ textposition="top center",
445
+ textfont=dict(color="black", size=16),
446
+ showlegend=False,
447
+ cliponaxis=False,
448
+ )
449
+ )
450
 
451
  # Layout adjustments
452
  fig.update_layout(
 
460
  showgrid=False,
461
  ticks="outside",
462
  ticklen=4,
463
+ tickfont=dict(size=16, color="black"),
 
 
 
464
  ticklabelposition="inside bottom",
465
  showline=False,
466
  zeroline=False,
467
  mirror=False,
468
+ side="bottom",
469
  ),
470
  yaxis=dict(
471
+ range=[0, 1], showticklabels=False, showgrid=False, showline=False
 
 
 
472
  ),
473
+ plot_bgcolor="white",
474
  )
475
 
476
  st.plotly_chart(fig, use_container_width=False)
 
487
 
488
  for s in range(2, 6):
489
  config = score_config[s]
490
+
491
  # Define styles cleanly without newlines/indentation to prevent HTML errors
492
  if s == int(pirads):
493
  # Selected: Thick border, full opacity
 
506
  # distinct styling properties are joined by semicolons
507
  html_circles += f"""
508
  <div style="
509
+ width: 60px;
510
+ height: 60px;
511
+ background-color: {config["bg"]};
512
+ color: {config["text"]};
513
+ border-radius: 50%;
514
+ display: flex;
515
+ align-items: center;
516
+ justify-content: center;
517
+ font-size: 24px;
518
+ font-weight: bold;
519
+ font-family: Arial, sans-serif;
520
  margin-right: 15px;
521
+ border: {border};
522
+ opacity: {opacity};
523
+ transform: {transform};
524
  box-shadow: {box_shadow};">
525
  {s}
526
  </div>
 
534
  {html_circles}
535
  </div>
536
  """,
537
+ unsafe_allow_html=True,
538
  )
539
  else:
540
  st.info("Results not available.")
 
554
  if os.path.exists(adc_vis_dir) and len(os.listdir(adc_vis_dir)) > 0:
555
  files_in_dir = os.listdir(adc_vis_dir)[0]
556
  adc_vis_path = os.path.join(adc_vis_dir, files_in_dir)
557
+
558
  dwi_vis_dir = os.path.join(OUTPUT_DIR, "DWI_registered")
559
  if os.path.exists(dwi_vis_dir) and len(os.listdir(dwi_vis_dir)) > 0:
560
  files_in_dir = os.listdir(dwi_vis_dir)[0]
561
  dwi_vis_path = os.path.join(dwi_vis_dir, files_in_dir)
562
+
563
  mask_vis_dir = os.path.join(OUTPUT_DIR, "prostate_mask")
564
  if os.path.exists(mask_vis_dir) and len(os.listdir(mask_vis_dir)) > 0:
565
  files_in_maskdir = os.listdir(mask_vis_dir)[0]
566
  mask_vis_path = os.path.join(mask_vis_dir, files_in_maskdir)
567
+ print("mask_vis_path")
568
  else:
569
+ print("No mask dir")
570
+
571
  roi_bbox = None
572
  if "coords" in st.session_state:
573
  detected_boxes = []
574
  for i in st.session_state.coords:
575
+ indi_box = [i[1], i[0], i[2], 64, 64, 3]
576
  detected_boxes.append(indi_box)
577
 
578
  scan_dict = {}
 
589
  mask_path=mask_vis_path,
590
  bboxes=detected_boxes,
591
  title="Salient Patch Viewer",
592
+ key_suffix="main_viz",
593
  )
594
  elif scan_dict:
595
  display_slicer(
 
597
  mask_path=mask_vis_path,
598
  bboxes=None,
599
  title="Salient Patch Viewer",
600
+ key_suffix="main_viz",
601
+ )
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -27,6 +27,7 @@ disable_error_code = ["override", "import-untyped"]
27
  mypy_path = "."
28
  pretty = true
29
  show_error_codes = true
 
30
 
31
  [[tool.mypy.overrides]]
32
  # These settings apply specifically to these external libraries
 
27
  mypy_path = "."
28
  pretty = true
29
  show_error_codes = true
30
+ exclude = ["^app\\.py$"]
31
 
32
  [[tool.mypy.overrides]]
33
  # These settings apply specifically to these external libraries
run_inference.py CHANGED
@@ -4,7 +4,6 @@ import logging
4
  import os
5
  from pathlib import Path
6
 
7
-
8
  import torch
9
  import yaml
10
  from monai.data import Dataset
@@ -46,7 +45,6 @@ def parse_args():
46
  return args
47
 
48
 
49
-
50
  if __name__ == "__main__":
51
  args = parse_args()
52
  if args.project_dir is None:
 
4
  import os
5
  from pathlib import Path
6
 
 
7
  import torch
8
  import yaml
9
  from monai.data import Dataset
 
45
  return args
46
 
47
 
 
48
  if __name__ == "__main__":
49
  args = parse_args()
50
  if args.project_dir is None:
src/utils.py CHANGED
@@ -4,9 +4,10 @@ import os
4
  import sys
5
  from pathlib import Path
6
  from typing import Any, Union
7
- import matplotlib.pyplot as plt
8
- import matplotlib.patches as patches
9
  import cv2
 
 
10
  import numpy as np
11
  import torch
12
  from monai.data import Dataset
@@ -174,8 +175,7 @@ def get_parent_image(temp_data_list, args: argparse.Namespace) -> np.ndarray:
174
  return dataset_image[0]["image"][0].numpy()
175
 
176
 
177
-
178
- def visualise_patches(coords, image, tile_size = 64, depth=3):
179
  """
180
  Visualize 3D image patches with their locations marked by bounding rectangles.
181
  This function creates a grid of subplot visualizations where each row represents
@@ -201,52 +201,45 @@ def visualise_patches(coords, image, tile_size = 64, depth=3):
201
 
202
  rows, _, _, slices = (len(coords), tile_size, tile_size, depth)
203
  fig, axes = plt.subplots(
204
- nrows=rows,
205
- ncols=slices,
206
- figsize=(slices * 3, rows * 3),
207
- squeeze=False
208
  )
209
 
210
  for i, x in enumerate(coords):
211
  for j in range(slices):
212
-
213
  ax = axes[i, j]
214
 
215
  slice_id = x[2] + j
216
- ax.imshow(image[:, :, slice_id], cmap='gray')
217
 
218
  rect = patches.Rectangle(
219
- (x[1], x[0]),
220
- tile_size,
221
- tile_size,
222
- linewidth=2,
223
- edgecolor='red',
224
- facecolor='none'
225
  )
226
  ax.add_patch(rect)
227
 
228
  # ---- slice ID text (every image) ----
229
  ax.text(
230
- 0.02, 0.98,
 
231
  f"z={slice_id}",
232
  transform=ax.transAxes,
233
  fontsize=10,
234
- color='white',
235
- va='top',
236
- ha='left',
237
- bbox=dict(facecolor='black', alpha=0.4, pad=2)
238
  )
239
 
240
- ax.axis('off')
241
 
242
  # Row label
243
  axes[i, 0].text(
244
- -0.08, 0.5,
245
- f"Patch {i+1}",
 
246
  transform=axes[i, 0].transAxes,
247
  fontsize=12,
248
- va='center',
249
- ha='right'
250
  )
251
 
252
  plt.subplots_adjust(left=0.06)
 
4
  import sys
5
  from pathlib import Path
6
  from typing import Any, Union
7
+
 
8
  import cv2
9
+ import matplotlib.patches as patches
10
+ import matplotlib.pyplot as plt
11
  import numpy as np
12
  import torch
13
  from monai.data import Dataset
 
175
  return dataset_image[0]["image"][0].numpy()
176
 
177
 
178
+ def visualise_patches(coords, image, tile_size=64, depth=3):
 
179
  """
180
  Visualize 3D image patches with their locations marked by bounding rectangles.
181
  This function creates a grid of subplot visualizations where each row represents
 
201
 
202
  rows, _, _, slices = (len(coords), tile_size, tile_size, depth)
203
  fig, axes = plt.subplots(
204
+ nrows=rows, ncols=slices, figsize=(slices * 3, rows * 3), squeeze=False
 
 
 
205
  )
206
 
207
  for i, x in enumerate(coords):
208
  for j in range(slices):
 
209
  ax = axes[i, j]
210
 
211
  slice_id = x[2] + j
212
+ ax.imshow(image[:, :, slice_id], cmap="gray")
213
 
214
  rect = patches.Rectangle(
215
+ (x[1], x[0]), tile_size, tile_size, linewidth=2, edgecolor="red", facecolor="none"
 
 
 
 
 
216
  )
217
  ax.add_patch(rect)
218
 
219
  # ---- slice ID text (every image) ----
220
  ax.text(
221
+ 0.02,
222
+ 0.98,
223
  f"z={slice_id}",
224
  transform=ax.transAxes,
225
  fontsize=10,
226
+ color="white",
227
+ va="top",
228
+ ha="left",
229
+ bbox=dict(facecolor="black", alpha=0.4, pad=2),
230
  )
231
 
232
+ ax.axis("off")
233
 
234
  # Row label
235
  axes[i, 0].text(
236
+ -0.08,
237
+ 0.5,
238
+ f"Patch {i + 1}",
239
  transform=axes[i, 0].transAxes,
240
  fontsize=12,
241
+ va="center",
242
+ ha="right",
243
  )
244
 
245
  plt.subplots_adjust(left=0.06)
visualisation.ipynb CHANGED
@@ -7,14 +7,14 @@
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
10
- "import os\n",
11
  "import json\n",
12
- "import matplotlib.pyplot as plt \n",
13
- "import nrrd\n",
14
- "import matplotlib.patches as patches\n",
15
- "from src.utils import visualise_patches\n",
16
  "import ipywidgets as widgets\n",
17
- "from IPython.display import display\n"
 
 
 
18
  ]
19
  },
20
  {
@@ -28,8 +28,8 @@
28
  "output_dir = \"/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed\"\n",
29
  "tile_size = 64\n",
30
  "depth = 3\n",
31
- "json_path = os.path.join(output_dir, 'results.json')\n",
32
- "with open(json_path, \"r\") as f:\n",
33
  " data = json.load(f)"
34
  ]
35
  },
@@ -72,21 +72,16 @@
72
  "files = data.keys()\n",
73
  "\n",
74
  "dropdown = widgets.Dropdown(\n",
75
- " options=files,\n",
76
- " description='Choose file:',\n",
77
- " style={'description_width': 'initial'}\n",
78
  ")\n",
79
- "sequences = {'T2W':'t2_registered', 'DWI':'DWI_registered', 'ADC':'ADC_registered'}\n",
80
  "\n",
81
  "dropdown_seq = widgets.Dropdown(\n",
82
- " options=sequences.keys(),\n",
83
- " description='Choose sequence:',\n",
84
- " style={'description_width': 'initial'}\n",
85
  ")\n",
86
  "\n",
87
  "display(dropdown)\n",
88
- "display(dropdown_seq)\n",
89
- "\n"
90
  ]
91
  },
92
  {
@@ -107,7 +102,6 @@
107
  }
108
  ],
109
  "source": [
110
- "\n",
111
  "# Access selection with\n",
112
  "key = dropdown.value\n",
113
  "seq = sequences[dropdown_seq.value]\n",
@@ -115,7 +109,7 @@
115
  "t2_path = os.path.join(output_dir, seq, key)\n",
116
  "t2, _ = nrrd.read(t2_path)\n",
117
  "visualise_patches(coords, t2)\n",
118
- "#The slice id is displayed on the top left corner of each patch"
119
  ]
120
  },
121
  {
 
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
 
10
  "import json\n",
11
+ "import os\n",
12
+ "\n",
 
 
13
  "import ipywidgets as widgets\n",
14
+ "import nrrd\n",
15
+ "from IPython.display import display\n",
16
+ "\n",
17
+ "from src.utils import visualise_patches"
18
  ]
19
  },
20
  {
 
28
  "output_dir = \"/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed\"\n",
29
  "tile_size = 64\n",
30
  "depth = 3\n",
31
+ "json_path = os.path.join(output_dir, \"results.json\")\n",
32
+ "with open(json_path) as f:\n",
33
  " data = json.load(f)"
34
  ]
35
  },
 
72
  "files = data.keys()\n",
73
  "\n",
74
  "dropdown = widgets.Dropdown(\n",
75
+ " options=files, description=\"Choose file:\", style={\"description_width\": \"initial\"}\n",
 
 
76
  ")\n",
77
+ "sequences = {\"T2W\": \"t2_registered\", \"DWI\": \"DWI_registered\", \"ADC\": \"ADC_registered\"}\n",
78
  "\n",
79
  "dropdown_seq = widgets.Dropdown(\n",
80
+ " options=sequences.keys(), description=\"Choose sequence:\", style={\"description_width\": \"initial\"}\n",
 
 
81
  ")\n",
82
  "\n",
83
  "display(dropdown)\n",
84
+ "display(dropdown_seq)"
 
85
  ]
86
  },
87
  {
 
102
  }
103
  ],
104
  "source": [
 
105
  "# Access selection with\n",
106
  "key = dropdown.value\n",
107
  "seq = sequences[dropdown_seq.value]\n",
 
109
  "t2_path = os.path.join(output_dir, seq, key)\n",
110
  "t2, _ = nrrd.read(t2_path)\n",
111
  "visualise_patches(coords, t2)\n",
112
+ "# The slice id is displayed on the top left corner of each patch"
113
  ]
114
  },
115
  {