Spaces:
Running
Running
| import os, sys, subprocess, importlib | |
| # Force pure-Python protobuf to sidestep the C++ ABI mismatch | |
| os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python") | |
| # Ensure protobuf < 3.20 for TF 2.9.1 | |
| def _ensure_pb319(): | |
| try: | |
| import google.protobuf as _pb | |
| v = _pb.__version__ | |
| parts = [int(p) for p in v.split(".")[:2] if p.isdigit()] | |
| if not parts or parts[0] > 3 or (parts[0] == 3 and parts[1] >= 20): | |
| raise ImportError | |
| except Exception: | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--no-cache-dir", | |
| "--upgrade", "--force-reinstall", "protobuf==3.19.6"], | |
| check=True | |
| ) | |
| importlib.invalidate_caches() | |
| _ensure_pb319() | |
| # --- End shim --- | |
| from utils.process_utils import * | |
| import skimage | |
| import streamlit as st | |
| import zipfile | |
| import tempfile | |
| import os | |
| import shutil | |
| import os | |
| from huggingface_hub import snapshot_download | |
| MODEL_REPO = os.getenv("MODEL_REPO") # e.g. "username/3d-cine-models" | |
| MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") # optional subfolder in the repo | |
| PERSIST_BASE = os.getenv("PERSIST_BASE", "/data") # HF Spaces persistent storage | |
| def get_models_base(): | |
| # cache models inside persistent storage to avoid re-downloads | |
| os.makedirs(PERSIST_BASE, exist_ok=True) | |
| if MODEL_REPO: | |
| repo_dir = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=os.path.join(PERSIST_BASE, "hf_models"), local_dir_use_symlinks=False) | |
| base = os.path.join(repo_dir, MODEL_SUBDIR) if MODEL_SUBDIR else repo_dir | |
| else: | |
| # fallback to a local folder in persistent storage | |
| base = os.path.join(PERSIST_BASE, MODELS_BASE) | |
| os.makedirs(base, exist_ok=True) | |
| return base | |
| MODELS_BASE = get_models_base() | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| if "initialized" not in st.session_state: | |
| for d in ("./out_dir", "./out_dicoms"): | |
| if os.path.exists(d): | |
| shutil.rmtree(d) | |
| os.makedirs(d, exist_ok=True) | |
| st.session_state.initialized = True | |
| # --- Session state defaults --- | |
| if "volume" not in st.session_state: | |
| st.session_state.volume = None | |
| if "data_processed" not in st.session_state: | |
| st.session_state.data_processed = False | |
| if "gif_ready" not in st.session_state: | |
| st.session_state.gif_ready = False | |
| if "dicom_create" not in st.session_state: | |
| st.session_state.dicom_create = False | |
| if "want_gif" not in st.session_state: | |
| st.session_state.want_gif = True | |
| if "num_phases" not in st.session_state: | |
| st.session_state.num_phases = None | |
| # --- Title --- | |
| st.title("3D Cine") | |
| # --- Upload --- | |
| st.header("Data Upload") | |
| uploaded_zip = st.file_uploader("Upload ZIP file of MRI folders", type="zip") | |
| _ = st.toggle("Generate a GIF preview after processing", key="want_gif") | |
| if uploaded_zip is not None: | |
| if st.button("Process Data"): | |
| with st.spinner("Processing ZIP..."): | |
| temp_dir = tempfile.mkdtemp() | |
| zip_path = os.path.join(temp_dir, "upload.zip") | |
| with open(zip_path, "wb") as f: | |
| f.write(uploaded_zip.read()) | |
| extract_zip(zip_path, temp_dir) | |
| st.session_state.volume, st.session_state.num_phases = load_cine_any(temp_dir, number_of_scans=None) | |
| num_phases = st.session_state.num_phases | |
| if st.session_state.volume is None or len(st.session_state.volume) == 0: | |
| st.error("Failed to load volume.") | |
| else: | |
| with st.spinner("Cropping..."): | |
| time_steps = num_phases | |
| sag_vols = np.array(st.session_state.volume) | |
| if sag_vols.shape[1] !=28: | |
| diff = sag_vols.shape[1] -28 | |
| sag_vols = sag_vols[:,diff:,:,:] | |
| if sag_vols.shape[2] ==512: | |
| sag_vols = skimage.transform.rescale(sag_vols,(1,1,0.5,0.5),order =3, anti_aliasing=True) | |
| sag_vols_cropped = [] | |
| for j in range(time_steps): | |
| sag_cropped = [] | |
| for i in range(sag_vols.shape[1]): | |
| sag_cropped.append(resize(sag_vols[j,i,:,:], 256, 128)) | |
| sag_cropped = np.dstack(sag_cropped) | |
| sag_cropped = np.swapaxes(sag_cropped, 0, 1) | |
| sag_cropped = np.swapaxes(sag_cropped, 0, 2) | |
| sag_vols_cropped.append(sag_cropped) | |
| sag_vols_cropped = norm(sag_vols_cropped) | |
| if st.session_state.want_gif: | |
| raw_us = skimage.transform.rescale(sag_vols_cropped, (1,4,1,1), order=2) | |
| with st.spinner("Contrast correction..."): | |
| debanded = apply_debanding_model(sag_vols_cropped, frames=time_steps) | |
| debanded = norm(debanded) | |
| debanded_us = debanded[:,0,...,0] | |
| if st.session_state.want_gif: | |
| debanded_us = skimage.transform.rescale(debanded_us, (1,4,1,1), order=2) | |
| with st.spinner("Respiratory correction..."): | |
| def_fields, resp_cor = apply_resp_model_28(debanded, frames=time_steps) | |
| resp_cor = norm(resp_cor) | |
| resp_cor_us = resp_cor[:,0,...,0] | |
| if st.session_state.want_gif: | |
| resp_cor_us = skimage.transform.rescale(resp_cor_us, (1,4,1,1), order=2) | |
| with st.spinner("Super-resolution..."): | |
| super_resed_E2E = apply_SR_model(resp_cor, frames=time_steps) | |
| super_resed_E2E = norm(super_resed_E2E) | |
| super_resed_E2E = super_resed_E2E[:,0,...,0] | |
| os.makedirs('./out_dir/', exist_ok=True) | |
| for i in range(time_steps): | |
| np.save(f'./out_dir/3D_cine_{i}.npy', super_resed_E2E[i]) | |
| if st.session_state.want_gif: | |
| np.save(f'./out_dir/resp_cor_{i}.npy', resp_cor_us[i]) | |
| np.save(f'./out_dir/debanded_{i}.npy', debanded_us[i]) | |
| np.save(f'./out_dir/raw_{i}.npy', raw_us[i]) | |
| st.success("✅ All models complete and data saved!") | |
| st.session_state.data_processed = True | |
| st.session_state.gif_ready = False # Reset gif status | |
| if not st.session_state.want_gif: | |
| st.session_state.dicom_create = True | |
| # --- GIF Generation Section --- | |
| if st.session_state.want_gif: | |
| num_phases = st.session_state.num_phases | |
| if st.session_state.data_processed: | |
| st.header("GIF Generator") | |
| axis_option = st.radio( | |
| "Select axis for slicing", | |
| options=["Axial", "Coronal"], | |
| index=0, | |
| key="axis_selector" | |
| ) | |
| axis_mapping = {"Axial": 1, "Coronal": 2} | |
| axis = axis_mapping[axis_option] | |
| slice_index = st.number_input("Select slice number", 0, 256, 60, 1) | |
| framerate = st.number_input("Framerate", 1, 100, num_phases, 1) | |
| if st.button("Generate and Show GIF"): | |
| gif_path = make_gif('./out_dir/', timepoints=num_phases, axis=axis, slice=slice_index, frame_rate=framerate) | |
| st.image(gif_path, caption="Generated GIF", use_container_width=True) | |
| st.session_state.gif_ready = True | |
| # --- Next Steps Section --- | |
| if st.session_state.gif_ready: | |
| next_action = st.radio( | |
| "What would you like to do next?", | |
| options=["Generate another GIF", "Proceed to DICOM export"], | |
| index=0 | |
| ) | |
| if next_action == "Generate another GIF": | |
| st.info("Adjust your settings above and click the button again.") | |
| elif next_action == "Proceed to DICOM export": | |
| st.session_state.dicom_create = True | |
| # --- DICOM Export Section --- | |
| if st.session_state.dicom_create: | |
| num_phases = st.session_state.num_phases | |
| st.header("DICOM Export") | |
| to_dicom(num_phases, patient_number=0) | |
| st.success("✅ Created DICOMs.") | |
| src_dir = "./out_dicoms" | |
| # build zip once per session so we don't recompress on every rerun | |
| if "dicom_zip" not in st.session_state: | |
| if os.path.isdir(src_dir) and any(os.scandir(src_dir)): | |
| st.session_state.dicom_zip = zip_dir_to_memory(src_dir) | |
| st.session_state.dicom_zip_name = f"dicoms_{time.strftime('%Y%m%d-%H%M%S')}.zip" | |
| else: | |
| st.warning("No DICOMs found to package.") | |
| if "dicom_zip" in st.session_state: | |
| st.download_button( | |
| label="⬇️ Download DICOMs (ZIP)", | |
| data=st.session_state.dicom_zip, | |
| file_name=st.session_state.dicom_zip_name, | |
| mime="application/zip", | |
| use_container_width=True | |
| ) |