titi commited on
Commit
bdd3d9c
·
1 Parent(s): eff9da0

now supporting mcp server

Browse files
Files changed (4) hide show
  1. README.md +6 -9
  2. app.py +138 -140
  3. core/utils.py +86 -23
  4. requirements.txt +1 -1
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🖥️
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.23.1
8
  app_file: app.py
9
  pinned: false
10
  ---
@@ -12,7 +12,7 @@ pinned: false
12
  A web-based application for automated lung segmentation using deep learning, powered by **Gradio** and **PyTorch**. This tool allows users to upload lung images and obtain segmented outputs efficiently.
13
 
14
  <p align="center">
15
- <img src="images/app.png" height="700">
16
  </p>
17
 
18
  ---
@@ -28,13 +28,10 @@ You can also provide a `.tif` file hosted online using a URL parameter.
28
 
29
  To do so, simply append `?file_url=...` to your app's URL.
30
 
31
- ##### Example (local):
32
- `http://localhost:7860/?file_url=https://zenodo.org/record/8099852/files/lungs_ct.tif`
33
-
34
  ##### Example (hosted on Hugging Face):
35
  `https://huggingface.co/spaces/qchapp/3d-lungs-segmentation/?file_url=https://zenodo.org/record/8099852/files/lungs_ct.tif`
36
 
37
- The application will automatically download the file and load it into the viewer.
38
 
39
  ---
40
 
@@ -55,7 +52,7 @@ Run:
55
  ```sh
56
  python app.py
57
  ```
58
- And go to http://localhost:7860/.
59
 
60
  ---
61
 
@@ -66,7 +63,7 @@ from pathlib import Path
66
  import shutil
67
  from gradio_client import Client, handle_file
68
 
69
- client = Client("https://huggingface.co/spaces/qchapp/3d-lungs-segmentation/")
70
  result_path = client.predict(
71
  file_obj=handle_file("https://zenodo.org/record/8099852/files/lungs_ct.tif?download=1"),
72
  api_name="/segment",
@@ -81,6 +78,6 @@ print("Saved the mask in:", dest.resolve())
81
  ---
82
 
83
  ## About Lungs Segmentation
84
- If you are interesten in the package used for segmentation please check the following [GitHub repository](https://github.com/qchapp/lungs-segmentation)!
85
 
86
  ---
 
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
12
  A web-based application for automated lung segmentation using deep learning, powered by **Gradio** and **PyTorch**. This tool allows users to upload lung images and obtain segmented outputs efficiently.
13
 
14
  <p align="center">
15
+ <img src="https://raw.githubusercontent.com/qchapp/lungs-segmentation-app/refs/heads/master/images/app.png" height="700">
16
  </p>
17
 
18
  ---
 
28
 
29
  To do so, simply append `?file_url=...` to your app's URL.
30
 
 
 
 
31
  ##### Example (hosted on Hugging Face):
32
  `https://huggingface.co/spaces/qchapp/3d-lungs-segmentation/?file_url=https://zenodo.org/record/8099852/files/lungs_ct.tif`
33
 
34
+ The application will automatically download the file and load it into the viewer (the operation can take some time).
35
 
36
  ---
37
 
 
52
  ```sh
53
  python app.py
54
  ```
55
+ And go to the indicated local URL.
56
 
57
  ---
58
 
 
63
  import shutil
64
  from gradio_client import Client, handle_file
65
 
66
+ client = Client("qchapp/3d-lungs-segmentation")
67
  result_path = client.predict(
68
  file_obj=handle_file("https://zenodo.org/record/8099852/files/lungs_ct.tif?download=1"),
69
  api_name="/segment",
 
78
  ---
79
 
80
  ## About Lungs Segmentation
81
+ If you are interested in the package used for segmentation please check the following [GitHub repository](https://github.com/qchapp/lungs-segmentation)!
82
 
83
  ---
app.py CHANGED
@@ -1,13 +1,23 @@
1
  import gradio as gr
2
- from core.utils import *
3
-
 
 
 
 
 
 
 
 
 
4
  import urllib.request
5
- import tempfile
6
- import os, time, threading, atexit
7
- from core.utils import APP_TMP_DIR, clean_temp, write_mask_tif
 
8
 
9
  CLEAN_EVERY_SEC = 1800 # every 30 min
10
- CLEAN_AGE_HOURS = 6 # every 6 hours
11
 
12
  def _start_cleanup_daemon():
13
  def _loop():
@@ -20,8 +30,6 @@ def _start_cleanup_daemon():
20
  threading.Thread(target=_loop, daemon=True).start()
21
 
22
  _start_cleanup_daemon()
23
- atexit.register(lambda: clean_temp(0))
24
-
25
 
26
  def get_axis_max(volume, axis):
27
  """Get the maximum index of each axis."""
@@ -33,67 +41,60 @@ def get_axis_max(volume, axis):
33
  def reset_app():
34
  """Reset everything to the initial state."""
35
  return (
36
- gr.update(value=None),
37
- None,
38
- None,
39
- gr.update(visible=False),
 
40
  gr.update(value=0), gr.update(value=0), gr.update(value=0),
41
  gr.update(value=None), gr.update(value=None), gr.update(value=None),
42
- gr.update(visible=False),
43
  gr.update(value=0), gr.update(value=0), gr.update(value=0),
44
  gr.update(value=None), gr.update(value=None), gr.update(value=None)
45
  )
46
 
47
- def segment_api(file_obj):
48
- """
49
- Accepts a TIF/TIFF via API, returns a TIF mask file path.
50
- """
51
- if not file_obj:
52
- raise gr.Error("No file provided")
53
-
54
- # Read volume (and let load_volume clean the temp upload)
55
- volume = load_volume(file_obj)
56
- seg = segment_volume(volume) # uses your existing model wrapper
57
  if seg is None:
58
  raise gr.Error("Segmentation failed")
59
-
60
- # Write compressed TIF to app temp; return file path
61
  out_path = write_mask_tif(seg)
62
  return out_path
63
 
64
  def run_seg_with_progress(volume, progress=gr.Progress(track_tqdm=True)):
65
- """
66
- Thin wrapper to surface a progress bar in Gradio while the model runs.
67
- """
68
  if volume is None:
69
  return None
70
  progress(0.1, desc="Preparing model…")
71
- seg = segment_volume(volume) # existing function from utils.py
72
  progress(1.0, desc="Done")
73
  return seg
74
 
75
  with gr.Blocks(delete_cache=(1800, 21600)) as demo:
76
- # ---- API (hidden) ----
77
- _api_in = gr.File(file_types=[".tif", ".tiff"], visible=False)
78
- _api_out = gr.File(visible=False)
79
- gr.Button(visible=False).click(
80
- fn=segment_api,
81
- inputs=_api_in,
82
- outputs=_api_out,
83
- api_name="segment"
84
  )
85
 
86
- # ---- UI ----
87
  gr.Markdown("# 🐭 3D Lungs Segmentation")
88
  gr.Markdown("### ⚠️ Note: the visualization may take some time to render!")
89
 
 
 
90
  volume_state = gr.State()
91
  seg_state = gr.State()
92
  norm_state = gr.State()
93
 
94
- file_input = gr.File(file_types=[".tif", ".tiff"], label="Upload your 3D TIF or TIFF file")
 
 
 
 
95
 
96
- # ---- Example loader ----
97
  gr.Examples(
98
  examples=[[example_file_path]],
99
  inputs=[file_input],
@@ -101,7 +102,6 @@ with gr.Blocks(delete_cache=(1800, 21600)) as demo:
101
  examples_per_page=1
102
  )
103
 
104
- # ---- RAW SLICES VIEWER ----
105
  with gr.Group(visible=False) as group_input:
106
  gr.Markdown("### Raw Volume Slices")
107
  with gr.Row():
@@ -114,10 +114,8 @@ with gr.Blocks(delete_cache=(1800, 21600)) as demo:
114
  x_img = gr.Image(label="X")
115
 
116
  segment_btn = gr.Button("Segment", visible=False)
117
-
118
  loading_md = gr.Markdown("⏳ **Segmenting…** This can take a bit.", visible=False)
119
 
120
- # ---- OVERLAY SLICES VIEWER ----
121
  with gr.Group(visible=False) as group_seg:
122
  gr.Markdown("### Segmentation Overlay Slices")
123
  with gr.Row():
@@ -133,61 +131,81 @@ with gr.Blocks(delete_cache=(1800, 21600)) as demo:
133
 
134
  gr.Markdown("#### 📝 This work is based on the Bachelor Project of Quentin Chappuis 2024; for more information, consult the [repository](https://github.com/qchapp/lungs-segmentation)!")
135
 
136
- # ---- CALLBACKS ----
137
-
138
- # A) Load volume
139
  file_input.change(
140
- fn=load_volume,
141
  inputs=file_input,
142
- outputs=volume_state
 
143
  ).then(
144
- fn=volume_stats,
145
  inputs=volume_state,
146
- outputs=norm_state
 
147
  ).then(
148
- fn=lambda vol: gr.update(visible=(vol is not None)),
149
  inputs=volume_state,
150
- outputs=group_input
 
151
  ).then(
152
- fn=lambda vol: gr.update(visible=(vol is not None)),
153
  inputs=volume_state,
154
- outputs=segment_btn
 
155
  ).then(
156
  fn=lambda vol: (
157
  gr.update(maximum=get_axis_max(vol, "Z")),
158
  gr.update(maximum=get_axis_max(vol, "Y")),
159
  gr.update(maximum=get_axis_max(vol, "X")),
160
- ),
161
  inputs=volume_state,
162
- outputs=[z_slider, y_slider, x_slider]
 
163
  ).then(
164
  fn=lambda vol, st: (
165
  browse_axis_fast("Z", 0, vol, st),
166
  browse_axis_fast("Y", 0, vol, st),
167
  browse_axis_fast("X", 0, vol, st),
168
- ),
169
  inputs=[volume_state, norm_state],
170
- outputs=[z_img, y_img, x_img]
 
171
  )
172
 
173
- # B) RAW sliders
174
- z_slider.change(fn=lambda idx, vol, st: browse_axis_fast("Z", idx, vol, st), inputs=[z_slider, volume_state, norm_state], outputs=z_img)
175
- y_slider.change(fn=lambda idx, vol, st: browse_axis_fast("Y", idx, vol, st), inputs=[y_slider, volume_state, norm_state], outputs=y_img)
176
- x_slider.change(fn=lambda idx, vol, st: browse_axis_fast("X", idx, vol, st), inputs=[x_slider, volume_state, norm_state], outputs=x_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- # C) Segment
179
  segment_btn.click(
180
  fn=lambda: (gr.update(visible=True), gr.update(interactive=False)),
181
  inputs=[],
182
- outputs=[loading_md, segment_btn]
 
183
  ).then(
184
- fn=run_seg_with_progress, # <— shows a progress bar
185
  inputs=volume_state,
186
- outputs=seg_state
 
187
  ).then(
188
  fn=lambda s: gr.update(visible=(s is not None)),
189
  inputs=seg_state,
190
- outputs=group_seg
 
191
  ).then(
192
  fn=lambda vol: (
193
  gr.update(maximum=get_axis_max(vol, "Z")),
@@ -195,7 +213,8 @@ with gr.Blocks(delete_cache=(1800, 21600)) as demo:
195
  gr.update(maximum=get_axis_max(vol, "X")),
196
  ),
197
  inputs=volume_state,
198
- outputs=[z_slider_seg, y_slider_seg, x_slider_seg]
 
199
  ).then(
200
  fn=lambda z, y, x, vol, seg, st: (
201
  browse_overlay_axis_fast("Z", z, vol, seg, st),
@@ -203,19 +222,34 @@ with gr.Blocks(delete_cache=(1800, 21600)) as demo:
203
  browse_overlay_axis_fast("X", x, vol, seg, st),
204
  ),
205
  inputs=[z_slider_seg, y_slider_seg, x_slider_seg, volume_state, seg_state, norm_state],
206
- outputs=[z_img_overlay, y_img_overlay, x_img_overlay]
 
207
  ).then(
208
  fn=lambda: (gr.update(visible=False), gr.update(interactive=True)),
209
  inputs=[],
210
- outputs=[loading_md, segment_btn]
 
211
  )
212
 
213
- # D) OVERLAY sliders
214
- z_slider_seg.change(fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("Z", idx, vol, seg, st), inputs=[z_slider_seg, volume_state, seg_state, norm_state], outputs=z_img_overlay)
215
- y_slider_seg.change(fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("Y", idx, vol, seg, st), inputs=[y_slider_seg, volume_state, seg_state, norm_state], outputs=y_img_overlay)
216
- x_slider_seg.change(fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("X", idx, vol, seg, st), inputs=[x_slider_seg, volume_state, seg_state, norm_state], outputs=x_img_overlay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- # E) Reset
219
  reset_btn.click(
220
  fn=reset_app,
221
  inputs=[],
@@ -224,85 +258,49 @@ with gr.Blocks(delete_cache=(1800, 21600)) as demo:
224
  volume_state,
225
  seg_state,
226
  group_input,
 
227
  z_slider, y_slider, x_slider,
228
  z_img, y_img, x_img,
229
  group_seg,
230
  z_slider_seg, y_slider_seg, x_slider_seg,
231
  z_img_overlay, y_img_overlay, x_img_overlay
232
- ]
 
233
  )
234
 
235
- # ---- HANDLE QUERY PARAMETERS ----
 
236
  @demo.load(
237
- outputs=[
238
- file_input,
239
- volume_state,
240
- norm_state,
241
- group_input,
242
- segment_btn,
243
- z_slider, y_slider, x_slider,
244
- z_img, y_img, x_img
245
- ]
246
  )
247
- def load_from_query(request: gr.Request):
248
  params = request.query_params
 
249
 
250
- if "file_url" in params:
251
- try:
252
- # A) Download the file from the URL to a managed temporary path
253
- url = params["file_url"]
254
- fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
255
- os.close(fd)
256
- urllib.request.urlretrieve(url, tmp_path)
257
-
258
- # B) Open the file as a binary object
259
- with open(tmp_path, "rb") as f:
260
- volume = load_volume(f)
261
 
262
- # Remove downloaded temp file now that it's in memory
263
- try:
264
- os.remove(tmp_path)
265
- except Exception as e:
266
- print(f"[load_from_query] couldn't remove {tmp_path}: {e}")
267
 
268
- # C) Return values for all components
269
- stats = volume_stats(volume)
270
- return [
271
- gr.update(value=None),
272
- volume,
273
- stats,
274
- gr.update(visible=True),
275
- gr.update(visible=True),
276
- gr.update(maximum=get_axis_max(volume, "Z")),
277
- gr.update(maximum=get_axis_max(volume, "Y")),
278
- gr.update(maximum=get_axis_max(volume, "X")),
279
- browse_axis_fast("Z", 0, volume, stats),
280
- browse_axis_fast("Y", 0, volume, stats),
281
- browse_axis_fast("X", 0, volume, stats),
282
- ]
283
-
284
- except Exception as e:
285
- print(f"[Error loading file_url] {e}")
286
 
287
- # Fallback if no file_url or failure
288
- return [
289
- None,
290
- None,
291
- (0.0, 1.0),
292
- gr.update(visible=False),
293
- gr.update(visible=False),
294
- gr.update(maximum=0),
295
- gr.update(maximum=0),
296
- gr.update(maximum=0),
297
- None, None, None
298
- ]
299
 
300
 
301
  if __name__ == "__main__":
302
- try:
303
- demo.queue(concurrency_count=1, max_size=16).launch()
304
- except TypeError:
305
- try:
306
- demo.queue(max_size=16).launch()
307
- except TypeError:
308
- demo.queue().launch()
 
1
  import gradio as gr
2
+ from core.utils import (
3
+ example_file_path,
4
+ _load_volume_from_any,
5
+ volume_stats,
6
+ browse_axis_fast,
7
+ browse_overlay_axis_fast,
8
+ segment_volume,
9
+ APP_TMP_DIR,
10
+ clean_temp,
11
+ write_mask_tif,
12
+ )
13
  import urllib.request
14
+ import time, threading, tempfile, os
15
+ from typing import Union
16
+ from gradio import skip
17
+
18
 
19
  CLEAN_EVERY_SEC = 1800 # every 30 min
20
+ CLEAN_AGE_HOURS = 12 # every 12 hours
21
 
22
  def _start_cleanup_daemon():
23
  def _loop():
 
30
  threading.Thread(target=_loop, daemon=True).start()
31
 
32
  _start_cleanup_daemon()
 
 
33
 
34
  def get_axis_max(volume, axis):
35
  """Get the maximum index of each axis."""
 
41
  def reset_app():
42
  """Reset everything to the initial state."""
43
  return (
44
+ gr.update(value=None), # file_input
45
+ None, # volume_state
46
+ None, # seg_state
47
+ gr.update(visible=False),# group_input
48
+ gr.update(visible=False),# segment_btn
49
  gr.update(value=0), gr.update(value=0), gr.update(value=0),
50
  gr.update(value=None), gr.update(value=None), gr.update(value=None),
51
+ gr.update(visible=False),# group_seg
52
  gr.update(value=0), gr.update(value=0), gr.update(value=0),
53
  gr.update(value=None), gr.update(value=None), gr.update(value=None)
54
  )
55
 
56
+ def segment_api(file_obj: Union[dict, str, bytes]) -> str:
57
+ """Segments a 3D TIF/TIFF volume and returns a server path to a compressed TIF mask."""
58
+ volume = _load_volume_from_any(file_obj)
59
+ seg = segment_volume(volume)
 
 
 
 
 
 
60
  if seg is None:
61
  raise gr.Error("Segmentation failed")
 
 
62
  out_path = write_mask_tif(seg)
63
  return out_path
64
 
65
  def run_seg_with_progress(volume, progress=gr.Progress(track_tqdm=True)):
66
+ """Surface a progress bar in Gradio while the model runs."""
 
 
67
  if volume is None:
68
  return None
69
  progress(0.1, desc="Preparing model…")
70
+ seg = segment_volume(volume)
71
  progress(1.0, desc="Done")
72
  return seg
73
 
74
  with gr.Blocks(delete_cache=(1800, 21600)) as demo:
75
+ # Expose ONLY the /segment API/MCP tool
76
+ gr.api(
77
+ segment_api,
78
+ api_name="segment",
79
+ api_description="Accepts a 3D TIF/TIFF (URL, uploaded file, or raw bytes) and returns a path to the compressed TIF mask."
 
 
 
80
  )
81
 
82
+ # -------- UI --------
83
  gr.Markdown("# 🐭 3D Lungs Segmentation")
84
  gr.Markdown("### ⚠️ Note: the visualization may take some time to render!")
85
 
86
+ # States
87
+ last_url_state = gr.State("") # last processed ?file_url
88
  volume_state = gr.State()
89
  seg_state = gr.State()
90
  norm_state = gr.State()
91
 
92
+ file_input = gr.File(
93
+ file_types=[".tif", ".tiff"],
94
+ file_count="single",
95
+ label="Upload your 3D TIF or TIFF file"
96
+ )
97
 
 
98
  gr.Examples(
99
  examples=[[example_file_path]],
100
  inputs=[file_input],
 
102
  examples_per_page=1
103
  )
104
 
 
105
  with gr.Group(visible=False) as group_input:
106
  gr.Markdown("### Raw Volume Slices")
107
  with gr.Row():
 
114
  x_img = gr.Image(label="X")
115
 
116
  segment_btn = gr.Button("Segment", visible=False)
 
117
  loading_md = gr.Markdown("⏳ **Segmenting…** This can take a bit.", visible=False)
118
 
 
119
  with gr.Group(visible=False) as group_seg:
120
  gr.Markdown("### Segmentation Overlay Slices")
121
  with gr.Row():
 
131
 
132
  gr.Markdown("#### 📝 This work is based on the Bachelor Project of Quentin Chappuis 2024; for more information, consult the [repository](https://github.com/qchapp/lungs-segmentation)!")
133
 
134
+ # -------- Callbacks (hidden from API/MCP) --------
 
 
135
  file_input.change(
136
+ fn=lambda f: _load_volume_from_any(f) if f is not None else skip(),
137
  inputs=file_input,
138
+ outputs=volume_state,
139
+ show_api=False
140
  ).then(
141
+ fn=lambda vol: volume_stats(vol) if vol is not None else skip(),
142
  inputs=volume_state,
143
+ outputs=norm_state,
144
+ show_api=False
145
  ).then(
146
+ fn=lambda vol: gr.update(visible=True) if vol is not None else skip(),
147
  inputs=volume_state,
148
+ outputs=group_input,
149
+ show_api=False
150
  ).then(
151
+ fn=lambda vol: gr.update(visible=True) if vol is not None else skip(),
152
  inputs=volume_state,
153
+ outputs=segment_btn,
154
+ show_api=False
155
  ).then(
156
  fn=lambda vol: (
157
  gr.update(maximum=get_axis_max(vol, "Z")),
158
  gr.update(maximum=get_axis_max(vol, "Y")),
159
  gr.update(maximum=get_axis_max(vol, "X")),
160
+ ) if vol is not None else (skip(), skip(), skip()),
161
  inputs=volume_state,
162
+ outputs=[z_slider, y_slider, x_slider],
163
+ show_api=False
164
  ).then(
165
  fn=lambda vol, st: (
166
  browse_axis_fast("Z", 0, vol, st),
167
  browse_axis_fast("Y", 0, vol, st),
168
  browse_axis_fast("X", 0, vol, st),
169
+ ) if vol is not None else (skip(), skip(), skip()),
170
  inputs=[volume_state, norm_state],
171
+ outputs=[z_img, y_img, x_img],
172
+ show_api=False
173
  )
174
 
175
+ z_slider.change(
176
+ fn=lambda idx, vol, st: browse_axis_fast("Z", idx, vol, st),
177
+ inputs=[z_slider, volume_state, norm_state],
178
+ outputs=z_img,
179
+ show_api=False
180
+ )
181
+ y_slider.change(
182
+ fn=lambda idx, vol, st: browse_axis_fast("Y", idx, vol, st),
183
+ inputs=[y_slider, volume_state, norm_state],
184
+ outputs=y_img,
185
+ show_api=False
186
+ )
187
+ x_slider.change(
188
+ fn=lambda idx, vol, st: browse_axis_fast("X", idx, vol, st),
189
+ inputs=[x_slider, volume_state, norm_state],
190
+ outputs=x_img,
191
+ show_api=False
192
+ )
193
 
 
194
  segment_btn.click(
195
  fn=lambda: (gr.update(visible=True), gr.update(interactive=False)),
196
  inputs=[],
197
+ outputs=[loading_md, segment_btn],
198
+ show_api=False
199
  ).then(
200
+ fn=run_seg_with_progress,
201
  inputs=volume_state,
202
+ outputs=seg_state,
203
+ show_api=False
204
  ).then(
205
  fn=lambda s: gr.update(visible=(s is not None)),
206
  inputs=seg_state,
207
+ outputs=group_seg,
208
+ show_api=False
209
  ).then(
210
  fn=lambda vol: (
211
  gr.update(maximum=get_axis_max(vol, "Z")),
 
213
  gr.update(maximum=get_axis_max(vol, "X")),
214
  ),
215
  inputs=volume_state,
216
+ outputs=[z_slider_seg, y_slider_seg, x_slider_seg],
217
+ show_api=False
218
  ).then(
219
  fn=lambda z, y, x, vol, seg, st: (
220
  browse_overlay_axis_fast("Z", z, vol, seg, st),
 
222
  browse_overlay_axis_fast("X", x, vol, seg, st),
223
  ),
224
  inputs=[z_slider_seg, y_slider_seg, x_slider_seg, volume_state, seg_state, norm_state],
225
+ outputs=[z_img_overlay, y_img_overlay, x_img_overlay],
226
+ show_api=False
227
  ).then(
228
  fn=lambda: (gr.update(visible=False), gr.update(interactive=True)),
229
  inputs=[],
230
+ outputs=[loading_md, segment_btn],
231
+ show_api=False
232
  )
233
 
234
+ z_slider_seg.change(
235
+ fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("Z", idx, vol, seg, st),
236
+ inputs=[z_slider_seg, volume_state, seg_state, norm_state],
237
+ outputs=z_img_overlay,
238
+ show_api=False
239
+ )
240
+ y_slider_seg.change(
241
+ fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("Y", idx, vol, seg, st),
242
+ inputs=[y_slider_seg, volume_state, seg_state, norm_state],
243
+ outputs=y_img_overlay,
244
+ show_api=False
245
+ )
246
+ x_slider_seg.change(
247
+ fn=lambda idx, vol, seg, st: browse_overlay_axis_fast("X", idx, vol, seg, st),
248
+ inputs=[x_slider_seg, volume_state, seg_state, norm_state],
249
+ outputs=x_img_overlay,
250
+ show_api=False
251
+ )
252
 
 
253
  reset_btn.click(
254
  fn=reset_app,
255
  inputs=[],
 
258
  volume_state,
259
  seg_state,
260
  group_input,
261
+ segment_btn,
262
  z_slider, y_slider, x_slider,
263
  z_img, y_img, x_img,
264
  group_seg,
265
  z_slider_seg, y_slider_seg, x_slider_seg,
266
  z_img_overlay, y_img_overlay, x_img_overlay
267
+ ],
268
+ show_api=False
269
  )
270
 
271
+
272
+ # -------- URL loader --------
273
  @demo.load(
274
+ inputs=[last_url_state],
275
+ outputs=[last_url_state, file_input], # only these two
276
+ show_api=False
 
 
 
 
 
 
277
  )
278
+ def load_from_query(prev_url, request: gr.Request):
279
  params = request.query_params
280
+ url = params.get("file_url") or ""
281
 
282
+ # No URL -> no-op
283
+ if not url:
284
+ return [gr.skip(), gr.skip()]
 
 
 
 
 
 
 
 
285
 
286
+ # 🔧 Short-circuit: same URL as last time -> no-op
287
+ if url == prev_url:
288
+ return [gr.skip(), gr.skip()]
 
 
289
 
290
+ # Download to CLOSED temp file and programmatically set the File value.
291
+ fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
292
+ os.close(fd)
293
+ try:
294
+ urllib.request.urlretrieve(url, tmp_path)
295
+ except Exception as e:
296
+ try:
297
+ os.remove(tmp_path)
298
+ except Exception:
299
+ pass
300
+ raise gr.Error(f"Failed to download file_url: {e}")
 
 
 
 
 
 
 
301
 
302
+ return [url, gr.update(value=tmp_path)]
 
 
 
 
 
 
 
 
 
 
 
303
 
304
 
305
  if __name__ == "__main__":
306
+ demo.queue(default_concurrency_limit=1, max_size=16).launch(mcp_server=True)
 
 
 
 
 
 
core/utils.py CHANGED
@@ -7,12 +7,24 @@ from PIL import Image
7
  from pathlib import Path
8
  import time, uuid, atexit
9
  from unet_lungs_segmentation import LungsPredict
 
10
 
11
  model = LungsPredict()
12
 
13
  APP_TMP_DIR = Path(tempfile.gettempdir()) / "lungs_seg_tmp"
14
  APP_TMP_DIR.mkdir(parents=True, exist_ok=True)
15
 
 
 
 
 
 
 
 
 
 
 
 
16
  def new_tmp_path(basename: str = "tmp.tif") -> str:
17
  """Return a unique path inside the app temp dir."""
18
  uid = uuid.uuid4().hex[:8]
@@ -39,29 +51,92 @@ def write_mask_tif(mask: np.ndarray) -> str:
39
  tifffile.imwrite(out_path, mask.astype(np.uint8), compression="zlib")
40
  return out_path
41
 
42
- def load_volume(file_obj):
43
- if not file_obj:
44
- return None
45
- path = getattr(file_obj, "name", None) or getattr(file_obj, "path", None) or file_obj
46
  arr = tifffile.imread(path)
47
-
48
  try:
49
  if path and os.path.exists(path):
50
- src = Path(path).resolve()
51
- if src not in PROTECTED_PATHS:
52
- os.remove(src)
53
  except Exception as e:
54
  print(f"[load_volume] couldn't remove temp file {path}: {e}")
55
-
56
  return arr
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def segment_volume(volume):
59
  """Run segmentation on the loaded volume (return shape (Z, Y, X))."""
60
  if volume is None:
61
  return None
62
  return model.segment_lungs(volume)
63
 
64
- # Optimization for faster processing
65
  def volume_stats(volume):
66
  """Return (min, max) as floats for global 8-bit scaling."""
67
  if volume is None:
@@ -103,20 +178,8 @@ def browse_overlay_axis_fast(axis, idx, volume, seg, stats, alpha=0.35):
103
 
104
  raw8 = _to_8bit_stats(raw, mn, mx)
105
  rgb = np.repeat(raw8[..., None], 3, axis=-1)
106
- # color mask in red channel
107
  mask_rgb = np.zeros_like(rgb)
108
  mask_rgb[..., 0] = (mask.astype(np.uint8) * 255)
109
 
110
  blended = rgb.astype(np.float32) * (1 - alpha) + mask_rgb.astype(np.float32) * alpha
111
- return Image.fromarray(blended.astype(np.uint8))
112
-
113
- # Example file
114
- def get_example_file():
115
- url = "https://zenodo.org/record/8099852/files/lungs_ct.tif?download=1"
116
- tmp_path = APP_TMP_DIR / "example_lungs.tif"
117
- if not tmp_path.exists():
118
- urllib.request.urlretrieve(url, tmp_path)
119
- return str(tmp_path)
120
-
121
- example_file_path = get_example_file()
122
- PROTECTED_PATHS = {Path(example_file_path).resolve()}
 
7
  from pathlib import Path
8
  import time, uuid, atexit
9
  from unet_lungs_segmentation import LungsPredict
10
+ import gradio as gr
11
 
12
  model = LungsPredict()
13
 
14
  APP_TMP_DIR = Path(tempfile.gettempdir()) / "lungs_seg_tmp"
15
  APP_TMP_DIR.mkdir(parents=True, exist_ok=True)
16
 
17
+ # ---------- Example file ----------
18
+ def get_example_file():
19
+ url = "https://zenodo.org/record/8099852/files/lungs_ct.tif?download=1"
20
+ tmp_path = APP_TMP_DIR / "example_lungs.tif"
21
+ if not tmp_path.exists():
22
+ urllib.request.urlretrieve(url, tmp_path)
23
+ return str(tmp_path)
24
+
25
+ example_file_path = get_example_file()
26
+ PROTECTED_PATHS = {Path(example_file_path).resolve()}
27
+
28
  def new_tmp_path(basename: str = "tmp.tif") -> str:
29
  """Return a unique path inside the app temp dir."""
30
  uid = uuid.uuid4().hex[:8]
 
51
  tifffile.imwrite(out_path, mask.astype(np.uint8), compression="zlib")
52
  return out_path
53
 
54
+ # ---------- Reading helpers ----------
55
+ def _read_tif_from_path(path: str):
56
+ """Read a tif from a local filesystem path; only auto-delete files in APP_TMP_DIR (not protected)."""
 
57
  arr = tifffile.imread(path)
 
58
  try:
59
  if path and os.path.exists(path):
60
+ rp = Path(path).resolve()
61
+ if (rp not in PROTECTED_PATHS) and (APP_TMP_DIR in rp.parents):
62
+ os.remove(rp)
63
  except Exception as e:
64
  print(f"[load_volume] couldn't remove temp file {path}: {e}")
 
65
  return arr
66
 
67
+ def load_volume(file_obj):
68
+ """
69
+ Backward-compatible wrapper used by older code that passes in a path-like object.
70
+ Prefer _load_volume_from_any() in new code.
71
+ """
72
+ if not file_obj:
73
+ return None
74
+ path = getattr(file_obj, "name", None) or getattr(file_obj, "path", None) or file_obj
75
+ if isinstance(path, (str, os.PathLike)):
76
+ return _read_tif_from_path(str(path))
77
+ # If a dict/FileData slipped through, delegate to the robust path:
78
+ return _load_volume_from_any(file_obj)
79
+
80
+ def _load_volume_from_any(file_obj):
81
+ """
82
+ Normalize different inputs to a real filesystem path and read via _read_tif_from_path.
83
+ Accepts:
84
+ - dict with 'path' or 'url' (Gradio FileData / programmatic)
85
+ - str local path or URL
86
+ - bytes / bytearray
87
+ - file-like object with .read()
88
+ """
89
+ try:
90
+ # Gradio FileData-like dict
91
+ if isinstance(file_obj, dict):
92
+ path = file_obj.get("path") or file_obj.get("url")
93
+ if not path:
94
+ raise gr.Error("Invalid file object (missing 'path' or 'url').")
95
+ if isinstance(path, str) and (path.startswith("http://") or path.startswith("https://")):
96
+ fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
97
+ os.close(fd)
98
+ urllib.request.urlretrieve(path, tmp_path)
99
+ return _read_tif_from_path(tmp_path)
100
+ return _read_tif_from_path(path)
101
+
102
+ # String path or URL
103
+ if isinstance(file_obj, (str, os.PathLike)):
104
+ s = str(file_obj)
105
+ if s.startswith("http://") or s.startswith("https://"):
106
+ fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
107
+ os.close(fd)
108
+ urllib.request.urlretrieve(s, tmp_path)
109
+ return _read_tif_from_path(tmp_path)
110
+ return _read_tif_from_path(s)
111
+
112
+ # Raw bytes
113
+ if isinstance(file_obj, (bytes, bytearray)):
114
+ fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
115
+ os.close(fd)
116
+ with open(tmp_path, "wb") as w:
117
+ w.write(file_obj)
118
+ return _read_tif_from_path(tmp_path)
119
+
120
+ # File-like object
121
+ if hasattr(file_obj, "read"):
122
+ data = file_obj.read()
123
+ fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
124
+ os.close(fd)
125
+ with open(tmp_path, "wb") as w:
126
+ w.write(data)
127
+ return _read_tif_from_path(tmp_path)
128
+
129
+ raise gr.Error(f"Unsupported input type for file_obj: {type(file_obj)}")
130
+ except Exception as e:
131
+ raise gr.Error(f"Failed to read input file: {e}")
132
+
133
+ # ---------- Model + viz ----------
134
  def segment_volume(volume):
135
  """Run segmentation on the loaded volume (return shape (Z, Y, X))."""
136
  if volume is None:
137
  return None
138
  return model.segment_lungs(volume)
139
 
 
140
  def volume_stats(volume):
141
  """Return (min, max) as floats for global 8-bit scaling."""
142
  if volume is None:
 
178
 
179
  raw8 = _to_8bit_stats(raw, mn, mx)
180
  rgb = np.repeat(raw8[..., None], 3, axis=-1)
 
181
  mask_rgb = np.zeros_like(rgb)
182
  mask_rgb[..., 0] = (mask.astype(np.uint8) * 255)
183
 
184
  blended = rgb.astype(np.float32) * (1 - alpha) + mask_rgb.astype(np.float32) * alpha
185
+ return Image.fromarray(blended.astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  unet_lungs_segmentation
2
- gradio==5.25.1
3
  torch==2.6.0
4
  torchvision==0.21.0
 
1
  unet_lungs_segmentation
2
+ gradio[mcp]==5.49.1
3
  torch==2.6.0
4
  torchvision==0.21.0