titi commited on
Commit
0f88329
·
1 Parent(s): 47ee9f3

initial commit

Browse files
Files changed (4) hide show
  1. README.md +26 -14
  2. app.py +150 -0
  3. core/utils.py +70 -0
  4. requirements.txt +4 -0
README.md CHANGED
@@ -1,14 +1,26 @@
1
- ---
2
- title: 3d Lungs Segmentation
3
- emoji: 🔥
4
- colorFrom: gray
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.23.1
8
- app_file: app.py
9
- pinned: false
10
- license: bsd-3-clause
11
- short_description: A web-based application for automated lung segmentation.
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🖥️ Lungs segmentation web application
2
+ A web-based application for automated lung segmentation using deep learning, powered by **Gradio** and **PyTorch**. This tool allows users to upload lung images and obtain segmented outputs efficiently.
3
+
4
+ <p align="center">
5
+ <img src="https://raw.githubusercontent.com/titi1000/lungs-segmentation-app/refs/heads/master/images/app.png" height="700">
6
+ </p>
7
+
8
+ ## Installation
9
+ We recommend performing the installation in a clean Python environment.
10
+
11
+ The code requires `python>=3.10`, as well as `pytorch>=2.0`. Please install Pytorch first and separately following the instructions for your platform on [pytorch.org](https://pytorch.org/get-started/locally/).
12
+
13
+ After that please run the following command:
14
+ ```sh
15
+ pip install -r requirements.txt
16
+ ```
17
+
18
+ ## Usage
19
+ Run:
20
+ ```sh
21
+ python app.py
22
+ ```
23
+ And go to http://localhost:7860/.
24
+
25
+ ## About Lungs Segmentation
26
+ If you are interesten in the package used for segmentation please check the following [GitHub repository](https://github.com/titi1000/lungs-segmentation)!
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core.utils import *
3
+
4
+ def get_axis_max(volume, axis):
5
+ """Get the maximum index of each axis."""
6
+ if volume is None:
7
+ return 0
8
+ shape = volume.shape
9
+ return shape[{"Z": 0, "Y": 1, "X": 2}[axis]] - 1
10
+
11
+ def reset_app():
12
+ """Reset everything to the initial state."""
13
+ return (
14
+ gr.update(value=None),
15
+ None,
16
+ None,
17
+ gr.update(visible=False),
18
+ gr.update(value=0), gr.update(value=0), gr.update(value=0),
19
+ gr.update(value=None), gr.update(value=None), gr.update(value=None),
20
+ gr.update(visible=False),
21
+ gr.update(value=0), gr.update(value=0), gr.update(value=0),
22
+ gr.update(value=None), gr.update(value=None), gr.update(value=None)
23
+ )
24
+
25
+ with gr.Blocks() as demo:
26
+ gr.Markdown("# 🐭 3D Lungs Segmentation")
27
+ gr.Markdown("### ⚠️ Note: the visualization may take some time to render!")
28
+
29
+ volume_state = gr.State()
30
+ seg_state = gr.State()
31
+
32
+ file_input = gr.File(file_types=[".tif", ".tiff"], label="Upload your 3D TIF or TIFF file")
33
+
34
+ # ---- RAW SLICES VIEWER ----
35
+ with gr.Group(visible=False) as group_input:
36
+ gr.Markdown("### Raw Volume Slices")
37
+ with gr.Row():
38
+ z_slider = gr.Slider(0, 0, step=1, label="Z Slice")
39
+ y_slider = gr.Slider(0, 0, step=1, label="Y Slice")
40
+ x_slider = gr.Slider(0, 0, step=1, label="X Slice")
41
+ with gr.Row():
42
+ z_img = gr.Image(label="Z")
43
+ y_img = gr.Image(label="Y")
44
+ x_img = gr.Image(label="X")
45
+
46
+ segment_btn = gr.Button("Segment", visible=False)
47
+
48
+ # ---- OVERLAY SLICES VIEWER ----
49
+ with gr.Group(visible=False) as group_seg:
50
+ gr.Markdown("### Segmentation Overlay Slices")
51
+ with gr.Row():
52
+ z_slider_seg = gr.Slider(0, 0, step=1, label="Z Slice (Overlay)")
53
+ y_slider_seg = gr.Slider(0, 0, step=1, label="Y Slice (Overlay)")
54
+ x_slider_seg = gr.Slider(0, 0, step=1, label="X Slice (Overlay)")
55
+ with gr.Row():
56
+ z_img_overlay = gr.Image(label="Z + Mask")
57
+ y_img_overlay = gr.Image(label="Y + Mask")
58
+ x_img_overlay = gr.Image(label="X + Mask")
59
+
60
+ reset_btn = gr.Button("Reset")
61
+
62
+ # ---- CALLBACKS ----
63
+
64
+ # A) Load volume
65
+ file_input.change(
66
+ fn=load_volume,
67
+ inputs=file_input,
68
+ outputs=volume_state
69
+ ).then(
70
+ fn=lambda vol: gr.update(visible=(vol is not None)),
71
+ inputs=volume_state,
72
+ outputs=group_input
73
+ ).then(
74
+ fn=lambda vol: gr.update(visible=(vol is not None)),
75
+ inputs=volume_state,
76
+ outputs=segment_btn
77
+ ).then(
78
+ fn=lambda vol: (
79
+ gr.update(maximum=get_axis_max(vol, "Z")),
80
+ gr.update(maximum=get_axis_max(vol, "Y")),
81
+ gr.update(maximum=get_axis_max(vol, "X")),
82
+ ),
83
+ inputs=volume_state,
84
+ outputs=[z_slider, y_slider, x_slider]
85
+ ).then(
86
+ fn=lambda vol: (
87
+ browse_axis("Z", 0, vol),
88
+ browse_axis("Y", 0, vol),
89
+ browse_axis("X", 0, vol),
90
+ ),
91
+ inputs=volume_state,
92
+ outputs=[z_img, y_img, x_img]
93
+ )
94
+
95
+ # B) RAW sliders
96
+ z_slider.change(fn=lambda idx, vol: browse_axis("Z", idx, vol), inputs=[z_slider, volume_state], outputs=z_img)
97
+ y_slider.change(fn=lambda idx, vol: browse_axis("Y", idx, vol), inputs=[y_slider, volume_state], outputs=y_img)
98
+ x_slider.change(fn=lambda idx, vol: browse_axis("X", idx, vol), inputs=[x_slider, volume_state], outputs=x_img)
99
+
100
+ # C) Segment
101
+ segment_btn.click(
102
+ fn=segment_volume,
103
+ inputs=volume_state,
104
+ outputs=seg_state
105
+ ).then(
106
+ fn=lambda s: gr.update(visible=(s is not None)),
107
+ inputs=seg_state,
108
+ outputs=group_seg
109
+ ).then(
110
+ fn=lambda vol: (
111
+ gr.update(maximum=get_axis_max(vol, "Z")),
112
+ gr.update(maximum=get_axis_max(vol, "Y")),
113
+ gr.update(maximum=get_axis_max(vol, "X")),
114
+ ),
115
+ inputs=volume_state,
116
+ outputs=[z_slider_seg, y_slider_seg, x_slider_seg]
117
+ ).then(
118
+ fn=lambda z, y, x, vol, seg: (
119
+ browse_overlay_axis("Z", z, vol, seg),
120
+ browse_overlay_axis("Y", y, vol, seg),
121
+ browse_overlay_axis("X", x, vol, seg),
122
+ ),
123
+ inputs=[z_slider_seg, y_slider_seg, x_slider_seg, volume_state, seg_state],
124
+ outputs=[z_img_overlay, y_img_overlay, x_img_overlay]
125
+ )
126
+
127
+ # D) OVERLAY sliders
128
+ z_slider_seg.change(fn=lambda idx, vol, seg: browse_overlay_axis("Z", idx, vol, seg), inputs=[z_slider_seg, volume_state, seg_state], outputs=z_img_overlay)
129
+ y_slider_seg.change(fn=lambda idx, vol, seg: browse_overlay_axis("Y", idx, vol, seg), inputs=[y_slider_seg, volume_state, seg_state], outputs=y_img_overlay)
130
+ x_slider_seg.change(fn=lambda idx, vol, seg: browse_overlay_axis("X", idx, vol, seg), inputs=[x_slider_seg, volume_state, seg_state], outputs=x_img_overlay)
131
+
132
+ # E) Reset
133
+ reset_btn.click(
134
+ fn=reset_app,
135
+ inputs=[],
136
+ outputs=[
137
+ file_input,
138
+ volume_state,
139
+ seg_state,
140
+ group_input,
141
+ z_slider, y_slider, x_slider,
142
+ z_img, y_img, x_img,
143
+ group_seg,
144
+ z_slider_seg, y_slider_seg, x_slider_seg,
145
+ z_img_overlay, y_img_overlay, x_img_overlay
146
+ ]
147
+ )
148
+
149
+ if __name__ == "__main__":
150
+ demo.launch()
core/utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tifffile
3
+ from PIL import Image
4
+ from unet_lungs_segmentation import LungsPredict
5
+
6
+ model = LungsPredict()
7
+
8
+ def _to_8bit(arr):
9
+ """Convert float/int array to 8-bit [0..255]."""
10
+ arr = arr.astype(np.float32)
11
+ mn, mx = arr.min(), arr.max()
12
+ rng = mx - mn
13
+ if rng < 1e-8:
14
+ rng = 1.0
15
+ norm = (arr - mn) / rng
16
+ return (norm * 255).astype(np.uint8)
17
+
18
+ def load_volume(file_obj):
19
+ """Read the uploaded TIF as a NumPy array (Z, Y, X)."""
20
+ if not file_obj:
21
+ return None
22
+ return tifffile.imread(file_obj.name)
23
+
24
+ def segment_volume(volume):
25
+ """Run segmentation on the loaded volume (return shape (Z, Y, X))."""
26
+ if volume is None:
27
+ return None
28
+ return model.segment_lungs(volume)
29
+
30
+ def browse_axis(axis, idx, volume):
31
+ """Return a single raw slice for the given axis."""
32
+ if volume is None:
33
+ return None
34
+
35
+ if axis == "Z":
36
+ slice_ = volume[idx]
37
+ elif axis == "Y":
38
+ slice_ = volume[:, idx, :]
39
+ elif axis == "X":
40
+ slice_ = volume[:, :, idx]
41
+ else:
42
+ return None
43
+
44
+ return Image.fromarray(_to_8bit(slice_))
45
+
46
+ def browse_overlay_axis(axis, idx, volume, seg):
47
+ """Return a single overlay slice for the given axis."""
48
+ if volume is None or seg is None:
49
+ return None
50
+
51
+ if axis == "Z":
52
+ raw = volume[idx]
53
+ mask = seg[idx]
54
+ elif axis == "Y":
55
+ raw = volume[:, idx, :]
56
+ mask = seg[:, idx, :]
57
+ elif axis == "X":
58
+ raw = volume[:, :, idx]
59
+ mask = seg[:, :, idx]
60
+ else:
61
+ return None
62
+
63
+ raw_8bit = _to_8bit(raw)
64
+ raw_rgb = np.stack([raw_8bit] * 3, axis=-1)
65
+ mask_rgb = np.zeros_like(raw_rgb)
66
+ mask_rgb[..., 0] = (mask * 255).astype(np.uint8)
67
+
68
+ alpha = 0.3
69
+ blended = (1 - alpha) * raw_rgb + alpha * mask_rgb
70
+ return Image.fromarray(blended.astype(np.uint8))
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ unet_lungs_segmentation==1.0.6
2
+ gradio==4.44.1
3
+ torch==2.6.0
4
+ torchvision==0.21.0