Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from core.utils import ( | |
| example_file_path, | |
| _load_volume_from_any, | |
| volume_stats, | |
| browse_axis_fast, | |
| browse_overlay_axis_fast, | |
| segment_volume, | |
| APP_TMP_DIR, | |
| clean_temp, | |
| write_mask_tif, | |
| ) | |
| import urllib.request | |
| import time, threading, tempfile, os | |
| from typing import Union | |
| from gradio import skip | |
| CLEAN_EVERY_SEC = 1800 # every 30 min | |
| CLEAN_AGE_HOURS = 12 # every 12 hours | |
| def _start_cleanup_daemon(): | |
| def _loop(): | |
| while True: | |
| try: | |
| clean_temp(CLEAN_AGE_HOURS) | |
| except Exception as e: | |
| print(f"[cleanup daemon] {e}") | |
| time.sleep(CLEAN_EVERY_SEC) | |
| threading.Thread(target=_loop, daemon=True).start() | |
| _start_cleanup_daemon() | |
| def get_axis_max(volume, axis): | |
| """Get the maximum index of each axis.""" | |
| if volume is None: | |
| return 0 | |
| shape = volume.shape | |
| return shape[{"Z": 0, "Y": 1, "X": 2}[axis]] - 1 | |
| def reset_app(): | |
| """Reset everything to the initial state.""" | |
| return ( | |
| gr.update(value=None), # file_input | |
| None, # volume_state | |
| None, # seg_state | |
| gr.update(visible=False),# group_input | |
| gr.update(visible=False),# segment_btn | |
| gr.update(value=0), gr.update(value=0), gr.update(value=0), | |
| gr.update(value=None), gr.update(value=None), gr.update(value=None), | |
| gr.update(visible=False),# group_seg | |
| gr.update(value=0), gr.update(value=0), gr.update(value=0), | |
| gr.update(value=None), gr.update(value=None), gr.update(value=None) | |
| ) | |
| def segment_api(file_obj: Union[dict, str, bytes]) -> str: | |
| """Segments a 3D TIF/TIFF volume and returns a server path to a compressed TIF mask.""" | |
| volume = _load_volume_from_any(file_obj) | |
| seg = segment_volume(volume) | |
| if seg is None: | |
| raise gr.Error("Segmentation failed") | |
| out_path = write_mask_tif(seg) | |
| return out_path | |
| def run_seg_with_progress(volume, progress=gr.Progress(track_tqdm=True)): | |
| """Surface a progress bar in Gradio while the model runs.""" | |
| if volume is None: | |
| return None | |
| progress(0.1, desc="Preparing model…") | |
| seg = segment_volume(volume) | |
| progress(1.0, desc="Done") | |
| return seg | |
| with gr.Blocks(delete_cache=(1800, 21600)) as demo: | |
| # Expose ONLY the /segment API/MCP tool | |
| gr.api( | |
| segment_api, | |
| api_name="segment", | |
| api_description="Accepts a 3D TIF/TIFF (URL, uploaded file, or raw bytes) and returns a path to the compressed TIF mask." | |
| ) | |
| # -------- UI -------- | |
| gr.Markdown("# 🐭 3D Lungs Segmentation") | |
| gr.Markdown("### ⚠️ Note: the visualization may take some time to render!") | |
| # States | |
| last_url_state = gr.State("") # last processed ?file_url | |
| volume_state = gr.State() | |
| seg_state = gr.State() | |
| norm_state = gr.State() | |
| file_input = gr.File( | |
| file_types=[".tif", ".tiff"], | |
| file_count="single", | |
| label="Upload your 3D TIF or TIFF file" | |
| ) | |
| gr.Examples( | |
| examples=[[example_file_path]], | |
| inputs=[file_input], | |
| label="Try an example!", | |
| examples_per_page=1 | |
| ) | |
| with gr.Group(visible=False) as group_input: | |
| gr.Markdown("### Raw Volume Slices") | |
| with gr.Row(): | |
| z_slider = gr.Slider(0, 0, step=1, label="Z Slice") | |
| y_slider = gr.Slider(0, 0, step=1, label="Y Slice") | |
| x_slider = gr.Slider(0, 0, step=1, label="X Slice") | |
| with gr.Row(): | |
| z_img = gr.Image(label="Z") | |
| y_img = gr.Image(label="Y") | |
| x_img = gr.Image(label="X") | |
| segment_btn = gr.Button("Segment", visible=False) | |
| loading_md = gr.Markdown("⏳ **Segmenting…** This can take a bit.", visible=False) | |
| with gr.Group(visible=False) as group_seg: | |
| gr.Markdown("### Segmentation Overlay Slices") | |
| with gr.Row(): | |
| z_slider_seg = gr.Slider(0, 0, step=1, label="Z Slice (Overlay)") | |
| y_slider_seg = gr.Slider(0, 0, step=1, label="Y Slice (Overlay)") | |
| x_slider_seg = gr.Slider(0, 0, step=1, label="X Slice (Overlay)") | |
| with gr.Row(): | |
| z_img_overlay = gr.Image(label="Z + Mask") | |
| y_img_overlay = gr.Image(label="Y + Mask") | |
| x_img_overlay = gr.Image(label="X + Mask") | |
| reset_btn = gr.Button("Reset") | |
| gr.Markdown("#### 📝 This work is based on the Bachelor Project of Quentin Chappuis 2024; for more information, consult the [repository](https://github.com/qchapp/lungs-segmentation)!") | |
| # -------- Callbacks (hidden from API/MCP) -------- | |
| file_input.change( | |
| fn=lambda f: _load_volume_from_any(f) if f is not None else skip(), | |
| inputs=file_input, | |
| outputs=volume_state, | |
| show_api=False | |
| ).then( | |
| fn=lambda vol: volume_stats(vol) if vol is not None else skip(), | |
| inputs=volume_state, | |
| outputs=norm_state, | |
| show_api=False | |
| ).then( | |
| fn=lambda vol: gr.update(visible=True) if vol is not None else skip(), | |
| inputs=volume_state, | |
| outputs=group_input, | |
| show_api=False | |
| ).then( | |
| fn=lambda vol: gr.update(visible=True) if vol is not None else skip(), | |
| inputs=volume_state, | |
| outputs=segment_btn, | |
| show_api=False | |
| ).then( | |
| fn=lambda vol: ( | |
| gr.update(maximum=get_axis_max(vol, "Z")), | |
| gr.update(maximum=get_axis_max(vol, "Y")), | |
| gr.update(maximum=get_axis_max(vol, "X")), | |
| ) if vol is not None else (skip(), skip(), skip()), | |
| inputs=volume_state, | |
| outputs=[z_slider, y_slider, x_slider], | |
| show_api=False | |
| ).then( | |
| fn=lambda vol, st: ( | |
| browse_axis_fast("Z", 0, vol, st), | |
| browse_axis_fast("Y", 0, vol, st), | |
| browse_axis_fast("X", 0, vol, st), | |
| ) if vol is not None else (skip(), skip(), skip()), | |
| inputs=[volume_state, norm_state], | |
| outputs=[z_img, y_img, x_img], | |
| show_api=False | |
| ) | |
| z_slider.change( | |
| fn=lambda idx, vol, st: browse_axis_fast("Z", idx, vol, st), | |
| inputs=[z_slider, volume_state, norm_state], | |
| outputs=z_img, | |
| show_api=False | |
| ) | |
| y_slider.change( | |
| fn=lambda idx, vol, st: browse_axis_fast("Y", idx, vol, st), | |
| inputs=[y_slider, volume_state, norm_state], | |
| outputs=y_img, | |
| show_api=False | |
| ) | |
| x_slider.change( | |
| fn=lambda idx, vol, st: browse_axis_fast("X", idx, vol, st), | |
| inputs=[x_slider, volume_state, norm_state], | |
| outputs=x_img, | |
| show_api=False | |
| ) | |
| segment_btn.click( | |
| fn=lambda: (gr.update(visible=True), gr.update(interactive=False)), | |
| inputs=[], | |
| outputs=[loading_md, segment_btn], | |
| show_api=False | |
| ).then( | |
| fn=run_seg_with_progress, | |
| inputs=volume_state, | |
| outputs=seg_state, | |
| show_api=False | |
| ).then( | |
| fn=lambda s: gr.update(visible=(s is not None)), | |
| inputs=seg_state, | |
| outputs=group_seg, | |
| show_api=False | |
| ).then( | |
| fn=lambda vol: ( | |
| gr.update(maximum=get_axis_max(vol, "Z")), | |
| gr.update(maximum=get_axis_max(vol, "Y")), | |
| gr.update(maximum=get_axis_max(vol, "X")), | |
| ), | |
| inputs=volume_state, | |
| outputs=[z_slider_seg, y_slider_seg, x_slider_seg], | |
| show_api=False | |
| ).then( | |
| fn=lambda z, y, x, vol, seg, st: ( | |
| browse_overlay_axis_fast("Z", z, vol, seg, st), | |
| browse_overlay_axis_fast("Y", y, vol, seg, st), | |
| browse_overlay_axis_fast("X", x, vol, seg, st), | |
| ), | |
| inputs=[z_slider_seg, y_slider_seg, x_slider_seg, volume_state, seg_state, norm_state], | |
| outputs=[z_img_overlay, y_img_overlay, x_img_overlay], | |
| show_api=False | |
| ).then( | |
| fn=lambda: (gr.update(visible=False), gr.update(interactive=True)), | |
| inputs=[], | |
| outputs=[loading_md, segment_btn], | |
| show_api=False | |
| ) | |
| z_slider_seg.change( | |
| fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("Z", idx, vol, seg, st), | |
| inputs=[z_slider_seg, volume_state, seg_state, norm_state], | |
| outputs=z_img_overlay, | |
| show_api=False | |
| ) | |
| y_slider_seg.change( | |
| fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("Y", idx, vol, seg, st), | |
| inputs=[y_slider_seg, volume_state, seg_state, norm_state], | |
| outputs=y_img_overlay, | |
| show_api=False | |
| ) | |
| x_slider_seg.change( | |
| fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("X", idx, vol, seg, st), | |
| inputs=[x_slider_seg, volume_state, seg_state, norm_state], | |
| outputs=x_img_overlay, | |
| show_api=False | |
| ) | |
| reset_btn.click( | |
| fn=reset_app, | |
| inputs=[], | |
| outputs=[ | |
| file_input, | |
| volume_state, | |
| seg_state, | |
| group_input, | |
| segment_btn, | |
| z_slider, y_slider, x_slider, | |
| z_img, y_img, x_img, | |
| group_seg, | |
| z_slider_seg, y_slider_seg, x_slider_seg, | |
| z_img_overlay, y_img_overlay, x_img_overlay | |
| ], | |
| show_api=False | |
| ) | |
| # -------- URL loader -------- | |
| def load_from_query(prev_url, request: gr.Request): | |
| params = request.query_params | |
| url = params.get("file_url") or "" | |
| # No URL -> no-op | |
| if not url: | |
| return [gr.skip(), gr.skip()] | |
| # 🔧 Short-circuit: same URL as last time -> no-op | |
| if url == prev_url: | |
| return [gr.skip(), gr.skip()] | |
| # Download to CLOSED temp file and programmatically set the File value. | |
| fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR)) | |
| os.close(fd) | |
| try: | |
| urllib.request.urlretrieve(url, tmp_path) | |
| except Exception as e: | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| raise gr.Error(f"Failed to download file_url: {e}") | |
| return [url, gr.update(value=tmp_path)] | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1, max_size=16).launch(mcp_server=True) |