Spaces:
Sleeping
Sleeping
Upload 16 files
Browse files- .gitattributes +1 -0
- .streamlit/config.toml +3 -0
- Dockerfile +29 -17
- README.md +16 -15
- app.py +198 -0
- devcontainer.json +59 -0
- docker-compose.yml +18 -0
- requirements.txt +13 -3
- utils/__pycache__/custom_unet_code.cpython-38.pyc +0 -0
- utils/__pycache__/layer_util.cpython-38.pyc +0 -0
- utils/__pycache__/process_utils.cpython-38.pyc +0 -0
- utils/__pycache__/unet3plusnew.cpython-38.pyc +0 -0
- utils/custom_unet_code.py +111 -0
- utils/dicom_headerfile.dcm +3 -0
- utils/layer_util.py +89 -0
- utils/process_utils.py +665 -0
- utils/unet3plusnew.py +186 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
utils/dicom_headerfile.dcm filter=lfs diff=lfs merge=lfs -text
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[server]
|
| 2 |
+
maxUploadSize = 2048 # MB
|
| 3 |
+
maxMessageSize = 2048 # MB
|
Dockerfile
CHANGED
|
@@ -1,21 +1,33 @@
|
|
| 1 |
-
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
RUN apt-get update && apt-get install -y \
|
| 6 |
-
build-essential \
|
| 7 |
-
curl \
|
| 8 |
-
software-properties-common \
|
| 9 |
-
git \
|
| 10 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
RUN
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
# See here for image contents: https://github.com/microsoft/vscode-dev-containers/tree/v0.245.0/containers/python-3/.devcontainer/base.Dockerfile
|
| 2 |
|
| 3 |
+
# [Choice] Python version (use -bullseye variants on local arm64/Apple Silicon): 3, 3.10, 3.9, 3.8, 3.7, 3.6, 3-bullseye, 3.10-bullseye, 3.9-bullseye, 3.8-bullseye, 3.7-bullseye, 3.6-bullseye, 3-buster, 3.10-buster, 3.9-buster, 3.8-buster, 3.7-buster, 3.6-buster
|
| 4 |
+
FROM tensorflow/tensorflow:2.9.1-gpu
|
| 5 |
+
RUN apt-get install libopenexr-dev -y
|
| 6 |
+
RUN pip install tensorflow-mri
|
| 7 |
+
RUN pip install tqdm
|
| 8 |
+
RUN pip install h5py
|
| 9 |
+
RUN pip install tensorflow-addons
|
| 10 |
+
RUN pip install scikit-learn
|
| 11 |
+
RUN pip install scikit-image
|
| 12 |
+
RUN pip install neptune-client
|
| 13 |
+
RUN pip install matplotlib
|
| 14 |
+
RUN pip install scipy
|
| 15 |
+
RUN pip install pydicom
|
| 16 |
+
RUN pip install streamlit
|
| 17 |
+
RUN pip install protobuf==3.20.*
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
# Create non-root user.
|
| 21 |
+
ARG USERNAME=vscode
|
| 22 |
+
ARG USER_UID=1003
|
| 23 |
+
ARG USER_GID=$USER_UID
|
| 24 |
|
| 25 |
+
RUN groupadd --gid $USER_GID $USERNAME && \
|
| 26 |
+
useradd --uid $USER_UID --gid $USER_GID -m $USERNAME && \
|
| 27 |
+
# Add user to sudoers.
|
| 28 |
+
apt-get update && \
|
| 29 |
+
apt-get install -y sudo && \
|
| 30 |
+
echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME && \
|
| 31 |
+
chmod 0440 /etc/sudoers.d/$USERNAME && \
|
| 32 |
+
# Change default shell to bash.
|
| 33 |
+
usermod --shell /bin/bash $USERNAME
|
README.md
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
- streamlit
|
| 10 |
-
pinned: false
|
| 11 |
-
short_description: Streamlit template space
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
forums](https://discuss.streamlit.io).
|
|
|
|
| 1 |
+
|
| 2 |
---
|
| 3 |
+
title: 3D Cine MRI (Streamlit)
|
| 4 |
+
emoji: 🧠
|
| 5 |
+
colorFrom: indigo
|
| 6 |
+
colorTo: pink
|
| 7 |
+
sdk: streamlit
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
+
## Setup (Hugging Face Spaces — free CPU)
|
| 13 |
+
1. Create a Space with **SDK: Streamlit**, hardware **CPU Basic**.
|
| 14 |
+
2. In **Settings → Variables**, add:
|
| 15 |
+
- `HF_HOME=/data/.huggingface` (speeds up caching)
|
| 16 |
+
- `MODEL_REPO=your-username/your-model-repo` (where your models live on the Hub)
|
| 17 |
+
- *(optional)* `MODEL_SUBDIR=subfolder-inside-repo`
|
| 18 |
+
- *(optional)* `PERSIST_BASE=/data`
|
| 19 |
|
| 20 |
+
Your models will be downloaded at runtime into persistent storage and reused after restarts.
|
|
|
app.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.process_utils import *
|
| 2 |
+
import skimage
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import zipfile
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from huggingface_hub import snapshot_download
|
| 11 |
+
|
| 12 |
+
MODEL_REPO = os.getenv("MODEL_REPO") # e.g. "username/3d-cine-models"
|
| 13 |
+
MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") # optional subfolder in the repo
|
| 14 |
+
PERSIST_BASE = os.getenv("PERSIST_BASE", "/data") # HF Spaces persistent storage
|
| 15 |
+
|
| 16 |
+
def get_models_base():
|
| 17 |
+
# cache models inside persistent storage to avoid re-downloads
|
| 18 |
+
os.makedirs(PERSIST_BASE, exist_ok=True)
|
| 19 |
+
if MODEL_REPO:
|
| 20 |
+
repo_dir = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=os.path.join(PERSIST_BASE, "hf_models"), local_dir_use_symlinks=False)
|
| 21 |
+
base = os.path.join(repo_dir, MODEL_SUBDIR) if MODEL_SUBDIR else repo_dir
|
| 22 |
+
else:
|
| 23 |
+
# fallback to a local folder in persistent storage
|
| 24 |
+
base = os.path.join(PERSIST_BASE, MODELS_BASE)
|
| 25 |
+
os.makedirs(base, exist_ok=True)
|
| 26 |
+
return base
|
| 27 |
+
|
| 28 |
+
MODELS_BASE = get_models_base()
|
| 29 |
+
|
| 30 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 31 |
+
|
| 32 |
+
if "initialized" not in st.session_state:
|
| 33 |
+
for d in ("./out_dir", "./out_dicoms"):
|
| 34 |
+
if os.path.exists(d):
|
| 35 |
+
shutil.rmtree(d)
|
| 36 |
+
os.makedirs(d, exist_ok=True)
|
| 37 |
+
st.session_state.initialized = True
|
| 38 |
+
|
| 39 |
+
# --- Session state defaults ---
|
| 40 |
+
if "volume" not in st.session_state:
|
| 41 |
+
st.session_state.volume = None
|
| 42 |
+
if "data_processed" not in st.session_state:
|
| 43 |
+
st.session_state.data_processed = False
|
| 44 |
+
if "gif_ready" not in st.session_state:
|
| 45 |
+
st.session_state.gif_ready = False
|
| 46 |
+
if "dicom_create" not in st.session_state:
|
| 47 |
+
st.session_state.dicom_create = False
|
| 48 |
+
if "want_gif" not in st.session_state:
|
| 49 |
+
st.session_state.want_gif = True
|
| 50 |
+
if "num_phases" not in st.session_state:
|
| 51 |
+
st.session_state.num_phases = None
|
| 52 |
+
|
| 53 |
+
# --- Title ---
|
| 54 |
+
st.title("3D Cine")
|
| 55 |
+
|
| 56 |
+
# --- Upload ---
|
| 57 |
+
st.header("Data Upload")
|
| 58 |
+
uploaded_zip = st.file_uploader("Upload ZIP file of MRI folders", type="zip")
|
| 59 |
+
|
| 60 |
+
_ = st.toggle("Generate a GIF preview after processing", key="want_gif")
|
| 61 |
+
|
| 62 |
+
if uploaded_zip is not None:
|
| 63 |
+
|
| 64 |
+
if st.button("Process Data"):
|
| 65 |
+
with st.spinner("Processing ZIP..."):
|
| 66 |
+
temp_dir = tempfile.mkdtemp()
|
| 67 |
+
zip_path = os.path.join(temp_dir, "upload.zip")
|
| 68 |
+
with open(zip_path, "wb") as f:
|
| 69 |
+
f.write(uploaded_zip.read())
|
| 70 |
+
|
| 71 |
+
extract_zip(zip_path, temp_dir)
|
| 72 |
+
st.session_state.volume, st.session_state.num_phases = load_cine_any(temp_dir, number_of_scans=None)
|
| 73 |
+
num_phases = st.session_state.num_phases
|
| 74 |
+
if st.session_state.volume is None or len(st.session_state.volume) == 0:
|
| 75 |
+
st.error("Failed to load volume.")
|
| 76 |
+
else:
|
| 77 |
+
with st.spinner("Cropping..."):
|
| 78 |
+
time_steps = num_phases
|
| 79 |
+
sag_vols = np.array(st.session_state.volume)
|
| 80 |
+
|
| 81 |
+
if sag_vols.shape[1] !=28:
|
| 82 |
+
diff = sag_vols.shape[1] -28
|
| 83 |
+
sag_vols = sag_vols[:,diff:,:,:]
|
| 84 |
+
|
| 85 |
+
if sag_vols.shape[2] ==512:
|
| 86 |
+
sag_vols = skimage.transform.rescale(sag_vols,(1,1,0.5,0.5),order =3, anti_aliasing=True)
|
| 87 |
+
|
| 88 |
+
sag_vols_cropped = []
|
| 89 |
+
for j in range(time_steps):
|
| 90 |
+
sag_cropped = []
|
| 91 |
+
for i in range(sag_vols.shape[1]):
|
| 92 |
+
sag_cropped.append(resize(sag_vols[j,i,:,:], 256, 128))
|
| 93 |
+
sag_cropped = np.dstack(sag_cropped)
|
| 94 |
+
sag_cropped = np.swapaxes(sag_cropped, 0, 1)
|
| 95 |
+
sag_cropped = np.swapaxes(sag_cropped, 0, 2)
|
| 96 |
+
sag_vols_cropped.append(sag_cropped)
|
| 97 |
+
|
| 98 |
+
sag_vols_cropped = norm(sag_vols_cropped)
|
| 99 |
+
|
| 100 |
+
if st.session_state.want_gif:
|
| 101 |
+
raw_us = skimage.transform.rescale(sag_vols_cropped, (1,4,1,1), order=2)
|
| 102 |
+
|
| 103 |
+
with st.spinner("Contrast correction..."):
|
| 104 |
+
debanded = apply_debanding_model(sag_vols_cropped, frames=time_steps)
|
| 105 |
+
debanded = norm(debanded)
|
| 106 |
+
debanded_us = debanded[:,0,...,0]
|
| 107 |
+
if st.session_state.want_gif:
|
| 108 |
+
debanded_us = skimage.transform.rescale(debanded_us, (1,4,1,1), order=2)
|
| 109 |
+
|
| 110 |
+
with st.spinner("Respiratory correction..."):
|
| 111 |
+
def_fields, resp_cor = apply_resp_model_28(debanded, frames=time_steps)
|
| 112 |
+
resp_cor = norm(resp_cor)
|
| 113 |
+
resp_cor_us = resp_cor[:,0,...,0]
|
| 114 |
+
if st.session_state.want_gif:
|
| 115 |
+
resp_cor_us = skimage.transform.rescale(resp_cor_us, (1,4,1,1), order=2)
|
| 116 |
+
|
| 117 |
+
with st.spinner("Super-resolution..."):
|
| 118 |
+
super_resed_E2E = apply_SR_model(resp_cor, frames=time_steps)
|
| 119 |
+
super_resed_E2E = norm(super_resed_E2E)
|
| 120 |
+
super_resed_E2E = super_resed_E2E[:,0,...,0]
|
| 121 |
+
|
| 122 |
+
os.makedirs('./out_dir/', exist_ok=True)
|
| 123 |
+
for i in range(time_steps):
|
| 124 |
+
np.save(f'./out_dir/3D_cine_{i}.npy', super_resed_E2E[i])
|
| 125 |
+
if st.session_state.want_gif:
|
| 126 |
+
np.save(f'./out_dir/resp_cor_{i}.npy', resp_cor_us[i])
|
| 127 |
+
np.save(f'./out_dir/debanded_{i}.npy', debanded_us[i])
|
| 128 |
+
np.save(f'./out_dir/raw_{i}.npy', raw_us[i])
|
| 129 |
+
|
| 130 |
+
st.success("✅ All models complete and data saved!")
|
| 131 |
+
st.session_state.data_processed = True
|
| 132 |
+
st.session_state.gif_ready = False # Reset gif status
|
| 133 |
+
|
| 134 |
+
if not st.session_state.want_gif:
|
| 135 |
+
st.session_state.dicom_create = True
|
| 136 |
+
|
| 137 |
+
# --- GIF Generation Section ---
|
| 138 |
+
if st.session_state.want_gif:
|
| 139 |
+
num_phases = st.session_state.num_phases
|
| 140 |
+
if st.session_state.data_processed:
|
| 141 |
+
st.header("GIF Generator")
|
| 142 |
+
|
| 143 |
+
axis_option = st.radio(
|
| 144 |
+
"Select axis for slicing",
|
| 145 |
+
options=["Axial", "Coronal"],
|
| 146 |
+
index=0,
|
| 147 |
+
key="axis_selector"
|
| 148 |
+
)
|
| 149 |
+
axis_mapping = {"Axial": 1, "Coronal": 2}
|
| 150 |
+
axis = axis_mapping[axis_option]
|
| 151 |
+
|
| 152 |
+
slice_index = st.number_input("Select slice number", 0, 256, 60, 1)
|
| 153 |
+
framerate = st.number_input("Framerate", 1, 100, num_phases, 1)
|
| 154 |
+
|
| 155 |
+
if st.button("Generate and Show GIF"):
|
| 156 |
+
gif_path = make_gif('./out_dir/', timepoints=num_phases, axis=axis, slice=slice_index, frame_rate=framerate)
|
| 157 |
+
st.image(gif_path, caption="Generated GIF", use_container_width=True)
|
| 158 |
+
st.session_state.gif_ready = True
|
| 159 |
+
|
| 160 |
+
# --- Next Steps Section ---
|
| 161 |
+
if st.session_state.gif_ready:
|
| 162 |
+
next_action = st.radio(
|
| 163 |
+
"What would you like to do next?",
|
| 164 |
+
options=["Generate another GIF", "Proceed to DICOM export"],
|
| 165 |
+
index=0
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if next_action == "Generate another GIF":
|
| 169 |
+
st.info("Adjust your settings above and click the button again.")
|
| 170 |
+
|
| 171 |
+
elif next_action == "Proceed to DICOM export":
|
| 172 |
+
st.session_state.dicom_create = True
|
| 173 |
+
|
| 174 |
+
# --- DICOM Export Section ---
|
| 175 |
+
if st.session_state.dicom_create:
|
| 176 |
+
num_phases = st.session_state.num_phases
|
| 177 |
+
st.header("DICOM Export")
|
| 178 |
+
to_dicom(num_phases, patient_number=0)
|
| 179 |
+
st.success("✅ Created DICOMs.")
|
| 180 |
+
|
| 181 |
+
src_dir = "./out_dicoms"
|
| 182 |
+
|
| 183 |
+
# build zip once per session so we don't recompress on every rerun
|
| 184 |
+
if "dicom_zip" not in st.session_state:
|
| 185 |
+
if os.path.isdir(src_dir) and any(os.scandir(src_dir)):
|
| 186 |
+
st.session_state.dicom_zip = zip_dir_to_memory(src_dir)
|
| 187 |
+
st.session_state.dicom_zip_name = f"dicoms_{time.strftime('%Y%m%d-%H%M%S')}.zip"
|
| 188 |
+
else:
|
| 189 |
+
st.warning("No DICOMs found to package.")
|
| 190 |
+
|
| 191 |
+
if "dicom_zip" in st.session_state:
|
| 192 |
+
st.download_button(
|
| 193 |
+
label="⬇️ Download DICOMs (ZIP)",
|
| 194 |
+
data=st.session_state.dicom_zip,
|
| 195 |
+
file_name=st.session_state.dicom_zip_name,
|
| 196 |
+
mime="application/zip",
|
| 197 |
+
use_container_width=True
|
| 198 |
+
)
|
devcontainer.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// For format details, see https://aka.ms/devcontainer.json. For config options, see the README at:
|
| 2 |
+
// https://github.com/microsoft/vscode-dev-containers/tree/v0.245.0/containers/python-3
|
| 3 |
+
{
|
| 4 |
+
"name": "Python 3",
|
| 5 |
+
"build": {
|
| 6 |
+
"dockerfile": "Dockerfile",
|
| 7 |
+
"context": ".."
|
| 8 |
+
},
|
| 9 |
+
|
| 10 |
+
// Enable GPUs
|
| 11 |
+
"runArgs": [
|
| 12 |
+
"--gpus=all"
|
| 13 |
+
],
|
| 14 |
+
// Enable plotting.
|
| 15 |
+
"mounts": [
|
| 16 |
+
"type=bind,source=/tmp/.X11-unix,target=/tmp/.X11-unix"
|
| 17 |
+
],
|
| 18 |
+
// Enable plotting.
|
| 19 |
+
"containerEnv": {
|
| 20 |
+
"DISPLAY": "${localEnv:DISPLAY}"
|
| 21 |
+
},
|
| 22 |
+
|
| 23 |
+
// Configure tool-specific properties.
|
| 24 |
+
"customizations": {
|
| 25 |
+
// Configure properties specific to VS Code.
|
| 26 |
+
"vscode": {
|
| 27 |
+
// Set *default* container specific settings.json values on container create.
|
| 28 |
+
"settings": {
|
| 29 |
+
"python.defaultInterpreterPath": "/usr/local/bin/python",
|
| 30 |
+
"python.linting.enabled": true,
|
| 31 |
+
"python.linting.pylintEnabled": true,
|
| 32 |
+
"python.formatting.autopep8Path": "/usr/local/py-utils/bin/autopep8",
|
| 33 |
+
"python.formatting.blackPath": "/usr/local/py-utils/bin/black",
|
| 34 |
+
"python.formatting.yapfPath": "/usr/local/py-utils/bin/yapf",
|
| 35 |
+
"python.linting.banditPath": "/usr/local/py-utils/bin/bandit",
|
| 36 |
+
"python.linting.flake8Path": "/usr/local/py-utils/bin/flake8",
|
| 37 |
+
"python.linting.mypyPath": "/usr/local/py-utils/bin/mypy",
|
| 38 |
+
"python.linting.pycodestylePath": "/usr/local/py-utils/bin/pycodestyle",
|
| 39 |
+
"python.linting.pydocstylePath": "/usr/local/py-utils/bin/pydocstyle",
|
| 40 |
+
"python.linting.pylintPath": "/usr/local/py-utils/bin/pylint"
|
| 41 |
+
},
|
| 42 |
+
|
| 43 |
+
// Add the IDs of extensions you want installed when the container is created.
|
| 44 |
+
"extensions": [
|
| 45 |
+
"ms-python.python",
|
| 46 |
+
"ms-python.vscode-pylance"
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
},
|
| 50 |
+
|
| 51 |
+
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
| 52 |
+
// "forwardPorts": [],
|
| 53 |
+
|
| 54 |
+
// Use 'postCreateCommand' to run commands after the container is created.
|
| 55 |
+
// "postCreateCommand": "pip3 install --user -r requirements.txt",
|
| 56 |
+
|
| 57 |
+
// Comment out to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.
|
| 58 |
+
"remoteUser": "vscode"
|
| 59 |
+
}
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: "3.9"
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
run_code:
|
| 5 |
+
build:
|
| 6 |
+
context: .
|
| 7 |
+
dockerfile: Dockerfile
|
| 8 |
+
ports:
|
| 9 |
+
- "8501:8501"
|
| 10 |
+
volumes:
|
| 11 |
+
- .:/app
|
| 12 |
+
working_dir: /app
|
| 13 |
+
command: streamlit run app.py --server.address=0.0.0.0
|
| 14 |
+
deploy:
|
| 15 |
+
resources:
|
| 16 |
+
reservations:
|
| 17 |
+
devices:
|
| 18 |
+
- capabilities: [gpu]
|
requirements.txt
CHANGED
|
@@ -1,3 +1,13 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
tensorflow==2.9.1
|
| 3 |
+
tensorflow-mri
|
| 4 |
+
tqdm
|
| 5 |
+
h5py
|
| 6 |
+
tensorflow-addons
|
| 7 |
+
scikit-learn
|
| 8 |
+
scikit-image
|
| 9 |
+
matplotlib
|
| 10 |
+
scipy
|
| 11 |
+
pydicom
|
| 12 |
+
huggingface_hub>=0.21
|
| 13 |
+
numpy<2
|
utils/__pycache__/custom_unet_code.cpython-38.pyc
ADDED
|
Binary file (3.15 kB). View file
|
|
|
utils/__pycache__/layer_util.cpython-38.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
utils/__pycache__/process_utils.cpython-38.pyc
ADDED
|
Binary file (18.9 kB). View file
|
|
|
utils/__pycache__/unet3plusnew.cpython-38.pyc
ADDED
|
Binary file (5.59 kB). View file
|
|
|
utils/custom_unet_code.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import layers, models
|
| 3 |
+
|
| 4 |
+
def time_distributed_conv_block(input_tensor, num_filters):
|
| 5 |
+
x = layers.Conv3D(num_filters, (3, 3, 3), padding="same")(input_tensor)
|
| 6 |
+
x = layers.ReLU()(x)
|
| 7 |
+
|
| 8 |
+
x = layers.Conv3D(num_filters, (3, 3, 3), padding="same")(x)
|
| 9 |
+
x = layers.ReLU()(x)
|
| 10 |
+
return x
|
| 11 |
+
|
| 12 |
+
def time_distributed_encoder_block_resp(input_tensor, num_filters, temporal_maxpool=True):
|
| 13 |
+
x = time_distributed_conv_block(input_tensor, num_filters)
|
| 14 |
+
|
| 15 |
+
p = layers.MaxPooling3D((1, 4, 4))(x)
|
| 16 |
+
|
| 17 |
+
if temporal_maxpool:
|
| 18 |
+
|
| 19 |
+
p = tf.transpose(p, (0,2,3,1,4))
|
| 20 |
+
p2 = layers.TimeDistributed(layers.TimeDistributed(layers.MaxPooling1D((2))))(p)
|
| 21 |
+
p2 = tf.transpose(p2, (0,3,1,2,4))
|
| 22 |
+
return x, p2
|
| 23 |
+
else:
|
| 24 |
+
return x, p
|
| 25 |
+
|
| 26 |
+
def time_distributed_decoder_block_resp(input_tensor, skip_tensor, num_filters):
|
| 27 |
+
|
| 28 |
+
x = layers.TimeDistributed(layers.UpSampling2D(( 4, 4), interpolation='bilinear'))(input_tensor)
|
| 29 |
+
x = layers.Conv3D(num_filters, (3, 3, 3), padding="same")(x)
|
| 30 |
+
x = layers.Conv3D(num_filters, (3, 3, 3), padding="same")(x)
|
| 31 |
+
|
| 32 |
+
x = layers.Concatenate()([x, skip_tensor])
|
| 33 |
+
x = time_distributed_conv_block(x, num_filters)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
def build_3d_unet_resp(input_shape, num_classes):
|
| 37 |
+
inputs = layers.Input(shape=input_shape)
|
| 38 |
+
|
| 39 |
+
# Encoding path
|
| 40 |
+
s1, p1 = time_distributed_encoder_block_resp(inputs, 32,temporal_maxpool=False)
|
| 41 |
+
s2, p2 = time_distributed_encoder_block_resp(p1, 64,temporal_maxpool=False)
|
| 42 |
+
|
| 43 |
+
# Bridge
|
| 44 |
+
b1 = time_distributed_conv_block(p2, 128)
|
| 45 |
+
|
| 46 |
+
d1 = time_distributed_decoder_block_resp(b1, s2, 64)
|
| 47 |
+
d2 = time_distributed_decoder_block_resp(d1, s1, 32)
|
| 48 |
+
|
| 49 |
+
outputs = layers.Conv3D(num_classes, (1, 1, 1))(d2)
|
| 50 |
+
|
| 51 |
+
model = models.Model(inputs, outputs, name="3D-U-Net-resp")
|
| 52 |
+
return model
|
| 53 |
+
|
| 54 |
+
def time_distributed_encoder_block(input_tensor, num_filters, temporal_maxpool=True):
|
| 55 |
+
x = time_distributed_conv_block(input_tensor, num_filters)
|
| 56 |
+
|
| 57 |
+
p = layers.TimeDistributed(layers.MaxPooling2D((2, 2)))(x)
|
| 58 |
+
if temporal_maxpool:
|
| 59 |
+
|
| 60 |
+
p = tf.transpose(p, (0,2,3,1,4))
|
| 61 |
+
p2 = layers.TimeDistributed(layers.TimeDistributed(layers.MaxPooling1D((2))))(p)
|
| 62 |
+
p2 = tf.transpose(p2, (0,3,1,2,4))
|
| 63 |
+
return x, p2
|
| 64 |
+
else:
|
| 65 |
+
return x, p
|
| 66 |
+
|
| 67 |
+
def time_distributed_decoder_block(input_tensor, skip_tensor, num_filters, temporal_upsamp=True):
|
| 68 |
+
x = layers.TimeDistributed(layers.UpSampling2D(( 2, 2)))(input_tensor)
|
| 69 |
+
x = layers.TimeDistributed(layers.Conv2D(num_filters, (3, 3), padding="same"))(x)
|
| 70 |
+
if temporal_upsamp:
|
| 71 |
+
x = tf.transpose(x, (0,2,3,1,4))
|
| 72 |
+
x = layers.TimeDistributed(layers.TimeDistributed(layers.UpSampling1D((2))))(x)
|
| 73 |
+
x = layers.TimeDistributed(layers.TimeDistributed(layers.Conv1D(num_filters, (2),padding="same")))(x)
|
| 74 |
+
x = tf.transpose(x, (0,3,1,2,4))
|
| 75 |
+
|
| 76 |
+
if x.shape[4] == 64:
|
| 77 |
+
skip_tensor = tf.transpose(skip_tensor, (0,2,3,1,4))
|
| 78 |
+
skip_tensor = layers.TimeDistributed(layers.TimeDistributed(layers.Conv1DTranspose(num_filters,kernel_size=2,strides=2)))(skip_tensor)
|
| 79 |
+
skip_tensor = tf.transpose(skip_tensor, (0,3,1,2,4))
|
| 80 |
+
|
| 81 |
+
if x.shape[4] == 32:
|
| 82 |
+
skip_tensor = tf.transpose(skip_tensor, (0,2,3,1,4))
|
| 83 |
+
skip_tensor = layers.TimeDistributed(layers.TimeDistributed(layers.Conv1DTranspose(num_filters,kernel_size=2,strides=2)))(skip_tensor)
|
| 84 |
+
skip_tensor = layers.TimeDistributed(layers.TimeDistributed(layers.Conv1DTranspose(num_filters,kernel_size=2,strides=2)))(skip_tensor)
|
| 85 |
+
skip_tensor = tf.transpose(skip_tensor, (0,3,1,2,4))
|
| 86 |
+
|
| 87 |
+
x = layers.Concatenate()([x, skip_tensor])
|
| 88 |
+
x = time_distributed_conv_block(x, num_filters)
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
def build_3d_unet(input_shape, num_classes):
|
| 92 |
+
inputs = layers.Input(shape=input_shape)
|
| 93 |
+
|
| 94 |
+
# Encoding path
|
| 95 |
+
s1, p1 = time_distributed_encoder_block(inputs, 32,temporal_maxpool=False)
|
| 96 |
+
s2, p2 = time_distributed_encoder_block(p1, 64,temporal_maxpool=False)
|
| 97 |
+
s3, p3 = time_distributed_encoder_block(p2, 128,temporal_maxpool=True)
|
| 98 |
+
|
| 99 |
+
# Bridge
|
| 100 |
+
b1 = time_distributed_conv_block(p3, 256)
|
| 101 |
+
|
| 102 |
+
# Decoding path
|
| 103 |
+
d1 = time_distributed_decoder_block(b1, s3, 128,temporal_upsamp=True)
|
| 104 |
+
d2 = time_distributed_decoder_block(d1, s2, 64,temporal_upsamp=True)
|
| 105 |
+
d3 = time_distributed_decoder_block(d2, s1, 32,temporal_upsamp=True)
|
| 106 |
+
|
| 107 |
+
# Output layer
|
| 108 |
+
outputs = layers.Conv3D(num_classes, (1, 1, 1))(d3)
|
| 109 |
+
|
| 110 |
+
model = models.Model(inputs, outputs, name="3D-U-Net")
|
| 111 |
+
return model
|
utils/dicom_headerfile.dcm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85bd29cc81a85d8957e34e710c4aff8658768758ead619f98a99372cf1d0319b
|
| 3 |
+
size 187536
|
utils/layer_util.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 University College London. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Layer utilities."""
|
| 16 |
+
|
| 17 |
+
import tensorflow as tf
|
| 18 |
+
|
| 19 |
+
# from tensorflow_mri.python.layers import convolutional
|
| 20 |
+
# from tensorflow_mri.python.layers import signal_layers
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_nd_layer(name, rank):
|
| 24 |
+
"""Get an N-D layer object.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
name: A `str`. The name of the requested layer.
|
| 28 |
+
rank: An `int`. The rank of the requested layer.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
A `tf.keras.layers.Layer` object.
|
| 32 |
+
|
| 33 |
+
Raises:
|
| 34 |
+
ValueError: If the requested layer is unknown to TFMRI.
|
| 35 |
+
"""
|
| 36 |
+
try:
|
| 37 |
+
return _ND_LAYERS[(name, rank)]
|
| 38 |
+
except KeyError as err:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
f"Could not find a layer with name '{name}' and rank {rank}.") from err
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
_ND_LAYERS = {
|
| 44 |
+
('AveragePooling', 1): tf.keras.layers.AveragePooling1D,
|
| 45 |
+
('AveragePooling', 2): tf.keras.layers.AveragePooling2D,
|
| 46 |
+
('AveragePooling', 3): tf.keras.layers.AveragePooling3D,
|
| 47 |
+
('Conv', 1): tf.keras.layers.Conv1D,
|
| 48 |
+
('Conv', 2): tf.keras.layers.Conv2D,
|
| 49 |
+
('Conv', 3): tf.keras.layers.Conv3D,
|
| 50 |
+
('ConvLSTM', 1): tf.keras.layers.ConvLSTM1D,
|
| 51 |
+
('ConvLSTM', 2): tf.keras.layers.ConvLSTM2D,
|
| 52 |
+
('ConvLSTM', 3): tf.keras.layers.ConvLSTM3D,
|
| 53 |
+
('ConvTranspose', 1): tf.keras.layers.Conv1DTranspose,
|
| 54 |
+
('ConvTranspose', 2): tf.keras.layers.Conv2DTranspose,
|
| 55 |
+
('ConvTranspose', 3): tf.keras.layers.Conv3DTranspose,
|
| 56 |
+
('Cropping', 1): tf.keras.layers.Cropping1D,
|
| 57 |
+
('Cropping', 2): tf.keras.layers.Cropping2D,
|
| 58 |
+
('Cropping', 3): tf.keras.layers.Cropping3D,
|
| 59 |
+
('DepthwiseConv', 1): tf.keras.layers.DepthwiseConv1D,
|
| 60 |
+
('DepthwiseConv', 2): tf.keras.layers.DepthwiseConv2D,
|
| 61 |
+
# ('DWT', 1): signal_layers.DWT1D,
|
| 62 |
+
# ('DWT', 2): signal_layers.DWT2D,
|
| 63 |
+
# ('DWT', 3): signal_layers.DWT3D,
|
| 64 |
+
('GlobalAveragePooling', 1): tf.keras.layers.GlobalAveragePooling1D,
|
| 65 |
+
('GlobalAveragePooling', 2): tf.keras.layers.GlobalAveragePooling2D,
|
| 66 |
+
('GlobalAveragePooling', 3): tf.keras.layers.GlobalAveragePooling3D,
|
| 67 |
+
('GlobalMaxPool', 1): tf.keras.layers.GlobalMaxPool1D,
|
| 68 |
+
('GlobalMaxPool', 2): tf.keras.layers.GlobalMaxPool2D,
|
| 69 |
+
('GlobalMaxPool', 3): tf.keras.layers.GlobalMaxPool3D,
|
| 70 |
+
# ('IDWT', 1): signal_layers.IDWT1D,
|
| 71 |
+
# ('IDWT', 2): signal_layers.IDWT2D,
|
| 72 |
+
# ('IDWT', 3): signal_layers.IDWT3D,
|
| 73 |
+
('LocallyConnected', 1): tf.keras.layers.LocallyConnected1D,
|
| 74 |
+
('LocallyConnected', 2): tf.keras.layers.LocallyConnected2D,
|
| 75 |
+
('MaxPool', 1): tf.keras.layers.MaxPool1D,
|
| 76 |
+
('MaxPool', 2): tf.keras.layers.MaxPool2D,
|
| 77 |
+
('MaxPool', 3): tf.keras.layers.MaxPool3D,
|
| 78 |
+
('SeparableConv', 1): tf.keras.layers.SeparableConv1D,
|
| 79 |
+
('SeparableConv', 2): tf.keras.layers.SeparableConv2D,
|
| 80 |
+
('SpatialDropout', 1): tf.keras.layers.SpatialDropout1D,
|
| 81 |
+
('SpatialDropout', 2): tf.keras.layers.SpatialDropout2D,
|
| 82 |
+
('SpatialDropout', 3): tf.keras.layers.SpatialDropout3D,
|
| 83 |
+
('UpSampling', 1): tf.keras.layers.UpSampling1D,
|
| 84 |
+
('UpSampling', 2): tf.keras.layers.UpSampling2D,
|
| 85 |
+
('UpSampling', 3): tf.keras.layers.UpSampling3D,
|
| 86 |
+
('ZeroPadding', 1): tf.keras.layers.ZeroPadding1D,
|
| 87 |
+
('ZeroPadding', 2): tf.keras.layers.ZeroPadding2D,
|
| 88 |
+
('ZeroPadding', 3): tf.keras.layers.ZeroPadding3D
|
| 89 |
+
}
|
utils/process_utils.py
ADDED
|
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import tensorflow_addons as tfa
|
| 4 |
+
import tensorflow_mri as tfmri
|
| 5 |
+
import tqdm
|
| 6 |
+
import os
|
| 7 |
+
import pydicom as dicom
|
| 8 |
+
import glob
|
| 9 |
+
from utils.unet3plusnew import *
|
| 10 |
+
from utils.custom_unet_code import *
|
| 11 |
+
from pydicom.tag import Tag
|
| 12 |
+
import imageio
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import streamlit as st
|
| 15 |
+
import zipfile
|
| 16 |
+
from collections import defaultdict
|
| 17 |
+
import re
|
| 18 |
+
from pydicom.tag import Tag
|
| 19 |
+
from collections import defaultdict, Counter
|
| 20 |
+
import io, shutil, zipfile, time
|
| 21 |
+
|
| 22 |
+
#Resizes image
|
| 23 |
+
def resize(t1,x,y):
|
| 24 |
+
# Adding new axis for the channels
|
| 25 |
+
t1 = tf.expand_dims(t1, -1)
|
| 26 |
+
|
| 27 |
+
im1 = tf.image.resize_with_crop_or_pad(t1,x,y)
|
| 28 |
+
return (im1)
|
| 29 |
+
|
| 30 |
+
#Function that normalises image
|
| 31 |
+
def norm(t1):
|
| 32 |
+
im1= t1
|
| 33 |
+
im1 = (im1-np.min(im1)) / np.max(im1)
|
| 34 |
+
return (im1)
|
| 35 |
+
|
| 36 |
+
#Applies debanding model to any number of slices
|
| 37 |
+
def apply_debanding_model(input_im,frames =32):
|
| 38 |
+
|
| 39 |
+
debanding_model = "./models_final/Deband_model"
|
| 40 |
+
debanding = tf.keras.models.load_model(debanding_model, compile=False)
|
| 41 |
+
weights = debanding.get_weights()
|
| 42 |
+
|
| 43 |
+
inputs = tf.keras.Input(shape = [None,None,None,1])
|
| 44 |
+
unet = tfmri.models.UNet3D([32,64,128], kernel_size=3, out_channels=1,use_global_residual=False)
|
| 45 |
+
DB = unet(inputs)
|
| 46 |
+
de_banding_model = tf.keras.Model(inputs = inputs, outputs = DB)
|
| 47 |
+
de_banding_model.set_weights(weights)
|
| 48 |
+
|
| 49 |
+
de_banded = []
|
| 50 |
+
for i in range(frames):
|
| 51 |
+
temp = de_banding_model.predict(tf.expand_dims(tf.expand_dims(input_im[i],0),-1),verbose = 0)
|
| 52 |
+
de_banded.append(temp)
|
| 53 |
+
|
| 54 |
+
return de_banded
|
| 55 |
+
|
| 56 |
+
#Function that applies deformations to 28 slice data
|
| 57 |
+
def deformation_28(x):
|
| 58 |
+
|
| 59 |
+
sagittal_deformed = []
|
| 60 |
+
|
| 61 |
+
for i in range(28):
|
| 62 |
+
|
| 63 |
+
input_img = tf.expand_dims(x[0][0,i,:,:], -1)
|
| 64 |
+
dy = tf.expand_dims(tf.expand_dims(x[1][0,i,:,:], -1),0)
|
| 65 |
+
dx = tf.expand_dims(tf.expand_dims(x[2][0,i,:,:], -1),0)
|
| 66 |
+
|
| 67 |
+
displacement = tf.concat((dy[0,...],dx[0,...]), axis=-1)
|
| 68 |
+
|
| 69 |
+
img = tf.image.convert_image_dtype(tf.expand_dims(input_img, 0), tf.dtypes.float32)
|
| 70 |
+
displacement = tf.image.convert_image_dtype(displacement, tf.dtypes.float32)
|
| 71 |
+
dense_img_warp = tfa.image.dense_image_warp(img, displacement)
|
| 72 |
+
im_deformed = tf.squeeze(dense_img_warp, 0)
|
| 73 |
+
sagittal_deformed.append(im_deformed)
|
| 74 |
+
|
| 75 |
+
sagittal_deformed = tf.image.convert_image_dtype(sagittal_deformed, tf.dtypes.float32)
|
| 76 |
+
sagittal_deformed = tf.expand_dims(sagittal_deformed,axis= 0)
|
| 77 |
+
|
| 78 |
+
return sagittal_deformed
|
| 79 |
+
|
| 80 |
+
#Applies respiratory correction model
|
| 81 |
+
def apply_resp_model_28(input_im,frames = 32):
|
| 82 |
+
|
| 83 |
+
inputs = tf.keras.Input(shape = [None,256,128,1])
|
| 84 |
+
unet = build_3d_unet_resp([None,256,128,1],2) # Acts as aa deformation field generator
|
| 85 |
+
deformation_fields = unet(inputs) # Outputs the deformation fields
|
| 86 |
+
lambda_deformation = tf.keras.layers.Lambda(deformation_28)
|
| 87 |
+
out_2 = lambda_deformation([inputs[:,:,:,:,0],deformation_fields[:,:,:,:,0],deformation_fields[:,:,:,:,1]]) # Outputs the deformed volume
|
| 88 |
+
outputs = [deformation_fields,out_2]
|
| 89 |
+
complete_model = tf.keras.Model(inputs = inputs, outputs = outputs)
|
| 90 |
+
complete_model.load_weights('./models_final/Resp_Correction_model/variables/variables')
|
| 91 |
+
|
| 92 |
+
resp_corrected = []
|
| 93 |
+
deformations = []
|
| 94 |
+
for i in range(frames):
|
| 95 |
+
|
| 96 |
+
def_fields, resp_cor = complete_model.predict(input_im[i][:,:,:,:,:],verbose=0)
|
| 97 |
+
resp_corrected.append(resp_cor)
|
| 98 |
+
deformations.append(def_fields)
|
| 99 |
+
|
| 100 |
+
return deformations, resp_corrected
|
| 101 |
+
|
| 102 |
+
#Applies super resolution model
|
| 103 |
+
def apply_SR_model(input_im,frames = 32):
|
| 104 |
+
E2E_model = "./models_final/E2E_SR_model"
|
| 105 |
+
E2E = tf.keras.models.load_model(E2E_model, compile=False)
|
| 106 |
+
weights = E2E.get_weights()
|
| 107 |
+
sr_weights = weights[22:]
|
| 108 |
+
|
| 109 |
+
inputs = tf.keras.Input(shape = [None,None,None,1])
|
| 110 |
+
SR_model = build_3d_unet(input_shape=(None, None,None,1), num_classes=1)
|
| 111 |
+
SR = SR_model(inputs)
|
| 112 |
+
SR_model_done = tf.keras.Model(inputs = inputs, outputs = SR)
|
| 113 |
+
SR_model_done.set_weights(sr_weights)
|
| 114 |
+
|
| 115 |
+
super_resed = []
|
| 116 |
+
for i in range(frames):
|
| 117 |
+
super_resed.append(SR_model_done.predict(input_im[i],verbose=0))
|
| 118 |
+
|
| 119 |
+
return super_resed
|
| 120 |
+
|
| 121 |
+
t = Tag(0x0019, 0x10D7)
|
| 122 |
+
#Reads in example RT sagittal stack
|
| 123 |
+
def load_data_samples(path_to_data):
|
| 124 |
+
sag_volumes = []
|
| 125 |
+
filename=f"{path_to_data}/*"
|
| 126 |
+
if not os.path.exists(path_to_data):
|
| 127 |
+
raise Exception("Error with file path.")
|
| 128 |
+
else:
|
| 129 |
+
|
| 130 |
+
clean_ims_1 = []
|
| 131 |
+
locations_1 = []
|
| 132 |
+
|
| 133 |
+
clean_ims_final =[]
|
| 134 |
+
locations_final = []
|
| 135 |
+
test = sorted(glob.glob(filename))
|
| 136 |
+
for file in test:
|
| 137 |
+
ds = dicom.dcmread(file)
|
| 138 |
+
|
| 139 |
+
locations_1.append(ds.SliceLocation)
|
| 140 |
+
clean_ims_1.append(ds.pixel_array)
|
| 141 |
+
if ds[t].value ==30:
|
| 142 |
+
clean_ims_final.append(np.array(clean_ims_1))
|
| 143 |
+
locations_final.append(locations_1)
|
| 144 |
+
clean_ims_1 = []
|
| 145 |
+
locations_1 = []
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
#clean_ims_1 = [x for _,x in sorted(zip(locations_1,clean_ims_1))]
|
| 149 |
+
#sag_volumes.append(clean_ims_1)
|
| 150 |
+
final = np.array(clean_ims_final)
|
| 151 |
+
final = np.transpose(final, (1,0,2,3))
|
| 152 |
+
|
| 153 |
+
return final
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def load_data_samples_from_folder(base_dir, number_of_scans=32):
|
| 157 |
+
"""
|
| 158 |
+
Recursively find all DICOM files under the first valid subfolder of base_dir,
|
| 159 |
+
group them by InstanceNumber (time), sort by SliceLocation (z), and return
|
| 160 |
+
a NumPy array of shape (time, z, H, W).
|
| 161 |
+
"""
|
| 162 |
+
# 1. Find the real nested folder (skip macOS junk)
|
| 163 |
+
candidates = [
|
| 164 |
+
d for d in os.listdir(base_dir)
|
| 165 |
+
if os.path.isdir(os.path.join(base_dir, d))
|
| 166 |
+
and not d.startswith("._")
|
| 167 |
+
and "__MACOSX" not in d
|
| 168 |
+
]
|
| 169 |
+
if not candidates:
|
| 170 |
+
st.error("No valid data folder found in ZIP.")
|
| 171 |
+
return np.array([])
|
| 172 |
+
nested_base = os.path.join(base_dir, candidates[0])
|
| 173 |
+
|
| 174 |
+
# 2. Recursively collect every file; we'll filter for DICOMs next
|
| 175 |
+
all_paths = glob.glob(os.path.join(nested_base, "**", "*"), recursive=True)
|
| 176 |
+
all_paths = [p for p in all_paths if os.path.isfile(p) and not os.path.basename(p).startswith("._")]
|
| 177 |
+
|
| 178 |
+
# 3. Filter valid DICOMs
|
| 179 |
+
dicom_files = []
|
| 180 |
+
for p in all_paths:
|
| 181 |
+
try:
|
| 182 |
+
ds = dicom.dcmread(p, force=True, stop_before_pixels=True)
|
| 183 |
+
if hasattr(ds, "InstanceNumber"):
|
| 184 |
+
dicom_files.append(p)
|
| 185 |
+
except:
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
st.write(f"🧾 Found {len(dicom_files)} DICOM files.")
|
| 189 |
+
|
| 190 |
+
if not dicom_files:
|
| 191 |
+
st.error("No valid DICOMs found.")
|
| 192 |
+
return np.array([])
|
| 193 |
+
|
| 194 |
+
# 4. Group by InstanceNumber (temporal frames)
|
| 195 |
+
grouped = defaultdict(list)
|
| 196 |
+
for p in dicom_files:
|
| 197 |
+
try:
|
| 198 |
+
ds = dicom.dcmread(p, force=True)
|
| 199 |
+
inst = ds.InstanceNumber
|
| 200 |
+
loc = getattr(ds, "SliceLocation", 0.0)
|
| 201 |
+
grouped[inst].append((loc, ds.pixel_array))
|
| 202 |
+
except:
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
# 5. Build volume up to number_of_scans frames
|
| 206 |
+
vols = []
|
| 207 |
+
for inst in sorted(grouped.keys())[:number_of_scans]:
|
| 208 |
+
slices = grouped[inst]
|
| 209 |
+
# sort along z
|
| 210 |
+
slices.sort(key=lambda x: x[0])
|
| 211 |
+
vols.append([img for _, img in slices])
|
| 212 |
+
|
| 213 |
+
volume = np.array(vols) # shape (T, Z, H, W)
|
| 214 |
+
st.write(f"✅ Found data shape: {volume.shape}")
|
| 215 |
+
return volume
|
| 216 |
+
|
| 217 |
+
def load_cine_any(
|
| 218 |
+
base_dir: str,
|
| 219 |
+
number_of_scans: int = None, # if None, use all detected phases
|
| 220 |
+
private_phase_tag: Tag = Tag(0x0019, 0x10D7),# your private phase tag (if present)
|
| 221 |
+
verbose: bool = True
|
| 222 |
+
):
|
| 223 |
+
"""
|
| 224 |
+
Universal DICOM cine loader (flat or nested folders).
|
| 225 |
+
|
| 226 |
+
Scans recursively from `base_dir`, detects cardiac phases, sorts slices,
|
| 227 |
+
and returns (T, Z, H, W) along with the total number of phases detected.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
volume: np.ndarray with shape (T, Z, H, W)
|
| 231 |
+
num_phases_detected: int (total phases found in the dataset)
|
| 232 |
+
"""
|
| 233 |
+
def log(msg):
|
| 234 |
+
if verbose:
|
| 235 |
+
try: st.write(msg)
|
| 236 |
+
except Exception: print(msg)
|
| 237 |
+
|
| 238 |
+
if not os.path.isdir(base_dir):
|
| 239 |
+
raise FileNotFoundError(f"No such directory: {base_dir}")
|
| 240 |
+
|
| 241 |
+
# --- Collect candidate files (recursive), skip junk/zips
|
| 242 |
+
candidates = glob.glob(os.path.join(base_dir, "**", "*"), recursive=True)
|
| 243 |
+
candidates = [
|
| 244 |
+
p for p in candidates
|
| 245 |
+
if os.path.isfile(p)
|
| 246 |
+
and "__MACOSX" not in p
|
| 247 |
+
and not os.path.basename(p).startswith("._")
|
| 248 |
+
and not p.lower().endswith(".zip")
|
| 249 |
+
]
|
| 250 |
+
if not candidates:
|
| 251 |
+
log("No files found under the provided directory.")
|
| 252 |
+
return np.array([]), 0
|
| 253 |
+
|
| 254 |
+
# --- Keep only files that parse as DICOM headers
|
| 255 |
+
dicom_files = []
|
| 256 |
+
for p in candidates:
|
| 257 |
+
try:
|
| 258 |
+
_ = dicom.dcmread(p, force=True, stop_before_pixels=True)
|
| 259 |
+
dicom_files.append(p)
|
| 260 |
+
except Exception:
|
| 261 |
+
pass
|
| 262 |
+
|
| 263 |
+
log(f"🧾 Candidate DICOM files: {len(dicom_files)}")
|
| 264 |
+
if not dicom_files:
|
| 265 |
+
log("No valid DICOM files found.")
|
| 266 |
+
return np.array([]), 0
|
| 267 |
+
|
| 268 |
+
# --- NEW: detect flat folder layout (all files in the same directory)
|
| 269 |
+
dicom_dirs = {os.path.dirname(p) for p in dicom_files}
|
| 270 |
+
is_flat = (len(dicom_dirs) == 1)
|
| 271 |
+
|
| 272 |
+
# --- Probe to choose the best phase key
|
| 273 |
+
def _try_get(ds, tag):
|
| 274 |
+
try: return ds[tag].value
|
| 275 |
+
except Exception: return None
|
| 276 |
+
|
| 277 |
+
uniq_priv, uniq_tpi, uniq_inst = set(), set(), set()
|
| 278 |
+
for p in dicom_files[:min(len(dicom_files), 200)]:
|
| 279 |
+
try:
|
| 280 |
+
ds = dicom.dcmread(p, force=True, stop_before_pixels=True)
|
| 281 |
+
v_priv = _try_get(ds, private_phase_tag)
|
| 282 |
+
if v_priv is not None:
|
| 283 |
+
try: uniq_priv.add(int(v_priv))
|
| 284 |
+
except Exception: pass
|
| 285 |
+
if hasattr(ds, "TemporalPositionIdentifier"):
|
| 286 |
+
try: uniq_tpi.add(int(ds.TemporalPositionIdentifier))
|
| 287 |
+
except Exception: pass
|
| 288 |
+
if hasattr(ds, "InstanceNumber"):
|
| 289 |
+
try: uniq_inst.add(int(ds.InstanceNumber))
|
| 290 |
+
except Exception: pass
|
| 291 |
+
except Exception:
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
if len(uniq_priv) > 1:
|
| 295 |
+
phase_key = ("private", private_phase_tag)
|
| 296 |
+
elif len(uniq_tpi) > 1:
|
| 297 |
+
phase_key = ("tpi", None)
|
| 298 |
+
elif len(uniq_inst) > 1:
|
| 299 |
+
phase_key = ("instance", None)
|
| 300 |
+
else:
|
| 301 |
+
log("Could not determine a phase key (no variation in private/TPI/InstanceNumber).")
|
| 302 |
+
return np.array([]), 0
|
| 303 |
+
|
| 304 |
+
def _get_phase(ds):
|
| 305 |
+
if phase_key[0] == "private":
|
| 306 |
+
v = _try_get(ds, phase_key[1]); return int(v) if v is not None else None
|
| 307 |
+
if phase_key[0] == "tpi":
|
| 308 |
+
return int(getattr(ds, "TemporalPositionIdentifier", None)) \
|
| 309 |
+
if hasattr(ds, "TemporalPositionIdentifier") else None
|
| 310 |
+
if phase_key[0] == "instance":
|
| 311 |
+
return int(getattr(ds, "InstanceNumber", None)) \
|
| 312 |
+
if hasattr(ds, "InstanceNumber") else None
|
| 313 |
+
return None
|
| 314 |
+
|
| 315 |
+
def _get_z(ds):
|
| 316 |
+
z = getattr(ds, "SliceLocation", None)
|
| 317 |
+
if z is None:
|
| 318 |
+
ipp = getattr(ds, "ImagePositionPatient", None)
|
| 319 |
+
if ipp is not None and len(ipp) >= 3:
|
| 320 |
+
try: z = float(ipp[2])
|
| 321 |
+
except Exception: z = 0.0
|
| 322 |
+
else:
|
| 323 |
+
z = 0.0
|
| 324 |
+
return float(z)
|
| 325 |
+
|
| 326 |
+
# --- Group by phase; sort by z
|
| 327 |
+
grouped = defaultdict(list)
|
| 328 |
+
for p in dicom_files:
|
| 329 |
+
try:
|
| 330 |
+
ds = dicom.dcmread(p, force=True)
|
| 331 |
+
ph = _get_phase(ds)
|
| 332 |
+
if ph is None:
|
| 333 |
+
continue
|
| 334 |
+
grouped[int(ph)].append((_get_z(ds), ds.pixel_array))
|
| 335 |
+
except Exception:
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
if not grouped:
|
| 339 |
+
log("No groups formed (no phase could be read).")
|
| 340 |
+
return np.array([]), 0
|
| 341 |
+
|
| 342 |
+
all_phase_ids = sorted(grouped.keys())
|
| 343 |
+
num_phases_detected = len(all_phase_ids)
|
| 344 |
+
phases_to_use = all_phase_ids if number_of_scans is None else all_phase_ids[:number_of_scans]
|
| 345 |
+
|
| 346 |
+
stacks_T, slice_counts = [], []
|
| 347 |
+
for ph in phases_to_use:
|
| 348 |
+
pairs = grouped[ph]
|
| 349 |
+
if not pairs:
|
| 350 |
+
continue
|
| 351 |
+
pairs.sort(key=lambda x: x[0]) # sort by z
|
| 352 |
+
stack = [img for _, img in pairs] # Z × H × W
|
| 353 |
+
stacks_T.append(stack)
|
| 354 |
+
slice_counts.append(len(stack))
|
| 355 |
+
|
| 356 |
+
if not stacks_T:
|
| 357 |
+
log("Groups existed but none had readable slices.")
|
| 358 |
+
return np.array([]), num_phases_detected
|
| 359 |
+
|
| 360 |
+
# Harmonize Z across phases (trim to the most common Z)
|
| 361 |
+
if len(set(slice_counts)) > 1:
|
| 362 |
+
common_Z = Counter(slice_counts).most_common(1)[0][0]
|
| 363 |
+
stacks_T = [s[:common_Z] for s in stacks_T if len(s) >= common_Z]
|
| 364 |
+
if not stacks_T:
|
| 365 |
+
log("All phases had inconsistent slice counts.")
|
| 366 |
+
return np.array([]), num_phases_detected
|
| 367 |
+
|
| 368 |
+
volume = np.array(stacks_T) # (T, Z, H, W)
|
| 369 |
+
|
| 370 |
+
# --- NEW: flip slice order (Z) if data came from a flat single folder
|
| 371 |
+
if is_flat:
|
| 372 |
+
volume = volume[:, ::-1, :, :]
|
| 373 |
+
|
| 374 |
+
log(f"✅ Final volume shape: {volume[0,...].shape} , Phases detected = {num_phases_detected}")
|
| 375 |
+
return volume, num_phases_detected
|
| 376 |
+
|
| 377 |
+
def load_data_samples_from_flat_folder(
|
| 378 |
+
base_dir: str,
|
| 379 |
+
number_of_scans: int = 32,
|
| 380 |
+
frame_tag: Tag = Tag(0x0019, 0x10D7) # private phase tag (adjust if needed)
|
| 381 |
+
) -> np.ndarray:
|
| 382 |
+
"""
|
| 383 |
+
Robust loader when all DICOMs are under one folder (possibly nested).
|
| 384 |
+
- Steps into the single subfolder if present (ignores upload.zip, macOS junk).
|
| 385 |
+
- Recursively finds DICOMs (even without .dcm extension).
|
| 386 |
+
- Groups by phase from `frame_tag` or fallback (0020,0100).
|
| 387 |
+
- Sorts by SliceLocation/IPPs and returns (Z, T, H, W).
|
| 388 |
+
"""
|
| 389 |
+
if not os.path.isdir(base_dir):
|
| 390 |
+
raise FileNotFoundError(f"No such directory: {base_dir}")
|
| 391 |
+
|
| 392 |
+
# --- Step 1: if there’s exactly one subfolder (plus upload.zip), dive into it
|
| 393 |
+
entries = [e for e in os.listdir(base_dir) if not e.startswith("._")]
|
| 394 |
+
subdirs = [os.path.join(base_dir, e) for e in entries
|
| 395 |
+
if os.path.isdir(os.path.join(base_dir, e)) and "__MACOSX" not in e]
|
| 396 |
+
# If precisely one subdir, prefer that as root; otherwise use base_dir as-is
|
| 397 |
+
root = subdirs[0] if len(subdirs) == 1 else base_dir
|
| 398 |
+
|
| 399 |
+
# --- Step 2: recursively collect candidate files (skip zips and junk)
|
| 400 |
+
candidates = glob.glob(os.path.join(root, "**", "*"), recursive=True)
|
| 401 |
+
candidates = [
|
| 402 |
+
p for p in candidates
|
| 403 |
+
if os.path.isfile(p)
|
| 404 |
+
and not os.path.basename(p).startswith("._")
|
| 405 |
+
and "__MACOSX" not in p
|
| 406 |
+
and not p.lower().endswith(".zip")
|
| 407 |
+
]
|
| 408 |
+
if not candidates:
|
| 409 |
+
st.error("No files found under the provided directory.")
|
| 410 |
+
return np.array([])
|
| 411 |
+
|
| 412 |
+
# --- Step 3: keep only files that parse as DICOM headers
|
| 413 |
+
dicom_files = []
|
| 414 |
+
for p in candidates:
|
| 415 |
+
try:
|
| 416 |
+
ds = dicom.dcmread(p, force=True, stop_before_pixels=True)
|
| 417 |
+
dicom_files.append(p)
|
| 418 |
+
except Exception:
|
| 419 |
+
pass
|
| 420 |
+
|
| 421 |
+
st.write(f"🧾 Candidate DICOM files: {len(dicom_files)}")
|
| 422 |
+
if not dicom_files:
|
| 423 |
+
st.error("No valid DICOM files found.")
|
| 424 |
+
return np.array([])
|
| 425 |
+
|
| 426 |
+
# --- Helper: determine phase index
|
| 427 |
+
def _get_phase(ds):
|
| 428 |
+
# Preferred: private tag (your dataset)
|
| 429 |
+
if frame_tag in ds:
|
| 430 |
+
try:
|
| 431 |
+
return int(ds[frame_tag].value)
|
| 432 |
+
except Exception:
|
| 433 |
+
pass
|
| 434 |
+
# Fallback: standard TemporalPositionIdentifier (0020,0100)
|
| 435 |
+
if hasattr(ds, "TemporalPositionIdentifier"):
|
| 436 |
+
try:
|
| 437 |
+
return int(ds.TemporalPositionIdentifier)
|
| 438 |
+
except Exception:
|
| 439 |
+
pass
|
| 440 |
+
# Last resort: AcquisitionNumber (not always phase, but useful fallback)
|
| 441 |
+
if hasattr(ds, "AcquisitionNumber"):
|
| 442 |
+
try:
|
| 443 |
+
return int(ds.AcquisitionNumber)
|
| 444 |
+
except Exception:
|
| 445 |
+
pass
|
| 446 |
+
return None
|
| 447 |
+
|
| 448 |
+
# --- Step 4: group by phase; sort by z
|
| 449 |
+
grouped = defaultdict(list)
|
| 450 |
+
phase_missing = 0
|
| 451 |
+
for p in dicom_files:
|
| 452 |
+
try:
|
| 453 |
+
ds = dicom.dcmread(p, force=True)
|
| 454 |
+
phase = _get_phase(ds)
|
| 455 |
+
if phase is None:
|
| 456 |
+
phase_missing += 1
|
| 457 |
+
continue
|
| 458 |
+
# z-order: SliceLocation if present else IPP[2] else 0
|
| 459 |
+
z = getattr(ds, "SliceLocation", None)
|
| 460 |
+
if z is None:
|
| 461 |
+
ipp = getattr(ds, "ImagePositionPatient", None)
|
| 462 |
+
if ipp is not None and len(ipp) >= 3:
|
| 463 |
+
z = float(ipp[2])
|
| 464 |
+
else:
|
| 465 |
+
z = 0.0
|
| 466 |
+
img = ds.pixel_array
|
| 467 |
+
grouped[int(phase)].append((z, img))
|
| 468 |
+
except Exception:
|
| 469 |
+
continue
|
| 470 |
+
|
| 471 |
+
if not grouped:
|
| 472 |
+
st.error(
|
| 473 |
+
"Could not determine a phase tag for any files. "
|
| 474 |
+
"Check for (0019,10D7) or (0020,0100) in your dataset."
|
| 475 |
+
)
|
| 476 |
+
st.write(f"Files missing phase: {phase_missing} / {len(dicom_files)}")
|
| 477 |
+
# Optional: show attributes of one file to discover tags
|
| 478 |
+
try:
|
| 479 |
+
ds0 = dicom.dcmread(dicom_files[0], force=True, stop_before_pixels=True)
|
| 480 |
+
st.write("Sample DICOM attributes:", ds0.dir())
|
| 481 |
+
except Exception:
|
| 482 |
+
pass
|
| 483 |
+
return np.array([])
|
| 484 |
+
|
| 485 |
+
# keep up to number_of_scans phases
|
| 486 |
+
phases = sorted(grouped.keys())[:number_of_scans]
|
| 487 |
+
|
| 488 |
+
stacks_T = []
|
| 489 |
+
slice_counts = []
|
| 490 |
+
for ph in phases:
|
| 491 |
+
pairs = grouped[ph]
|
| 492 |
+
if not pairs:
|
| 493 |
+
continue
|
| 494 |
+
pairs.sort(key=lambda x: x[0]) # sort by z
|
| 495 |
+
stack = [img for _, img in pairs] # Z × H × W
|
| 496 |
+
stacks_T.append(stack)
|
| 497 |
+
slice_counts.append(len(stack))
|
| 498 |
+
|
| 499 |
+
if not stacks_T:
|
| 500 |
+
st.error("No phases contained readable slices.")
|
| 501 |
+
return np.array([])
|
| 502 |
+
|
| 503 |
+
# Harmonize Z across phases (trim to the most common slice count)
|
| 504 |
+
if len(set(slice_counts)) > 1:
|
| 505 |
+
common_Z = Counter(slice_counts).most_common(1)[0][0]
|
| 506 |
+
stacks_T = [s[:common_Z] for s in stacks_T if len(s) >= common_Z]
|
| 507 |
+
if not stacks_T:
|
| 508 |
+
st.error("All phases had inconsistent slice counts.")
|
| 509 |
+
return np.array([])
|
| 510 |
+
|
| 511 |
+
vol = np.array(stacks_T) # (T, Z, H, W)
|
| 512 |
+
st.write(f"✅ Final volume shape: {vol.shape} (T, S, H, W)")
|
| 513 |
+
return vol
|
| 514 |
+
|
| 515 |
+
def extract_zip(zip_path, extract_to):
|
| 516 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 517 |
+
# Filter out __MACOSX and dotfiles
|
| 518 |
+
valid_files = [
|
| 519 |
+
f for f in zip_ref.namelist()
|
| 520 |
+
if "__MACOSX" not in f and not os.path.basename(f).startswith("._")
|
| 521 |
+
]
|
| 522 |
+
zip_ref.extractall(extract_to, members=valid_files)
|
| 523 |
+
|
| 524 |
+
def make_gif(path, timepoints, axis=-1, slice=60,frame_rate = 30):
|
| 525 |
+
# 1. Locate all .npy files
|
| 526 |
+
all_files = glob.glob(os.path.join(path, "*.npy"))
|
| 527 |
+
print("Found NPY files:", all_files)
|
| 528 |
+
|
| 529 |
+
# 2. Group them by prefix
|
| 530 |
+
scan_keys = ['raw', 'debanded', 'resp_cor', '3D_cine']
|
| 531 |
+
groups = {k: [] for k in scan_keys}
|
| 532 |
+
|
| 533 |
+
pattern = re.compile(r'(?P<prefix>raw|debanded|resp_cor|3D_cine)_(?P<index>\d+)\.npy')
|
| 534 |
+
|
| 535 |
+
for p in all_files:
|
| 536 |
+
fn = os.path.basename(p)
|
| 537 |
+
match = pattern.match(fn)
|
| 538 |
+
if match:
|
| 539 |
+
prefix = match.group("prefix")
|
| 540 |
+
t_idx = int(match.group("index"))
|
| 541 |
+
groups[prefix].append((t_idx, p))
|
| 542 |
+
|
| 543 |
+
# 3. Sanity check: do all groups exist & have equal lengths?
|
| 544 |
+
Ts = [len(v) for v in groups.values()]
|
| 545 |
+
print("Group counts:", Ts)
|
| 546 |
+
if not all(T == timepoints for T in Ts):
|
| 547 |
+
raise ValueError(f"Mismatch in timepoints across groups. Expected {timepoints}, got {Ts}")
|
| 548 |
+
|
| 549 |
+
for k in groups:
|
| 550 |
+
groups[k].sort(key=lambda x: x[0]) # sort by t_idx
|
| 551 |
+
|
| 552 |
+
# 4. Determine normalization range per group
|
| 553 |
+
stats = {}
|
| 554 |
+
for k in scan_keys:
|
| 555 |
+
mins, maxs = [], []
|
| 556 |
+
for _, p in groups[k]:
|
| 557 |
+
vol = np.load(p)
|
| 558 |
+
if axis == -1:
|
| 559 |
+
slice_ = vol[:, :, slice]
|
| 560 |
+
else:
|
| 561 |
+
slice_ = vol[:, slice, :] if axis == 1 else vol[slice, :, :]
|
| 562 |
+
mins.append(slice_.min())
|
| 563 |
+
maxs.append(slice_.max())
|
| 564 |
+
stats[k] = (min(mins), max(maxs))
|
| 565 |
+
|
| 566 |
+
# 5. Create frames
|
| 567 |
+
frames = []
|
| 568 |
+
for t in range(timepoints):
|
| 569 |
+
imgs_t = []
|
| 570 |
+
for k in scan_keys:
|
| 571 |
+
_, p = groups[k][t]
|
| 572 |
+
vol = np.load(p).astype(np.float32)
|
| 573 |
+
|
| 574 |
+
if axis == 2:
|
| 575 |
+
img = vol[::-1, :, slice]
|
| 576 |
+
elif axis == 1:
|
| 577 |
+
img = vol[:, slice, :]
|
| 578 |
+
elif axis == 0:
|
| 579 |
+
img = vol[slice, :, :]
|
| 580 |
+
img = np.transpose(img[:,::-1])
|
| 581 |
+
|
| 582 |
+
mn, mx = stats[k]
|
| 583 |
+
img = np.clip(img, mn, mx)
|
| 584 |
+
img8 = ((img - mn) / (mx - mn) * 255).astype(np.uint8)
|
| 585 |
+
img8 = img8.T[:, ::-1] # flip + transpose
|
| 586 |
+
imgs_t.append(img8)
|
| 587 |
+
|
| 588 |
+
# Stitch side-by-side
|
| 589 |
+
composite = np.concatenate(imgs_t, axis=1)
|
| 590 |
+
resized = Image.fromarray(composite).resize((composite.shape[1]*3, composite.shape[0]*3), Image.NEAREST)
|
| 591 |
+
frames.append(np.array(resized))
|
| 592 |
+
|
| 593 |
+
# 6. Save and return
|
| 594 |
+
out_path = os.path.join(path, f"temp.gif")
|
| 595 |
+
imageio.mimsave(out_path, frames, duration=1000/frame_rate,loop =0)
|
| 596 |
+
|
| 597 |
+
return out_path
|
| 598 |
+
|
| 599 |
+
def to_dicom(cardiac_frames, patient_number):
|
| 600 |
+
|
| 601 |
+
filename="./utils/dicom_headerfile.dcm"
|
| 602 |
+
for file in glob.glob(filename):
|
| 603 |
+
ds = dicom.read_file(file)
|
| 604 |
+
|
| 605 |
+
for i in range(cardiac_frames):
|
| 606 |
+
|
| 607 |
+
volume = np.load(f'./out_dir/3D_cine_{i}.npy')
|
| 608 |
+
print(f"Volume: {i}")
|
| 609 |
+
for j in range(volume.shape[0]):
|
| 610 |
+
|
| 611 |
+
PixelData = volume[j,:,:]
|
| 612 |
+
|
| 613 |
+
PixelData = (PixelData * 255).astype(np.uint16)
|
| 614 |
+
|
| 615 |
+
Dicoms = ds.copy()
|
| 616 |
+
|
| 617 |
+
Dicoms.InstanceNumber = j
|
| 618 |
+
Dicoms.PatientID = 'Mark'
|
| 619 |
+
Dicoms.PatientName = 'Mark'
|
| 620 |
+
Dicoms.StudyDescription = '3D Cine'
|
| 621 |
+
Dicoms.SeriesDescription = 'HR'
|
| 622 |
+
Dicoms.StudyInstanceUID = '1.3.12.2.1107.5.2.41.169828.3001002301121546102500000000' + str(patient_number)
|
| 623 |
+
|
| 624 |
+
Dicoms.SliceThickness = 1.5
|
| 625 |
+
Dicoms.Rows = 256
|
| 626 |
+
Dicoms.Columns = 128
|
| 627 |
+
Dicoms.AcquisitionMatrix = [0, 256,128, 0]
|
| 628 |
+
Dicoms.ImageOrientationPatient = [1.0 ,0.0, 0.0, 0.0, 0.0, -1.0]
|
| 629 |
+
Dicoms.SliceLocation = -100.0 + ((j-1) * Dicoms.SliceThickness)
|
| 630 |
+
Dicoms.SamplesPerPixel = 1
|
| 631 |
+
|
| 632 |
+
Dicoms.BitsAllocated = 16
|
| 633 |
+
Dicoms.BitsStored = 12
|
| 634 |
+
Dicoms.HighBit = 11
|
| 635 |
+
Dicoms.PixelRepresentation = 0
|
| 636 |
+
Dicoms.AcquisitionNumber = i
|
| 637 |
+
Dicoms.SeriesNumber = i
|
| 638 |
+
Dicoms.PixelSpacing = [1.5,1.5]
|
| 639 |
+
|
| 640 |
+
Dicoms.SmallestImagePixelValue = 0
|
| 641 |
+
Dicoms.LargestImagePixelValue = 255
|
| 642 |
+
|
| 643 |
+
Dicoms.PixelData = PixelData.tobytes()
|
| 644 |
+
|
| 645 |
+
Dicoms.SeriesInstanceUID = '1.3.12.2.1107.5.2.41.169828.300100230112154610250000001' + str(i)
|
| 646 |
+
Dicoms.SOPInstanceUID = dicom.uid.generate_uid()
|
| 647 |
+
Dicoms.AcquisitionTime = str(i)
|
| 648 |
+
Dicoms.SeriesTime = str(i)
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
dicom.filewriter.dcmwrite(filename=f'./out_dicoms/MARK_PATIENT_{patient_number}_VOL_{i}_SLICE_{j}.dcm',dataset=Dicoms)
|
| 652 |
+
|
| 653 |
+
return 42
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def zip_dir_to_memory(dir_path: str) -> io.BytesIO:
|
| 657 |
+
buf = io.BytesIO()
|
| 658 |
+
with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
| 659 |
+
for root, _, files in os.walk(dir_path):
|
| 660 |
+
for f in files:
|
| 661 |
+
full = os.path.join(root, f)
|
| 662 |
+
arc = os.path.relpath(full, dir_path) # keep relative paths in zip
|
| 663 |
+
zf.write(full, arc)
|
| 664 |
+
buf.seek(0)
|
| 665 |
+
return buf
|
utils/unet3plusnew.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python file containing unet3plus code used to train segmentation model
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import numpy as np
|
| 5 |
+
import utils.layer_util
|
| 6 |
+
tf.random.set_seed(0)
|
| 7 |
+
|
| 8 |
+
class unet3plus:
|
| 9 |
+
def __init__(self,
|
| 10 |
+
inputs,
|
| 11 |
+
filters = [32,64,128,256,512],
|
| 12 |
+
rank = 2,
|
| 13 |
+
out_channels = 3,
|
| 14 |
+
kernel_initializer=tf.keras.initializers.HeNormal(seed=0),
|
| 15 |
+
bias_initializer=tf.keras.initializers.Zeros(),
|
| 16 |
+
kernel_regularizer=None,
|
| 17 |
+
bias_regularizer=None,
|
| 18 |
+
add_dropout = False,
|
| 19 |
+
padding = 'same',
|
| 20 |
+
dropout_rate = 0.5,
|
| 21 |
+
kernel_size = 3,
|
| 22 |
+
out_kernel_size = 3,
|
| 23 |
+
pool_size = 2,
|
| 24 |
+
encoder_block_depth = 2,
|
| 25 |
+
decoder_block_depth = 1,
|
| 26 |
+
batch_norm = True,
|
| 27 |
+
activation = 'relu',
|
| 28 |
+
out_activation = None,
|
| 29 |
+
skip_batch_norm = True,
|
| 30 |
+
skip_type = 'encoder',
|
| 31 |
+
CGM = False,
|
| 32 |
+
deep_supervision = True):
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
self.inputs = inputs
|
| 36 |
+
self.filters = filters
|
| 37 |
+
self.scales = len(filters)
|
| 38 |
+
self.rank = rank
|
| 39 |
+
self.out_channels = out_channels
|
| 40 |
+
self.encoder_block_depth = encoder_block_depth
|
| 41 |
+
self.decoder_block_depth = decoder_block_depth
|
| 42 |
+
self.kernel_size = kernel_size
|
| 43 |
+
self.add_dropout = add_dropout
|
| 44 |
+
self.dropout_rate = dropout_rate
|
| 45 |
+
self.skip_type = skip_type
|
| 46 |
+
self.skip_batch_norm = skip_batch_norm
|
| 47 |
+
self.batch_norm = batch_norm
|
| 48 |
+
if isinstance(activation, str):
|
| 49 |
+
self.activation = tf.keras.activations.get(activation)
|
| 50 |
+
else:
|
| 51 |
+
self.activation = activation
|
| 52 |
+
if isinstance(out_activation, str):
|
| 53 |
+
self.out_activation = tf.keras.activations.get(out_activation)
|
| 54 |
+
else:
|
| 55 |
+
self.out_activation = out_activation
|
| 56 |
+
# Assign pool size
|
| 57 |
+
if isinstance(pool_size,tuple):
|
| 58 |
+
self.pool_size = pool_size
|
| 59 |
+
else:
|
| 60 |
+
self.pool_size = tuple([pool_size for _ in range(rank)])
|
| 61 |
+
if isinstance(kernel_size,tuple):
|
| 62 |
+
self.kernel_size = kernel_size
|
| 63 |
+
else:
|
| 64 |
+
self.kernel_size = tuple([kernel_size for _ in range(rank)])
|
| 65 |
+
if isinstance(out_kernel_size,tuple):
|
| 66 |
+
self.out_kernel_size = out_kernel_size
|
| 67 |
+
else:
|
| 68 |
+
self.out_kernel_size = tuple([out_kernel_size for _ in range(rank)])
|
| 69 |
+
self.CGM = CGM
|
| 70 |
+
self.deep_supervision = deep_supervision
|
| 71 |
+
self.conv_config = dict(kernel_size = self.kernel_size,
|
| 72 |
+
padding = padding,
|
| 73 |
+
kernel_initializer = kernel_initializer,
|
| 74 |
+
bias_initializer = bias_initializer,
|
| 75 |
+
kernel_regularizer = kernel_regularizer,
|
| 76 |
+
bias_regularizer = bias_regularizer)
|
| 77 |
+
self.out_conv_config = dict(kernel_size = out_kernel_size,
|
| 78 |
+
padding = padding,
|
| 79 |
+
kernel_initializer = kernel_initializer,
|
| 80 |
+
bias_initializer = bias_initializer,
|
| 81 |
+
kernel_regularizer = kernel_regularizer,
|
| 82 |
+
bias_regularizer = bias_regularizer)
|
| 83 |
+
|
| 84 |
+
def aggregate(self, scale_list, scale):
|
| 85 |
+
X = tf.keras.layers.Concatenate(name = f'D{scale}_input', axis = -1)(scale_list)
|
| 86 |
+
X = self.conv_block(X, self.filters[0] * self.scales, num_stacks = self.decoder_block_depth, layer_type = 'Decoder', scale=scale)
|
| 87 |
+
return X
|
| 88 |
+
|
| 89 |
+
def deep_sup(self, inputs, scale):
|
| 90 |
+
conv = layer_util.get_nd_layer('Conv', self.rank)
|
| 91 |
+
upsamp = layer_util.get_nd_layer('UpSampling', self.rank)
|
| 92 |
+
size = tuple(np.array(self.pool_size)** (abs(scale-1)))
|
| 93 |
+
if self.rank == 2:
|
| 94 |
+
upsamp_config = dict(size=size, interpolation='bilinear')
|
| 95 |
+
else:
|
| 96 |
+
upsamp_config = dict(size=size)
|
| 97 |
+
X = inputs
|
| 98 |
+
X = conv(self.out_channels, activation = None, **self.out_conv_config, name = f'deepsup_conv_{scale}')(X)
|
| 99 |
+
if scale != 1:
|
| 100 |
+
X = upsamp(**upsamp_config, name = f'deepsup_upsamp_{scale}')(X)
|
| 101 |
+
#X = tf.keras.layers.Activation(activation = 'sigmoid' if self.out_channels == 1 else 'softmax', name = f'deepsup_activation_{scale}')(X)
|
| 102 |
+
X =self.out_activation(X)
|
| 103 |
+
return X
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def full_scale(self, inputs, to_layer, from_layer):
|
| 108 |
+
conv = layer_util.get_nd_layer('Conv', self.rank)
|
| 109 |
+
layer_diff = from_layer - to_layer
|
| 110 |
+
size = tuple(np.array(self.pool_size)** (abs(layer_diff)))
|
| 111 |
+
maxpool = layer_util.get_nd_layer('MaxPool', self.rank)
|
| 112 |
+
upsamp = layer_util.get_nd_layer('UpSampling', self.rank)
|
| 113 |
+
if self.rank == 2:
|
| 114 |
+
upsamp_config = dict(size=size, interpolation='bilinear')
|
| 115 |
+
else:
|
| 116 |
+
upsamp_config = dict(size=size)
|
| 117 |
+
|
| 118 |
+
X = inputs
|
| 119 |
+
if to_layer < from_layer:
|
| 120 |
+
X = upsamp(**upsamp_config, name = f'Skip_Upsample_{from_layer}_{to_layer}')(X)
|
| 121 |
+
elif to_layer > from_layer:
|
| 122 |
+
X = maxpool(pool_size = size, name = f'Skip_Maxpool_{from_layer}_{to_layer}')(X)
|
| 123 |
+
|
| 124 |
+
if self.skip_batch_norm:
|
| 125 |
+
X = self.conv_block(X, self.filters[0], num_stacks = self.decoder_block_depth, layer_type ='Skip', scale = f'{from_layer}_{to_layer}')
|
| 126 |
+
else:
|
| 127 |
+
X = conv(self.filters[0],**self.conv_config, name = f'Skip_Conv_{from_layer}_{to_layer}')(X)
|
| 128 |
+
|
| 129 |
+
return X
|
| 130 |
+
|
| 131 |
+
def conv_block(self, inputs, filters, num_stacks,layer_type, scale):
|
| 132 |
+
conv = layer_util.get_nd_layer('Conv', self.rank)
|
| 133 |
+
X = inputs
|
| 134 |
+
for i in range(num_stacks):
|
| 135 |
+
X = conv(filters, **self.conv_config, name = f'{layer_type}{scale}_Conv_{i+1}')(X)
|
| 136 |
+
if self.batch_norm:
|
| 137 |
+
X = tf.keras.layers.BatchNormalization(axis=-1, name = f'{layer_type}{scale}_BN_{i+1}')(X)
|
| 138 |
+
#X = tf.keras.layers.LeakyReLU(name = f'{layer_type}{scale}_Activation_{i+1}')(X)
|
| 139 |
+
X = self.activation(X)
|
| 140 |
+
return X
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def encode(self, inputs, scale, num_stacks):
|
| 144 |
+
maxpool = layer_util.get_nd_layer('MaxPool', self.rank)
|
| 145 |
+
scale -= 1 # python index
|
| 146 |
+
filters = self.filters[scale]
|
| 147 |
+
X = inputs
|
| 148 |
+
if scale != 0:
|
| 149 |
+
X = maxpool(pool_size=self.pool_size, name = f'encoding_{scale}_maxpool')(X)
|
| 150 |
+
X = self.conv_block(X, filters, num_stacks, layer_type = 'Encoder', scale = scale+1)
|
| 151 |
+
if scale == (self.scales-1) and self.add_dropout:
|
| 152 |
+
X = tf.keras.layers.Dropout(rate = self.dropout_rate, name = f'Encoder{scale+1}_dropout')(X)
|
| 153 |
+
return X
|
| 154 |
+
|
| 155 |
+
def outputs(self):
|
| 156 |
+
XE = [self.inputs]
|
| 157 |
+
for i in range(self.scales):
|
| 158 |
+
XE.append(self.encode(XE[i], scale = i+1, num_stacks = self.encoder_block_depth))
|
| 159 |
+
XD = [XE[-1]]
|
| 160 |
+
if self.skip_type == 'encoder':
|
| 161 |
+
for decoder_level in range(self.scales-1,0,-1):
|
| 162 |
+
input_contributions = []
|
| 163 |
+
for unet_level in range(1,self.scales+1):
|
| 164 |
+
if unet_level == decoder_level+1:
|
| 165 |
+
input_contributions.append(self.full_scale(XD[-1], decoder_level, unet_level))
|
| 166 |
+
else:
|
| 167 |
+
input_contributions.append(self.full_scale(XE[unet_level], decoder_level, unet_level))
|
| 168 |
+
XD.append(self.aggregate(input_contributions,decoder_level))
|
| 169 |
+
elif self.skip_type == 'decoder':
|
| 170 |
+
for decoder_level in range(self.scales-1,0,-1):
|
| 171 |
+
skip_contributions = []
|
| 172 |
+
# Append skips from encoder
|
| 173 |
+
for encoder_level in range(1,decoder_level+1):
|
| 174 |
+
skip_contributions.append(self.full_scale(XE[encoder_level], decoder_level, encoder_level))
|
| 175 |
+
# Append skips from decoder
|
| 176 |
+
for i in range(len(XD)-1,-1,-1):
|
| 177 |
+
skip_contributions.append(self.full_scale(XD[i], decoder_level, (self.scales-i)))
|
| 178 |
+
XD.append(self.aggregate(skip_contributions,decoder_level))
|
| 179 |
+
else:
|
| 180 |
+
raise ValueError(f"Invalid skip_type")
|
| 181 |
+
if self.deep_supervision == True:
|
| 182 |
+
XD = [self.deep_sup(xd, self.scales-i) for i,xd in enumerate(XD)]
|
| 183 |
+
return XD
|
| 184 |
+
else:
|
| 185 |
+
XD[-1] = self.deep_sup(XD[-1],1)
|
| 186 |
+
return XD[-1]
|