MarkWrobel commited on
Commit
f8317f9
·
verified ·
1 Parent(s): 9549f5f

Upload 16 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ utils/dicom_headerfile.dcm filter=lfs diff=lfs merge=lfs -text
.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [server]
2
+ maxUploadSize = 2048 # MB
3
+ maxMessageSize = 2048 # MB
Dockerfile CHANGED
@@ -1,21 +1,33 @@
1
- FROM python:3.9-slim
2
 
3
- WORKDIR /app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- software-properties-common \
9
- git \
10
- && rm -rf /var/lib/apt/lists/*
11
 
12
- COPY requirements.txt ./
13
- COPY src/ ./src/
 
 
14
 
15
- RUN pip3 install -r requirements.txt
16
-
17
- EXPOSE 8501
18
-
19
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
-
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
 
1
+ # See here for image contents: https://github.com/microsoft/vscode-dev-containers/tree/v0.245.0/containers/python-3/.devcontainer/base.Dockerfile
2
 
3
+ # [Choice] Python version (use -bullseye variants on local arm64/Apple Silicon): 3, 3.10, 3.9, 3.8, 3.7, 3.6, 3-bullseye, 3.10-bullseye, 3.9-bullseye, 3.8-bullseye, 3.7-bullseye, 3.6-bullseye, 3-buster, 3.10-buster, 3.9-buster, 3.8-buster, 3.7-buster, 3.6-buster
4
+ FROM tensorflow/tensorflow:2.9.1-gpu
5
+ RUN apt-get install libopenexr-dev -y
6
+ RUN pip install tensorflow-mri
7
+ RUN pip install tqdm
8
+ RUN pip install h5py
9
+ RUN pip install tensorflow-addons
10
+ RUN pip install scikit-learn
11
+ RUN pip install scikit-image
12
+ RUN pip install neptune-client
13
+ RUN pip install matplotlib
14
+ RUN pip install scipy
15
+ RUN pip install pydicom
16
+ RUN pip install streamlit
17
+ RUN pip install protobuf==3.20.*
18
 
 
 
 
 
 
 
19
 
20
+ # Create non-root user.
21
+ ARG USERNAME=vscode
22
+ ARG USER_UID=1003
23
+ ARG USER_GID=$USER_UID
24
 
25
+ RUN groupadd --gid $USER_GID $USERNAME && \
26
+ useradd --uid $USER_UID --gid $USER_GID -m $USERNAME && \
27
+ # Add user to sudoers.
28
+ apt-get update && \
29
+ apt-get install -y sudo && \
30
+ echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME && \
31
+ chmod 0440 /etc/sudoers.d/$USERNAME && \
32
+ # Change default shell to bash.
33
+ usermod --shell /bin/bash $USERNAME
README.md CHANGED
@@ -1,19 +1,20 @@
 
1
  ---
2
- title: 3DCine
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Streamlit template space
12
  ---
13
 
14
- # Welcome to Streamlit!
15
-
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
 
 
 
 
17
 
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
1
+
2
  ---
3
+ title: 3D Cine MRI (Streamlit)
4
+ emoji: 🧠
5
+ colorFrom: indigo
6
+ colorTo: pink
7
+ sdk: streamlit
8
+ app_file: app.py
9
+ pinned: true
 
 
 
10
  ---
11
 
12
+ ## Setup (Hugging Face Spaces — free CPU)
13
+ 1. Create a Space with **SDK: Streamlit**, hardware **CPU Basic**.
14
+ 2. In **Settings Variables**, add:
15
+ - `HF_HOME=/data/.huggingface` (speeds up caching)
16
+ - `MODEL_REPO=your-username/your-model-repo` (where your models live on the Hub)
17
+ - *(optional)* `MODEL_SUBDIR=subfolder-inside-repo`
18
+ - *(optional)* `PERSIST_BASE=/data`
19
 
20
+ Your models will be downloaded at runtime into persistent storage and reused after restarts.
 
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.process_utils import *
2
+ import skimage
3
+ import streamlit as st
4
+ import zipfile
5
+ import tempfile
6
+ import os
7
+ import shutil
8
+
9
+ import os
10
+ from huggingface_hub import snapshot_download
11
+
12
+ MODEL_REPO = os.getenv("MODEL_REPO") # e.g. "username/3d-cine-models"
13
+ MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") # optional subfolder in the repo
14
+ PERSIST_BASE = os.getenv("PERSIST_BASE", "/data") # HF Spaces persistent storage
15
+
16
+ def get_models_base():
17
+ # cache models inside persistent storage to avoid re-downloads
18
+ os.makedirs(PERSIST_BASE, exist_ok=True)
19
+ if MODEL_REPO:
20
+ repo_dir = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=os.path.join(PERSIST_BASE, "hf_models"), local_dir_use_symlinks=False)
21
+ base = os.path.join(repo_dir, MODEL_SUBDIR) if MODEL_SUBDIR else repo_dir
22
+ else:
23
+ # fallback to a local folder in persistent storage
24
+ base = os.path.join(PERSIST_BASE, MODELS_BASE)
25
+ os.makedirs(base, exist_ok=True)
26
+ return base
27
+
28
+ MODELS_BASE = get_models_base()
29
+
30
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
31
+
32
+ if "initialized" not in st.session_state:
33
+ for d in ("./out_dir", "./out_dicoms"):
34
+ if os.path.exists(d):
35
+ shutil.rmtree(d)
36
+ os.makedirs(d, exist_ok=True)
37
+ st.session_state.initialized = True
38
+
39
+ # --- Session state defaults ---
40
+ if "volume" not in st.session_state:
41
+ st.session_state.volume = None
42
+ if "data_processed" not in st.session_state:
43
+ st.session_state.data_processed = False
44
+ if "gif_ready" not in st.session_state:
45
+ st.session_state.gif_ready = False
46
+ if "dicom_create" not in st.session_state:
47
+ st.session_state.dicom_create = False
48
+ if "want_gif" not in st.session_state:
49
+ st.session_state.want_gif = True
50
+ if "num_phases" not in st.session_state:
51
+ st.session_state.num_phases = None
52
+
53
+ # --- Title ---
54
+ st.title("3D Cine")
55
+
56
+ # --- Upload ---
57
+ st.header("Data Upload")
58
+ uploaded_zip = st.file_uploader("Upload ZIP file of MRI folders", type="zip")
59
+
60
+ _ = st.toggle("Generate a GIF preview after processing", key="want_gif")
61
+
62
+ if uploaded_zip is not None:
63
+
64
+ if st.button("Process Data"):
65
+ with st.spinner("Processing ZIP..."):
66
+ temp_dir = tempfile.mkdtemp()
67
+ zip_path = os.path.join(temp_dir, "upload.zip")
68
+ with open(zip_path, "wb") as f:
69
+ f.write(uploaded_zip.read())
70
+
71
+ extract_zip(zip_path, temp_dir)
72
+ st.session_state.volume, st.session_state.num_phases = load_cine_any(temp_dir, number_of_scans=None)
73
+ num_phases = st.session_state.num_phases
74
+ if st.session_state.volume is None or len(st.session_state.volume) == 0:
75
+ st.error("Failed to load volume.")
76
+ else:
77
+ with st.spinner("Cropping..."):
78
+ time_steps = num_phases
79
+ sag_vols = np.array(st.session_state.volume)
80
+
81
+ if sag_vols.shape[1] !=28:
82
+ diff = sag_vols.shape[1] -28
83
+ sag_vols = sag_vols[:,diff:,:,:]
84
+
85
+ if sag_vols.shape[2] ==512:
86
+ sag_vols = skimage.transform.rescale(sag_vols,(1,1,0.5,0.5),order =3, anti_aliasing=True)
87
+
88
+ sag_vols_cropped = []
89
+ for j in range(time_steps):
90
+ sag_cropped = []
91
+ for i in range(sag_vols.shape[1]):
92
+ sag_cropped.append(resize(sag_vols[j,i,:,:], 256, 128))
93
+ sag_cropped = np.dstack(sag_cropped)
94
+ sag_cropped = np.swapaxes(sag_cropped, 0, 1)
95
+ sag_cropped = np.swapaxes(sag_cropped, 0, 2)
96
+ sag_vols_cropped.append(sag_cropped)
97
+
98
+ sag_vols_cropped = norm(sag_vols_cropped)
99
+
100
+ if st.session_state.want_gif:
101
+ raw_us = skimage.transform.rescale(sag_vols_cropped, (1,4,1,1), order=2)
102
+
103
+ with st.spinner("Contrast correction..."):
104
+ debanded = apply_debanding_model(sag_vols_cropped, frames=time_steps)
105
+ debanded = norm(debanded)
106
+ debanded_us = debanded[:,0,...,0]
107
+ if st.session_state.want_gif:
108
+ debanded_us = skimage.transform.rescale(debanded_us, (1,4,1,1), order=2)
109
+
110
+ with st.spinner("Respiratory correction..."):
111
+ def_fields, resp_cor = apply_resp_model_28(debanded, frames=time_steps)
112
+ resp_cor = norm(resp_cor)
113
+ resp_cor_us = resp_cor[:,0,...,0]
114
+ if st.session_state.want_gif:
115
+ resp_cor_us = skimage.transform.rescale(resp_cor_us, (1,4,1,1), order=2)
116
+
117
+ with st.spinner("Super-resolution..."):
118
+ super_resed_E2E = apply_SR_model(resp_cor, frames=time_steps)
119
+ super_resed_E2E = norm(super_resed_E2E)
120
+ super_resed_E2E = super_resed_E2E[:,0,...,0]
121
+
122
+ os.makedirs('./out_dir/', exist_ok=True)
123
+ for i in range(time_steps):
124
+ np.save(f'./out_dir/3D_cine_{i}.npy', super_resed_E2E[i])
125
+ if st.session_state.want_gif:
126
+ np.save(f'./out_dir/resp_cor_{i}.npy', resp_cor_us[i])
127
+ np.save(f'./out_dir/debanded_{i}.npy', debanded_us[i])
128
+ np.save(f'./out_dir/raw_{i}.npy', raw_us[i])
129
+
130
+ st.success("✅ All models complete and data saved!")
131
+ st.session_state.data_processed = True
132
+ st.session_state.gif_ready = False # Reset gif status
133
+
134
+ if not st.session_state.want_gif:
135
+ st.session_state.dicom_create = True
136
+
137
+ # --- GIF Generation Section ---
138
+ if st.session_state.want_gif:
139
+ num_phases = st.session_state.num_phases
140
+ if st.session_state.data_processed:
141
+ st.header("GIF Generator")
142
+
143
+ axis_option = st.radio(
144
+ "Select axis for slicing",
145
+ options=["Axial", "Coronal"],
146
+ index=0,
147
+ key="axis_selector"
148
+ )
149
+ axis_mapping = {"Axial": 1, "Coronal": 2}
150
+ axis = axis_mapping[axis_option]
151
+
152
+ slice_index = st.number_input("Select slice number", 0, 256, 60, 1)
153
+ framerate = st.number_input("Framerate", 1, 100, num_phases, 1)
154
+
155
+ if st.button("Generate and Show GIF"):
156
+ gif_path = make_gif('./out_dir/', timepoints=num_phases, axis=axis, slice=slice_index, frame_rate=framerate)
157
+ st.image(gif_path, caption="Generated GIF", use_container_width=True)
158
+ st.session_state.gif_ready = True
159
+
160
+ # --- Next Steps Section ---
161
+ if st.session_state.gif_ready:
162
+ next_action = st.radio(
163
+ "What would you like to do next?",
164
+ options=["Generate another GIF", "Proceed to DICOM export"],
165
+ index=0
166
+ )
167
+
168
+ if next_action == "Generate another GIF":
169
+ st.info("Adjust your settings above and click the button again.")
170
+
171
+ elif next_action == "Proceed to DICOM export":
172
+ st.session_state.dicom_create = True
173
+
174
+ # --- DICOM Export Section ---
175
+ if st.session_state.dicom_create:
176
+ num_phases = st.session_state.num_phases
177
+ st.header("DICOM Export")
178
+ to_dicom(num_phases, patient_number=0)
179
+ st.success("✅ Created DICOMs.")
180
+
181
+ src_dir = "./out_dicoms"
182
+
183
+ # build zip once per session so we don't recompress on every rerun
184
+ if "dicom_zip" not in st.session_state:
185
+ if os.path.isdir(src_dir) and any(os.scandir(src_dir)):
186
+ st.session_state.dicom_zip = zip_dir_to_memory(src_dir)
187
+ st.session_state.dicom_zip_name = f"dicoms_{time.strftime('%Y%m%d-%H%M%S')}.zip"
188
+ else:
189
+ st.warning("No DICOMs found to package.")
190
+
191
+ if "dicom_zip" in st.session_state:
192
+ st.download_button(
193
+ label="⬇️ Download DICOMs (ZIP)",
194
+ data=st.session_state.dicom_zip,
195
+ file_name=st.session_state.dicom_zip_name,
196
+ mime="application/zip",
197
+ use_container_width=True
198
+ )
devcontainer.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at:
2
+ // https://github.com/microsoft/vscode-dev-containers/tree/v0.245.0/containers/python-3
3
+ {
4
+ "name": "Python 3",
5
+ "build": {
6
+ "dockerfile": "Dockerfile",
7
+ "context": ".."
8
+ },
9
+
10
+ // Enable GPUs
11
+ "runArgs": [
12
+ "--gpus=all"
13
+ ],
14
+ // Enable plotting.
15
+ "mounts": [
16
+ "type=bind,source=/tmp/.X11-unix,target=/tmp/.X11-unix"
17
+ ],
18
+ // Enable plotting.
19
+ "containerEnv": {
20
+ "DISPLAY": "${localEnv:DISPLAY}"
21
+ },
22
+
23
+ // Configure tool-specific properties.
24
+ "customizations": {
25
+ // Configure properties specific to VS Code.
26
+ "vscode": {
27
+ // Set *default* container specific settings.json values on container create.
28
+ "settings": {
29
+ "python.defaultInterpreterPath": "/usr/local/bin/python",
30
+ "python.linting.enabled": true,
31
+ "python.linting.pylintEnabled": true,
32
+ "python.formatting.autopep8Path": "/usr/local/py-utils/bin/autopep8",
33
+ "python.formatting.blackPath": "/usr/local/py-utils/bin/black",
34
+ "python.formatting.yapfPath": "/usr/local/py-utils/bin/yapf",
35
+ "python.linting.banditPath": "/usr/local/py-utils/bin/bandit",
36
+ "python.linting.flake8Path": "/usr/local/py-utils/bin/flake8",
37
+ "python.linting.mypyPath": "/usr/local/py-utils/bin/mypy",
38
+ "python.linting.pycodestylePath": "/usr/local/py-utils/bin/pycodestyle",
39
+ "python.linting.pydocstylePath": "/usr/local/py-utils/bin/pydocstyle",
40
+ "python.linting.pylintPath": "/usr/local/py-utils/bin/pylint"
41
+ },
42
+
43
+ // Add the IDs of extensions you want installed when the container is created.
44
+ "extensions": [
45
+ "ms-python.python",
46
+ "ms-python.vscode-pylance"
47
+ ]
48
+ }
49
+ },
50
+
51
+ // Use 'forwardPorts' to make a list of ports inside the container available locally.
52
+ // "forwardPorts": [],
53
+
54
+ // Use 'postCreateCommand' to run commands after the container is created.
55
+ // "postCreateCommand": "pip3 install --user -r requirements.txt",
56
+
57
+ // Comment out to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.
58
+ "remoteUser": "vscode"
59
+ }
docker-compose.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.9"
2
+
3
+ services:
4
+ run_code:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ ports:
9
+ - "8501:8501"
10
+ volumes:
11
+ - .:/app
12
+ working_dir: /app
13
+ command: streamlit run app.py --server.address=0.0.0.0
14
+ deploy:
15
+ resources:
16
+ reservations:
17
+ devices:
18
+ - capabilities: [gpu]
requirements.txt CHANGED
@@ -1,3 +1,13 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ tensorflow==2.9.1
3
+ tensorflow-mri
4
+ tqdm
5
+ h5py
6
+ tensorflow-addons
7
+ scikit-learn
8
+ scikit-image
9
+ matplotlib
10
+ scipy
11
+ pydicom
12
+ huggingface_hub>=0.21
13
+ numpy<2
utils/__pycache__/custom_unet_code.cpython-38.pyc ADDED
Binary file (3.15 kB). View file
 
utils/__pycache__/layer_util.cpython-38.pyc ADDED
Binary file (2.35 kB). View file
 
utils/__pycache__/process_utils.cpython-38.pyc ADDED
Binary file (18.9 kB). View file
 
utils/__pycache__/unet3plusnew.cpython-38.pyc ADDED
Binary file (5.59 kB). View file
 
utils/custom_unet_code.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, models
3
+
4
+ def time_distributed_conv_block(input_tensor, num_filters):
5
+ x = layers.Conv3D(num_filters, (3, 3, 3), padding="same")(input_tensor)
6
+ x = layers.ReLU()(x)
7
+
8
+ x = layers.Conv3D(num_filters, (3, 3, 3), padding="same")(x)
9
+ x = layers.ReLU()(x)
10
+ return x
11
+
12
+ def time_distributed_encoder_block_resp(input_tensor, num_filters, temporal_maxpool=True):
13
+ x = time_distributed_conv_block(input_tensor, num_filters)
14
+
15
+ p = layers.MaxPooling3D((1, 4, 4))(x)
16
+
17
+ if temporal_maxpool:
18
+
19
+ p = tf.transpose(p, (0,2,3,1,4))
20
+ p2 = layers.TimeDistributed(layers.TimeDistributed(layers.MaxPooling1D((2))))(p)
21
+ p2 = tf.transpose(p2, (0,3,1,2,4))
22
+ return x, p2
23
+ else:
24
+ return x, p
25
+
26
+ def time_distributed_decoder_block_resp(input_tensor, skip_tensor, num_filters):
27
+
28
+ x = layers.TimeDistributed(layers.UpSampling2D(( 4, 4), interpolation='bilinear'))(input_tensor)
29
+ x = layers.Conv3D(num_filters, (3, 3, 3), padding="same")(x)
30
+ x = layers.Conv3D(num_filters, (3, 3, 3), padding="same")(x)
31
+
32
+ x = layers.Concatenate()([x, skip_tensor])
33
+ x = time_distributed_conv_block(x, num_filters)
34
+ return x
35
+
36
+ def build_3d_unet_resp(input_shape, num_classes):
37
+ inputs = layers.Input(shape=input_shape)
38
+
39
+ # Encoding path
40
+ s1, p1 = time_distributed_encoder_block_resp(inputs, 32,temporal_maxpool=False)
41
+ s2, p2 = time_distributed_encoder_block_resp(p1, 64,temporal_maxpool=False)
42
+
43
+ # Bridge
44
+ b1 = time_distributed_conv_block(p2, 128)
45
+
46
+ d1 = time_distributed_decoder_block_resp(b1, s2, 64)
47
+ d2 = time_distributed_decoder_block_resp(d1, s1, 32)
48
+
49
+ outputs = layers.Conv3D(num_classes, (1, 1, 1))(d2)
50
+
51
+ model = models.Model(inputs, outputs, name="3D-U-Net-resp")
52
+ return model
53
+
54
+ def time_distributed_encoder_block(input_tensor, num_filters, temporal_maxpool=True):
55
+ x = time_distributed_conv_block(input_tensor, num_filters)
56
+
57
+ p = layers.TimeDistributed(layers.MaxPooling2D((2, 2)))(x)
58
+ if temporal_maxpool:
59
+
60
+ p = tf.transpose(p, (0,2,3,1,4))
61
+ p2 = layers.TimeDistributed(layers.TimeDistributed(layers.MaxPooling1D((2))))(p)
62
+ p2 = tf.transpose(p2, (0,3,1,2,4))
63
+ return x, p2
64
+ else:
65
+ return x, p
66
+
67
+ def time_distributed_decoder_block(input_tensor, skip_tensor, num_filters, temporal_upsamp=True):
68
+ x = layers.TimeDistributed(layers.UpSampling2D(( 2, 2)))(input_tensor)
69
+ x = layers.TimeDistributed(layers.Conv2D(num_filters, (3, 3), padding="same"))(x)
70
+ if temporal_upsamp:
71
+ x = tf.transpose(x, (0,2,3,1,4))
72
+ x = layers.TimeDistributed(layers.TimeDistributed(layers.UpSampling1D((2))))(x)
73
+ x = layers.TimeDistributed(layers.TimeDistributed(layers.Conv1D(num_filters, (2),padding="same")))(x)
74
+ x = tf.transpose(x, (0,3,1,2,4))
75
+
76
+ if x.shape[4] == 64:
77
+ skip_tensor = tf.transpose(skip_tensor, (0,2,3,1,4))
78
+ skip_tensor = layers.TimeDistributed(layers.TimeDistributed(layers.Conv1DTranspose(num_filters,kernel_size=2,strides=2)))(skip_tensor)
79
+ skip_tensor = tf.transpose(skip_tensor, (0,3,1,2,4))
80
+
81
+ if x.shape[4] == 32:
82
+ skip_tensor = tf.transpose(skip_tensor, (0,2,3,1,4))
83
+ skip_tensor = layers.TimeDistributed(layers.TimeDistributed(layers.Conv1DTranspose(num_filters,kernel_size=2,strides=2)))(skip_tensor)
84
+ skip_tensor = layers.TimeDistributed(layers.TimeDistributed(layers.Conv1DTranspose(num_filters,kernel_size=2,strides=2)))(skip_tensor)
85
+ skip_tensor = tf.transpose(skip_tensor, (0,3,1,2,4))
86
+
87
+ x = layers.Concatenate()([x, skip_tensor])
88
+ x = time_distributed_conv_block(x, num_filters)
89
+ return x
90
+
91
+ def build_3d_unet(input_shape, num_classes):
92
+ inputs = layers.Input(shape=input_shape)
93
+
94
+ # Encoding path
95
+ s1, p1 = time_distributed_encoder_block(inputs, 32,temporal_maxpool=False)
96
+ s2, p2 = time_distributed_encoder_block(p1, 64,temporal_maxpool=False)
97
+ s3, p3 = time_distributed_encoder_block(p2, 128,temporal_maxpool=True)
98
+
99
+ # Bridge
100
+ b1 = time_distributed_conv_block(p3, 256)
101
+
102
+ # Decoding path
103
+ d1 = time_distributed_decoder_block(b1, s3, 128,temporal_upsamp=True)
104
+ d2 = time_distributed_decoder_block(d1, s2, 64,temporal_upsamp=True)
105
+ d3 = time_distributed_decoder_block(d2, s1, 32,temporal_upsamp=True)
106
+
107
+ # Output layer
108
+ outputs = layers.Conv3D(num_classes, (1, 1, 1))(d3)
109
+
110
+ model = models.Model(inputs, outputs, name="3D-U-Net")
111
+ return model
utils/dicom_headerfile.dcm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85bd29cc81a85d8957e34e710c4aff8658768758ead619f98a99372cf1d0319b
3
+ size 187536
utils/layer_util.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 University College London. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Layer utilities."""
16
+
17
+ import tensorflow as tf
18
+
19
+ # from tensorflow_mri.python.layers import convolutional
20
+ # from tensorflow_mri.python.layers import signal_layers
21
+
22
+
23
+ def get_nd_layer(name, rank):
24
+ """Get an N-D layer object.
25
+
26
+ Args:
27
+ name: A `str`. The name of the requested layer.
28
+ rank: An `int`. The rank of the requested layer.
29
+
30
+ Returns:
31
+ A `tf.keras.layers.Layer` object.
32
+
33
+ Raises:
34
+ ValueError: If the requested layer is unknown to TFMRI.
35
+ """
36
+ try:
37
+ return _ND_LAYERS[(name, rank)]
38
+ except KeyError as err:
39
+ raise ValueError(
40
+ f"Could not find a layer with name '{name}' and rank {rank}.") from err
41
+
42
+
43
+ _ND_LAYERS = {
44
+ ('AveragePooling', 1): tf.keras.layers.AveragePooling1D,
45
+ ('AveragePooling', 2): tf.keras.layers.AveragePooling2D,
46
+ ('AveragePooling', 3): tf.keras.layers.AveragePooling3D,
47
+ ('Conv', 1): tf.keras.layers.Conv1D,
48
+ ('Conv', 2): tf.keras.layers.Conv2D,
49
+ ('Conv', 3): tf.keras.layers.Conv3D,
50
+ ('ConvLSTM', 1): tf.keras.layers.ConvLSTM1D,
51
+ ('ConvLSTM', 2): tf.keras.layers.ConvLSTM2D,
52
+ ('ConvLSTM', 3): tf.keras.layers.ConvLSTM3D,
53
+ ('ConvTranspose', 1): tf.keras.layers.Conv1DTranspose,
54
+ ('ConvTranspose', 2): tf.keras.layers.Conv2DTranspose,
55
+ ('ConvTranspose', 3): tf.keras.layers.Conv3DTranspose,
56
+ ('Cropping', 1): tf.keras.layers.Cropping1D,
57
+ ('Cropping', 2): tf.keras.layers.Cropping2D,
58
+ ('Cropping', 3): tf.keras.layers.Cropping3D,
59
+ ('DepthwiseConv', 1): tf.keras.layers.DepthwiseConv1D,
60
+ ('DepthwiseConv', 2): tf.keras.layers.DepthwiseConv2D,
61
+ # ('DWT', 1): signal_layers.DWT1D,
62
+ # ('DWT', 2): signal_layers.DWT2D,
63
+ # ('DWT', 3): signal_layers.DWT3D,
64
+ ('GlobalAveragePooling', 1): tf.keras.layers.GlobalAveragePooling1D,
65
+ ('GlobalAveragePooling', 2): tf.keras.layers.GlobalAveragePooling2D,
66
+ ('GlobalAveragePooling', 3): tf.keras.layers.GlobalAveragePooling3D,
67
+ ('GlobalMaxPool', 1): tf.keras.layers.GlobalMaxPool1D,
68
+ ('GlobalMaxPool', 2): tf.keras.layers.GlobalMaxPool2D,
69
+ ('GlobalMaxPool', 3): tf.keras.layers.GlobalMaxPool3D,
70
+ # ('IDWT', 1): signal_layers.IDWT1D,
71
+ # ('IDWT', 2): signal_layers.IDWT2D,
72
+ # ('IDWT', 3): signal_layers.IDWT3D,
73
+ ('LocallyConnected', 1): tf.keras.layers.LocallyConnected1D,
74
+ ('LocallyConnected', 2): tf.keras.layers.LocallyConnected2D,
75
+ ('MaxPool', 1): tf.keras.layers.MaxPool1D,
76
+ ('MaxPool', 2): tf.keras.layers.MaxPool2D,
77
+ ('MaxPool', 3): tf.keras.layers.MaxPool3D,
78
+ ('SeparableConv', 1): tf.keras.layers.SeparableConv1D,
79
+ ('SeparableConv', 2): tf.keras.layers.SeparableConv2D,
80
+ ('SpatialDropout', 1): tf.keras.layers.SpatialDropout1D,
81
+ ('SpatialDropout', 2): tf.keras.layers.SpatialDropout2D,
82
+ ('SpatialDropout', 3): tf.keras.layers.SpatialDropout3D,
83
+ ('UpSampling', 1): tf.keras.layers.UpSampling1D,
84
+ ('UpSampling', 2): tf.keras.layers.UpSampling2D,
85
+ ('UpSampling', 3): tf.keras.layers.UpSampling3D,
86
+ ('ZeroPadding', 1): tf.keras.layers.ZeroPadding1D,
87
+ ('ZeroPadding', 2): tf.keras.layers.ZeroPadding2D,
88
+ ('ZeroPadding', 3): tf.keras.layers.ZeroPadding3D
89
+ }
utils/process_utils.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import tensorflow_addons as tfa
4
+ import tensorflow_mri as tfmri
5
+ import tqdm
6
+ import os
7
+ import pydicom as dicom
8
+ import glob
9
+ from utils.unet3plusnew import *
10
+ from utils.custom_unet_code import *
11
+ from pydicom.tag import Tag
12
+ import imageio
13
+ from PIL import Image
14
+ import streamlit as st
15
+ import zipfile
16
+ from collections import defaultdict
17
+ import re
18
+ from pydicom.tag import Tag
19
+ from collections import defaultdict, Counter
20
+ import io, shutil, zipfile, time
21
+
22
+ #Resizes image
23
+ def resize(t1,x,y):
24
+ # Adding new axis for the channels
25
+ t1 = tf.expand_dims(t1, -1)
26
+
27
+ im1 = tf.image.resize_with_crop_or_pad(t1,x,y)
28
+ return (im1)
29
+
30
+ #Function that normalises image
31
+ def norm(t1):
32
+ im1= t1
33
+ im1 = (im1-np.min(im1)) / np.max(im1)
34
+ return (im1)
35
+
36
+ #Applies debanding model to any number of slices
37
+ def apply_debanding_model(input_im,frames =32):
38
+
39
+ debanding_model = "./models_final/Deband_model"
40
+ debanding = tf.keras.models.load_model(debanding_model, compile=False)
41
+ weights = debanding.get_weights()
42
+
43
+ inputs = tf.keras.Input(shape = [None,None,None,1])
44
+ unet = tfmri.models.UNet3D([32,64,128], kernel_size=3, out_channels=1,use_global_residual=False)
45
+ DB = unet(inputs)
46
+ de_banding_model = tf.keras.Model(inputs = inputs, outputs = DB)
47
+ de_banding_model.set_weights(weights)
48
+
49
+ de_banded = []
50
+ for i in range(frames):
51
+ temp = de_banding_model.predict(tf.expand_dims(tf.expand_dims(input_im[i],0),-1),verbose = 0)
52
+ de_banded.append(temp)
53
+
54
+ return de_banded
55
+
56
+ #Function that applies deformations to 28 slice data
57
+ def deformation_28(x):
58
+
59
+ sagittal_deformed = []
60
+
61
+ for i in range(28):
62
+
63
+ input_img = tf.expand_dims(x[0][0,i,:,:], -1)
64
+ dy = tf.expand_dims(tf.expand_dims(x[1][0,i,:,:], -1),0)
65
+ dx = tf.expand_dims(tf.expand_dims(x[2][0,i,:,:], -1),0)
66
+
67
+ displacement = tf.concat((dy[0,...],dx[0,...]), axis=-1)
68
+
69
+ img = tf.image.convert_image_dtype(tf.expand_dims(input_img, 0), tf.dtypes.float32)
70
+ displacement = tf.image.convert_image_dtype(displacement, tf.dtypes.float32)
71
+ dense_img_warp = tfa.image.dense_image_warp(img, displacement)
72
+ im_deformed = tf.squeeze(dense_img_warp, 0)
73
+ sagittal_deformed.append(im_deformed)
74
+
75
+ sagittal_deformed = tf.image.convert_image_dtype(sagittal_deformed, tf.dtypes.float32)
76
+ sagittal_deformed = tf.expand_dims(sagittal_deformed,axis= 0)
77
+
78
+ return sagittal_deformed
79
+
80
+ #Applies respiratory correction model
81
+ def apply_resp_model_28(input_im,frames = 32):
82
+
83
+ inputs = tf.keras.Input(shape = [None,256,128,1])
84
+ unet = build_3d_unet_resp([None,256,128,1],2) # Acts as aa deformation field generator
85
+ deformation_fields = unet(inputs) # Outputs the deformation fields
86
+ lambda_deformation = tf.keras.layers.Lambda(deformation_28)
87
+ out_2 = lambda_deformation([inputs[:,:,:,:,0],deformation_fields[:,:,:,:,0],deformation_fields[:,:,:,:,1]]) # Outputs the deformed volume
88
+ outputs = [deformation_fields,out_2]
89
+ complete_model = tf.keras.Model(inputs = inputs, outputs = outputs)
90
+ complete_model.load_weights('./models_final/Resp_Correction_model/variables/variables')
91
+
92
+ resp_corrected = []
93
+ deformations = []
94
+ for i in range(frames):
95
+
96
+ def_fields, resp_cor = complete_model.predict(input_im[i][:,:,:,:,:],verbose=0)
97
+ resp_corrected.append(resp_cor)
98
+ deformations.append(def_fields)
99
+
100
+ return deformations, resp_corrected
101
+
102
+ #Applies super resolution model
103
+ def apply_SR_model(input_im,frames = 32):
104
+ E2E_model = "./models_final/E2E_SR_model"
105
+ E2E = tf.keras.models.load_model(E2E_model, compile=False)
106
+ weights = E2E.get_weights()
107
+ sr_weights = weights[22:]
108
+
109
+ inputs = tf.keras.Input(shape = [None,None,None,1])
110
+ SR_model = build_3d_unet(input_shape=(None, None,None,1), num_classes=1)
111
+ SR = SR_model(inputs)
112
+ SR_model_done = tf.keras.Model(inputs = inputs, outputs = SR)
113
+ SR_model_done.set_weights(sr_weights)
114
+
115
+ super_resed = []
116
+ for i in range(frames):
117
+ super_resed.append(SR_model_done.predict(input_im[i],verbose=0))
118
+
119
+ return super_resed
120
+
121
+ t = Tag(0x0019, 0x10D7)
122
+ #Reads in example RT sagittal stack
123
+ def load_data_samples(path_to_data):
124
+ sag_volumes = []
125
+ filename=f"{path_to_data}/*"
126
+ if not os.path.exists(path_to_data):
127
+ raise Exception("Error with file path.")
128
+ else:
129
+
130
+ clean_ims_1 = []
131
+ locations_1 = []
132
+
133
+ clean_ims_final =[]
134
+ locations_final = []
135
+ test = sorted(glob.glob(filename))
136
+ for file in test:
137
+ ds = dicom.dcmread(file)
138
+
139
+ locations_1.append(ds.SliceLocation)
140
+ clean_ims_1.append(ds.pixel_array)
141
+ if ds[t].value ==30:
142
+ clean_ims_final.append(np.array(clean_ims_1))
143
+ locations_final.append(locations_1)
144
+ clean_ims_1 = []
145
+ locations_1 = []
146
+
147
+
148
+ #clean_ims_1 = [x for _,x in sorted(zip(locations_1,clean_ims_1))]
149
+ #sag_volumes.append(clean_ims_1)
150
+ final = np.array(clean_ims_final)
151
+ final = np.transpose(final, (1,0,2,3))
152
+
153
+ return final
154
+
155
+
156
+ def load_data_samples_from_folder(base_dir, number_of_scans=32):
157
+ """
158
+ Recursively find all DICOM files under the first valid subfolder of base_dir,
159
+ group them by InstanceNumber (time), sort by SliceLocation (z), and return
160
+ a NumPy array of shape (time, z, H, W).
161
+ """
162
+ # 1. Find the real nested folder (skip macOS junk)
163
+ candidates = [
164
+ d for d in os.listdir(base_dir)
165
+ if os.path.isdir(os.path.join(base_dir, d))
166
+ and not d.startswith("._")
167
+ and "__MACOSX" not in d
168
+ ]
169
+ if not candidates:
170
+ st.error("No valid data folder found in ZIP.")
171
+ return np.array([])
172
+ nested_base = os.path.join(base_dir, candidates[0])
173
+
174
+ # 2. Recursively collect every file; we'll filter for DICOMs next
175
+ all_paths = glob.glob(os.path.join(nested_base, "**", "*"), recursive=True)
176
+ all_paths = [p for p in all_paths if os.path.isfile(p) and not os.path.basename(p).startswith("._")]
177
+
178
+ # 3. Filter valid DICOMs
179
+ dicom_files = []
180
+ for p in all_paths:
181
+ try:
182
+ ds = dicom.dcmread(p, force=True, stop_before_pixels=True)
183
+ if hasattr(ds, "InstanceNumber"):
184
+ dicom_files.append(p)
185
+ except:
186
+ continue
187
+
188
+ st.write(f"🧾 Found {len(dicom_files)} DICOM files.")
189
+
190
+ if not dicom_files:
191
+ st.error("No valid DICOMs found.")
192
+ return np.array([])
193
+
194
+ # 4. Group by InstanceNumber (temporal frames)
195
+ grouped = defaultdict(list)
196
+ for p in dicom_files:
197
+ try:
198
+ ds = dicom.dcmread(p, force=True)
199
+ inst = ds.InstanceNumber
200
+ loc = getattr(ds, "SliceLocation", 0.0)
201
+ grouped[inst].append((loc, ds.pixel_array))
202
+ except:
203
+ continue
204
+
205
+ # 5. Build volume up to number_of_scans frames
206
+ vols = []
207
+ for inst in sorted(grouped.keys())[:number_of_scans]:
208
+ slices = grouped[inst]
209
+ # sort along z
210
+ slices.sort(key=lambda x: x[0])
211
+ vols.append([img for _, img in slices])
212
+
213
+ volume = np.array(vols) # shape (T, Z, H, W)
214
+ st.write(f"✅ Found data shape: {volume.shape}")
215
+ return volume
216
+
217
+ def load_cine_any(
218
+ base_dir: str,
219
+ number_of_scans: int = None, # if None, use all detected phases
220
+ private_phase_tag: Tag = Tag(0x0019, 0x10D7),# your private phase tag (if present)
221
+ verbose: bool = True
222
+ ):
223
+ """
224
+ Universal DICOM cine loader (flat or nested folders).
225
+
226
+ Scans recursively from `base_dir`, detects cardiac phases, sorts slices,
227
+ and returns (T, Z, H, W) along with the total number of phases detected.
228
+
229
+ Returns:
230
+ volume: np.ndarray with shape (T, Z, H, W)
231
+ num_phases_detected: int (total phases found in the dataset)
232
+ """
233
+ def log(msg):
234
+ if verbose:
235
+ try: st.write(msg)
236
+ except Exception: print(msg)
237
+
238
+ if not os.path.isdir(base_dir):
239
+ raise FileNotFoundError(f"No such directory: {base_dir}")
240
+
241
+ # --- Collect candidate files (recursive), skip junk/zips
242
+ candidates = glob.glob(os.path.join(base_dir, "**", "*"), recursive=True)
243
+ candidates = [
244
+ p for p in candidates
245
+ if os.path.isfile(p)
246
+ and "__MACOSX" not in p
247
+ and not os.path.basename(p).startswith("._")
248
+ and not p.lower().endswith(".zip")
249
+ ]
250
+ if not candidates:
251
+ log("No files found under the provided directory.")
252
+ return np.array([]), 0
253
+
254
+ # --- Keep only files that parse as DICOM headers
255
+ dicom_files = []
256
+ for p in candidates:
257
+ try:
258
+ _ = dicom.dcmread(p, force=True, stop_before_pixels=True)
259
+ dicom_files.append(p)
260
+ except Exception:
261
+ pass
262
+
263
+ log(f"🧾 Candidate DICOM files: {len(dicom_files)}")
264
+ if not dicom_files:
265
+ log("No valid DICOM files found.")
266
+ return np.array([]), 0
267
+
268
+ # --- NEW: detect flat folder layout (all files in the same directory)
269
+ dicom_dirs = {os.path.dirname(p) for p in dicom_files}
270
+ is_flat = (len(dicom_dirs) == 1)
271
+
272
+ # --- Probe to choose the best phase key
273
+ def _try_get(ds, tag):
274
+ try: return ds[tag].value
275
+ except Exception: return None
276
+
277
+ uniq_priv, uniq_tpi, uniq_inst = set(), set(), set()
278
+ for p in dicom_files[:min(len(dicom_files), 200)]:
279
+ try:
280
+ ds = dicom.dcmread(p, force=True, stop_before_pixels=True)
281
+ v_priv = _try_get(ds, private_phase_tag)
282
+ if v_priv is not None:
283
+ try: uniq_priv.add(int(v_priv))
284
+ except Exception: pass
285
+ if hasattr(ds, "TemporalPositionIdentifier"):
286
+ try: uniq_tpi.add(int(ds.TemporalPositionIdentifier))
287
+ except Exception: pass
288
+ if hasattr(ds, "InstanceNumber"):
289
+ try: uniq_inst.add(int(ds.InstanceNumber))
290
+ except Exception: pass
291
+ except Exception:
292
+ continue
293
+
294
+ if len(uniq_priv) > 1:
295
+ phase_key = ("private", private_phase_tag)
296
+ elif len(uniq_tpi) > 1:
297
+ phase_key = ("tpi", None)
298
+ elif len(uniq_inst) > 1:
299
+ phase_key = ("instance", None)
300
+ else:
301
+ log("Could not determine a phase key (no variation in private/TPI/InstanceNumber).")
302
+ return np.array([]), 0
303
+
304
+ def _get_phase(ds):
305
+ if phase_key[0] == "private":
306
+ v = _try_get(ds, phase_key[1]); return int(v) if v is not None else None
307
+ if phase_key[0] == "tpi":
308
+ return int(getattr(ds, "TemporalPositionIdentifier", None)) \
309
+ if hasattr(ds, "TemporalPositionIdentifier") else None
310
+ if phase_key[0] == "instance":
311
+ return int(getattr(ds, "InstanceNumber", None)) \
312
+ if hasattr(ds, "InstanceNumber") else None
313
+ return None
314
+
315
+ def _get_z(ds):
316
+ z = getattr(ds, "SliceLocation", None)
317
+ if z is None:
318
+ ipp = getattr(ds, "ImagePositionPatient", None)
319
+ if ipp is not None and len(ipp) >= 3:
320
+ try: z = float(ipp[2])
321
+ except Exception: z = 0.0
322
+ else:
323
+ z = 0.0
324
+ return float(z)
325
+
326
+ # --- Group by phase; sort by z
327
+ grouped = defaultdict(list)
328
+ for p in dicom_files:
329
+ try:
330
+ ds = dicom.dcmread(p, force=True)
331
+ ph = _get_phase(ds)
332
+ if ph is None:
333
+ continue
334
+ grouped[int(ph)].append((_get_z(ds), ds.pixel_array))
335
+ except Exception:
336
+ continue
337
+
338
+ if not grouped:
339
+ log("No groups formed (no phase could be read).")
340
+ return np.array([]), 0
341
+
342
+ all_phase_ids = sorted(grouped.keys())
343
+ num_phases_detected = len(all_phase_ids)
344
+ phases_to_use = all_phase_ids if number_of_scans is None else all_phase_ids[:number_of_scans]
345
+
346
+ stacks_T, slice_counts = [], []
347
+ for ph in phases_to_use:
348
+ pairs = grouped[ph]
349
+ if not pairs:
350
+ continue
351
+ pairs.sort(key=lambda x: x[0]) # sort by z
352
+ stack = [img for _, img in pairs] # Z × H × W
353
+ stacks_T.append(stack)
354
+ slice_counts.append(len(stack))
355
+
356
+ if not stacks_T:
357
+ log("Groups existed but none had readable slices.")
358
+ return np.array([]), num_phases_detected
359
+
360
+ # Harmonize Z across phases (trim to the most common Z)
361
+ if len(set(slice_counts)) > 1:
362
+ common_Z = Counter(slice_counts).most_common(1)[0][0]
363
+ stacks_T = [s[:common_Z] for s in stacks_T if len(s) >= common_Z]
364
+ if not stacks_T:
365
+ log("All phases had inconsistent slice counts.")
366
+ return np.array([]), num_phases_detected
367
+
368
+ volume = np.array(stacks_T) # (T, Z, H, W)
369
+
370
+ # --- NEW: flip slice order (Z) if data came from a flat single folder
371
+ if is_flat:
372
+ volume = volume[:, ::-1, :, :]
373
+
374
+ log(f"✅ Final volume shape: {volume[0,...].shape} , Phases detected = {num_phases_detected}")
375
+ return volume, num_phases_detected
376
+
377
+ def load_data_samples_from_flat_folder(
378
+ base_dir: str,
379
+ number_of_scans: int = 32,
380
+ frame_tag: Tag = Tag(0x0019, 0x10D7) # private phase tag (adjust if needed)
381
+ ) -> np.ndarray:
382
+ """
383
+ Robust loader when all DICOMs are under one folder (possibly nested).
384
+ - Steps into the single subfolder if present (ignores upload.zip, macOS junk).
385
+ - Recursively finds DICOMs (even without .dcm extension).
386
+ - Groups by phase from `frame_tag` or fallback (0020,0100).
387
+ - Sorts by SliceLocation/IPPs and returns (Z, T, H, W).
388
+ """
389
+ if not os.path.isdir(base_dir):
390
+ raise FileNotFoundError(f"No such directory: {base_dir}")
391
+
392
+ # --- Step 1: if there’s exactly one subfolder (plus upload.zip), dive into it
393
+ entries = [e for e in os.listdir(base_dir) if not e.startswith("._")]
394
+ subdirs = [os.path.join(base_dir, e) for e in entries
395
+ if os.path.isdir(os.path.join(base_dir, e)) and "__MACOSX" not in e]
396
+ # If precisely one subdir, prefer that as root; otherwise use base_dir as-is
397
+ root = subdirs[0] if len(subdirs) == 1 else base_dir
398
+
399
+ # --- Step 2: recursively collect candidate files (skip zips and junk)
400
+ candidates = glob.glob(os.path.join(root, "**", "*"), recursive=True)
401
+ candidates = [
402
+ p for p in candidates
403
+ if os.path.isfile(p)
404
+ and not os.path.basename(p).startswith("._")
405
+ and "__MACOSX" not in p
406
+ and not p.lower().endswith(".zip")
407
+ ]
408
+ if not candidates:
409
+ st.error("No files found under the provided directory.")
410
+ return np.array([])
411
+
412
+ # --- Step 3: keep only files that parse as DICOM headers
413
+ dicom_files = []
414
+ for p in candidates:
415
+ try:
416
+ ds = dicom.dcmread(p, force=True, stop_before_pixels=True)
417
+ dicom_files.append(p)
418
+ except Exception:
419
+ pass
420
+
421
+ st.write(f"🧾 Candidate DICOM files: {len(dicom_files)}")
422
+ if not dicom_files:
423
+ st.error("No valid DICOM files found.")
424
+ return np.array([])
425
+
426
+ # --- Helper: determine phase index
427
+ def _get_phase(ds):
428
+ # Preferred: private tag (your dataset)
429
+ if frame_tag in ds:
430
+ try:
431
+ return int(ds[frame_tag].value)
432
+ except Exception:
433
+ pass
434
+ # Fallback: standard TemporalPositionIdentifier (0020,0100)
435
+ if hasattr(ds, "TemporalPositionIdentifier"):
436
+ try:
437
+ return int(ds.TemporalPositionIdentifier)
438
+ except Exception:
439
+ pass
440
+ # Last resort: AcquisitionNumber (not always phase, but useful fallback)
441
+ if hasattr(ds, "AcquisitionNumber"):
442
+ try:
443
+ return int(ds.AcquisitionNumber)
444
+ except Exception:
445
+ pass
446
+ return None
447
+
448
+ # --- Step 4: group by phase; sort by z
449
+ grouped = defaultdict(list)
450
+ phase_missing = 0
451
+ for p in dicom_files:
452
+ try:
453
+ ds = dicom.dcmread(p, force=True)
454
+ phase = _get_phase(ds)
455
+ if phase is None:
456
+ phase_missing += 1
457
+ continue
458
+ # z-order: SliceLocation if present else IPP[2] else 0
459
+ z = getattr(ds, "SliceLocation", None)
460
+ if z is None:
461
+ ipp = getattr(ds, "ImagePositionPatient", None)
462
+ if ipp is not None and len(ipp) >= 3:
463
+ z = float(ipp[2])
464
+ else:
465
+ z = 0.0
466
+ img = ds.pixel_array
467
+ grouped[int(phase)].append((z, img))
468
+ except Exception:
469
+ continue
470
+
471
+ if not grouped:
472
+ st.error(
473
+ "Could not determine a phase tag for any files. "
474
+ "Check for (0019,10D7) or (0020,0100) in your dataset."
475
+ )
476
+ st.write(f"Files missing phase: {phase_missing} / {len(dicom_files)}")
477
+ # Optional: show attributes of one file to discover tags
478
+ try:
479
+ ds0 = dicom.dcmread(dicom_files[0], force=True, stop_before_pixels=True)
480
+ st.write("Sample DICOM attributes:", ds0.dir())
481
+ except Exception:
482
+ pass
483
+ return np.array([])
484
+
485
+ # keep up to number_of_scans phases
486
+ phases = sorted(grouped.keys())[:number_of_scans]
487
+
488
+ stacks_T = []
489
+ slice_counts = []
490
+ for ph in phases:
491
+ pairs = grouped[ph]
492
+ if not pairs:
493
+ continue
494
+ pairs.sort(key=lambda x: x[0]) # sort by z
495
+ stack = [img for _, img in pairs] # Z × H × W
496
+ stacks_T.append(stack)
497
+ slice_counts.append(len(stack))
498
+
499
+ if not stacks_T:
500
+ st.error("No phases contained readable slices.")
501
+ return np.array([])
502
+
503
+ # Harmonize Z across phases (trim to the most common slice count)
504
+ if len(set(slice_counts)) > 1:
505
+ common_Z = Counter(slice_counts).most_common(1)[0][0]
506
+ stacks_T = [s[:common_Z] for s in stacks_T if len(s) >= common_Z]
507
+ if not stacks_T:
508
+ st.error("All phases had inconsistent slice counts.")
509
+ return np.array([])
510
+
511
+ vol = np.array(stacks_T) # (T, Z, H, W)
512
+ st.write(f"✅ Final volume shape: {vol.shape} (T, S, H, W)")
513
+ return vol
514
+
515
+ def extract_zip(zip_path, extract_to):
516
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
517
+ # Filter out __MACOSX and dotfiles
518
+ valid_files = [
519
+ f for f in zip_ref.namelist()
520
+ if "__MACOSX" not in f and not os.path.basename(f).startswith("._")
521
+ ]
522
+ zip_ref.extractall(extract_to, members=valid_files)
523
+
524
+ def make_gif(path, timepoints, axis=-1, slice=60,frame_rate = 30):
525
+ # 1. Locate all .npy files
526
+ all_files = glob.glob(os.path.join(path, "*.npy"))
527
+ print("Found NPY files:", all_files)
528
+
529
+ # 2. Group them by prefix
530
+ scan_keys = ['raw', 'debanded', 'resp_cor', '3D_cine']
531
+ groups = {k: [] for k in scan_keys}
532
+
533
+ pattern = re.compile(r'(?P<prefix>raw|debanded|resp_cor|3D_cine)_(?P<index>\d+)\.npy')
534
+
535
+ for p in all_files:
536
+ fn = os.path.basename(p)
537
+ match = pattern.match(fn)
538
+ if match:
539
+ prefix = match.group("prefix")
540
+ t_idx = int(match.group("index"))
541
+ groups[prefix].append((t_idx, p))
542
+
543
+ # 3. Sanity check: do all groups exist & have equal lengths?
544
+ Ts = [len(v) for v in groups.values()]
545
+ print("Group counts:", Ts)
546
+ if not all(T == timepoints for T in Ts):
547
+ raise ValueError(f"Mismatch in timepoints across groups. Expected {timepoints}, got {Ts}")
548
+
549
+ for k in groups:
550
+ groups[k].sort(key=lambda x: x[0]) # sort by t_idx
551
+
552
+ # 4. Determine normalization range per group
553
+ stats = {}
554
+ for k in scan_keys:
555
+ mins, maxs = [], []
556
+ for _, p in groups[k]:
557
+ vol = np.load(p)
558
+ if axis == -1:
559
+ slice_ = vol[:, :, slice]
560
+ else:
561
+ slice_ = vol[:, slice, :] if axis == 1 else vol[slice, :, :]
562
+ mins.append(slice_.min())
563
+ maxs.append(slice_.max())
564
+ stats[k] = (min(mins), max(maxs))
565
+
566
+ # 5. Create frames
567
+ frames = []
568
+ for t in range(timepoints):
569
+ imgs_t = []
570
+ for k in scan_keys:
571
+ _, p = groups[k][t]
572
+ vol = np.load(p).astype(np.float32)
573
+
574
+ if axis == 2:
575
+ img = vol[::-1, :, slice]
576
+ elif axis == 1:
577
+ img = vol[:, slice, :]
578
+ elif axis == 0:
579
+ img = vol[slice, :, :]
580
+ img = np.transpose(img[:,::-1])
581
+
582
+ mn, mx = stats[k]
583
+ img = np.clip(img, mn, mx)
584
+ img8 = ((img - mn) / (mx - mn) * 255).astype(np.uint8)
585
+ img8 = img8.T[:, ::-1] # flip + transpose
586
+ imgs_t.append(img8)
587
+
588
+ # Stitch side-by-side
589
+ composite = np.concatenate(imgs_t, axis=1)
590
+ resized = Image.fromarray(composite).resize((composite.shape[1]*3, composite.shape[0]*3), Image.NEAREST)
591
+ frames.append(np.array(resized))
592
+
593
+ # 6. Save and return
594
+ out_path = os.path.join(path, f"temp.gif")
595
+ imageio.mimsave(out_path, frames, duration=1000/frame_rate,loop =0)
596
+
597
+ return out_path
598
+
599
+ def to_dicom(cardiac_frames, patient_number):
600
+
601
+ filename="./utils/dicom_headerfile.dcm"
602
+ for file in glob.glob(filename):
603
+ ds = dicom.read_file(file)
604
+
605
+ for i in range(cardiac_frames):
606
+
607
+ volume = np.load(f'./out_dir/3D_cine_{i}.npy')
608
+ print(f"Volume: {i}")
609
+ for j in range(volume.shape[0]):
610
+
611
+ PixelData = volume[j,:,:]
612
+
613
+ PixelData = (PixelData * 255).astype(np.uint16)
614
+
615
+ Dicoms = ds.copy()
616
+
617
+ Dicoms.InstanceNumber = j
618
+ Dicoms.PatientID = 'Mark'
619
+ Dicoms.PatientName = 'Mark'
620
+ Dicoms.StudyDescription = '3D Cine'
621
+ Dicoms.SeriesDescription = 'HR'
622
+ Dicoms.StudyInstanceUID = '1.3.12.2.1107.5.2.41.169828.3001002301121546102500000000' + str(patient_number)
623
+
624
+ Dicoms.SliceThickness = 1.5
625
+ Dicoms.Rows = 256
626
+ Dicoms.Columns = 128
627
+ Dicoms.AcquisitionMatrix = [0, 256,128, 0]
628
+ Dicoms.ImageOrientationPatient = [1.0 ,0.0, 0.0, 0.0, 0.0, -1.0]
629
+ Dicoms.SliceLocation = -100.0 + ((j-1) * Dicoms.SliceThickness)
630
+ Dicoms.SamplesPerPixel = 1
631
+
632
+ Dicoms.BitsAllocated = 16
633
+ Dicoms.BitsStored = 12
634
+ Dicoms.HighBit = 11
635
+ Dicoms.PixelRepresentation = 0
636
+ Dicoms.AcquisitionNumber = i
637
+ Dicoms.SeriesNumber = i
638
+ Dicoms.PixelSpacing = [1.5,1.5]
639
+
640
+ Dicoms.SmallestImagePixelValue = 0
641
+ Dicoms.LargestImagePixelValue = 255
642
+
643
+ Dicoms.PixelData = PixelData.tobytes()
644
+
645
+ Dicoms.SeriesInstanceUID = '1.3.12.2.1107.5.2.41.169828.300100230112154610250000001' + str(i)
646
+ Dicoms.SOPInstanceUID = dicom.uid.generate_uid()
647
+ Dicoms.AcquisitionTime = str(i)
648
+ Dicoms.SeriesTime = str(i)
649
+
650
+
651
+ dicom.filewriter.dcmwrite(filename=f'./out_dicoms/MARK_PATIENT_{patient_number}_VOL_{i}_SLICE_{j}.dcm',dataset=Dicoms)
652
+
653
+ return 42
654
+
655
+
656
+ def zip_dir_to_memory(dir_path: str) -> io.BytesIO:
657
+ buf = io.BytesIO()
658
+ with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
659
+ for root, _, files in os.walk(dir_path):
660
+ for f in files:
661
+ full = os.path.join(root, f)
662
+ arc = os.path.relpath(full, dir_path) # keep relative paths in zip
663
+ zf.write(full, arc)
664
+ buf.seek(0)
665
+ return buf
utils/unet3plusnew.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python file containing unet3plus code used to train segmentation model
2
+
3
+ import tensorflow as tf
4
+ import numpy as np
5
+ import utils.layer_util
6
+ tf.random.set_seed(0)
7
+
8
+ class unet3plus:
9
+ def __init__(self,
10
+ inputs,
11
+ filters = [32,64,128,256,512],
12
+ rank = 2,
13
+ out_channels = 3,
14
+ kernel_initializer=tf.keras.initializers.HeNormal(seed=0),
15
+ bias_initializer=tf.keras.initializers.Zeros(),
16
+ kernel_regularizer=None,
17
+ bias_regularizer=None,
18
+ add_dropout = False,
19
+ padding = 'same',
20
+ dropout_rate = 0.5,
21
+ kernel_size = 3,
22
+ out_kernel_size = 3,
23
+ pool_size = 2,
24
+ encoder_block_depth = 2,
25
+ decoder_block_depth = 1,
26
+ batch_norm = True,
27
+ activation = 'relu',
28
+ out_activation = None,
29
+ skip_batch_norm = True,
30
+ skip_type = 'encoder',
31
+ CGM = False,
32
+ deep_supervision = True):
33
+
34
+
35
+ self.inputs = inputs
36
+ self.filters = filters
37
+ self.scales = len(filters)
38
+ self.rank = rank
39
+ self.out_channels = out_channels
40
+ self.encoder_block_depth = encoder_block_depth
41
+ self.decoder_block_depth = decoder_block_depth
42
+ self.kernel_size = kernel_size
43
+ self.add_dropout = add_dropout
44
+ self.dropout_rate = dropout_rate
45
+ self.skip_type = skip_type
46
+ self.skip_batch_norm = skip_batch_norm
47
+ self.batch_norm = batch_norm
48
+ if isinstance(activation, str):
49
+ self.activation = tf.keras.activations.get(activation)
50
+ else:
51
+ self.activation = activation
52
+ if isinstance(out_activation, str):
53
+ self.out_activation = tf.keras.activations.get(out_activation)
54
+ else:
55
+ self.out_activation = out_activation
56
+ # Assign pool size
57
+ if isinstance(pool_size,tuple):
58
+ self.pool_size = pool_size
59
+ else:
60
+ self.pool_size = tuple([pool_size for _ in range(rank)])
61
+ if isinstance(kernel_size,tuple):
62
+ self.kernel_size = kernel_size
63
+ else:
64
+ self.kernel_size = tuple([kernel_size for _ in range(rank)])
65
+ if isinstance(out_kernel_size,tuple):
66
+ self.out_kernel_size = out_kernel_size
67
+ else:
68
+ self.out_kernel_size = tuple([out_kernel_size for _ in range(rank)])
69
+ self.CGM = CGM
70
+ self.deep_supervision = deep_supervision
71
+ self.conv_config = dict(kernel_size = self.kernel_size,
72
+ padding = padding,
73
+ kernel_initializer = kernel_initializer,
74
+ bias_initializer = bias_initializer,
75
+ kernel_regularizer = kernel_regularizer,
76
+ bias_regularizer = bias_regularizer)
77
+ self.out_conv_config = dict(kernel_size = out_kernel_size,
78
+ padding = padding,
79
+ kernel_initializer = kernel_initializer,
80
+ bias_initializer = bias_initializer,
81
+ kernel_regularizer = kernel_regularizer,
82
+ bias_regularizer = bias_regularizer)
83
+
84
+ def aggregate(self, scale_list, scale):
85
+ X = tf.keras.layers.Concatenate(name = f'D{scale}_input', axis = -1)(scale_list)
86
+ X = self.conv_block(X, self.filters[0] * self.scales, num_stacks = self.decoder_block_depth, layer_type = 'Decoder', scale=scale)
87
+ return X
88
+
89
+ def deep_sup(self, inputs, scale):
90
+ conv = layer_util.get_nd_layer('Conv', self.rank)
91
+ upsamp = layer_util.get_nd_layer('UpSampling', self.rank)
92
+ size = tuple(np.array(self.pool_size)** (abs(scale-1)))
93
+ if self.rank == 2:
94
+ upsamp_config = dict(size=size, interpolation='bilinear')
95
+ else:
96
+ upsamp_config = dict(size=size)
97
+ X = inputs
98
+ X = conv(self.out_channels, activation = None, **self.out_conv_config, name = f'deepsup_conv_{scale}')(X)
99
+ if scale != 1:
100
+ X = upsamp(**upsamp_config, name = f'deepsup_upsamp_{scale}')(X)
101
+ #X = tf.keras.layers.Activation(activation = 'sigmoid' if self.out_channels == 1 else 'softmax', name = f'deepsup_activation_{scale}')(X)
102
+ X =self.out_activation(X)
103
+ return X
104
+
105
+
106
+
107
+ def full_scale(self, inputs, to_layer, from_layer):
108
+ conv = layer_util.get_nd_layer('Conv', self.rank)
109
+ layer_diff = from_layer - to_layer
110
+ size = tuple(np.array(self.pool_size)** (abs(layer_diff)))
111
+ maxpool = layer_util.get_nd_layer('MaxPool', self.rank)
112
+ upsamp = layer_util.get_nd_layer('UpSampling', self.rank)
113
+ if self.rank == 2:
114
+ upsamp_config = dict(size=size, interpolation='bilinear')
115
+ else:
116
+ upsamp_config = dict(size=size)
117
+
118
+ X = inputs
119
+ if to_layer < from_layer:
120
+ X = upsamp(**upsamp_config, name = f'Skip_Upsample_{from_layer}_{to_layer}')(X)
121
+ elif to_layer > from_layer:
122
+ X = maxpool(pool_size = size, name = f'Skip_Maxpool_{from_layer}_{to_layer}')(X)
123
+
124
+ if self.skip_batch_norm:
125
+ X = self.conv_block(X, self.filters[0], num_stacks = self.decoder_block_depth, layer_type ='Skip', scale = f'{from_layer}_{to_layer}')
126
+ else:
127
+ X = conv(self.filters[0],**self.conv_config, name = f'Skip_Conv_{from_layer}_{to_layer}')(X)
128
+
129
+ return X
130
+
131
+ def conv_block(self, inputs, filters, num_stacks,layer_type, scale):
132
+ conv = layer_util.get_nd_layer('Conv', self.rank)
133
+ X = inputs
134
+ for i in range(num_stacks):
135
+ X = conv(filters, **self.conv_config, name = f'{layer_type}{scale}_Conv_{i+1}')(X)
136
+ if self.batch_norm:
137
+ X = tf.keras.layers.BatchNormalization(axis=-1, name = f'{layer_type}{scale}_BN_{i+1}')(X)
138
+ #X = tf.keras.layers.LeakyReLU(name = f'{layer_type}{scale}_Activation_{i+1}')(X)
139
+ X = self.activation(X)
140
+ return X
141
+
142
+
143
+ def encode(self, inputs, scale, num_stacks):
144
+ maxpool = layer_util.get_nd_layer('MaxPool', self.rank)
145
+ scale -= 1 # python index
146
+ filters = self.filters[scale]
147
+ X = inputs
148
+ if scale != 0:
149
+ X = maxpool(pool_size=self.pool_size, name = f'encoding_{scale}_maxpool')(X)
150
+ X = self.conv_block(X, filters, num_stacks, layer_type = 'Encoder', scale = scale+1)
151
+ if scale == (self.scales-1) and self.add_dropout:
152
+ X = tf.keras.layers.Dropout(rate = self.dropout_rate, name = f'Encoder{scale+1}_dropout')(X)
153
+ return X
154
+
155
+ def outputs(self):
156
+ XE = [self.inputs]
157
+ for i in range(self.scales):
158
+ XE.append(self.encode(XE[i], scale = i+1, num_stacks = self.encoder_block_depth))
159
+ XD = [XE[-1]]
160
+ if self.skip_type == 'encoder':
161
+ for decoder_level in range(self.scales-1,0,-1):
162
+ input_contributions = []
163
+ for unet_level in range(1,self.scales+1):
164
+ if unet_level == decoder_level+1:
165
+ input_contributions.append(self.full_scale(XD[-1], decoder_level, unet_level))
166
+ else:
167
+ input_contributions.append(self.full_scale(XE[unet_level], decoder_level, unet_level))
168
+ XD.append(self.aggregate(input_contributions,decoder_level))
169
+ elif self.skip_type == 'decoder':
170
+ for decoder_level in range(self.scales-1,0,-1):
171
+ skip_contributions = []
172
+ # Append skips from encoder
173
+ for encoder_level in range(1,decoder_level+1):
174
+ skip_contributions.append(self.full_scale(XE[encoder_level], decoder_level, encoder_level))
175
+ # Append skips from decoder
176
+ for i in range(len(XD)-1,-1,-1):
177
+ skip_contributions.append(self.full_scale(XD[i], decoder_level, (self.scales-i)))
178
+ XD.append(self.aggregate(skip_contributions,decoder_level))
179
+ else:
180
+ raise ValueError(f"Invalid skip_type")
181
+ if self.deep_supervision == True:
182
+ XD = [self.deep_sup(xd, self.scales-i) for i,xd in enumerate(XD)]
183
+ return XD
184
+ else:
185
+ XD[-1] = self.deep_sup(XD[-1],1)
186
+ return XD[-1]