Spaces:
Runtime error
Runtime error
| # seg2med_app/app.py | |
| # streamlit run tutorial8_app.py | |
| # F:\yang_Environments\torch\venv\Scripts\activate.ps1 | |
| # streamlit run tutorial8_app.py --server.address=0.0.0.0 --server.port=8501 | |
| # http://129.206.168.125:8501 http://169.254.3.1:8501 | |
| #import sys | |
| #sys.path.append('./seg2med_app') | |
| import os | |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | |
| # seg2med_app/main.py | |
| import os | |
| import streamlit as st | |
| import zipfile | |
| import hashlib | |
| import pandas as pd | |
| import numpy as np | |
| import nibabel as nib | |
| from seg2med_app.simulation.get_labels import get_labels | |
| from seg2med_app.app_utils.image_utils import ( | |
| show_three_planes, | |
| show_label_overlay, | |
| show_three_planes_interactive, | |
| show_single_planes_interactive, | |
| show_label_overlay_single, | |
| generate_color_map, | |
| load_image_canonical, | |
| global_slice_slider, | |
| image_to_base64, | |
| show_single_slice_image, | |
| show_single_slice_label, | |
| ) | |
| from seg2med_app.ui.simulation_and_display import simulation_controls | |
| from seg2med_app.ui.upload_and_prepare import handle_upload, compute_md5 | |
| from dataprocesser.simulation_functions import ( | |
| _merge_seg_tissue, | |
| _create_body_contour_by_tissue_seg, | |
| _create_body_contour | |
| ) | |
| from seg2med_app.simulation.combine_selected_organs import combine_selected_organs | |
| from seg2med_app.ui.inference_controls import inference_controls | |
| from seg2med_app.ui.inference_gradio import call_gradio_gpu_infer | |
| from seg2med_app.frankenstein.frankenstein import frankenstein_control | |
| from seg2med_app.app_utils.titles import * | |
| # ========== CONFIG ========== | |
| app_root = 'seg2med_app' | |
| os.makedirs(os.path.join(app_root, "tmp"), exist_ok=True) | |
| # ========== UI STRUCTURE ========== | |
| st.set_page_config( | |
| page_title="Frankenstein App", | |
| page_icon="🧠", | |
| layout="wide" | |
| ) | |
| st.session_state["app_root"] = app_root | |
| import streamlit as st | |
| from PIL import Image | |
| import os | |
| def reset_app(): | |
| st.session_state.clear() | |
| st.session_state.authenticated = True | |
| st.session_state["authenticated"] = True | |
| st.success("App has been reset. Login information is preserved.") | |
| print("App has been reset. Login information is preserved.") | |
| st.rerun() | |
| image = Image.open(os.path.join(app_root, "Frankenstein0.png")) | |
| image_to_base64(image) | |
| st.title("\U0001F9E0 Frankenstein - multimodal medical image generation") | |
| st.markdown(""" | |
| **Created by**: Zeyu Yang | |
| PhD Student, Computer-assisted Clinical Medicine | |
| University of Heidelberg | |
| 🔗 [GitHub Repository](https://github.com/musetee/frankenstein) | |
| 📄 [Preprint on arXiv](https://arxiv.org/abs/2504.09182) | |
| ✉️ Contact: [Zeyu.Yang@medma.uni-heidelberg.de](mailto:Zeyu.Yang@medma.uni-heidelberg.de) | |
| """) | |
| PASSWORD = "frankenstein" | |
| if "authenticated" not in st.session_state: | |
| st.session_state.authenticated = True # set False to be authenticated | |
| if not st.session_state.authenticated: | |
| st.session_state["app_password"] = st.text_input("Enter access code", type="password") | |
| if st.session_state["app_password"] == PASSWORD: | |
| st.session_state.authenticated = True | |
| st.success("✅ Access granted!") | |
| else: | |
| st.warning("🔒 Please enter the correct access code to continue.") | |
| st.stop() | |
| # ========== SIDEBAR (DATASET LOADER) ========== | |
| st.sidebar.title("\U0001F9EC Dataset Loading") | |
| load_method = st.sidebar.radio("Select load method", ["\U0001F3AE Random sample & manual draw", "\U0001F4C1 Upload segmentation"]) | |
| if st.button("🔄 Reset App"): | |
| reset_app() | |
| Begin = "### 🎨 Begin: Choose a colormap to visualize different tissues" | |
| st.write(Begin) | |
| default_cmap = "PiYG" | |
| cmap_options = [default_cmap, "nipy_spectral", "tab20", "Set3", "Paired", "tab10", "gist_rainbow", "custom"] | |
| selected_cmap = st.selectbox("Label colormap", cmap_options, index=0) | |
| # 如果选择“自定义”,显示文本框供用户输入 | |
| if selected_cmap == "custom": | |
| custom_cmap = st.text_input("please type custom colormap name", value=default_cmap) | |
| selected_cmap = custom_cmap | |
| else: | |
| selected_cmap = selected_cmap | |
| st.session_state.update({"selected_cmap": selected_cmap}) | |
| # ========== select color map for visualization segmentation ============== | |
| if "label_ids" in st.session_state: | |
| st.session_state["label_to_color"] = generate_color_map(st.session_state["label_ids"], cmap=st.session_state["selected_cmap"]) | |
| print('organ label to color: ', list(st.session_state["label_to_color"].items())[:5]) | |
| # ========== MAIN: UPLOAD SEGMENTATION ========== | |
| if load_method == "\U0001F4C1 Upload segmentation": | |
| # ========== FIRST ROW ========== | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| uploaded_file = st.file_uploader("Upload segmentation", type=["zip", "nii.gz", "nii"]) | |
| with col2: | |
| uploaded_tissue = st.file_uploader("Upload tissue segmentation", type=["zip", "nii.gz", "nii"], key="tissue_upload") | |
| with col3: | |
| original_file = st.file_uploader("Upload original image", type=["nii.gz", "nii", "dcm"]) | |
| with col4: | |
| # 设置 body threshold(默认值根据模态设置或用户手动输入) | |
| default_body_threshold = 0 | |
| if "body_threshold" not in st.session_state: | |
| st.session_state["body_threshold"] = default_body_threshold | |
| user_input_threshold = st.number_input( | |
| "Body threshold for contour extraction (used on original image)", | |
| value=st.session_state["body_threshold"], | |
| step=1 | |
| ) | |
| use_custom_threshold = st.checkbox("Use custom body threshold", value=False) | |
| st.session_state["use_custom_threshold"] = use_custom_threshold | |
| visual_options = ["Only Axial Plane", "Three Planes"] | |
| st.session_state["selected_visual"] = st.selectbox("Visualization Type", visual_options, index=0) | |
| if user_input_threshold: | |
| st.session_state["body_threshold"] = user_input_threshold | |
| if user_input_threshold and "orig_img" in st.session_state: | |
| st.session_state["contour"] = _create_body_contour(st.session_state['orig_img'], st.session_state['body_threshold'], body_mask_value=1) | |
| # ========== HASH MANAGEMENT ========== | |
| new_upload_hash = compute_md5(uploaded_file) if uploaded_file else None | |
| cached_upload_hash = st.session_state.get("uploaded_file_hash", None) | |
| new_tissue_hash = compute_md5(uploaded_tissue) if uploaded_tissue else None | |
| cached_tissue_hash = st.session_state.get("uploaded_tissue_hash", None) | |
| new_origin_hash = compute_md5(original_file) if original_file else None | |
| cached_origin_hash = st.session_state.get("uploaded_origin_hash", None) | |
| handle_upload(app_root, | |
| uploaded_file, uploaded_tissue, original_file | |
| ) | |
| # ========== SIMULATION UI (SHARED) ========== | |
| simulation_controls(app_root) | |
| # ========== INFERENCE UI (SHARED) ========== | |
| inference_controls() | |
| # ========== visualize ========== | |
| if "combined_seg" in st.session_state: | |
| z_idx, y_idx, x_idx = global_slice_slider(st.session_state["volume_shape"]) | |
| st.session_state.update({ | |
| "z_idx": z_idx, | |
| "y_idx": y_idx, | |
| "x_idx": x_idx, | |
| }) | |
| if st.session_state["selected_visual"] == "Three Planes": | |
| show_three_planes_interactive(st.session_state["contour"], z_idx, y_idx, x_idx) | |
| show_label_overlay(st.session_state["combined_seg"], z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"]) | |
| else: | |
| show_single_planes_interactive(st.session_state["contour"], z_idx, y_idx, x_idx) | |
| show_label_overlay_single(st.session_state["combined_seg"], z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"]) | |
| if "selected_organs" in st.session_state and len(st.session_state["selected_organs"]) > 0: | |
| multi_seg = combine_selected_organs(uploaded_file) | |
| if st.session_state["selected_visual"] == "Three Planes": | |
| show_label_overlay(multi_seg, z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"]) | |
| else: | |
| show_label_overlay_single(multi_seg, z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"]) | |
| if "orig_img" in st.session_state: | |
| if st.session_state["selected_visual"] == "Three Planes": | |
| show_three_planes_interactive(st.session_state["orig_img"], z_idx, y_idx, x_idx,) | |
| else: | |
| show_single_planes_interactive(st.session_state["orig_img"], z_idx, y_idx, x_idx,) | |
| if st.session_state.get("processed_img") is not None: | |
| st.markdown("🔍 View Simulation Result") | |
| if st.session_state["selected_visual"] == "Three Planes": | |
| show_three_planes_interactive(st.session_state["processed_img"], | |
| st.session_state["z_idx"], | |
| st.session_state["y_idx"], | |
| st.session_state["x_idx"],) | |
| else: | |
| show_single_planes_interactive(st.session_state["processed_img"], | |
| st.session_state["z_idx"], | |
| st.session_state["y_idx"], | |
| st.session_state["x_idx"],) | |
| if st.session_state.get("output_img") is not None: | |
| st.session_state["output_volume_to_save"] = np.expand_dims(st.session_state["output_img"].T, axis=-1) | |
| if st.session_state["selected_visual"] == "Three Planes": | |
| show_three_planes_interactive(np.expand_dims(st.session_state["output_img"], axis=-1),0,0,0,orientation_type='none',) # model output already in correct orientation | |
| else: | |
| show_single_planes_interactive(np.expand_dims(st.session_state["output_img"], axis=-1),0,0,0,orientation_type='none',) # model output already in correct orientation | |
| #st.success(f"Saved to {filename_output}") | |
| # ========== RANDOM DRAW PAGE PLACEHOLDER ========== | |
| elif load_method == "\U0001F3AE Random sample & manual draw": | |
| st.markdown("## 🎮 Frankenstein Interactive creating tool") | |
| frankenstein_control() | |
| make_step_renderer(step5_frankenstein) | |
| simulation_controls(app_root) | |
| make_step_renderer(step7_frankenstein) | |
| inference_controls() | |
| if st.button("⚙️ Run inference by Gradio"): | |
| st.info("Running inference...") | |
| modality = st.session_state["modality_idx"] | |
| image_slice = st.session_state["processed_img"][:, :, st.session_state["z_idx"]] | |
| result = call_gradio_gpu_infer(modality, image_slice) | |
| st.image(result, caption="Predicted Image") | |
| import matplotlib.pyplot as plt | |
| if "output_img" in st.session_state: | |
| output_img = st.session_state["output_img"] | |
| plt.figure() | |
| plt.imshow(output_img, cmap="gray") | |
| plt.grid(False) | |
| plt.savefig(r'seg2med_app\modeloutput.png') | |
| plt.close() | |
| width=400 | |
| col1, col2, col3, col4 = st.columns([1, 1, 1, 1]) | |
| with col1: | |
| if "contour" in st.session_state: | |
| show_single_slice_image(st.session_state["contour"].squeeze(),title="contour") | |
| with col2: | |
| if "combined_seg" in st.session_state: | |
| show_single_slice_label(st.session_state["combined_seg"].squeeze(), | |
| st.session_state["label_to_color"], | |
| title="combined segs") | |
| with col3: | |
| if st.session_state.get("processed_img") is not None: | |
| print(np.unique(st.session_state["processed_img"])) | |
| show_single_slice_image(st.session_state["processed_img"].squeeze(), title="image prior") | |
| with col4: | |
| if st.session_state.get("output_img") is not None: | |
| st.session_state["output_volume_to_save"] = np.expand_dims(st.session_state["output_img"].T, axis=-1) | |
| # no need to set orientation because the model output should be correct | |
| show_single_slice_image(st.session_state["output_img"], title="inference image", orientation_type='none') | |
| make_step_renderer(step8_frankenstein) | |
| # ========== SAVE ========== | |
| output_folder = os.path.join(app_root, 'output') | |
| os.makedirs(output_folder, exist_ok=True) | |
| col1, col2, col3, col4 = st.columns([1,1,1,1]) | |
| with col1: | |
| filename_prior = st.text_input("Filename (.nii.gz)", value="contour.nii.gz", key="filename_contour") | |
| prior_save_path = os.path.join(output_folder, filename_prior) | |
| if st.session_state.get("contour") is not None: # st.button("💾 Save Image Prior") and | |
| img_to_save = nib.Nifti1Image(st.session_state["contour"], st.session_state["orig_affine"]) | |
| nib.save(img_to_save, prior_save_path) | |
| if os.path.exists(prior_save_path): | |
| with open(prior_save_path, "rb") as f: | |
| st.download_button( | |
| label="⬇️ Download Contour", | |
| data=f, | |
| file_name=filename_prior, | |
| mime="application/gzip" | |
| ) | |
| #st.success(f"Saved to {filename_prior}") | |
| with col2: | |
| filename_output = st.text_input("Filename (.nii.gz)", value="combined_seg.nii.gz", key="filename_combined") | |
| output_save_path = os.path.join(output_folder, filename_output) | |
| if st.session_state.get("combined_seg") is not None : # and st.button("💾 Save Output") | |
| img_to_save = nib.Nifti1Image(st.session_state["combined_seg"], st.session_state["orig_affine"]) | |
| nib.save(img_to_save, output_save_path) | |
| if os.path.exists(output_save_path): | |
| with open(output_save_path, "rb") as f: | |
| st.download_button( | |
| label="⬇️ Download Combined Segmentation", | |
| data=f, | |
| file_name=filename_output, | |
| mime="application/gzip" | |
| ) | |
| with col3: | |
| filename_prior = st.text_input("Filename (.nii.gz)", value="prior_image.nii.gz", key="filename_prior") | |
| prior_save_path = os.path.join(output_folder, filename_prior) | |
| if st.session_state.get("processed_img") is not None: # st.button("💾 Save Image Prior") and | |
| img_to_save = nib.Nifti1Image(st.session_state["processed_img"], st.session_state["orig_affine"]) | |
| nib.save(img_to_save, prior_save_path) | |
| if os.path.exists(prior_save_path): | |
| with open(prior_save_path, "rb") as f: | |
| st.download_button( | |
| label="⬇️ Download Prior Image", | |
| data=f, | |
| file_name=filename_prior, | |
| mime="application/gzip" | |
| ) | |
| with col4: | |
| filename_output = st.text_input("Filename (.nii.gz)", value="model_output.nii.gz", key="filename_output") | |
| output_save_path = os.path.join(output_folder, filename_output) | |
| if st.session_state.get("output_volume_to_save") is not None : # and st.button("💾 Save Output") | |
| img_to_save = nib.Nifti1Image(st.session_state["output_volume_to_save"], st.session_state["orig_affine"]) | |
| nib.save(img_to_save, output_save_path) | |
| if os.path.exists(output_save_path): | |
| with open(output_save_path, "rb") as f: | |
| st.download_button( | |
| label="⬇️ Download Output Image", | |
| data=f, | |
| file_name=filename_output, | |
| mime="application/gzip" | |
| ) | |