import os, sys, subprocess, importlib # Force pure-Python protobuf to sidestep the C++ ABI mismatch os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python") # Ensure protobuf < 3.20 for TF 2.9.1 def _ensure_pb319(): try: import google.protobuf as _pb v = _pb.__version__ parts = [int(p) for p in v.split(".")[:2] if p.isdigit()] if not parts or parts[0] > 3 or (parts[0] == 3 and parts[1] >= 20): raise ImportError except Exception: subprocess.run( [sys.executable, "-m", "pip", "install", "--no-cache-dir", "--upgrade", "--force-reinstall", "protobuf==3.19.6"], check=True ) importlib.invalidate_caches() _ensure_pb319() # --- End shim --- from utils.process_utils import * import skimage import streamlit as st import zipfile import tempfile import os import shutil import os from huggingface_hub import snapshot_download MODEL_REPO = os.getenv("MODEL_REPO") # e.g. "username/3d-cine-models" MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") # optional subfolder in the repo PERSIST_BASE = os.getenv("PERSIST_BASE", "/data") # HF Spaces persistent storage def get_models_base(): # cache models inside persistent storage to avoid re-downloads os.makedirs(PERSIST_BASE, exist_ok=True) if MODEL_REPO: repo_dir = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=os.path.join(PERSIST_BASE, "hf_models"), local_dir_use_symlinks=False) base = os.path.join(repo_dir, MODEL_SUBDIR) if MODEL_SUBDIR else repo_dir else: # fallback to a local folder in persistent storage base = os.path.join(PERSIST_BASE, MODELS_BASE) os.makedirs(base, exist_ok=True) return base MODELS_BASE = get_models_base() os.environ["CUDA_VISIBLE_DEVICES"] = "0" if "initialized" not in st.session_state: for d in ("./out_dir", "./out_dicoms"): if os.path.exists(d): shutil.rmtree(d) os.makedirs(d, exist_ok=True) st.session_state.initialized = True # --- Session state defaults --- if "volume" not in st.session_state: st.session_state.volume = None if "data_processed" not in st.session_state: st.session_state.data_processed = False if "gif_ready" not in st.session_state: st.session_state.gif_ready = False if "dicom_create" not in st.session_state: st.session_state.dicom_create = False if "want_gif" not in st.session_state: st.session_state.want_gif = True if "num_phases" not in st.session_state: st.session_state.num_phases = None # --- Title --- st.title("3D Cine") # --- Upload --- st.header("Data Upload") uploaded_zip = st.file_uploader("Upload ZIP file of MRI folders", type="zip") _ = st.toggle("Generate a GIF preview after processing", key="want_gif") if uploaded_zip is not None: if st.button("Process Data"): with st.spinner("Processing ZIP..."): temp_dir = tempfile.mkdtemp() zip_path = os.path.join(temp_dir, "upload.zip") with open(zip_path, "wb") as f: f.write(uploaded_zip.read()) extract_zip(zip_path, temp_dir) st.session_state.volume, st.session_state.num_phases = load_cine_any(temp_dir, number_of_scans=None) num_phases = st.session_state.num_phases if st.session_state.volume is None or len(st.session_state.volume) == 0: st.error("Failed to load volume.") else: with st.spinner("Cropping..."): time_steps = num_phases sag_vols = np.array(st.session_state.volume) if sag_vols.shape[1] !=28: diff = sag_vols.shape[1] -28 sag_vols = sag_vols[:,diff:,:,:] if sag_vols.shape[2] ==512: sag_vols = skimage.transform.rescale(sag_vols,(1,1,0.5,0.5),order =3, anti_aliasing=True) sag_vols_cropped = [] for j in range(time_steps): sag_cropped = [] for i in range(sag_vols.shape[1]): sag_cropped.append(resize(sag_vols[j,i,:,:], 256, 128)) sag_cropped = np.dstack(sag_cropped) sag_cropped = np.swapaxes(sag_cropped, 0, 1) sag_cropped = np.swapaxes(sag_cropped, 0, 2) sag_vols_cropped.append(sag_cropped) sag_vols_cropped = norm(sag_vols_cropped) if st.session_state.want_gif: raw_us = skimage.transform.rescale(sag_vols_cropped, (1,4,1,1), order=2) with st.spinner("Contrast correction..."): debanded = apply_debanding_model(sag_vols_cropped, frames=time_steps) debanded = norm(debanded) debanded_us = debanded[:,0,...,0] if st.session_state.want_gif: debanded_us = skimage.transform.rescale(debanded_us, (1,4,1,1), order=2) with st.spinner("Respiratory correction..."): def_fields, resp_cor = apply_resp_model_28(debanded, frames=time_steps) resp_cor = norm(resp_cor) resp_cor_us = resp_cor[:,0,...,0] if st.session_state.want_gif: resp_cor_us = skimage.transform.rescale(resp_cor_us, (1,4,1,1), order=2) with st.spinner("Super-resolution..."): super_resed_E2E = apply_SR_model(resp_cor, frames=time_steps) super_resed_E2E = norm(super_resed_E2E) super_resed_E2E = super_resed_E2E[:,0,...,0] os.makedirs('./out_dir/', exist_ok=True) for i in range(time_steps): np.save(f'./out_dir/3D_cine_{i}.npy', super_resed_E2E[i]) if st.session_state.want_gif: np.save(f'./out_dir/resp_cor_{i}.npy', resp_cor_us[i]) np.save(f'./out_dir/debanded_{i}.npy', debanded_us[i]) np.save(f'./out_dir/raw_{i}.npy', raw_us[i]) st.success("✅ All models complete and data saved!") st.session_state.data_processed = True st.session_state.gif_ready = False # Reset gif status if not st.session_state.want_gif: st.session_state.dicom_create = True # --- GIF Generation Section --- if st.session_state.want_gif: num_phases = st.session_state.num_phases if st.session_state.data_processed: st.header("GIF Generator") axis_option = st.radio( "Select axis for slicing", options=["Axial", "Coronal"], index=0, key="axis_selector" ) axis_mapping = {"Axial": 1, "Coronal": 2} axis = axis_mapping[axis_option] slice_index = st.number_input("Select slice number", 0, 256, 60, 1) framerate = st.number_input("Framerate", 1, 100, num_phases, 1) if st.button("Generate and Show GIF"): gif_path = make_gif('./out_dir/', timepoints=num_phases, axis=axis, slice=slice_index, frame_rate=framerate) st.image(gif_path, caption="Generated GIF", use_container_width=True) st.session_state.gif_ready = True # --- Next Steps Section --- if st.session_state.gif_ready: next_action = st.radio( "What would you like to do next?", options=["Generate another GIF", "Proceed to DICOM export"], index=0 ) if next_action == "Generate another GIF": st.info("Adjust your settings above and click the button again.") elif next_action == "Proceed to DICOM export": st.session_state.dicom_create = True # --- DICOM Export Section --- if st.session_state.dicom_create: num_phases = st.session_state.num_phases st.header("DICOM Export") to_dicom(num_phases, patient_number=0) st.success("✅ Created DICOMs.") src_dir = "./out_dicoms" # build zip once per session so we don't recompress on every rerun if "dicom_zip" not in st.session_state: if os.path.isdir(src_dir) and any(os.scandir(src_dir)): st.session_state.dicom_zip = zip_dir_to_memory(src_dir) st.session_state.dicom_zip_name = f"dicoms_{time.strftime('%Y%m%d-%H%M%S')}.zip" else: st.warning("No DICOMs found to package.") if "dicom_zip" in st.session_state: st.download_button( label="⬇️ Download DICOMs (ZIP)", data=st.session_state.dicom_zip, file_name=st.session_state.dicom_zip_name, mime="application/zip", use_container_width=True )