titi
now supporting mcp server
bdd3d9c
import gradio as gr
from core.utils import (
example_file_path,
_load_volume_from_any,
volume_stats,
browse_axis_fast,
browse_overlay_axis_fast,
segment_volume,
APP_TMP_DIR,
clean_temp,
write_mask_tif,
)
import urllib.request
import time, threading, tempfile, os
from typing import Union
from gradio import skip
CLEAN_EVERY_SEC = 1800 # every 30 min
CLEAN_AGE_HOURS = 12 # every 12 hours
def _start_cleanup_daemon():
def _loop():
while True:
try:
clean_temp(CLEAN_AGE_HOURS)
except Exception as e:
print(f"[cleanup daemon] {e}")
time.sleep(CLEAN_EVERY_SEC)
threading.Thread(target=_loop, daemon=True).start()
_start_cleanup_daemon()
def get_axis_max(volume, axis):
"""Get the maximum index of each axis."""
if volume is None:
return 0
shape = volume.shape
return shape[{"Z": 0, "Y": 1, "X": 2}[axis]] - 1
def reset_app():
"""Reset everything to the initial state."""
return (
gr.update(value=None), # file_input
None, # volume_state
None, # seg_state
gr.update(visible=False),# group_input
gr.update(visible=False),# segment_btn
gr.update(value=0), gr.update(value=0), gr.update(value=0),
gr.update(value=None), gr.update(value=None), gr.update(value=None),
gr.update(visible=False),# group_seg
gr.update(value=0), gr.update(value=0), gr.update(value=0),
gr.update(value=None), gr.update(value=None), gr.update(value=None)
)
def segment_api(file_obj: Union[dict, str, bytes]) -> str:
"""Segments a 3D TIF/TIFF volume and returns a server path to a compressed TIF mask."""
volume = _load_volume_from_any(file_obj)
seg = segment_volume(volume)
if seg is None:
raise gr.Error("Segmentation failed")
out_path = write_mask_tif(seg)
return out_path
def run_seg_with_progress(volume, progress=gr.Progress(track_tqdm=True)):
"""Surface a progress bar in Gradio while the model runs."""
if volume is None:
return None
progress(0.1, desc="Preparing model…")
seg = segment_volume(volume)
progress(1.0, desc="Done")
return seg
with gr.Blocks(delete_cache=(1800, 21600)) as demo:
# Expose ONLY the /segment API/MCP tool
gr.api(
segment_api,
api_name="segment",
api_description="Accepts a 3D TIF/TIFF (URL, uploaded file, or raw bytes) and returns a path to the compressed TIF mask."
)
# -------- UI --------
gr.Markdown("# 🐭 3D Lungs Segmentation")
gr.Markdown("### ⚠️ Note: the visualization may take some time to render!")
# States
last_url_state = gr.State("") # last processed ?file_url
volume_state = gr.State()
seg_state = gr.State()
norm_state = gr.State()
file_input = gr.File(
file_types=[".tif", ".tiff"],
file_count="single",
label="Upload your 3D TIF or TIFF file"
)
gr.Examples(
examples=[[example_file_path]],
inputs=[file_input],
label="Try an example!",
examples_per_page=1
)
with gr.Group(visible=False) as group_input:
gr.Markdown("### Raw Volume Slices")
with gr.Row():
z_slider = gr.Slider(0, 0, step=1, label="Z Slice")
y_slider = gr.Slider(0, 0, step=1, label="Y Slice")
x_slider = gr.Slider(0, 0, step=1, label="X Slice")
with gr.Row():
z_img = gr.Image(label="Z")
y_img = gr.Image(label="Y")
x_img = gr.Image(label="X")
segment_btn = gr.Button("Segment", visible=False)
loading_md = gr.Markdown("⏳ **Segmenting…** This can take a bit.", visible=False)
with gr.Group(visible=False) as group_seg:
gr.Markdown("### Segmentation Overlay Slices")
with gr.Row():
z_slider_seg = gr.Slider(0, 0, step=1, label="Z Slice (Overlay)")
y_slider_seg = gr.Slider(0, 0, step=1, label="Y Slice (Overlay)")
x_slider_seg = gr.Slider(0, 0, step=1, label="X Slice (Overlay)")
with gr.Row():
z_img_overlay = gr.Image(label="Z + Mask")
y_img_overlay = gr.Image(label="Y + Mask")
x_img_overlay = gr.Image(label="X + Mask")
reset_btn = gr.Button("Reset")
gr.Markdown("#### 📝 This work is based on the Bachelor Project of Quentin Chappuis 2024; for more information, consult the [repository](https://github.com/qchapp/lungs-segmentation)!")
# -------- Callbacks (hidden from API/MCP) --------
file_input.change(
fn=lambda f: _load_volume_from_any(f) if f is not None else skip(),
inputs=file_input,
outputs=volume_state,
show_api=False
).then(
fn=lambda vol: volume_stats(vol) if vol is not None else skip(),
inputs=volume_state,
outputs=norm_state,
show_api=False
).then(
fn=lambda vol: gr.update(visible=True) if vol is not None else skip(),
inputs=volume_state,
outputs=group_input,
show_api=False
).then(
fn=lambda vol: gr.update(visible=True) if vol is not None else skip(),
inputs=volume_state,
outputs=segment_btn,
show_api=False
).then(
fn=lambda vol: (
gr.update(maximum=get_axis_max(vol, "Z")),
gr.update(maximum=get_axis_max(vol, "Y")),
gr.update(maximum=get_axis_max(vol, "X")),
) if vol is not None else (skip(), skip(), skip()),
inputs=volume_state,
outputs=[z_slider, y_slider, x_slider],
show_api=False
).then(
fn=lambda vol, st: (
browse_axis_fast("Z", 0, vol, st),
browse_axis_fast("Y", 0, vol, st),
browse_axis_fast("X", 0, vol, st),
) if vol is not None else (skip(), skip(), skip()),
inputs=[volume_state, norm_state],
outputs=[z_img, y_img, x_img],
show_api=False
)
z_slider.change(
fn=lambda idx, vol, st: browse_axis_fast("Z", idx, vol, st),
inputs=[z_slider, volume_state, norm_state],
outputs=z_img,
show_api=False
)
y_slider.change(
fn=lambda idx, vol, st: browse_axis_fast("Y", idx, vol, st),
inputs=[y_slider, volume_state, norm_state],
outputs=y_img,
show_api=False
)
x_slider.change(
fn=lambda idx, vol, st: browse_axis_fast("X", idx, vol, st),
inputs=[x_slider, volume_state, norm_state],
outputs=x_img,
show_api=False
)
segment_btn.click(
fn=lambda: (gr.update(visible=True), gr.update(interactive=False)),
inputs=[],
outputs=[loading_md, segment_btn],
show_api=False
).then(
fn=run_seg_with_progress,
inputs=volume_state,
outputs=seg_state,
show_api=False
).then(
fn=lambda s: gr.update(visible=(s is not None)),
inputs=seg_state,
outputs=group_seg,
show_api=False
).then(
fn=lambda vol: (
gr.update(maximum=get_axis_max(vol, "Z")),
gr.update(maximum=get_axis_max(vol, "Y")),
gr.update(maximum=get_axis_max(vol, "X")),
),
inputs=volume_state,
outputs=[z_slider_seg, y_slider_seg, x_slider_seg],
show_api=False
).then(
fn=lambda z, y, x, vol, seg, st: (
browse_overlay_axis_fast("Z", z, vol, seg, st),
browse_overlay_axis_fast("Y", y, vol, seg, st),
browse_overlay_axis_fast("X", x, vol, seg, st),
),
inputs=[z_slider_seg, y_slider_seg, x_slider_seg, volume_state, seg_state, norm_state],
outputs=[z_img_overlay, y_img_overlay, x_img_overlay],
show_api=False
).then(
fn=lambda: (gr.update(visible=False), gr.update(interactive=True)),
inputs=[],
outputs=[loading_md, segment_btn],
show_api=False
)
z_slider_seg.change(
fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("Z", idx, vol, seg, st),
inputs=[z_slider_seg, volume_state, seg_state, norm_state],
outputs=z_img_overlay,
show_api=False
)
y_slider_seg.change(
fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("Y", idx, vol, seg, st),
inputs=[y_slider_seg, volume_state, seg_state, norm_state],
outputs=y_img_overlay,
show_api=False
)
x_slider_seg.change(
fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("X", idx, vol, seg, st),
inputs=[x_slider_seg, volume_state, seg_state, norm_state],
outputs=x_img_overlay,
show_api=False
)
reset_btn.click(
fn=reset_app,
inputs=[],
outputs=[
file_input,
volume_state,
seg_state,
group_input,
segment_btn,
z_slider, y_slider, x_slider,
z_img, y_img, x_img,
group_seg,
z_slider_seg, y_slider_seg, x_slider_seg,
z_img_overlay, y_img_overlay, x_img_overlay
],
show_api=False
)
# -------- URL loader --------
@demo.load(
inputs=[last_url_state],
outputs=[last_url_state, file_input], # only these two
show_api=False
)
def load_from_query(prev_url, request: gr.Request):
params = request.query_params
url = params.get("file_url") or ""
# No URL -> no-op
if not url:
return [gr.skip(), gr.skip()]
# 🔧 Short-circuit: same URL as last time -> no-op
if url == prev_url:
return [gr.skip(), gr.skip()]
# Download to CLOSED temp file and programmatically set the File value.
fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
os.close(fd)
try:
urllib.request.urlretrieve(url, tmp_path)
except Exception as e:
try:
os.remove(tmp_path)
except Exception:
pass
raise gr.Error(f"Failed to download file_url: {e}")
return [url, gr.update(value=tmp_path)]
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1, max_size=16).launch(mcp_server=True)