Amodit commited on
Commit
d791fee
·
verified ·
1 Parent(s): f172dad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -214
app.py CHANGED
@@ -1,214 +1,217 @@
1
- import gradio as gr
2
- import torch
3
- import nibabel as nib
4
- import numpy as np
5
- from monai.networks.nets import SwinUNETR
6
- from monai.inferers import sliding_window_inference
7
- from monai.transforms import (
8
- Compose, LoadImage, EnsureChannelFirst, Orientation,
9
- Spacing, ScaleIntensityRange, CropForeground, Resize, EnsureType
10
- )
11
- import tempfile
12
- import matplotlib.pyplot as plt
13
- from io import BytesIO
14
- from PIL import Image
15
- import os
16
-
17
- # ============================================================================
18
- # MODEL INITIALIZATION
19
- # ============================================================================
20
- print(" Initializing model...")
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- print(f" Using device: {device}")
23
-
24
- model = SwinUNETR(
25
- in_channels=1,
26
- out_channels=2,
27
- patch_size=(2, 2, 2),
28
- depths=(2, 2, 2, 2),
29
- num_heads=(3, 6, 12, 24),
30
- window_size=7,
31
- feature_size=48,
32
- norm_name="instance",
33
- use_checkpoint=True,
34
- spatial_dims=3,
35
- ).to(device)
36
-
37
- # Load trained weights
38
- model_path = "best_metric_model.pth"
39
- if os.path.exists(model_path):
40
- model.load_state_dict(torch.load(model_path, map_location=device))
41
- print(" Model loaded successfully!")
42
- else:
43
- print(" Model file not found! Please upload best_metric_model.pth")
44
-
45
- model.eval()
46
-
47
- # ============================================================================
48
- # PREPROCESSING PIPELINE
49
- # ============================================================================
50
- transforms = Compose([
51
- LoadImage(image_only=True),
52
- EnsureChannelFirst(),
53
- Orientation(axcodes="RAS"),
54
- Spacing(pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
55
- ScaleIntensityRange(a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
56
- CropForeground(source_key="image"),
57
- Resize(spatial_size=(128, 128, 128), mode="trilinear"),
58
- EnsureType(dtype=torch.float32),
59
- ])
60
-
61
- # ============================================================================
62
- # SEGMENTATION FUNCTION
63
- # ============================================================================
64
- def segment_liver(nifti_file, show_slice):
65
- """
66
- Segment liver from uploaded NIfTI file
67
-
68
- Args:
69
- nifti_file: Uploaded .nii.gz file
70
- show_slice: Which slice to visualize (0-127)
71
-
72
- Returns:
73
- Visualization image and downloadable segmentation file
74
- """
75
- if nifti_file is None:
76
- return None, None
77
-
78
- try:
79
- print(f" Processing file: {nifti_file.name}")
80
-
81
- # Preprocess
82
- volume = transforms(nifti_file.name)
83
- volume_input = volume.unsqueeze(0).to(device)
84
-
85
- # Inference with sliding window
86
- print(" Running inference...")
87
- with torch.no_grad():
88
- output = sliding_window_inference(
89
- volume_input,
90
- roi_size=(128, 128, 64),
91
- sw_batch_size=2,
92
- predictor=model,
93
- overlap=0.75
94
- )
95
- prediction = torch.argmax(output, dim=1, keepdim=True)
96
-
97
- # Convert to numpy
98
- vol_np = volume.cpu().numpy()
99
- pred_np = prediction.cpu().numpy()[0, 0]
100
-
101
- # Normalize volume for display
102
- vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
103
-
104
- # Clamp slice index
105
- max_slice = vol_np.shape[2] - 1
106
- slice_idx = min(max(0, show_slice), max_slice)
107
-
108
- # Create visualization (3 views)
109
- fig, axes = plt.subplots(1, 3, figsize=(18, 6))
110
-
111
- # CT scan
112
- axes[0].imshow(vol_display[:, :, slice_idx].T, cmap='gray', origin='lower')
113
- axes[0].set_title(f'CT Scan (Slice {slice_idx}/{max_slice})', fontsize=14)
114
- axes[0].axis('off')
115
-
116
- # Segmentation only
117
- axes[1].imshow(pred_np[:, :, slice_idx].T, cmap='Reds', origin='lower')
118
- axes[1].set_title('Liver Segmentation', fontsize=14)
119
- axes[1].axis('off')
120
-
121
- # Overlay
122
- axes[2].imshow(vol_display[:, :, slice_idx].T, cmap='gray', origin='lower')
123
- axes[2].imshow(pred_np[:, :, slice_idx].T, cmap='Greens', alpha=0.5, origin='lower')
124
- axes[2].set_title('Overlay (Green = Liver)', fontsize=14)
125
- axes[2].axis('off')
126
-
127
- plt.tight_layout()
128
-
129
- # Convert plot to image
130
- buf = BytesIO()
131
- plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
132
- buf.seek(0)
133
- result_image = Image.open(buf)
134
- plt.close()
135
-
136
- # Save segmentation as NIfTI
137
- output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz')
138
-
139
- # Load original NIfTI to preserve metadata
140
- original_nii = nib.load(nifti_file.name)
141
-
142
- # Create segmentation NIfTI with original affine
143
- pred_nifti = nib.Nifti1Image(pred_np.astype(np.uint8), affine=original_nii.affine)
144
- nib.save(pred_nifti, output_file.name)
145
-
146
- print("Segmentation complete!")
147
-
148
- return result_image, output_file.name
149
-
150
- except Exception as e:
151
- print(f" Error: {str(e)}")
152
- return None, None
153
-
154
- # ============================================================================
155
- # GRADIO INTERFACE
156
- # ============================================================================
157
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
158
- gr.Markdown(
159
- """
160
- # AI-Powered Liver Segmentation
161
-
162
- Upload a liver CT scan in NIfTI format to get automatic liver segmentation using deep learning.
163
-
164
- **Model**: SwinUNETR (Vision Transformer) trained on Medical Segmentation Decathlon
165
- **Performance**: Mean Dice Score = **95.5% ± 2.0%** on test set
166
- **Input**: NIfTI files (.nii or .nii.gz)
167
- """
168
- )
169
-
170
- with gr.Row():
171
- with gr.Column():
172
- file_input = gr.File(
173
- label=" Upload CT Scan (NIfTI format)",
174
- file_types=[".nii", ".nii.gz"]
175
- )
176
- slice_slider = gr.Slider(
177
- minimum=0,
178
- maximum=127,
179
- value=64,
180
- step=1,
181
- label="🔍 Select Slice to Visualize"
182
- )
183
- segment_btn = gr.Button("🚀 Segment Liver", variant="primary", size="lg")
184
-
185
- with gr.Column():
186
- output_image = gr.Image(label=" Segmentation Result")
187
- output_file = gr.File(label=" Download Segmentation (NIfTI)")
188
-
189
- gr.Markdown(
190
- """
191
- ---
192
- ### How to Use:
193
- 1. Upload a liver CT scan in NIfTI format (.nii.gz)
194
- 2. Adjust the slice slider to view different cross-sections
195
- 3. Click "Segment Liver" to run the AI model
196
- 4. Download the segmentation mask for further analysis
197
-
198
- ### ⚠️ Note:
199
- - This is a research prototype, not for clinical use
200
- - Processing time: ~10-30 seconds depending on GPU availability
201
- - Best results with contrast-enhanced CT scans
202
- """
203
- )
204
-
205
- # Connect button to function
206
- segment_btn.click(
207
- fn=segment_liver,
208
- inputs=[file_input, slice_slider],
209
- outputs=[output_image, output_file]
210
- )
211
-
212
- # Launch the app
213
- if __name__ == "__main__":
214
- demo.launch()
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import nibabel as nib
4
+ import numpy as np
5
+ from monai.networks.nets import SwinUNETR
6
+ from monai.inferers import sliding_window_inference
7
+ from monai.transforms import (
8
+ Compose, LoadImage, EnsureChannelFirst, Orientation,
9
+ Spacing, ScaleIntensityRange, CropForeground, Resize, EnsureType
10
+ )
11
+ import tempfile
12
+ import matplotlib.pyplot as plt
13
+ from io import BytesIO
14
+ from PIL import Image
15
+ import os
16
+
17
+ # ============================================================================
18
+ # MODEL INITIALIZATION
19
+ # ============================================================================
20
+ print(" Initializing model...")
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ print(f" Using device: {device}")
23
+
24
+ # Initialize SwinUNETR with EXACT same config as your training
25
+ model = SwinUNETR(
26
+ in_channels=1,
27
+ out_channels=2,
28
+ depths=(2, 2, 2, 2),
29
+ num_heads=(3, 6, 12, 24),
30
+ feature_size=48,
31
+ norm_name="instance",
32
+ drop_rate=0.0,
33
+ attn_drop_rate=0.0,
34
+ dropout_path_rate=0.0,
35
+ normalize=True,
36
+ use_checkpoint=True,
37
+ spatial_dims=3,
38
+ ).to(device)
39
+
40
+ # Load trained weights
41
+ model_path = "best_metric_model.pth"
42
+ if os.path.exists(model_path):
43
+ model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
44
+ print(" Model loaded successfully!")
45
+ else:
46
+ print(" Model file not found! Please upload best_metric_model.pth")
47
+
48
+ model.eval()
49
+
50
+ # ============================================================================
51
+ # PREPROCESSING PIPELINE
52
+ # ============================================================================
53
+ transforms = Compose([
54
+ LoadImage(image_only=True),
55
+ EnsureChannelFirst(),
56
+ Orientation(axcodes="RAS"),
57
+ Spacing(pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
58
+ ScaleIntensityRange(a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
59
+ CropForeground(source_key="image"),
60
+ Resize(spatial_size=(128, 128, 128), mode="trilinear"),
61
+ EnsureType(dtype=torch.float32),
62
+ ])
63
+
64
+ # ============================================================================
65
+ # SEGMENTATION FUNCTION
66
+ # ============================================================================
67
+ def segment_liver(nifti_file, show_slice):
68
+ """
69
+ Segment liver from uploaded NIfTI file
70
+
71
+ Args:
72
+ nifti_file: Uploaded .nii.gz file
73
+ show_slice: Which slice to visualize (0-127)
74
+
75
+ Returns:
76
+ Visualization image and downloadable segmentation file
77
+ """
78
+ if nifti_file is None:
79
+ return None, None
80
+
81
+ try:
82
+ print(f" Processing file: {nifti_file.name}")
83
+
84
+ # Preprocess
85
+ volume = transforms(nifti_file.name)
86
+ volume_input = volume.unsqueeze(0).to(device)
87
+
88
+ # Inference with sliding window
89
+ print(" Running inference...")
90
+ with torch.no_grad():
91
+ output = sliding_window_inference(
92
+ volume_input,
93
+ roi_size=(96, 96, 96), # Smaller ROI for CPU
94
+ sw_batch_size=1, # Batch size 1 for CPU
95
+ predictor=model,
96
+ overlap=0.5
97
+ )
98
+ prediction = torch.argmax(output, dim=1, keepdim=True)
99
+
100
+ # Convert to numpy
101
+ vol_np = volume.cpu().numpy()
102
+ pred_np = prediction.cpu().numpy()[0, 0]
103
+
104
+ # Normalize volume for display
105
+ vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
106
+
107
+ # Clamp slice index
108
+ max_slice = vol_np.shape[2] - 1
109
+ slice_idx = min(max(0, show_slice), max_slice)
110
+
111
+ # Create visualization (3 views)
112
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
113
+
114
+ # CT scan
115
+ axes[0].imshow(vol_display[:, :, slice_idx].T, cmap='gray', origin='lower')
116
+ axes[0].set_title(f'CT Scan (Slice {slice_idx}/{max_slice})', fontsize=14)
117
+ axes[0].axis('off')
118
+
119
+ # Segmentation only
120
+ axes[1].imshow(pred_np[:, :, slice_idx].T, cmap='Reds', origin='lower')
121
+ axes[1].set_title('Liver Segmentation', fontsize=14)
122
+ axes[1].axis('off')
123
+
124
+ # Overlay
125
+ axes[2].imshow(vol_display[:, :, slice_idx].T, cmap='gray', origin='lower')
126
+ axes[2].imshow(pred_np[:, :, slice_idx].T, cmap='Greens', alpha=0.5, origin='lower')
127
+ axes[2].set_title('Overlay (Green = Liver)', fontsize=14)
128
+ axes[2].axis('off')
129
+
130
+ plt.tight_layout()
131
+
132
+ # Convert plot to image
133
+ buf = BytesIO()
134
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
135
+ buf.seek(0)
136
+ result_image = Image.open(buf)
137
+ plt.close()
138
+
139
+ # Save segmentation as NIfTI
140
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz')
141
+
142
+ # Create segmentation NIfTI
143
+ pred_nifti = nib.Nifti1Image(pred_np.astype(np.uint8), affine=np.eye(4))
144
+ nib.save(pred_nifti, output_file.name)
145
+
146
+ print(" Segmentation complete!")
147
+
148
+ return result_image, output_file.name
149
+
150
+ except Exception as e:
151
+ print(f" Error: {str(e)}")
152
+ import traceback
153
+ traceback.print_exc()
154
+ return None, None
155
+
156
+ # ============================================================================
157
+ # GRADIO INTERFACE
158
+ # ============================================================================
159
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
160
+ gr.Markdown(
161
+ """
162
+ # 🫀 AI-Powered Liver Segmentation
163
+
164
+ Upload a liver CT scan in NIfTI format to get automatic liver segmentation using deep learning.
165
+
166
+ **Model**: SwinUNETR (Vision Transformer) trained on Medical Segmentation Decathlon
167
+ **Performance**: Mean Dice Score = **95.5% ± 2.0%** on test set
168
+ **Input**: NIfTI files (.nii or .nii.gz)
169
+ """
170
+ )
171
+
172
+ with gr.Row():
173
+ with gr.Column():
174
+ file_input = gr.File(
175
+ label=" Upload CT Scan (NIfTI format)",
176
+ file_types=[".nii", ".nii.gz"]
177
+ )
178
+ slice_slider = gr.Slider(
179
+ minimum=0,
180
+ maximum=127,
181
+ value=64,
182
+ step=1,
183
+ label=" Select Slice to Visualize"
184
+ )
185
+ segment_btn = gr.Button(" Segment Liver", variant="primary", size="lg")
186
+
187
+ with gr.Column():
188
+ output_image = gr.Image(label=" Segmentation Result")
189
+ output_file = gr.File(label=" Download Segmentation (NIfTI)")
190
+
191
+ gr.Markdown(
192
+ """
193
+ ---
194
+ ### How to Use:
195
+ 1. Upload a liver CT scan in NIfTI format (.nii.gz)
196
+ 2. Adjust the slice slider to view different cross-sections
197
+ 3. Click "Segment Liver" to run the AI model
198
+ 4. Download the segmentation mask for further analysis
199
+
200
+ ### ⚠️ Note:
201
+ - This is a research prototype, not for clinical use
202
+ - Processing time: ~30-60 seconds on CPU
203
+ - Best results with contrast-enhanced CT scans
204
+ """
205
+ )
206
+
207
+ # Connect button to function
208
+ segment_btn.click(
209
+ fn=segment_liver,
210
+ inputs=[file_input, slice_slider],
211
+ outputs=[output_image, output_file]
212
+ )
213
+
214
+ # Launch the app
215
+ if __name__ == "__main__":
216
+ demo.launch()
217
+