File size: 15,345 Bytes
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48355e9
 
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4cade8
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76bf0bf
fd601de
880f41d
 
f1e0db9
fd601de
 
9fbb587
 
 
 
 
 
 
 
 
 
fd601de
 
 
 
659a112
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdfc683
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48355e9
fd601de
48355e9
fd601de
48355e9
fd601de
 
 
 
 
 
 
 
 
 
48355e9
fd601de
 
 
48355e9
 
 
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48355e9
 
 
 
 
 
fd601de
 
 
48355e9
 
 
 
 
fd601de
48355e9
 
 
 
 
fd601de
 
48355e9
 
 
 
 
 
 
 
 
 
 
fd601de
 
48355e9
 
 
 
fd601de
 
 
 
 
 
 
 
 
 
 
 
073df2e
 
3e96d2e
 
 
 
 
 
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
# seg2med_app/app.py
# streamlit run tutorial8_app.py 
# F:\yang_Environments\torch\venv\Scripts\activate.ps1
# streamlit run tutorial8_app.py --server.address=0.0.0.0 --server.port=8501
# http://129.206.168.125:8501 http://169.254.3.1:8501
#import sys
#sys.path.append('./seg2med_app')
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# seg2med_app/main.py

import os
import streamlit as st
import zipfile
import hashlib
import pandas as pd
import numpy as np
import nibabel as nib
from seg2med_app.simulation.get_labels import get_labels
from seg2med_app.app_utils.image_utils import (
    show_three_planes,
    show_label_overlay,
    show_three_planes_interactive,
    show_single_planes_interactive,
    show_label_overlay_single,
    generate_color_map,
    load_image_canonical,
    global_slice_slider,
    image_to_base64,
    show_single_slice_image,
    show_single_slice_label,
)
from seg2med_app.ui.simulation_and_display import simulation_controls
from seg2med_app.ui.upload_and_prepare import handle_upload, compute_md5
from dataprocesser.simulation_functions import (
    _merge_seg_tissue,
    _create_body_contour_by_tissue_seg,
    _create_body_contour
)

from seg2med_app.simulation.combine_selected_organs import combine_selected_organs
from seg2med_app.ui.inference_controls import inference_controls
from seg2med_app.ui.inference_gradio import call_gradio_gpu_infer
from seg2med_app.frankenstein.frankenstein import frankenstein_control
from seg2med_app.app_utils.titles import *
# ========== CONFIG ==========
app_root = 'seg2med_app'
os.makedirs(os.path.join(app_root, "tmp"), exist_ok=True)

# ========== UI STRUCTURE ==========
st.set_page_config(
    page_title="Frankenstein App",
    page_icon="🧠",
    layout="wide"
)


st.session_state["app_root"] = app_root

import streamlit as st
from PIL import Image
import os

def reset_app():
    st.session_state.clear()
    st.session_state.authenticated = True
    st.session_state["authenticated"] = True
    st.success("App has been reset. Login information is preserved.")
    print("App has been reset. Login information is preserved.")
    st.rerun()
    
image = Image.open(os.path.join(app_root, "Frankenstein0.png"))
image_to_base64(image)


st.title("\U0001F9E0 Frankenstein - multimodal medical image generation")
st.markdown("""
**Created by**: Zeyu Yang  
PhD Student, Computer-assisted Clinical Medicine  
University of Heidelberg  

🔗 [GitHub Repository](https://github.com/musetee/frankenstein)  
📄 [Preprint on arXiv](https://arxiv.org/abs/2504.09182)  
✉️ Contact: [Zeyu.Yang@medma.uni-heidelberg.de](mailto:Zeyu.Yang@medma.uni-heidelberg.de)
""")


PASSWORD = "frankenstein"

if "authenticated" not in st.session_state:
    st.session_state.authenticated = True # set False to be authenticated

if not st.session_state.authenticated:
    st.session_state["app_password"] = st.text_input("Enter access code", type="password")
    if st.session_state["app_password"] == PASSWORD:
        st.session_state.authenticated = True
        st.success("✅ Access granted!")
    else:
        st.warning("🔒 Please enter the correct access code to continue.")
        st.stop()

# ========== SIDEBAR (DATASET LOADER) ==========
st.sidebar.title("\U0001F9EC Dataset Loading")
load_method = st.sidebar.radio("Select load method", ["\U0001F3AE Random sample & manual draw", "\U0001F4C1 Upload segmentation"])

if st.button("🔄 Reset App"):
    reset_app()

Begin = "### 🎨 Begin: Choose a colormap to visualize different tissues"

st.write(Begin)
default_cmap = "PiYG"
cmap_options = [default_cmap, "nipy_spectral", "tab20", "Set3", "Paired", "tab10", "gist_rainbow", "custom"]
selected_cmap = st.selectbox("Label colormap", cmap_options, index=0)

# 如果选择“自定义”,显示文本框供用户输入
if selected_cmap == "custom":
    custom_cmap = st.text_input("please type custom colormap name", value=default_cmap)
    selected_cmap = custom_cmap
else:
    selected_cmap = selected_cmap
    
st.session_state.update({"selected_cmap": selected_cmap})

# ========== select color map for visualization segmentation ==============
if "label_ids" in st.session_state:
    st.session_state["label_to_color"] = generate_color_map(st.session_state["label_ids"], cmap=st.session_state["selected_cmap"])
    print('organ label to color: ', list(st.session_state["label_to_color"].items())[:5])

# ========== MAIN: UPLOAD SEGMENTATION ==========
if load_method == "\U0001F4C1 Upload segmentation":
    # ========== FIRST ROW ==========
    col1, col2, col3, col4 = st.columns(4)
    with col1:
        uploaded_file = st.file_uploader("Upload segmentation", type=["zip", "nii.gz", "nii"])
    with col2:
        uploaded_tissue = st.file_uploader("Upload tissue segmentation", type=["zip", "nii.gz", "nii"], key="tissue_upload")
    with col3:
        original_file = st.file_uploader("Upload original image", type=["nii.gz", "nii", "dcm"])
    with col4:
        # 设置 body threshold(默认值根据模态设置或用户手动输入)
        default_body_threshold = 0
        if "body_threshold" not in st.session_state:
            st.session_state["body_threshold"] = default_body_threshold
        user_input_threshold = st.number_input(
            "Body threshold for contour extraction (used on original image)",
            value=st.session_state["body_threshold"],
            step=1
        )
        
        use_custom_threshold = st.checkbox("Use custom body threshold", value=False)
        st.session_state["use_custom_threshold"] = use_custom_threshold
        
        visual_options = ["Only Axial Plane", "Three Planes"]
        st.session_state["selected_visual"] = st.selectbox("Visualization Type", visual_options, index=0)
        
        if user_input_threshold:
            st.session_state["body_threshold"] = user_input_threshold
        if user_input_threshold and "orig_img" in st.session_state:
            st.session_state["contour"] = _create_body_contour(st.session_state['orig_img'], st.session_state['body_threshold'], body_mask_value=1)
            
    # ========== HASH MANAGEMENT ==========
    new_upload_hash = compute_md5(uploaded_file) if uploaded_file else None
    cached_upload_hash = st.session_state.get("uploaded_file_hash", None)
    new_tissue_hash = compute_md5(uploaded_tissue) if uploaded_tissue else None
    cached_tissue_hash = st.session_state.get("uploaded_tissue_hash", None)
    new_origin_hash = compute_md5(original_file) if original_file else None
    cached_origin_hash = st.session_state.get("uploaded_origin_hash", None)

    handle_upload(app_root, 
            uploaded_file, uploaded_tissue, original_file
        )

    # ========== SIMULATION UI (SHARED) ==========
    simulation_controls(app_root)
    
    # ========== INFERENCE UI (SHARED) ==========
    inference_controls()
    

    # ========== visualize ==========
    if "combined_seg" in st.session_state:
        z_idx, y_idx, x_idx = global_slice_slider(st.session_state["volume_shape"])
        st.session_state.update({
            "z_idx": z_idx,
            "y_idx": y_idx,
            "x_idx": x_idx,
        })
        if st.session_state["selected_visual"] == "Three Planes":
            show_three_planes_interactive(st.session_state["contour"], z_idx, y_idx, x_idx)
            show_label_overlay(st.session_state["combined_seg"], z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"])
        else:
            show_single_planes_interactive(st.session_state["contour"], z_idx, y_idx, x_idx)
            show_label_overlay_single(st.session_state["combined_seg"], z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"])
        
    if "selected_organs" in st.session_state and len(st.session_state["selected_organs"]) > 0:
        multi_seg = combine_selected_organs(uploaded_file)
        if st.session_state["selected_visual"] == "Three Planes":
            show_label_overlay(multi_seg, z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"])
        else:
            show_label_overlay_single(multi_seg, z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"])
            
    if "orig_img" in st.session_state:
        if st.session_state["selected_visual"] == "Three Planes":
            show_three_planes_interactive(st.session_state["orig_img"], z_idx, y_idx, x_idx,)
        else:
            show_single_planes_interactive(st.session_state["orig_img"], z_idx, y_idx, x_idx,)
            
    if st.session_state.get("processed_img") is not None:
        st.markdown("🔍 View Simulation Result")
        if st.session_state["selected_visual"] == "Three Planes":
            show_three_planes_interactive(st.session_state["processed_img"],
                                            st.session_state["z_idx"],
                                            st.session_state["y_idx"],
                                            st.session_state["x_idx"],)
        else:
            show_single_planes_interactive(st.session_state["processed_img"],
                                            st.session_state["z_idx"],
                                            st.session_state["y_idx"],
                                            st.session_state["x_idx"],)
            
    if st.session_state.get("output_img") is not None:
        st.session_state["output_volume_to_save"] = np.expand_dims(st.session_state["output_img"].T, axis=-1)
        if st.session_state["selected_visual"] == "Three Planes":
            show_three_planes_interactive(np.expand_dims(st.session_state["output_img"], axis=-1),0,0,0,orientation_type='none',) # model output already in correct orientation
        else:
            show_single_planes_interactive(np.expand_dims(st.session_state["output_img"], axis=-1),0,0,0,orientation_type='none',) # model output already in correct orientation
                #st.success(f"Saved to {filename_output}")
        
# ========== RANDOM DRAW PAGE PLACEHOLDER ==========
elif load_method == "\U0001F3AE Random sample & manual draw":
    st.markdown("## 🎮 Frankenstein Interactive creating tool")
    frankenstein_control()
    
    
    make_step_renderer(step5_frankenstein)
    simulation_controls(app_root)
    
    make_step_renderer(step7_frankenstein)
    inference_controls()
    if st.button("⚙️ Run inference by Gradio"):
        st.info("Running inference...")
        modality = st.session_state["modality_idx"]
        image_slice = st.session_state["processed_img"][:, :, st.session_state["z_idx"]]
        result = call_gradio_gpu_infer(modality, image_slice)
        st.image(result, caption="Predicted Image")
        
    import matplotlib.pyplot as plt
    if "output_img" in st.session_state:
        output_img = st.session_state["output_img"]
        
        plt.figure()
        plt.imshow(output_img, cmap="gray")
        plt.grid(False)
        plt.savefig(r'seg2med_app\modeloutput.png')
        plt.close()
        width=400
    col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
    with col1:
        if "contour" in st.session_state:
            show_single_slice_image(st.session_state["contour"].squeeze(),title="contour")
    with col2:        
        if "combined_seg" in st.session_state:
            show_single_slice_label(st.session_state["combined_seg"].squeeze(), 
                                    st.session_state["label_to_color"], 
                                    title="combined segs")
    with col3:
        if st.session_state.get("processed_img") is not None:
            print(np.unique(st.session_state["processed_img"]))
            show_single_slice_image(st.session_state["processed_img"].squeeze(), title="image prior")

    with col4:
        if st.session_state.get("output_img") is not None:
            st.session_state["output_volume_to_save"] = np.expand_dims(st.session_state["output_img"].T, axis=-1)
            # no need to set orientation because the model output should be correct
            show_single_slice_image(st.session_state["output_img"], title="inference image", orientation_type='none')

make_step_renderer(step8_frankenstein)

# ========== SAVE ==========
output_folder = os.path.join(app_root, 'output')
os.makedirs(output_folder, exist_ok=True)
col1, col2, col3, col4 = st.columns([1,1,1,1])
with col1:
    filename_prior = st.text_input("Filename (.nii.gz)", value="contour.nii.gz", key="filename_contour")
    prior_save_path = os.path.join(output_folder, filename_prior)

    if st.session_state.get("contour") is not None: # st.button("💾 Save Image Prior") and 
        img_to_save = nib.Nifti1Image(st.session_state["contour"], st.session_state["orig_affine"])
        nib.save(img_to_save, prior_save_path)
    if os.path.exists(prior_save_path):
        with open(prior_save_path, "rb") as f:
            st.download_button(
                label="⬇️ Download Contour",
                data=f,
                file_name=filename_prior,
                mime="application/gzip"
            )
            #st.success(f"Saved to {filename_prior}")
with col2:
    filename_output = st.text_input("Filename (.nii.gz)", value="combined_seg.nii.gz", key="filename_combined")
    output_save_path = os.path.join(output_folder, filename_output)
    if st.session_state.get("combined_seg") is not None : # and st.button("💾 Save Output") 
        img_to_save = nib.Nifti1Image(st.session_state["combined_seg"], st.session_state["orig_affine"])
        nib.save(img_to_save, output_save_path)
    if os.path.exists(output_save_path):
        with open(output_save_path, "rb") as f:
            st.download_button(
                label="⬇️ Download Combined Segmentation",
                data=f,
                file_name=filename_output,
                mime="application/gzip"
            )
            
with col3:
    filename_prior = st.text_input("Filename (.nii.gz)", value="prior_image.nii.gz", key="filename_prior")
    prior_save_path = os.path.join(output_folder, filename_prior)

    if st.session_state.get("processed_img") is not None: # st.button("💾 Save Image Prior") and 
        img_to_save = nib.Nifti1Image(st.session_state["processed_img"], st.session_state["orig_affine"])
        nib.save(img_to_save, prior_save_path)
    if os.path.exists(prior_save_path):
        with open(prior_save_path, "rb") as f:
            st.download_button(
                label="⬇️ Download Prior Image",
                data=f,
                file_name=filename_prior,
                mime="application/gzip"
            )

with col4:
    filename_output = st.text_input("Filename (.nii.gz)", value="model_output.nii.gz", key="filename_output")
    output_save_path = os.path.join(output_folder, filename_output)
    if st.session_state.get("output_volume_to_save") is not None : # and st.button("💾 Save Output") 
        img_to_save = nib.Nifti1Image(st.session_state["output_volume_to_save"], st.session_state["orig_affine"])
        nib.save(img_to_save, output_save_path)
    if os.path.exists(output_save_path):
        with open(output_save_path, "rb") as f:
            st.download_button(
                label="⬇️ Download Output Image",
                data=f,
                file_name=filename_output,
                mime="application/gzip"
            )