Spaces:
Runtime error
Runtime error
File size: 15,345 Bytes
fd601de 48355e9 fd601de c4cade8 fd601de 76bf0bf fd601de 880f41d f1e0db9 fd601de 9fbb587 fd601de 659a112 fd601de cdfc683 fd601de 48355e9 fd601de 48355e9 fd601de 48355e9 fd601de 48355e9 fd601de 48355e9 fd601de 48355e9 fd601de 48355e9 fd601de 48355e9 fd601de 48355e9 fd601de 48355e9 fd601de 073df2e 3e96d2e fd601de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 |
# seg2med_app/app.py
# streamlit run tutorial8_app.py
# F:\yang_Environments\torch\venv\Scripts\activate.ps1
# streamlit run tutorial8_app.py --server.address=0.0.0.0 --server.port=8501
# http://129.206.168.125:8501 http://169.254.3.1:8501
#import sys
#sys.path.append('./seg2med_app')
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# seg2med_app/main.py
import os
import streamlit as st
import zipfile
import hashlib
import pandas as pd
import numpy as np
import nibabel as nib
from seg2med_app.simulation.get_labels import get_labels
from seg2med_app.app_utils.image_utils import (
show_three_planes,
show_label_overlay,
show_three_planes_interactive,
show_single_planes_interactive,
show_label_overlay_single,
generate_color_map,
load_image_canonical,
global_slice_slider,
image_to_base64,
show_single_slice_image,
show_single_slice_label,
)
from seg2med_app.ui.simulation_and_display import simulation_controls
from seg2med_app.ui.upload_and_prepare import handle_upload, compute_md5
from dataprocesser.simulation_functions import (
_merge_seg_tissue,
_create_body_contour_by_tissue_seg,
_create_body_contour
)
from seg2med_app.simulation.combine_selected_organs import combine_selected_organs
from seg2med_app.ui.inference_controls import inference_controls
from seg2med_app.ui.inference_gradio import call_gradio_gpu_infer
from seg2med_app.frankenstein.frankenstein import frankenstein_control
from seg2med_app.app_utils.titles import *
# ========== CONFIG ==========
app_root = 'seg2med_app'
os.makedirs(os.path.join(app_root, "tmp"), exist_ok=True)
# ========== UI STRUCTURE ==========
st.set_page_config(
page_title="Frankenstein App",
page_icon="🧠",
layout="wide"
)
st.session_state["app_root"] = app_root
import streamlit as st
from PIL import Image
import os
def reset_app():
st.session_state.clear()
st.session_state.authenticated = True
st.session_state["authenticated"] = True
st.success("App has been reset. Login information is preserved.")
print("App has been reset. Login information is preserved.")
st.rerun()
image = Image.open(os.path.join(app_root, "Frankenstein0.png"))
image_to_base64(image)
st.title("\U0001F9E0 Frankenstein - multimodal medical image generation")
st.markdown("""
**Created by**: Zeyu Yang
PhD Student, Computer-assisted Clinical Medicine
University of Heidelberg
🔗 [GitHub Repository](https://github.com/musetee/frankenstein)
📄 [Preprint on arXiv](https://arxiv.org/abs/2504.09182)
✉️ Contact: [Zeyu.Yang@medma.uni-heidelberg.de](mailto:Zeyu.Yang@medma.uni-heidelberg.de)
""")
PASSWORD = "frankenstein"
if "authenticated" not in st.session_state:
st.session_state.authenticated = True # set False to be authenticated
if not st.session_state.authenticated:
st.session_state["app_password"] = st.text_input("Enter access code", type="password")
if st.session_state["app_password"] == PASSWORD:
st.session_state.authenticated = True
st.success("✅ Access granted!")
else:
st.warning("🔒 Please enter the correct access code to continue.")
st.stop()
# ========== SIDEBAR (DATASET LOADER) ==========
st.sidebar.title("\U0001F9EC Dataset Loading")
load_method = st.sidebar.radio("Select load method", ["\U0001F3AE Random sample & manual draw", "\U0001F4C1 Upload segmentation"])
if st.button("🔄 Reset App"):
reset_app()
Begin = "### 🎨 Begin: Choose a colormap to visualize different tissues"
st.write(Begin)
default_cmap = "PiYG"
cmap_options = [default_cmap, "nipy_spectral", "tab20", "Set3", "Paired", "tab10", "gist_rainbow", "custom"]
selected_cmap = st.selectbox("Label colormap", cmap_options, index=0)
# 如果选择“自定义”,显示文本框供用户输入
if selected_cmap == "custom":
custom_cmap = st.text_input("please type custom colormap name", value=default_cmap)
selected_cmap = custom_cmap
else:
selected_cmap = selected_cmap
st.session_state.update({"selected_cmap": selected_cmap})
# ========== select color map for visualization segmentation ==============
if "label_ids" in st.session_state:
st.session_state["label_to_color"] = generate_color_map(st.session_state["label_ids"], cmap=st.session_state["selected_cmap"])
print('organ label to color: ', list(st.session_state["label_to_color"].items())[:5])
# ========== MAIN: UPLOAD SEGMENTATION ==========
if load_method == "\U0001F4C1 Upload segmentation":
# ========== FIRST ROW ==========
col1, col2, col3, col4 = st.columns(4)
with col1:
uploaded_file = st.file_uploader("Upload segmentation", type=["zip", "nii.gz", "nii"])
with col2:
uploaded_tissue = st.file_uploader("Upload tissue segmentation", type=["zip", "nii.gz", "nii"], key="tissue_upload")
with col3:
original_file = st.file_uploader("Upload original image", type=["nii.gz", "nii", "dcm"])
with col4:
# 设置 body threshold(默认值根据模态设置或用户手动输入)
default_body_threshold = 0
if "body_threshold" not in st.session_state:
st.session_state["body_threshold"] = default_body_threshold
user_input_threshold = st.number_input(
"Body threshold for contour extraction (used on original image)",
value=st.session_state["body_threshold"],
step=1
)
use_custom_threshold = st.checkbox("Use custom body threshold", value=False)
st.session_state["use_custom_threshold"] = use_custom_threshold
visual_options = ["Only Axial Plane", "Three Planes"]
st.session_state["selected_visual"] = st.selectbox("Visualization Type", visual_options, index=0)
if user_input_threshold:
st.session_state["body_threshold"] = user_input_threshold
if user_input_threshold and "orig_img" in st.session_state:
st.session_state["contour"] = _create_body_contour(st.session_state['orig_img'], st.session_state['body_threshold'], body_mask_value=1)
# ========== HASH MANAGEMENT ==========
new_upload_hash = compute_md5(uploaded_file) if uploaded_file else None
cached_upload_hash = st.session_state.get("uploaded_file_hash", None)
new_tissue_hash = compute_md5(uploaded_tissue) if uploaded_tissue else None
cached_tissue_hash = st.session_state.get("uploaded_tissue_hash", None)
new_origin_hash = compute_md5(original_file) if original_file else None
cached_origin_hash = st.session_state.get("uploaded_origin_hash", None)
handle_upload(app_root,
uploaded_file, uploaded_tissue, original_file
)
# ========== SIMULATION UI (SHARED) ==========
simulation_controls(app_root)
# ========== INFERENCE UI (SHARED) ==========
inference_controls()
# ========== visualize ==========
if "combined_seg" in st.session_state:
z_idx, y_idx, x_idx = global_slice_slider(st.session_state["volume_shape"])
st.session_state.update({
"z_idx": z_idx,
"y_idx": y_idx,
"x_idx": x_idx,
})
if st.session_state["selected_visual"] == "Three Planes":
show_three_planes_interactive(st.session_state["contour"], z_idx, y_idx, x_idx)
show_label_overlay(st.session_state["combined_seg"], z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"])
else:
show_single_planes_interactive(st.session_state["contour"], z_idx, y_idx, x_idx)
show_label_overlay_single(st.session_state["combined_seg"], z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"])
if "selected_organs" in st.session_state and len(st.session_state["selected_organs"]) > 0:
multi_seg = combine_selected_organs(uploaded_file)
if st.session_state["selected_visual"] == "Three Planes":
show_label_overlay(multi_seg, z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"])
else:
show_label_overlay_single(multi_seg, z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"])
if "orig_img" in st.session_state:
if st.session_state["selected_visual"] == "Three Planes":
show_three_planes_interactive(st.session_state["orig_img"], z_idx, y_idx, x_idx,)
else:
show_single_planes_interactive(st.session_state["orig_img"], z_idx, y_idx, x_idx,)
if st.session_state.get("processed_img") is not None:
st.markdown("🔍 View Simulation Result")
if st.session_state["selected_visual"] == "Three Planes":
show_three_planes_interactive(st.session_state["processed_img"],
st.session_state["z_idx"],
st.session_state["y_idx"],
st.session_state["x_idx"],)
else:
show_single_planes_interactive(st.session_state["processed_img"],
st.session_state["z_idx"],
st.session_state["y_idx"],
st.session_state["x_idx"],)
if st.session_state.get("output_img") is not None:
st.session_state["output_volume_to_save"] = np.expand_dims(st.session_state["output_img"].T, axis=-1)
if st.session_state["selected_visual"] == "Three Planes":
show_three_planes_interactive(np.expand_dims(st.session_state["output_img"], axis=-1),0,0,0,orientation_type='none',) # model output already in correct orientation
else:
show_single_planes_interactive(np.expand_dims(st.session_state["output_img"], axis=-1),0,0,0,orientation_type='none',) # model output already in correct orientation
#st.success(f"Saved to {filename_output}")
# ========== RANDOM DRAW PAGE PLACEHOLDER ==========
elif load_method == "\U0001F3AE Random sample & manual draw":
st.markdown("## 🎮 Frankenstein Interactive creating tool")
frankenstein_control()
make_step_renderer(step5_frankenstein)
simulation_controls(app_root)
make_step_renderer(step7_frankenstein)
inference_controls()
if st.button("⚙️ Run inference by Gradio"):
st.info("Running inference...")
modality = st.session_state["modality_idx"]
image_slice = st.session_state["processed_img"][:, :, st.session_state["z_idx"]]
result = call_gradio_gpu_infer(modality, image_slice)
st.image(result, caption="Predicted Image")
import matplotlib.pyplot as plt
if "output_img" in st.session_state:
output_img = st.session_state["output_img"]
plt.figure()
plt.imshow(output_img, cmap="gray")
plt.grid(False)
plt.savefig(r'seg2med_app\modeloutput.png')
plt.close()
width=400
col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
with col1:
if "contour" in st.session_state:
show_single_slice_image(st.session_state["contour"].squeeze(),title="contour")
with col2:
if "combined_seg" in st.session_state:
show_single_slice_label(st.session_state["combined_seg"].squeeze(),
st.session_state["label_to_color"],
title="combined segs")
with col3:
if st.session_state.get("processed_img") is not None:
print(np.unique(st.session_state["processed_img"]))
show_single_slice_image(st.session_state["processed_img"].squeeze(), title="image prior")
with col4:
if st.session_state.get("output_img") is not None:
st.session_state["output_volume_to_save"] = np.expand_dims(st.session_state["output_img"].T, axis=-1)
# no need to set orientation because the model output should be correct
show_single_slice_image(st.session_state["output_img"], title="inference image", orientation_type='none')
make_step_renderer(step8_frankenstein)
# ========== SAVE ==========
output_folder = os.path.join(app_root, 'output')
os.makedirs(output_folder, exist_ok=True)
col1, col2, col3, col4 = st.columns([1,1,1,1])
with col1:
filename_prior = st.text_input("Filename (.nii.gz)", value="contour.nii.gz", key="filename_contour")
prior_save_path = os.path.join(output_folder, filename_prior)
if st.session_state.get("contour") is not None: # st.button("💾 Save Image Prior") and
img_to_save = nib.Nifti1Image(st.session_state["contour"], st.session_state["orig_affine"])
nib.save(img_to_save, prior_save_path)
if os.path.exists(prior_save_path):
with open(prior_save_path, "rb") as f:
st.download_button(
label="⬇️ Download Contour",
data=f,
file_name=filename_prior,
mime="application/gzip"
)
#st.success(f"Saved to {filename_prior}")
with col2:
filename_output = st.text_input("Filename (.nii.gz)", value="combined_seg.nii.gz", key="filename_combined")
output_save_path = os.path.join(output_folder, filename_output)
if st.session_state.get("combined_seg") is not None : # and st.button("💾 Save Output")
img_to_save = nib.Nifti1Image(st.session_state["combined_seg"], st.session_state["orig_affine"])
nib.save(img_to_save, output_save_path)
if os.path.exists(output_save_path):
with open(output_save_path, "rb") as f:
st.download_button(
label="⬇️ Download Combined Segmentation",
data=f,
file_name=filename_output,
mime="application/gzip"
)
with col3:
filename_prior = st.text_input("Filename (.nii.gz)", value="prior_image.nii.gz", key="filename_prior")
prior_save_path = os.path.join(output_folder, filename_prior)
if st.session_state.get("processed_img") is not None: # st.button("💾 Save Image Prior") and
img_to_save = nib.Nifti1Image(st.session_state["processed_img"], st.session_state["orig_affine"])
nib.save(img_to_save, prior_save_path)
if os.path.exists(prior_save_path):
with open(prior_save_path, "rb") as f:
st.download_button(
label="⬇️ Download Prior Image",
data=f,
file_name=filename_prior,
mime="application/gzip"
)
with col4:
filename_output = st.text_input("Filename (.nii.gz)", value="model_output.nii.gz", key="filename_output")
output_save_path = os.path.join(output_folder, filename_output)
if st.session_state.get("output_volume_to_save") is not None : # and st.button("💾 Save Output")
img_to_save = nib.Nifti1Image(st.session_state["output_volume_to_save"], st.session_state["orig_affine"])
nib.save(img_to_save, output_save_path)
if os.path.exists(output_save_path):
with open(output_save_path, "rb") as f:
st.download_button(
label="⬇️ Download Output Image",
data=f,
file_name=filename_output,
mime="application/gzip"
)
|