3DCine / app.py
MarkWrobel's picture
Update app.py
21790fa verified
import os, sys, subprocess, importlib
# Force pure-Python protobuf to sidestep the C++ ABI mismatch
os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
# Ensure protobuf < 3.20 for TF 2.9.1
def _ensure_pb319():
try:
import google.protobuf as _pb
v = _pb.__version__
parts = [int(p) for p in v.split(".")[:2] if p.isdigit()]
if not parts or parts[0] > 3 or (parts[0] == 3 and parts[1] >= 20):
raise ImportError
except Exception:
subprocess.run(
[sys.executable, "-m", "pip", "install", "--no-cache-dir",
"--upgrade", "--force-reinstall", "protobuf==3.19.6"],
check=True
)
importlib.invalidate_caches()
_ensure_pb319()
# --- End shim ---
from utils.process_utils import *
import skimage
import streamlit as st
import zipfile
import tempfile
import os
import shutil
import os
from huggingface_hub import snapshot_download
MODEL_REPO = os.getenv("MODEL_REPO") # e.g. "username/3d-cine-models"
MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") # optional subfolder in the repo
PERSIST_BASE = os.getenv("PERSIST_BASE", "/data") # HF Spaces persistent storage
def get_models_base():
# cache models inside persistent storage to avoid re-downloads
os.makedirs(PERSIST_BASE, exist_ok=True)
if MODEL_REPO:
repo_dir = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=os.path.join(PERSIST_BASE, "hf_models"), local_dir_use_symlinks=False)
base = os.path.join(repo_dir, MODEL_SUBDIR) if MODEL_SUBDIR else repo_dir
else:
# fallback to a local folder in persistent storage
base = os.path.join(PERSIST_BASE, MODELS_BASE)
os.makedirs(base, exist_ok=True)
return base
MODELS_BASE = get_models_base()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if "initialized" not in st.session_state:
for d in ("./out_dir", "./out_dicoms"):
if os.path.exists(d):
shutil.rmtree(d)
os.makedirs(d, exist_ok=True)
st.session_state.initialized = True
# --- Session state defaults ---
if "volume" not in st.session_state:
st.session_state.volume = None
if "data_processed" not in st.session_state:
st.session_state.data_processed = False
if "gif_ready" not in st.session_state:
st.session_state.gif_ready = False
if "dicom_create" not in st.session_state:
st.session_state.dicom_create = False
if "want_gif" not in st.session_state:
st.session_state.want_gif = True
if "num_phases" not in st.session_state:
st.session_state.num_phases = None
# --- Title ---
st.title("3D Cine")
# --- Upload ---
st.header("Data Upload")
uploaded_zip = st.file_uploader("Upload ZIP file of MRI folders", type="zip")
_ = st.toggle("Generate a GIF preview after processing", key="want_gif")
if uploaded_zip is not None:
if st.button("Process Data"):
with st.spinner("Processing ZIP..."):
temp_dir = tempfile.mkdtemp()
zip_path = os.path.join(temp_dir, "upload.zip")
with open(zip_path, "wb") as f:
f.write(uploaded_zip.read())
extract_zip(zip_path, temp_dir)
st.session_state.volume, st.session_state.num_phases = load_cine_any(temp_dir, number_of_scans=None)
num_phases = st.session_state.num_phases
if st.session_state.volume is None or len(st.session_state.volume) == 0:
st.error("Failed to load volume.")
else:
with st.spinner("Cropping..."):
time_steps = num_phases
sag_vols = np.array(st.session_state.volume)
if sag_vols.shape[1] !=28:
diff = sag_vols.shape[1] -28
sag_vols = sag_vols[:,diff:,:,:]
if sag_vols.shape[2] ==512:
sag_vols = skimage.transform.rescale(sag_vols,(1,1,0.5,0.5),order =3, anti_aliasing=True)
sag_vols_cropped = []
for j in range(time_steps):
sag_cropped = []
for i in range(sag_vols.shape[1]):
sag_cropped.append(resize(sag_vols[j,i,:,:], 256, 128))
sag_cropped = np.dstack(sag_cropped)
sag_cropped = np.swapaxes(sag_cropped, 0, 1)
sag_cropped = np.swapaxes(sag_cropped, 0, 2)
sag_vols_cropped.append(sag_cropped)
sag_vols_cropped = norm(sag_vols_cropped)
if st.session_state.want_gif:
raw_us = skimage.transform.rescale(sag_vols_cropped, (1,4,1,1), order=2)
with st.spinner("Contrast correction..."):
debanded = apply_debanding_model(sag_vols_cropped, frames=time_steps)
debanded = norm(debanded)
debanded_us = debanded[:,0,...,0]
if st.session_state.want_gif:
debanded_us = skimage.transform.rescale(debanded_us, (1,4,1,1), order=2)
with st.spinner("Respiratory correction..."):
def_fields, resp_cor = apply_resp_model_28(debanded, frames=time_steps)
resp_cor = norm(resp_cor)
resp_cor_us = resp_cor[:,0,...,0]
if st.session_state.want_gif:
resp_cor_us = skimage.transform.rescale(resp_cor_us, (1,4,1,1), order=2)
with st.spinner("Super-resolution..."):
super_resed_E2E = apply_SR_model(resp_cor, frames=time_steps)
super_resed_E2E = norm(super_resed_E2E)
super_resed_E2E = super_resed_E2E[:,0,...,0]
os.makedirs('./out_dir/', exist_ok=True)
for i in range(time_steps):
np.save(f'./out_dir/3D_cine_{i}.npy', super_resed_E2E[i])
if st.session_state.want_gif:
np.save(f'./out_dir/resp_cor_{i}.npy', resp_cor_us[i])
np.save(f'./out_dir/debanded_{i}.npy', debanded_us[i])
np.save(f'./out_dir/raw_{i}.npy', raw_us[i])
st.success("✅ All models complete and data saved!")
st.session_state.data_processed = True
st.session_state.gif_ready = False # Reset gif status
if not st.session_state.want_gif:
st.session_state.dicom_create = True
# --- GIF Generation Section ---
if st.session_state.want_gif:
num_phases = st.session_state.num_phases
if st.session_state.data_processed:
st.header("GIF Generator")
axis_option = st.radio(
"Select axis for slicing",
options=["Axial", "Coronal"],
index=0,
key="axis_selector"
)
axis_mapping = {"Axial": 1, "Coronal": 2}
axis = axis_mapping[axis_option]
slice_index = st.number_input("Select slice number", 0, 256, 60, 1)
framerate = st.number_input("Framerate", 1, 100, num_phases, 1)
if st.button("Generate and Show GIF"):
gif_path = make_gif('./out_dir/', timepoints=num_phases, axis=axis, slice=slice_index, frame_rate=framerate)
st.image(gif_path, caption="Generated GIF", use_container_width=True)
st.session_state.gif_ready = True
# --- Next Steps Section ---
if st.session_state.gif_ready:
next_action = st.radio(
"What would you like to do next?",
options=["Generate another GIF", "Proceed to DICOM export"],
index=0
)
if next_action == "Generate another GIF":
st.info("Adjust your settings above and click the button again.")
elif next_action == "Proceed to DICOM export":
st.session_state.dicom_create = True
# --- DICOM Export Section ---
if st.session_state.dicom_create:
num_phases = st.session_state.num_phases
st.header("DICOM Export")
to_dicom(num_phases, patient_number=0)
st.success("✅ Created DICOMs.")
src_dir = "./out_dicoms"
# build zip once per session so we don't recompress on every rerun
if "dicom_zip" not in st.session_state:
if os.path.isdir(src_dir) and any(os.scandir(src_dir)):
st.session_state.dicom_zip = zip_dir_to_memory(src_dir)
st.session_state.dicom_zip_name = f"dicoms_{time.strftime('%Y%m%d-%H%M%S')}.zip"
else:
st.warning("No DICOMs found to package.")
if "dicom_zip" in st.session_state:
st.download_button(
label="⬇️ Download DICOMs (ZIP)",
data=st.session_state.dicom_zip,
file_name=st.session_state.dicom_zip_name,
mime="application/zip",
use_container_width=True
)