wanhanisah's picture
Update app.py
a2e7718 verified
# App.py
import os
import io
import math
import zipfile
import tempfile
import streamlit.components.v1 as components
from textwrap import dedent
from pathlib import Path
from typing import Optional, List, Dict
import numpy as np
import pandas as pd
import nibabel as nib
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from matplotlib import animation
from PIL import Image
from skimage.measure import label as cc_label, regionprops
import streamlit as st
from textwrap import dedent
import base64 # kept for reference; not required when using st.download_button
# ---- Custom layer used by model
from utils.layer_util import ResizeAndConcatenate
# =========================
# Config / paths
# =========================
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# Uncomment next line to force CPU only
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
MODEL_PATH = "./models_final/SEG459.h5"
OUTPUT_ROOT = "/tmp/NIFTI_OUTPUTS"
GIFS_DIR = os.path.join(OUTPUT_ROOT, "GIFs")
CSV_DIR = os.path.join(OUTPUT_ROOT, "CSV")
os.makedirs(GIFS_DIR, exist_ok=True)
os.makedirs(CSV_DIR, exist_ok=True)
# Inference & metrics settings
SIZE_X, SIZE_Y = 256, 256
TARGET_HW = (SIZE_Y, SIZE_X)
N_CLASSES = 3
BATCH_SIZE = 16
MYO_DENSITY = 1.05 # g/mL β†’ mg via mm^3 * g/mL
# Island removal (3D)
ENABLE_ISLAND_REMOVAL = True
ISLAND_MIN_SLICE_SPAN = 2
ISLAND_MIN_AREA_PER_SLICE = 10
ISLAND_CENTROID_DIST_THRESH = 40
# Orientation / display
ORIENT_TARGET = None # None (keep native), "LPS", or "RAS"
DISPLAY_MATCH_DICOM = False
DISPLAY_RULES = {
'LPS': dict(rot90_cw=True, flip_ud=True, flip_lr=False),
'RAS': dict(rot90_cw=True, flip_ud=False, flip_lr=True),
None: dict(rot90_cw=False, flip_ud=False, flip_lr=False),
}
CURRENT_DISPLAY_ORIENT = ORIENT_TARGET
# ED/ES robust selection (mid-slices subset)
USE_MID_SLICES_FOR_ED_ES = True
MID_K = 4
MID_MIN_VALID_FRAC = 0.7
MID_A_BLOOD_MIN = 30
MID_A_MYO_MIN = 30
GIF_FPS = 2
GIF_DPI = 300
# =========================
# Branding / UI
# =========================
LOGO_URL = "https://raw.githubusercontent.com/whanisa/Segmentation/main/icon/logo.png"
LOGO_LINK = "https://github.com/whanisa/Segmentation"
LOGO_HEIGHT_PX = 120
SAFE_INSET_PX = 18
# --- CSS
def _inject_layout_css():
# Tune these three knobs as needed
CONTENT_MEASURE_PX = 920 # width of the hero + paragraphs column
LEFT_OFFSET_PX = 40 # push BOTH text and uploader to the right (aligns with left gutter)
UPLOAD_WIDTH_PX = 420 # make uploader column narrower
st.markdown(f"""
<style>
:root {{
--content-measure: {CONTENT_MEASURE_PX}px;
--left-offset: {LEFT_OFFSET_PX}px;
--upload-width: {UPLOAD_WIDTH_PX}px;
/* Fixed-edge logo tunables */
--logo-height: {LOGO_HEIGHT_PX}px;
--edge-x: max(12px, env(safe-area-inset-left));
/* Height of Streamlit/HF header; adjust to 40-56–64px if needed */
--header-clear: 40px;
--edge-y: calc(env(safe-area-inset-top) + var(--header-clear) + 0px);
/* Gap below the logo before the tabs appear */
--tabs-top-gap: calc(var(--logo-height) + 16px);
/* Shift the tabs a bit to the right */
--tabs-left-shift: 32px;
/* Accent color for the active tab underline/text */
--accent: #ef4444; /* red-500 */
}}
/* Ensure the fixed logo isn't clipped by Streamlit containers */
.stApp, .appview-container, .main {{ overflow: visible !important; }}
/* Compact page padding */
.appview-container .main .block-container {{
padding-top: 0.75rem;
padding-bottom: 1rem;
}}
/* Outer wrapper */
.content-wrap {{
width: min(1300px, 100%);
margin: 0 auto;
padding: 0 18px;
box-sizing: border-box;
}}
/* Text column (hero + paragraphs): same width + same left offset as uploader */
.measure-wrap {{
max-width: var(--content-measure);
margin-left: var(--left-offset);
margin-right: auto;
}}
/* ---- Fixed, far-left edge logo ---- */
#fixed-edge-logo {{
position: fixed;
left: var(--edge-x);
top: var(--edge-y);
z-index: 1000;
pointer-events: none;
}}
#fixed-edge-logo img {{
height: var(--logo-height);
width: auto;
display: block;
}}
/* Spacer so tabs/hero don't hide under the fixed logo */
.edge-logo-spacer {{ height: var(--tabs-top-gap); }}
/* Title: justify both sides */
.hero-title {{
font-size: 40px;
line-height: 1.25;
font-weight: 800;
margin: 0 0 20px; /* adjust to control gap to first paragraph */
text-align: justify; /* both edges */
text-justify: inter-word;
}}
/* Optional subtitle inside the H1 */
.hero-title .sub {{
display: block;
font-size: 28px;
line-height: 1.25;
margin: 0;
}}
/* Body paragraphs: justified, but last line remains ragged-right */
.text-wrap p {{
margin: 0 0 14px 0; /* paragraph spacing */
font-size: 17px;
line-height: 1.5;
text-align: justify;
text-justify: inter-word;
color: #333;
hyphens: auto;
-webkit-hyphens: auto;
-ms-hyphens: auto;
}}
/* Let long URLs wrap so they don’t wreck the right edge */
.text-wrap p a {{
overflow-wrap: anywhere;
word-break: break-word;
}}
/* Link aesthetics (optional) */
.text-wrap a {{
color: #0066cc;
text-decoration: underline;
text-underline-offset: 2px;
text-decoration-thickness: 1.5px;
}}
/* Note under last paragraph */
.note-text {{
font-size: 14px; /* smaller than normal text */
color: #333; /* optional: softer gray */
line-height: 1.4; /* a bit tighter spacing */
margin-top: 4px; /* space above note */
}}
/* Uploader block alignment */
#upload-wrap {{
max-width: var(--upload-width);
margin-left: var(--left-offset);
margin-right: auto;
}}
#upload-wrap [data-testid="stFileUploader"] {{
width: 100% !important;
margin-left: 0 !important;
margin-right: 0 !important;
}}
#upload-wrap [data-testid="stFileUploaderDropzone"] {{
padding-left: 0 !important;
padding-right: 44px !important;
}}
/* ---- REAL Streamlit tabs styling and alignment ---- */
/* Shift the tab strip to align with content */
div[data-testid="stTabs"] > div[role="tablist"],
div[data-baseweb="tab-list"],
.stTabs [role="tablist"] {{
margin-left: calc(var(--left-offset) + var(--tabs-left-shift)) !important;
margin-right: 18px !important;
border-bottom: 0; /* remove gray baseline */
padding-bottom: 6px;
}}
/* Tab buttons */
div[data-baseweb="tab-list"] button[role="tab"],
.stTabs [role="tab"] {{
color: #374151; /* gray-700 */
background: transparent;
border: none;
outline: none;
padding: 6px 14px 10px 14px;
margin: 0 4px;
font-weight: 600;
}}
/* Active tab: keep ONLY our single orange underline */
div[data-baseweb="tab-list"] button[aria-selected="true"],
.stTabs [role="tab"][aria-selected="true"] {{
color: var(--accent) !important;
border-bottom: 3px solid var(--accent) !important;
}}
/* Hide BaseWeb's moving highlight to avoid double orange lines */
div[data-baseweb="tab-highlight"] {{ display: none !important; }}
/* Small screens */
@media (max-width: 480px) {{
:root {{
--logo-height: {max(48, int(LOGO_HEIGHT_PX*0.7))}px;
--header-clear: 64px;
--tabs-left-shift: 16px;
}}
}}
</style>
""", unsafe_allow_html=True)
# =========================
# Small utilities
# =========================
def log(msg: str):
print(f"[INFO] {msg}")
def normalize_images(x):
x = tf.convert_to_tensor(x, dtype=tf.float32)
mn = tf.reduce_min(x, axis=[1, 2], keepdims=True)
mx = tf.reduce_max(x, axis=[1, 2], keepdims=True)
rng = mx - mn
x_norm = tf.where(rng > 0.0, (x - mn) / rng, tf.zeros_like(x))
return x_norm.numpy()
def _tf_resize_bilinear(img, *, target_h=SIZE_Y, target_w=SIZE_X):
arr = img[np.newaxis, ..., np.newaxis].astype(np.float32)
out = tf.image.resize(arr, [target_h, target_w], method='bilinear', antialias=True)
return np.squeeze(out.numpy()).astype(np.float32)
def _resize_nn(img, new_h, new_w):
arr = img[None, ..., None].astype(np.float32)
out = tf.image.resize(arr, [new_h, new_w], method='nearest')
return np.squeeze(out.numpy()).astype(img.dtype)
def display_xform(img2d, orient_target=None, enable=DISPLAY_MATCH_DICOM):
if orient_target is None:
orient_target = CURRENT_DISPLAY_ORIENT
if not enable:
return img2d
rule = DISPLAY_RULES.get(orient_target, DISPLAY_RULES[None])
out = img2d
if rule.get('rot90_cw'):
out = np.rot90(out, k=-1)
if rule.get('flip_ud'):
out = np.flipud(out)
if rule.get('flip_lr'):
out = np.fliplr(out)
return out
# =========================
# NIfTI I/O + spacing
# =========================
def _reorient_nifti(img: nib.Nifti1Image, target: Optional[str]):
if not target:
return img, None
tgt = target.upper()
if tgt not in ("LPS", "RAS"):
raise ValueError("ORIENT_TARGET must be None, 'LPS', or 'RAS'")
cur = nib.orientations.io_orientation(img.affine)
wanted = nib.orientations.axcodes2ornt(tuple(tgt))
xfm = nib.orientations.ornt_transform(cur, wanted)
if np.allclose(xfm, np.array([[0,1],[1,1],[2,1]])):
return img, xfm
data = img.get_fdata()
data_re = nib.orientations.apply_orientation(data, xfm)
aff_re = img.affine @ nib.orientations.inv_ornt_aff(xfm, img.shape)
return nib.Nifti1Image(data_re, aff_re, header=img.header), xfm
def load_nifti_4d(path, orient_target: Optional[str] = ORIENT_TARGET):
img_native = nib.load(path)
img, _ = _reorient_nifti(img_native, orient_target)
data = img.get_fdata(dtype=np.float32) # (X,Y,Z[,T])
if data.ndim == 3:
data = data[..., None]
data_4d = np.transpose(data, (1, 0, 2, 3)).astype(np.float32) # -> (H,W,S,F)
zooms = img.header.get_zooms()
col_mm = float(zooms[0]) if len(zooms) > 0 else 1.0
row_mm = float(zooms[1]) if len(zooms) > 1 else 1.0
slice_thickness_mm = float(zooms[2]) if len(zooms) > 2 else 1.0
frame_time_ms = float(zooms[3]) if len(zooms) > 3 else None
spacing = dict(
row_mm=row_mm,
col_mm=col_mm,
slice_thickness_mm=slice_thickness_mm,
frame_time_ms=frame_time_ms
)
return data_4d, spacing, img.affine
# =========================
# Inference
# =========================
def nifti_to_model_batches(data_4d):
H, W, S, F = data_4d.shape
batches, index = [], []
for f in range(F):
for s in range(S):
img = data_4d[..., s, f]
img_resized = _tf_resize_bilinear(img, target_h=SIZE_Y, target_w=SIZE_X)
batches.append(img_resized[..., None])
index.append((s, f))
x = np.stack(batches, axis=0).astype(np.float32)
return x, index, (H, W, S, F)
def _ensure_logits_last(preds):
if isinstance(preds, (list, tuple)):
preds = preds[-1]
return preds
def predict_nifti_4d(model, data_4d, batch_size=None):
x, index, shape4d = nifti_to_model_batches(data_4d)
x = normalize_images(x)
preds = _ensure_logits_last(model.predict(x, verbose=0, batch_size=batch_size))
labels = np.argmax(preds, axis=-1).astype(np.uint8) # (N,256,256)
S, F = shape4d[2], shape4d[3]
preds_4d = np.zeros((SIZE_Y, SIZE_X, S, F), dtype=np.uint8)
for k, (s, f) in enumerate(index):
preds_4d[..., s, f] = labels[k]
return preds_4d
def resize_masks_to_native(preds_4d_256, native_h, native_w):
Hm, Wm, S, F = preds_4d_256.shape
out = np.zeros((native_h, native_w, S, F), dtype=preds_4d_256.dtype)
for f in range(F):
for s in range(S):
out[..., s, f] = _resize_nn(preds_4d_256[..., s, f], native_h, native_w)
return out
# =========================
# Cleaning + ED/ES + metrics
# =========================
def clean_predictions_per_frame_3d(mask_4d):
H, W, S, F = mask_4d.shape
out = mask_4d.copy()
for f in range(F):
vol_f = out[:, :, :, f]
for cls in (1, 2):
m = (vol_f == cls)
if not m.any():
continue
cc = cc_label(m, connectivity=1)
props = regionprops(cc)
if not props:
continue
dom = max(props, key=lambda r: r.area)
dom_centroid = np.array(dom.centroid)
keep = {dom.label}
for r in props:
if r.label == dom.label:
continue
zmin, zmax = r.bbox[2], r.bbox[5]
slice_span = zmax - zmin
areas = [np.count_nonzero(cc[:, :, z] == r.label) for z in range(zmin, zmax)]
median_area = np.median(areas) if areas else 0
dist = np.linalg.norm(np.array(r.centroid) - dom_centroid)
if (slice_span >= ISLAND_MIN_SLICE_SPAN) and (median_area >= ISLAND_MIN_AREA_PER_SLICE) and (dist <= ISLAND_CENTROID_DIST_THRESH):
keep.add(r.label)
drop = (cc > 0) & (~np.isin(cc, list(keep)))
vol_f[drop] = 0
out[:, :, :, f] = vol_f
return out
def compute_per_frame_metrics(preds_4d, spacing, labels={"myo":1, "blood":2}):
row_mm = float(spacing["row_mm"]); col_mm = float(spacing["col_mm"]); thk = float(spacing["slice_thickness_mm"])
voxel_mm3 = row_mm * col_mm * thk
H, W, S, F = preds_4d.shape
blood_counts = (preds_4d == labels["blood"]).sum(axis=(0,1,2))
myo_counts = (preds_4d == labels["myo"]).sum(axis=(0,1,2))
volume_uL = blood_counts * voxel_mm3
myo_mass_mg = myo_counts * voxel_mm3 * MYO_DENSITY
return pd.DataFrame({"Frame": np.arange(F, dtype=int), "Volume_uL": volume_uL, "MyocardiumMass_mg": myo_mass_mg})
def slice_validity_matrix(preds_4d, A_blood_min=30, A_myo_min=30):
H, W, S, F = preds_4d.shape
blood = (preds_4d == 2)
myo = (preds_4d == 1)
areas_blood = blood.reshape(H*W, S, F).sum(axis=0) # (S,F)
areas_myo = myo.reshape(H*W, S, F).sum(axis=0)
has_blood = areas_blood >= A_blood_min
has_myo = areas_myo >= A_myo_min
return has_blood, has_myo, areas_blood, areas_myo
def choose_mid_slices(has_blood, has_myo, K=4, min_frac=0.7):
S, F = has_blood.shape
valid_frac = ((has_blood & has_myo).sum(axis=1) / max(F,1))
target = int(S // 2)
best = None
for start in range(0, max(S - K + 1, 1)):
block = list(range(start, min(start + K, S)))
score = valid_frac[block].mean() - 0.01 * np.mean([abs(s - target) for s in block])
if best is None or score > best[0]:
best = (score, block)
_, chosen = best
if np.mean(valid_frac[chosen]) < min_frac:
order = np.argsort(-valid_frac)
chosen = sorted(order[:K].tolist())
return chosen
def frame_volumes_subset_uL(preds_4d, spacing, slice_indices):
voxel = float(spacing["row_mm"]) * float(spacing["col_mm"]) * float(spacing["slice_thickness_mm"])
F = preds_4d.shape[3]
vols = np.zeros(F, dtype=np.float32)
for f in range(F):
sub = preds_4d[:, :, slice_indices, f]
vols[f] = (sub == 2).sum() * voxel
return vols
def pick_ed_es_from_volumes(vols_uL, prefer_frame0=True, rel_tol=0.05, min_sep=1):
ed = int(np.argmax(vols_uL))
if prefer_frame0 and abs(vols_uL[0] - vols_uL[ed]) <= rel_tol * max(vols_uL[ed], 1e-6):
ed = 0
es = int(np.argsort(vols_uL)[0])
for c in np.argsort(vols_uL):
if abs(int(c) - ed) >= min_sep:
es = int(c)
break
return ed, es
# =========================
# GIF (ED vs ES)
# =========================
def gif_animation_for_patient_pred_only(images_4d, preds_4d, patient_id, ed_idx, es_idx, output_dir):
os.makedirs(output_dir, exist_ok=True)
def overlay(ax, img, pred, alpha_myo=0.45, alpha_blood=0.45):
base = display_xform(img)
myo_mask = display_xform((pred == 1).astype(np.uint8)).astype(bool)
blood_mask = display_xform((pred == 2).astype(np.uint8)).astype(bool)
ax.imshow(base, cmap='gray', interpolation="none")
if myo_mask.any():
ax.imshow(np.ma.masked_where(~myo_mask, myo_mask),
cmap="Blues", alpha=alpha_myo, vmin=0, vmax=1, interpolation="none")
if blood_mask.any():
ax.imshow(np.ma.masked_where(~blood_mask, blood_mask),
cmap="jet", alpha=alpha_blood, vmin=0, vmax=1, interpolation="none")
ax.axis('off')
H, W, S, F = images_4d.shape
fig, axarr = plt.subplots(1, 2, figsize=(8, 4))
plt.tight_layout(rect=[0, 0, 1, 0.92])
def update(slice_idx):
axarr[0].clear()
axarr[1].clear()
overlay(axarr[0], images_4d[:, :, slice_idx, ed_idx], preds_4d[:, :, slice_idx, ed_idx])
axarr[0].set_title(f'ED (frame {ed_idx+1}) | Slice {slice_idx+1}')
overlay(axarr[1], images_4d[:, :, slice_idx, es_idx], preds_4d[:, :, slice_idx, es_idx])
axarr[1].set_title(f'ES (frame {es_idx+1}) | Slice {slice_idx+1}')
fig.suptitle(f'Mouse ID: {patient_id}', fontsize=14, y=0.98)
anim = animation.FuncAnimation(fig, update, frames=S, interval=700)
out_path = os.path.join(output_dir, f"{patient_id}_pred.gif")
anim.save(out_path, writer='pillow', fps=GIF_FPS)
plt.close(fig)
return out_path
# =========================
# CSV writer
# =========================
def write_all_in_one_csv(rows, per_frame_rows, csv_dir):
df_summary = pd.DataFrame(rows)
# Ensure summary columns exist even if rows is empty (prevents KeyError)
summary_cols = [
'Patient_ID',
'EDV_uL', 'ESV_uL', 'SV_uL', 'EF_%',
'MyocardiumMass_ED_mg', 'MyocardiumMass_ES_mg',
'ED_frame_index', 'ES_frame_index',
'PixelSpacing_row_mm', 'PixelSpacing_col_mm', 'SliceThickness_mm'
]
df_summary = df_summary.reindex(columns=summary_cols)
if per_frame_rows:
df_perframe = pd.concat(per_frame_rows, ignore_index=True)
else:
df_perframe = pd.DataFrame(columns=['Patient_ID','Frame','Volume_uL','MyocardiumMass_mg'])
# Convert indices to 1-based for output
if not df_perframe.empty:
df_perframe["Frame"] = df_perframe["Frame"].astype(int) + 1
if not df_summary.empty:
for c in ("ED_frame_index", "ES_frame_index"):
if c in df_summary.columns:
df_summary[c] = df_summary[c].astype('Int64') + 1
for c in ("EF_%", "EDV_uL", "ESV_uL", "SV_uL",
"MyocardiumMass_ED_mg", "MyocardiumMass_ES_mg",
"PixelSpacing_row_mm", "PixelSpacing_col_mm", "SliceThickness_mm"):
if c in df_summary.columns:
df_summary[c] = df_summary[c].astype(float).map(lambda x: f"{x:.2f}")
for c in ("Volume_uL", "MyocardiumMass_mg"):
if c in df_perframe.columns and not df_perframe.empty:
df_perframe[c] = df_perframe[c].astype(float).map(lambda x: f"{x:.2f}")
# Merge per-frame + summary (safe even if df_summary is empty)
all_in_one = df_perframe.merge(
df_summary[summary_cols],
on='Patient_ID',
how='left'
)[
[
'Patient_ID', 'ED_frame_index', 'ES_frame_index',
'EDV_uL', 'ESV_uL', 'SV_uL', 'EF_%',
'MyocardiumMass_ED_mg', 'MyocardiumMass_ES_mg',
'Frame', 'Volume_uL', 'MyocardiumMass_mg',
'PixelSpacing_row_mm', 'PixelSpacing_col_mm', 'SliceThickness_mm'
]
]
os.makedirs(csv_dir, exist_ok=True)
out_csv = os.path.join(csv_dir, 'Results.csv')
all_in_one.to_csv(out_csv, index=False)
log(f"CSV written: {out_csv}")
return out_csv
# =========================
# Same-tab download (no /media, no new tab)
# Styled like Streamlit buttons
# =========================
def _same_tab_download_button(label: str, data_bytes: bytes, file_name: str, mime: str = "text/csv", *, key: Optional[str] = None):
"""
Streamlit-like download button that:
β€’ hovers as white bg + red text/border
β€’ turns solid red with white text while pressed
β€’ downloads in the SAME TAB (Blob + programmatic click)
"""
import html, hashlib, base64
import streamlit as st
import streamlit.components.v1 as components
b64 = base64.b64encode(data_bytes).decode("ascii")
btn_id = f"dl_{(key or file_name)}_{hashlib.sha256((key or file_name).encode()).hexdigest()[:8]}"
# CSS: normal -> hover -> pressed (solid red)
st.markdown(f"""
<style>
a#{btn_id} {{
appearance: none;
display: inline-flex; align-items: center; justify-content: center;
padding: 0.5rem 0.75rem;
border-radius: 0.5rem;
border: 1px solid rgba(49,51,63,.2);
background: var(--background-color);
color: var(--text-color);
font-weight: 600; text-decoration: none !important;
box-shadow: 0 1px 2px rgba(0,0,0,0.04);
transition: color .15s ease, border-color .15s ease,
box-shadow .15s ease, transform .05s ease, background-color .15s;
cursor: pointer;
user-select: none;
-webkit-tap-highlight-color: transparent;
}}
/* Hover: white bg, red border/text */
a#{btn_id}:hover, a#{btn_id}:focus {{
background: var(--background-color);
color: var(--accent) !important;
border-color: var(--accent) !important;
box-shadow: 0 2px 6px rgba(239,68,68,0.20);
transform: translateY(-1px);
}}
/* Active/pressed: solid red with white text */
a#{btn_id}:active,
a#{btn_id}.pressed {{
background: var(--accent) !important;
border-color: var(--accent) !important;
color: #fff !important;
box-shadow: 0 3px 10px rgba(239,68,68,0.35);
transform: translateY(0);
}}
a#{btn_id}:focus-visible {{
outline: none;
box-shadow: 0 0 0 0.2rem rgba(239,68,68,0.35);
}}
</style>
""", unsafe_allow_html=True)
# Render the button (no navigation in href)
st.markdown(
f'<a id="{btn_id}" href="#" '
f' data-b64="{b64}" data-mime="{html.escape(mime)}" data-fname="{html.escape(file_name)}">{html.escape(label)}</a>',
unsafe_allow_html=True
)
# JS: add a temporary "pressed" class on mousedown/touch, then same-tab download via Blob
components.html(f"""
<script>
(function () {{
try {{
const doc = window.parent.document;
const a = doc.getElementById("{btn_id}");
if (!a) return;
const pressOn = () => a.classList.add("pressed");
const pressOff = () => a.classList.remove("pressed");
a.addEventListener("mousedown", pressOn, true);
a.addEventListener("mouseup", pressOff, true);
a.addEventListener("mouseleave",pressOff, true);
a.addEventListener("touchstart",pressOn, {{passive:true}});
a.addEventListener("touchend", pressOff, true);
a.addEventListener("touchcancel",pressOff, true);
a.addEventListener("click", function(ev) {{
ev.preventDefault();
ev.stopImmediatePropagation();
const b64 = a.getAttribute("data-b64");
const mime = a.getAttribute("data-mime") || "application/octet-stream";
const fname = a.getAttribute("data-fname") || "download";
// base64 β†’ Blob
const bstr = atob(b64);
const len = bstr.length;
const u8 = new Uint8Array(len);
for (let i = 0; i < len; i++) u8[i] = bstr.charCodeAt(i);
const blob = new Blob([u8], {{ type: mime }});
const url = URL.createObjectURL(blob);
// programmatic same-tab download
const tmp = doc.createElement("a");
tmp.href = url;
tmp.download = fname;
tmp.style.display = "none";
doc.body.appendChild(tmp);
tmp.click();
// keep the red state briefly so it's visible, then clean up
setTimeout(() => {{
URL.revokeObjectURL(url);
tmp.remove();
pressOff();
}}, 150);
}}, true);
}} catch (err) {{
console.debug("download handler error:", err);
}}
}})();
</script>
""", height=0)
# =========================
# Per-file processing
# =========================
def process_nifti_case(nifti_path: str, model, rows_acc: List[Dict], per_frame_rows_acc: List[pd.DataFrame]):
pid = Path(nifti_path).stem
log(f"Using NIfTI input: {nifti_path}")
imgs_4d, spacing, final_aff = load_nifti_4d(nifti_path, orient_target=ORIENT_TARGET)
# Decide display baseline for non-destructive on-screen transforms
global CURRENT_DISPLAY_ORIENT
if ORIENT_TARGET is None:
try:
axc = nib.aff2axcodes(final_aff) # affine -> axcodes
if tuple(axc[:3]) == ('L','P','S'):
CURRENT_DISPLAY_ORIENT = 'LPS'
elif tuple(axc[:3]) == ('R','A','S'):
CURRENT_DISPLAY_ORIENT = 'RAS'
else:
CURRENT_DISPLAY_ORIENT = 'LPS'
except Exception:
CURRENT_DISPLAY_ORIENT = 'LPS'
else:
CURRENT_DISPLAY_ORIENT = ORIENT_TARGET
native_h, native_w, S, F = imgs_4d.shape
# Predict (256x256) β†’ resize back to native for metrics
preds_4d_256 = predict_nifti_4d(model, imgs_4d, batch_size=BATCH_SIZE)
preds_4d = resize_masks_to_native(preds_4d_256, native_h, native_w)
# Clean islands
if ENABLE_ISLAND_REMOVAL:
preds_4d = clean_predictions_per_frame_3d(preds_4d)
# Robust ED/ES via mid-slice subset; then full-stack EDV/ESV at those indices
voxel_mm3 = spacing["row_mm"] * spacing["col_mm"] * spacing["slice_thickness_mm"]
if USE_MID_SLICES_FOR_ED_ES:
has_blood, has_myo, _, _ = slice_validity_matrix(preds_4d, A_blood_min=MID_A_BLOOD_MIN, A_myo_min=MID_A_MYO_MIN)
mid_slices = choose_mid_slices(has_blood, has_myo, K=min(MID_K, preds_4d.shape[2]), min_frac=MID_MIN_VALID_FRAC)
vols_subset = frame_volumes_subset_uL(preds_4d, spacing, mid_slices)
ed_idx, es_idx = pick_ed_es_from_volumes(vols_subset, prefer_frame0=True, rel_tol=0.05, min_sep=1)
vols_full = np.array([(preds_4d[..., f] == 2).sum() * voxel_mm3 for f in range(F)], dtype=np.float32)
EDV_uL = float(vols_full[ed_idx]); ESV_uL = float(vols_full[es_idx])
else:
vols_full = np.array([(preds_4d[..., f] == 2).sum() * voxel_mm3 for f in range(F)], dtype=np.float32)
ed_idx, es_idx = pick_ed_es_from_volumes(vols_full, prefer_frame0=True, rel_tol=0.05, min_sep=1)
EDV_uL = float(vols_full[ed_idx]); ESV_uL = float(vols_full[es_idx])
SV_uL = EDV_uL - ESV_uL
EF_pct = (SV_uL / EDV_uL * 100.0) if EDV_uL > 0 else 0.0
# Per-frame myocardium mass (mg)
per_frame_df = compute_per_frame_metrics(preds_4d, spacing)
myo_mass_ED_mg = float(per_frame_df.loc[per_frame_df["Frame"] == ed_idx, "MyocardiumMass_mg"].values[0])
myo_mass_ES_mg = float(per_frame_df.loc[per_frame_df["Frame"] == es_idx, "MyocardiumMass_mg"].values[0])
per_frame_df.insert(0, "Patient_ID", pid)
per_frame_rows_acc.append(per_frame_df)
# -------- Append summary row BEFORE any UI drawing
rows_acc.append({
'Patient_ID': pid,
'EDV_uL': EDV_uL,
'ESV_uL': ESV_uL,
'SV_uL' : SV_uL,
'EF_%' : EF_pct,
'MyocardiumMass_ED_mg': myo_mass_ED_mg,
'MyocardiumMass_ES_mg': myo_mass_ES_mg,
'ED_frame_index': int(ed_idx),
'ES_frame_index': int(es_idx),
'SliceThickness_mm': spacing['slice_thickness_mm'],
'PixelSpacing_row_mm': spacing['row_mm'],
'PixelSpacing_col_mm': spacing['col_mm'],
})
# GIF: ED vs ES (slices animate)
gif_path = gif_animation_for_patient_pred_only(imgs_4d, preds_4d, pid, ed_idx, es_idx, GIFS_DIR)
try:
st.image(gif_path, caption="Generated GIF", use_column_width=True)
except TypeError:
st.image(gif_path, caption="Generated GIF")
log(f"GIF saved: {gif_path}")
# =========================
# UI
# =========================
def main():
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
_inject_layout_css()
# ---- Fixed, far-left logo; spacer prevents overlap with tabs ----
st.markdown(
f'''
<div id="fixed-edge-logo" aria-hidden="true" role="presentation">
<img src="{LOGO_URL}" alt="Pre-Clinical Cardiac MRI Segmentation">
</div>
<div class="edge-logo-spacer"></div>
''',
unsafe_allow_html=True,
)
# ---------- REAL tabs ----------
tab1, tab2, tab3 = st.tabs(["Segmentation App", "Dataset", "NIfTI converter"])
# ===== Tab 1: Segmentation App =====
with tab1:
# ---------- REMOVE PAPERCLIP ----------
st.markdown(
"""
<style>
/* Hide Streamlit's hover anchor/paperclip on all headings */
[data-testid="stHeading"] a,
h1 a[href^="#"],
h2 a[href^="#"],
h3 a[href^="#"] {
display: none !important;
visibility: hidden !important;
}
</style>
""",
unsafe_allow_html=True
)
# ---------- HERO ----------
HERO_HTML = dedent("""\
<div class="content-wrap">
<div class="measure-wrap">
<div class="text-wrap">
<h1 class="hero-title">
Open-Source Pre-Clinical Image Segmentation:<br>
Mouse cardiac MRI datasets with a deep learning segmentation framework
</h1>
</div>
<div class="text-wrap">
<p>We present the first publicly-available pre-clinical cardiac MRI dataset, along with an open-source DL segmentation model (both available on GitHub:
<a href="https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git" target="_blank" rel="noopener">https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git</a>) and this web-based interface for easy deployment.</p>
<p>The dataset comprises complete cine short-axis cardiac MRI images from 130 mice with diverse phenotypes. It also contains expert manual segmentations of left ventricular (LV) blood pool and myocardium at end-diastole, end-systole, as well as additional timeframes with artefacts to improve robustness.</p>
<p>Using this resource, we developed an open-source DL segmentation model based on the UNet3+ architecture.</p>
<p>This Streamlit application consists of the inference model to provide an easy-to-use interface for our DL segmentation model, without the need for local installation. The application requires the complete SAX cine image data to be uploaded in NIfTI format, as a ZIP file using the simple file browser below.</p>
<p>Pre-processing and inference are then performed on all 2D images. The resulting blood-pool and myocardial volumes are combined across all slices at each timeframe and output to a .csv file. The blood-pool volumes are used to identify ED and ES, and these volumes are displayed as a GIF with the segmentations overlaid.</p>
<p class="note-text">(Note: This Hugging Face model was developed as part of a manuscript submitted to the <em>Journal of Cardiovascular Magnetic Resonance</em>)</p>
</div>
</div>
</div>
""")
st.markdown(HERO_HTML, unsafe_allow_html=True)
# HERO_HTML = dedent("""
# <div class="content-wrap">
# <div class="measure-wrap">
# <div class="hero-wrap">
# <h1 class="hero-title">
# Open-Source Pre-Clinical Image Segmentation:<br>
# Mouse cardiac MRI datasets with a deep learning segmentation framework
# </h1>
# </div>
# <div class="text-wrap">
# <p>We present the first publicly-available pre-clinical cardiac MRI dataset, along with an open-source DL segmentation model (both available on GitHub:
# <a href="https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git" target="_blank" rel="noopener">https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git</a>) and this web-based interface for easy deployment.</p>
# <p>The dataset comprises complete cine short-axis cardiac MRI images from 130 mice with diverse phenotypes. It also contains expert manual segmentations of left ventricular (LV) blood pool and myocardium at end-diastole, end-systole, as well as additional timeframes with artefacts to improve robustness.</p>
# <p>Using this resource, we developed an open-source DL segmentation model based on the UNet3+ architecture.</p>
# <p>This Streamlit application consists of the inference model to provide an easy-to-use interface for our DL segmentation model, without the need for local installation. The application requires the complete SAX cine image data to be uploaded in NIfTI format, as a ZIP file using the simple file browser below.</p>
# <p>Pre-processing and inference are then performed on all 2D images. The resulting blood-pool and myocardial volumes are combined across all slices at each timeframe and output to a .csv file. The blood-pool volumes are used to identify ED and ES, and these volumes are displayed as a GIF with the segmentations overlaid.</p>
# <p>
# Pre-processing and inference are then performed on all 2D images. The resulting blood-pool and myocardial volumes are combined across all slices at each timeframe and output to a .csv file. The blood-pool volumes are used to identify ED and ES, and these volumes are displayed as a GIF with the segmentations overlaid.
# <br>
# (Note: This Hugging Face model was developed as part of a manuscript submitted to the <em>Journal of Cardiovascular Magnetic Resonance</em>).
# </p>
# </div>
# </div>
# </div>
# """)
# st.markdown(HERO_HTML, unsafe_allow_html=True)
# ---------- DATA UPLOAD (aligned) ----------
st.markdown('<div class="content-wrap"><div class="measure-wrap" id="upload-wrap">', unsafe_allow_html=True)
st.markdown(
"""
<h2 style='margin-bottom:0.2rem;'>
Data Upload <span style='font-size:33px;'>πŸ“€</span>
</h2>
""",
unsafe_allow_html=True
)
uploaded_zip = st.file_uploader(
"Upload ZIP of NIfTI files 🐭",
type="zip",
label_visibility="visible"
)
st.markdown(
"""
<p style="margin-top:0.3rem; font-size:15px; color:#444;">
Or download our <a href="https://huggingface.co/spaces/mrphys/Pre-clinical_DL_segmentation/tree/main/NIfTI_dataset" target="_blank" rel="noopener">
sample NIfTI dataset</a> to try it out!
</p>
""",
unsafe_allow_html=True
)
st.markdown('</div></div>', unsafe_allow_html=True)
# ---- Clear stale CSV when a new ZIP is picked ----
if uploaded_zip is not None:
if st.session_state.get("_last_zip_name") != uploaded_zip.name:
st.session_state.pop("csv_bytes", None)
st.session_state.pop("csv_name", None)
st.session_state.pop("rows_count", None)
st.session_state.pop("_dl_token", None) # reset download-button identity
st.session_state["_last_zip_name"] = uploaded_zip.name
# ---- Extract helper ----
def extract_zip(zip_path, extract_to):
os.makedirs(extract_to, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
valid_files = [
f for f in zip_ref.namelist()
if "__MACOSX" not in f and not os.path.basename(f).startswith("._")
]
zip_ref.extractall(extract_to, members=valid_files)
# ---------- PROCESS ----------
if uploaded_zip and st.button("Process Data"):
zip_label = uploaded_zip.name or "ZIP"
with st.spinner(f"Processing {zip_label}..."):
tmpdir = tempfile.mkdtemp()
zpath = os.path.join(tmpdir, uploaded_zip.name)
with open(zpath, "wb") as f:
f.write(uploaded_zip.read())
extract_zip(zpath, tmpdir)
# Find NIfTI files inside ZIP
nii_files: List[str] = []
for root, _, files in os.walk(tmpdir):
for fn in files:
low = fn.lower()
if low.endswith(".nii") or low.endswith(".nii.gz"):
nii_files.append(os.path.join(root, fn))
if not nii_files:
st.error("No NIfTI files (.nii / .nii.gz) found in the uploaded ZIP.")
else:
# Load model
model = keras.models.load_model(
MODEL_PATH,
custom_objects={
'focal_tversky_loss': None,
'dice_coef_no_bkg': None,
'ResizeAndConcatenate': ResizeAndConcatenate,
'dice_myo': None,
'dice_blood': None,
'dice': None
},
compile=False
)
log("Model loaded.")
rows: List[Dict] = []
per_frame_rows: List[pd.DataFrame] = []
for fp in sorted(nii_files):
try:
process_nifti_case(fp, model, rows, per_frame_rows)
except Exception as e:
st.warning(f"Failed: {Path(fp).name} β€” {e}")
# ---- BELOW the GIF(s): write CSV & persist bytes/name ----
csv_path = write_all_in_one_csv(rows, per_frame_rows, CSV_DIR)
csv_download_name = f"{Path(zip_label).stem}_Results.csv"
with open(csv_path, "rb") as f:
csv_bytes = f.read()
st.session_state["csv_bytes"] = csv_bytes
st.session_state["csv_name"] = csv_download_name
st.session_state["rows_count"] = len(rows)
# ---- Re-render success + robust download on EVERY run (same tab, no 404) ----
if "csv_bytes" in st.session_state and "csv_name" in st.session_state:
st.success(f"Processed {st.session_state.get('rows_count', 0)} NIfTI file(s).")
# Use our same-tab data:URI button that looks like Streamlit's and turns red on hover
_same_tab_download_button(
label="Download CSV",
data_bytes=st.session_state["csv_bytes"],
file_name=st.session_state["csv_name"],
mime="text/csv",
key="results"
)
# ===== Tab 2: Dataset =====
with tab2:
st.markdown(
"""
<style>
/* --- Full-width dark hero section --- */
.ds-hero-section {
background: #082c3a;
padding: 30px 10px;
text-align: center;
margin-left: -100vw;
margin-right: -100vw;
left: 0; right: 0; position: relative;
}
.ds-hero-section-inner { max-width: 1100px; margin: 0 auto; }
/* --- Hero image (centered) --- */
.ds-heroimg {
max-width: 1000px; width: 100%; height: auto;
border-radius: 10px; box-shadow: 0 8px 24px rgba(0,0,0,.25);
display: block; margin: 0 auto;
}
/* --- Caption (light on dark) --- */
.ds-caption {
text-align: center; color: #e0f2f1;
font-size: 18px; line-height: 1.5;
margin: 14px 6px 0; font-style: italic;
}
/* --- Thicker orange divider --- */
.ds-hr {
height: 8px;
border: 0; background: #ea580c;
margin: 24px 0 20px;
border-radius: 3px;
}
/* --- White background lower content --- */
.ds-wrap {
max-width: var(--content-measure, 920px);
margin-left: var(--left-offset, 40px);
margin-right: auto;
background: #fff; padding: 16px 24px; border-radius: 6px;
}
/* --- Section headers --- */
.ds-section h2 {
font-size: 20px; font-weight: 700;
margin: 0 0 2px;
color: #082c3a;
}
/* --- Text content --- */
.ds-section p {
font-size: 16px; line-height: 1.6; color: #333;
margin: 0 0 6px;
}
.ds-section ul {
margin: 2px 0 8px 18px;
padding: 0;
}
.ds-section li {
font-size: 16px; line-height: 1.6; color: #333;
margin-bottom: 10px;
}
.ds-section a {
color: #0b66c3 !important; text-decoration: underline !important;
}
/* --- Remove Streamlit paperclip/anchor on headings --- */
h2 a, [data-testid="stHeading"] a { display: none !important; }
</style>
<!-- Full-width dark top section -->
<div class="ds-hero-section">
<div class="ds-hero-section-inner">
<img class="ds-heroimg"
src="https://raw.githubusercontent.com/whanisa/Segmentation/main/icon/open_source.png"
alt="Illustration of mouse with heart representing open-source pre-clinical cardiac MRI dataset" />
<p class="ds-caption">
The first publicly-available pre-clinical cardiac MRI dataset,<br/>
with an open-source segmentation model and an easy-to-use web app.
</p>
</div>
</div>
<hr class="ds-hr"/>
<!-- White lower content -->
<div class="ds-wrap">
<div class="ds-section">
<h2>Repository & Paper Resources</h2>
<p>GitHub:
<a href="https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git" target="_blank">
Open-Source_Pre-Clinical_Segmentation
</a>
</p>
<h2>πŸ“Š Dataset Availability</h2>
<ul>
<li>
<strong>Full dataset (130 mice, HDF5 format):</strong><br/>
Available in our
<a href="https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation/tree/master/Data" target="_blank">
GitHub repository
</a>.<br/>
Each .h5 file contains the complete cine SAX MRI and expert manual segmentations.
</li>
<li>
<strong>Sample datasets (3 mice, NIfTI format):</strong><br/>
Available here:
<a href="https://huggingface.co/spaces/mrphys/Pre-clinical_DL_segmentation/tree/main/NIfTI_dataset" target="_blank">
NIfTI Sample Dataset
</a>.<br/>
We provide 3 example NIfTI datasets for quick download and direct use within the app.
</li>
</ul>
</div>
<hr class="ds-hr"/>
<div class="ds-section">
<h2>Notes</h2>
<ul>
<li>Complete SAX cine MRI for 130 mice with expert LV blood & myocardium labels (ED/ES).</li>
</ul>
</div>
</div>
""",
unsafe_allow_html=True
)
# ===== Tab 3: NIfTI converter =====
with tab3:
st.subheader("NIfTI converter")
st.markdown(
"""
**Working with Agilent data?**
Easily convert your fid files to NIfTI using our **fid2niix**.
```bash
fid2niix -z y -o /path/to/out -f "%p_%s" /path/to/fid_folder
```
**πŸ’‘ Tips**
- Upload a ZIP file that includes both `fid` and `procpar`.
- Conversion outputs **NIfTI-1** format, ready to use with our web app.
"""
)
if __name__ == "__main__":
main()