Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import io | |
| import tempfile | |
| import zipfile | |
| import numpy as np | |
| import SimpleITK as sitk | |
| import streamlit as st | |
| from PIL import Image, ImageDraw | |
| from huggingface_hub import snapshot_download | |
| HF_REPOS = { | |
| "Task 1 (MR โ CT)": "aehrc/Synthrad2025_task1", | |
| "Task 2 (CBCT โ CT)": "aehrc/Synthrad2025_task2", | |
| } | |
| LOCAL_WEIGHTS_DIRS = { | |
| "Task 1 (MR โ CT)": os.path.abspath("weights/task1"), | |
| "Task 2 (CBCT โ CT)": os.path.abspath("weights/task2"), | |
| } | |
| token = os.getenv("HF_TOKEN") | |
| if token is None: | |
| print("[Warn] HF_TOKEN not set. If the model repo is private, set it in Settings โ Variables and secrets.") | |
| REPO_DIRS = {} | |
| for task_name, repo in HF_REPOS.items(): | |
| repo_dir = snapshot_download( | |
| repo_id=repo, | |
| repo_type="model", | |
| local_dir=LOCAL_WEIGHTS_DIRS[task_name], | |
| local_dir_use_symlinks=False, | |
| token=token, | |
| ) | |
| REPO_DIRS[task_name] = repo_dir | |
| os.environ.setdefault("nnUNet_raw", "./nnunet_raw") | |
| os.environ.setdefault("nnUNet_preprocessed", "./nnunet_preprocessed") | |
| os.environ["OPENBLAS_NUM_THREADS"] = "1" | |
| from process import SynthradAlgorithm2 | |
| from process_1 import SynthradAlgorithm1 | |
| st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide") | |
| st.title("SynthRad โ MRI/CBCT + Mask โ synthetic CT") | |
| st.image("./workflow.png",width=800) | |
| TASKS = ["Task 1 (MR โ CT)", "Task 2 (CBCT โ CT)"] | |
| task = st.radio("Select Task", TASKS, index=0, horizontal=True) | |
| if task == "Task 1 (MR โ CT)": | |
| vol_label = "MRI volume (.nii/.nii.gz/.mha)" | |
| else: | |
| vol_label = "CBCT volume (.nii/.nii.gz/.mha)" | |
| os.environ["nnUNet_results"] = REPO_DIRS[task] | |
| if "algos" not in st.session_state: | |
| st.session_state.algos = {} | |
| if "synth_ct" not in st.session_state: | |
| st.session_state.synth_ct = None | |
| if "orig_meta" not in st.session_state: | |
| st.session_state.orig_meta = None | |
| if "vol_np" not in st.session_state: | |
| st.session_state.vol_np = None | |
| if "input_vol" not in st.session_state: | |
| st.session_state.input_vol = None | |
| if "input_mask" not in st.session_state: | |
| st.session_state.input_mask = None | |
| def get_algo(task_name: str): | |
| if task_name not in st.session_state.algos: | |
| if task_name == "Task 1 (MR โ CT)": | |
| st.session_state.algos[task_name] = SynthradAlgorithm1() | |
| else: | |
| st.session_state.algos[task_name] = SynthradAlgorithm2() | |
| return st.session_state.algos[task_name] | |
| algo = get_algo(task) | |
| st.subheader("Input") | |
| src = st.radio("Source", ["Sample", "Upload"], index=0, horizontal=True) | |
| def build_sample_map(task_name: str): | |
| repo_dir = REPO_DIRS[task_name] | |
| if task_name == "Task 1 (MR โ CT)": | |
| vol_fname = "mr.mha" | |
| mask_fname = "mask1.mha" | |
| else: | |
| vol_fname = "cbct.mha" | |
| mask_fname = "mask2.mha" | |
| def pack(region_dir): | |
| vol_path = os.path.join(repo_dir, region_dir, vol_fname) | |
| mask_path = os.path.join(repo_dir, region_dir, mask_fname) | |
| gt_path = os.path.join(repo_dir, region_dir, "ct.mha") # ็บฆๅฎ๏ผGT=ct.mha | |
| return {"vol": vol_path, "mask": mask_path, "gt": gt_path} | |
| sample_map = { | |
| "Abdomen (sample)": {"region": "Abdomen", **pack("Abdomen")}, | |
| "Head and Neck (sample)": {"region": "Head and Neck", **pack("Head and Neck")}, | |
| "Thorax (sample)": {"region": "Thorax", **pack("Thorax")}, | |
| } | |
| return sample_map | |
| SAMPLE_MAP = build_sample_map(task) | |
| def _download_sitk_image(img: sitk.Image, file_name: str, label: str): | |
| with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp: | |
| sitk.WriteImage(img, tmp.name) | |
| tmp_path = tmp.name | |
| with open(tmp_path, "rb") as f: | |
| st.download_button( | |
| label=label, | |
| data=f.read(), | |
| file_name=file_name, | |
| mime="application/octet-stream", | |
| ) | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| def _read_sitk_from_uploaded(f): | |
| suffix = ".nii.gz" if f.name.endswith(".nii.gz") else os.path.splitext(f.name)[1] | |
| bio = io.BytesIO(f.read()) | |
| with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: | |
| tmp.write(bio.getvalue()); tmp.flush(); path = tmp.name | |
| img = sitk.ReadImage(path) | |
| try: | |
| os.remove(path) | |
| except Exception: | |
| pass | |
| return img | |
| def _read_sitk_from_path(path): | |
| if not os.path.exists(path): | |
| st.error(f"Sample file missing: {path}") | |
| st.stop() | |
| return sitk.ReadImage(path) | |
| def _norm2u8(slice2d): | |
| s = slice2d.astype(np.float32) | |
| s = (s - np.percentile(s, 1)) / (np.percentile(s, 99) - np.percentile(s, 1) + 1e-6) | |
| s = np.clip(s, 0, 1) | |
| return (s * 255).astype(np.uint8) | |
| c1, c2, c3 = st.columns([2, 2, 1]) | |
| if src == "Upload": | |
| with c1: | |
| vol_file = st.file_uploader(vol_label, type=["nii", "nii.gz", "mha"], key="vol") | |
| with c2: | |
| mask_file = st.file_uploader("Mask volume (.nii/.nii.gz/.mha)", type=["nii", "nii.gz", "mha"], key="mask") | |
| with c3: | |
| region = st.radio("Region", ["Head and Neck", "Abdomen", "Thorax"], index=1) | |
| inputs_ready = (vol_file is not None) and (mask_file is not None) | |
| region_for_run = region | |
| else: | |
| with c1: | |
| sample_key = st.selectbox("Choose a sample", list(SAMPLE_MAP.keys())) | |
| with c2: | |
| st.markdown("Region (fixed by sample)") | |
| st.write(f"**{SAMPLE_MAP[sample_key]['region']}**") | |
| with c3: | |
| st.markdown(" ", unsafe_allow_html=True) | |
| inputs_ready = (sample_key is not None) | |
| region_for_run = SAMPLE_MAP[sample_key]["region"] | |
| run_btn = st.button("Run", type="primary", disabled=not inputs_ready) | |
| if run_btn: | |
| with st.spinner(f"Running nnUNetv2 {('SynthradAlgorithm1' if task=='Task 1 (MR โ CT)' else 'SynthradAlgorithm2')}..."): | |
| if src == "Upload": | |
| in_vol_img = _read_sitk_from_uploaded(vol_file) | |
| mask_img = _read_sitk_from_uploaded(mask_file) | |
| else: | |
| sample = SAMPLE_MAP[sample_key] | |
| in_vol_img = _read_sitk_from_path(sample["vol"]) | |
| mask_img = _read_sitk_from_path(sample["mask"]) | |
| st.session_state.orig_meta = ( | |
| in_vol_img.GetSpacing(), | |
| in_vol_img.GetOrigin(), | |
| in_vol_img.GetDirection(), | |
| ) | |
| out_img = algo.predict({"image": in_vol_img, "mask": mask_img, "region": region_for_run}) | |
| st.session_state.synth_ct = out_img | |
| st.session_state.vol_np = sitk.GetArrayFromImage(out_img).astype(np.float32) | |
| st.session_state.input_vol = in_vol_img | |
| st.session_state.input_mask = mask_img | |
| if st.session_state.vol_np is None: | |
| st.info("Select Upload or Sample, then click Run") | |
| else: | |
| in_lps = sitk.DICOMOrient(st.session_state.input_vol, "LPS") | |
| out_lps = sitk.DICOMOrient(st.session_state.synth_ct, "LPS") | |
| res = sitk.ResampleImageFilter() | |
| res.SetReferenceImage(in_lps) | |
| res.SetInterpolator(sitk.sitkLinear) | |
| res.SetOutputPixelType(out_lps.GetPixelID()) | |
| out_on_input = res.Execute(out_lps) | |
| gt_on_input = None | |
| if src == "Sample": | |
| gt_path = SAMPLE_MAP[sample_key].get("gt", None) | |
| if gt_path and os.path.exists(gt_path): | |
| gt_img = sitk.DICOMOrient(sitk.ReadImage(gt_path), "LPS") | |
| res.SetReferenceImage(in_lps) | |
| res.SetInterpolator(sitk.sitkLinear) | |
| res.SetOutputPixelType(gt_img.GetPixelID()) | |
| gt_on_input = res.Execute(gt_img) | |
| # numpy | |
| in_vol = sitk.GetArrayFromImage(in_lps).astype(np.float32) | |
| syn_vol = sitk.GetArrayFromImage(out_on_input).astype(np.float32) | |
| gt_vol = sitk.GetArrayFromImage(gt_on_input).astype(np.float32) if gt_on_input is not None else None | |
| st.subheader("Input vs Synthetic CT Viewer (Axial only)") | |
| n_slices = in_vol.shape[0] | |
| idx = st.slider("Slice index (Axial/Z)", 0, n_slices - 1, n_slices // 2) | |
| def get_axial(arr, k): | |
| return arr[k, :, :] | |
| sl_in = get_axial(in_vol, idx) | |
| sl_syn = get_axial(syn_vol, idx) | |
| img_in = _norm2u8(sl_in) | |
| img_syn = _norm2u8(sl_syn) | |
| img_gt = _norm2u8(get_axial(gt_vol, idx)) if gt_vol is not None else None | |
| overlay_mask = st.checkbox("Overlay mask (red)") | |
| alpha = st.slider("Mask opacity", 0.0, 1.0, 0.35, 0.05, disabled=not overlay_mask) | |
| mask_slice = None | |
| if overlay_mask and st.session_state.input_mask is not None: | |
| mask_lps = sitk.DICOMOrient(st.session_state.input_mask, "LPS") | |
| res_nn = sitk.ResampleImageFilter() | |
| res_nn.SetReferenceImage(in_lps) | |
| res_nn.SetInterpolator(sitk.sitkNearestNeighbor) | |
| mask_on_input = res_nn.Execute(mask_lps) | |
| mask_np = sitk.GetArrayFromImage(mask_on_input) | |
| mask_slice = get_axial(mask_np, min(idx, mask_np.shape[0]-1)) | |
| mask_plot = np.where(mask_slice > 0, 1.0, np.nan) | |
| else: | |
| mask_plot = None | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| sx, sy, _ = in_lps.GetSpacing() | |
| xs = np.arange(img_in.shape[1]) * sx | |
| ys = np.arange(img_in.shape[0]) * sy | |
| cols = 3 if (src == "Sample" and img_gt is not None) else 2 | |
| titles = ["Input (MRI/CBCT)", "Synthetic CT"] + (["Ground-Truth CT"] if cols == 3 else []) | |
| fig = make_subplots(rows=1, cols=cols, subplot_titles=tuple(titles)) | |
| fig.add_trace(go.Heatmap(z=img_in, x=xs, y=ys, colorscale="gray", | |
| zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=1) | |
| # synCT | |
| fig.add_trace(go.Heatmap(z=img_syn, x=xs, y=ys, colorscale="gray", | |
| zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=2) | |
| # GT | |
| if cols == 3: | |
| fig.add_trace(go.Heatmap(z=img_gt, x=xs, y=ys, colorscale="gray", | |
| zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=3) | |
| # mask overlay | |
| if mask_plot is not None: | |
| red_scale = [[0.0, "rgba(255,0,0,1.0)"], [1.0, "rgba(255,0,0,1.0)"]] | |
| for c in range(1, cols+1): | |
| fig.add_trace(go.Heatmap(z=mask_plot, x=xs, y=ys, colorscale=red_scale, | |
| showscale=False, opacity=alpha, hoverinfo="skip"), row=1, col=c) | |
| for c in range(1, cols+1): | |
| fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=c) | |
| fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=c) | |
| fig.update_layout(height=600, margin=dict(l=10, r=10, t=40, b=10)) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Caption | |
| if cols == 3: | |
| st.caption(f"Axial (Z) slice {idx+1}/{n_slices} โ All aligned to input geometry; GT only for samples.") | |
| else: | |
| st.caption(f"Axial (Z) slice {idx+1}/{n_slices} โ Aligned to input geometry.") | |
| col_d1, col_d2, col_d3 = st.columns(3) | |
| with col_d3: | |
| _download_sitk_image(st.session_state.synth_ct, | |
| file_name="synth_ct.nii.gz", | |
| label="Download synthetic CT") | |
| with col_d1: | |
| if st.session_state.input_vol is not None: | |
| in_name = "input_mr.nii.gz" if task == "Task 1 (MR โ CT)" else "input_cbct.nii.gz" | |
| in_label = "Download input MRI" if task == "Task 1 (MR โ CT)" else "Download input CBCT" | |
| _download_sitk_image(st.session_state.input_vol, file_name=in_name, label=in_label) | |
| else: | |
| st.button("Download input", disabled=True) | |
| with col_d2: | |
| if st.session_state.input_mask is not None: | |
| _download_sitk_image(st.session_state.input_mask, | |
| file_name="input_mask.nii.gz", | |
| label="Download input Mask") | |
| else: | |
| st.button("Download input Mask", disabled=True) | |