Synthrad2025 / app.py
FelixzeroSun's picture
Update app.py
ef42723 verified
# app.py
import os
import io
import tempfile
import zipfile
import numpy as np
import SimpleITK as sitk
import streamlit as st
from PIL import Image, ImageDraw
from huggingface_hub import snapshot_download
HF_REPOS = {
"Task 1 (MR โ†’ CT)": "aehrc/Synthrad2025_task1",
"Task 2 (CBCT โ†’ CT)": "aehrc/Synthrad2025_task2",
}
LOCAL_WEIGHTS_DIRS = {
"Task 1 (MR โ†’ CT)": os.path.abspath("weights/task1"),
"Task 2 (CBCT โ†’ CT)": os.path.abspath("weights/task2"),
}
token = os.getenv("HF_TOKEN")
if token is None:
print("[Warn] HF_TOKEN not set. If the model repo is private, set it in Settings โ†’ Variables and secrets.")
REPO_DIRS = {}
for task_name, repo in HF_REPOS.items():
repo_dir = snapshot_download(
repo_id=repo,
repo_type="model",
local_dir=LOCAL_WEIGHTS_DIRS[task_name],
local_dir_use_symlinks=False,
token=token,
)
REPO_DIRS[task_name] = repo_dir
os.environ.setdefault("nnUNet_raw", "./nnunet_raw")
os.environ.setdefault("nnUNet_preprocessed", "./nnunet_preprocessed")
os.environ["OPENBLAS_NUM_THREADS"] = "1"
from process import SynthradAlgorithm2
from process_1 import SynthradAlgorithm1
st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide")
st.title("SynthRad โ€” MRI/CBCT + Mask โ†’ synthetic CT")
st.image("./workflow.png",width=800)
TASKS = ["Task 1 (MR โ†’ CT)", "Task 2 (CBCT โ†’ CT)"]
task = st.radio("Select Task", TASKS, index=0, horizontal=True)
if task == "Task 1 (MR โ†’ CT)":
vol_label = "MRI volume (.nii/.nii.gz/.mha)"
else:
vol_label = "CBCT volume (.nii/.nii.gz/.mha)"
os.environ["nnUNet_results"] = REPO_DIRS[task]
if "algos" not in st.session_state:
st.session_state.algos = {}
if "synth_ct" not in st.session_state:
st.session_state.synth_ct = None
if "orig_meta" not in st.session_state:
st.session_state.orig_meta = None
if "vol_np" not in st.session_state:
st.session_state.vol_np = None
if "input_vol" not in st.session_state:
st.session_state.input_vol = None
if "input_mask" not in st.session_state:
st.session_state.input_mask = None
def get_algo(task_name: str):
if task_name not in st.session_state.algos:
if task_name == "Task 1 (MR โ†’ CT)":
st.session_state.algos[task_name] = SynthradAlgorithm1()
else:
st.session_state.algos[task_name] = SynthradAlgorithm2()
return st.session_state.algos[task_name]
algo = get_algo(task)
st.subheader("Input")
src = st.radio("Source", ["Sample", "Upload"], index=0, horizontal=True)
def build_sample_map(task_name: str):
repo_dir = REPO_DIRS[task_name]
if task_name == "Task 1 (MR โ†’ CT)":
vol_fname = "mr.mha"
mask_fname = "mask1.mha"
else:
vol_fname = "cbct.mha"
mask_fname = "mask2.mha"
def pack(region_dir):
vol_path = os.path.join(repo_dir, region_dir, vol_fname)
mask_path = os.path.join(repo_dir, region_dir, mask_fname)
gt_path = os.path.join(repo_dir, region_dir, "ct.mha") # ็บฆๅฎš๏ผšGT=ct.mha
return {"vol": vol_path, "mask": mask_path, "gt": gt_path}
sample_map = {
"Abdomen (sample)": {"region": "Abdomen", **pack("Abdomen")},
"Head and Neck (sample)": {"region": "Head and Neck", **pack("Head and Neck")},
"Thorax (sample)": {"region": "Thorax", **pack("Thorax")},
}
return sample_map
SAMPLE_MAP = build_sample_map(task)
def _download_sitk_image(img: sitk.Image, file_name: str, label: str):
with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp:
sitk.WriteImage(img, tmp.name)
tmp_path = tmp.name
with open(tmp_path, "rb") as f:
st.download_button(
label=label,
data=f.read(),
file_name=file_name,
mime="application/octet-stream",
)
try:
os.remove(tmp_path)
except Exception:
pass
def _read_sitk_from_uploaded(f):
suffix = ".nii.gz" if f.name.endswith(".nii.gz") else os.path.splitext(f.name)[1]
bio = io.BytesIO(f.read())
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(bio.getvalue()); tmp.flush(); path = tmp.name
img = sitk.ReadImage(path)
try:
os.remove(path)
except Exception:
pass
return img
def _read_sitk_from_path(path):
if not os.path.exists(path):
st.error(f"Sample file missing: {path}")
st.stop()
return sitk.ReadImage(path)
def _norm2u8(slice2d):
s = slice2d.astype(np.float32)
s = (s - np.percentile(s, 1)) / (np.percentile(s, 99) - np.percentile(s, 1) + 1e-6)
s = np.clip(s, 0, 1)
return (s * 255).astype(np.uint8)
c1, c2, c3 = st.columns([2, 2, 1])
if src == "Upload":
with c1:
vol_file = st.file_uploader(vol_label, type=["nii", "nii.gz", "mha"], key="vol")
with c2:
mask_file = st.file_uploader("Mask volume (.nii/.nii.gz/.mha)", type=["nii", "nii.gz", "mha"], key="mask")
with c3:
region = st.radio("Region", ["Head and Neck", "Abdomen", "Thorax"], index=1)
inputs_ready = (vol_file is not None) and (mask_file is not None)
region_for_run = region
else:
with c1:
sample_key = st.selectbox("Choose a sample", list(SAMPLE_MAP.keys()))
with c2:
st.markdown("Region (fixed by sample)")
st.write(f"**{SAMPLE_MAP[sample_key]['region']}**")
with c3:
st.markdown(" ", unsafe_allow_html=True)
inputs_ready = (sample_key is not None)
region_for_run = SAMPLE_MAP[sample_key]["region"]
run_btn = st.button("Run", type="primary", disabled=not inputs_ready)
if run_btn:
with st.spinner(f"Running nnUNetv2 {('SynthradAlgorithm1' if task=='Task 1 (MR โ†’ CT)' else 'SynthradAlgorithm2')}..."):
if src == "Upload":
in_vol_img = _read_sitk_from_uploaded(vol_file)
mask_img = _read_sitk_from_uploaded(mask_file)
else:
sample = SAMPLE_MAP[sample_key]
in_vol_img = _read_sitk_from_path(sample["vol"])
mask_img = _read_sitk_from_path(sample["mask"])
st.session_state.orig_meta = (
in_vol_img.GetSpacing(),
in_vol_img.GetOrigin(),
in_vol_img.GetDirection(),
)
out_img = algo.predict({"image": in_vol_img, "mask": mask_img, "region": region_for_run})
st.session_state.synth_ct = out_img
st.session_state.vol_np = sitk.GetArrayFromImage(out_img).astype(np.float32)
st.session_state.input_vol = in_vol_img
st.session_state.input_mask = mask_img
if st.session_state.vol_np is None:
st.info("Select Upload or Sample, then click Run")
else:
in_lps = sitk.DICOMOrient(st.session_state.input_vol, "LPS")
out_lps = sitk.DICOMOrient(st.session_state.synth_ct, "LPS")
res = sitk.ResampleImageFilter()
res.SetReferenceImage(in_lps)
res.SetInterpolator(sitk.sitkLinear)
res.SetOutputPixelType(out_lps.GetPixelID())
out_on_input = res.Execute(out_lps)
gt_on_input = None
if src == "Sample":
gt_path = SAMPLE_MAP[sample_key].get("gt", None)
if gt_path and os.path.exists(gt_path):
gt_img = sitk.DICOMOrient(sitk.ReadImage(gt_path), "LPS")
res.SetReferenceImage(in_lps)
res.SetInterpolator(sitk.sitkLinear)
res.SetOutputPixelType(gt_img.GetPixelID())
gt_on_input = res.Execute(gt_img)
# numpy
in_vol = sitk.GetArrayFromImage(in_lps).astype(np.float32)
syn_vol = sitk.GetArrayFromImage(out_on_input).astype(np.float32)
gt_vol = sitk.GetArrayFromImage(gt_on_input).astype(np.float32) if gt_on_input is not None else None
st.subheader("Input vs Synthetic CT Viewer (Axial only)")
n_slices = in_vol.shape[0]
idx = st.slider("Slice index (Axial/Z)", 0, n_slices - 1, n_slices // 2)
def get_axial(arr, k):
return arr[k, :, :]
sl_in = get_axial(in_vol, idx)
sl_syn = get_axial(syn_vol, idx)
img_in = _norm2u8(sl_in)
img_syn = _norm2u8(sl_syn)
img_gt = _norm2u8(get_axial(gt_vol, idx)) if gt_vol is not None else None
overlay_mask = st.checkbox("Overlay mask (red)")
alpha = st.slider("Mask opacity", 0.0, 1.0, 0.35, 0.05, disabled=not overlay_mask)
mask_slice = None
if overlay_mask and st.session_state.input_mask is not None:
mask_lps = sitk.DICOMOrient(st.session_state.input_mask, "LPS")
res_nn = sitk.ResampleImageFilter()
res_nn.SetReferenceImage(in_lps)
res_nn.SetInterpolator(sitk.sitkNearestNeighbor)
mask_on_input = res_nn.Execute(mask_lps)
mask_np = sitk.GetArrayFromImage(mask_on_input)
mask_slice = get_axial(mask_np, min(idx, mask_np.shape[0]-1))
mask_plot = np.where(mask_slice > 0, 1.0, np.nan)
else:
mask_plot = None
import plotly.graph_objects as go
from plotly.subplots import make_subplots
sx, sy, _ = in_lps.GetSpacing()
xs = np.arange(img_in.shape[1]) * sx
ys = np.arange(img_in.shape[0]) * sy
cols = 3 if (src == "Sample" and img_gt is not None) else 2
titles = ["Input (MRI/CBCT)", "Synthetic CT"] + (["Ground-Truth CT"] if cols == 3 else [])
fig = make_subplots(rows=1, cols=cols, subplot_titles=tuple(titles))
fig.add_trace(go.Heatmap(z=img_in, x=xs, y=ys, colorscale="gray",
zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=1)
# synCT
fig.add_trace(go.Heatmap(z=img_syn, x=xs, y=ys, colorscale="gray",
zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=2)
# GT
if cols == 3:
fig.add_trace(go.Heatmap(z=img_gt, x=xs, y=ys, colorscale="gray",
zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=3)
# mask overlay
if mask_plot is not None:
red_scale = [[0.0, "rgba(255,0,0,1.0)"], [1.0, "rgba(255,0,0,1.0)"]]
for c in range(1, cols+1):
fig.add_trace(go.Heatmap(z=mask_plot, x=xs, y=ys, colorscale=red_scale,
showscale=False, opacity=alpha, hoverinfo="skip"), row=1, col=c)
for c in range(1, cols+1):
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=c)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=c)
fig.update_layout(height=600, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(fig, use_container_width=True)
# Caption
if cols == 3:
st.caption(f"Axial (Z) slice {idx+1}/{n_slices} โ€” All aligned to input geometry; GT only for samples.")
else:
st.caption(f"Axial (Z) slice {idx+1}/{n_slices} โ€” Aligned to input geometry.")
col_d1, col_d2, col_d3 = st.columns(3)
with col_d3:
_download_sitk_image(st.session_state.synth_ct,
file_name="synth_ct.nii.gz",
label="Download synthetic CT")
with col_d1:
if st.session_state.input_vol is not None:
in_name = "input_mr.nii.gz" if task == "Task 1 (MR โ†’ CT)" else "input_cbct.nii.gz"
in_label = "Download input MRI" if task == "Task 1 (MR โ†’ CT)" else "Download input CBCT"
_download_sitk_image(st.session_state.input_vol, file_name=in_name, label=in_label)
else:
st.button("Download input", disabled=True)
with col_d2:
if st.session_state.input_mask is not None:
_download_sitk_image(st.session_state.input_mask,
file_name="input_mask.nii.gz",
label="Download input Mask")
else:
st.button("Download input Mask", disabled=True)