Amodit commited on
Commit
f172dad
ยท
verified ยท
1 Parent(s): 7195fcb

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +214 -0
  2. best_metric_model.pth +3 -0
  3. readme.md +35 -0
  4. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import nibabel as nib
4
+ import numpy as np
5
+ from monai.networks.nets import SwinUNETR
6
+ from monai.inferers import sliding_window_inference
7
+ from monai.transforms import (
8
+ Compose, LoadImage, EnsureChannelFirst, Orientation,
9
+ Spacing, ScaleIntensityRange, CropForeground, Resize, EnsureType
10
+ )
11
+ import tempfile
12
+ import matplotlib.pyplot as plt
13
+ from io import BytesIO
14
+ from PIL import Image
15
+ import os
16
+
17
+ # ============================================================================
18
+ # MODEL INITIALIZATION
19
+ # ============================================================================
20
+ print(" Initializing model...")
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ print(f" Using device: {device}")
23
+
24
+ model = SwinUNETR(
25
+ in_channels=1,
26
+ out_channels=2,
27
+ patch_size=(2, 2, 2),
28
+ depths=(2, 2, 2, 2),
29
+ num_heads=(3, 6, 12, 24),
30
+ window_size=7,
31
+ feature_size=48,
32
+ norm_name="instance",
33
+ use_checkpoint=True,
34
+ spatial_dims=3,
35
+ ).to(device)
36
+
37
+ # Load trained weights
38
+ model_path = "best_metric_model.pth"
39
+ if os.path.exists(model_path):
40
+ model.load_state_dict(torch.load(model_path, map_location=device))
41
+ print(" Model loaded successfully!")
42
+ else:
43
+ print(" Model file not found! Please upload best_metric_model.pth")
44
+
45
+ model.eval()
46
+
47
+ # ============================================================================
48
+ # PREPROCESSING PIPELINE
49
+ # ============================================================================
50
+ transforms = Compose([
51
+ LoadImage(image_only=True),
52
+ EnsureChannelFirst(),
53
+ Orientation(axcodes="RAS"),
54
+ Spacing(pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
55
+ ScaleIntensityRange(a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
56
+ CropForeground(source_key="image"),
57
+ Resize(spatial_size=(128, 128, 128), mode="trilinear"),
58
+ EnsureType(dtype=torch.float32),
59
+ ])
60
+
61
+ # ============================================================================
62
+ # SEGMENTATION FUNCTION
63
+ # ============================================================================
64
+ def segment_liver(nifti_file, show_slice):
65
+ """
66
+ Segment liver from uploaded NIfTI file
67
+
68
+ Args:
69
+ nifti_file: Uploaded .nii.gz file
70
+ show_slice: Which slice to visualize (0-127)
71
+
72
+ Returns:
73
+ Visualization image and downloadable segmentation file
74
+ """
75
+ if nifti_file is None:
76
+ return None, None
77
+
78
+ try:
79
+ print(f" Processing file: {nifti_file.name}")
80
+
81
+ # Preprocess
82
+ volume = transforms(nifti_file.name)
83
+ volume_input = volume.unsqueeze(0).to(device)
84
+
85
+ # Inference with sliding window
86
+ print(" Running inference...")
87
+ with torch.no_grad():
88
+ output = sliding_window_inference(
89
+ volume_input,
90
+ roi_size=(128, 128, 64),
91
+ sw_batch_size=2,
92
+ predictor=model,
93
+ overlap=0.75
94
+ )
95
+ prediction = torch.argmax(output, dim=1, keepdim=True)
96
+
97
+ # Convert to numpy
98
+ vol_np = volume.cpu().numpy()
99
+ pred_np = prediction.cpu().numpy()[0, 0]
100
+
101
+ # Normalize volume for display
102
+ vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
103
+
104
+ # Clamp slice index
105
+ max_slice = vol_np.shape[2] - 1
106
+ slice_idx = min(max(0, show_slice), max_slice)
107
+
108
+ # Create visualization (3 views)
109
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
110
+
111
+ # CT scan
112
+ axes[0].imshow(vol_display[:, :, slice_idx].T, cmap='gray', origin='lower')
113
+ axes[0].set_title(f'CT Scan (Slice {slice_idx}/{max_slice})', fontsize=14)
114
+ axes[0].axis('off')
115
+
116
+ # Segmentation only
117
+ axes[1].imshow(pred_np[:, :, slice_idx].T, cmap='Reds', origin='lower')
118
+ axes[1].set_title('Liver Segmentation', fontsize=14)
119
+ axes[1].axis('off')
120
+
121
+ # Overlay
122
+ axes[2].imshow(vol_display[:, :, slice_idx].T, cmap='gray', origin='lower')
123
+ axes[2].imshow(pred_np[:, :, slice_idx].T, cmap='Greens', alpha=0.5, origin='lower')
124
+ axes[2].set_title('Overlay (Green = Liver)', fontsize=14)
125
+ axes[2].axis('off')
126
+
127
+ plt.tight_layout()
128
+
129
+ # Convert plot to image
130
+ buf = BytesIO()
131
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
132
+ buf.seek(0)
133
+ result_image = Image.open(buf)
134
+ plt.close()
135
+
136
+ # Save segmentation as NIfTI
137
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz')
138
+
139
+ # Load original NIfTI to preserve metadata
140
+ original_nii = nib.load(nifti_file.name)
141
+
142
+ # Create segmentation NIfTI with original affine
143
+ pred_nifti = nib.Nifti1Image(pred_np.astype(np.uint8), affine=original_nii.affine)
144
+ nib.save(pred_nifti, output_file.name)
145
+
146
+ print("Segmentation complete!")
147
+
148
+ return result_image, output_file.name
149
+
150
+ except Exception as e:
151
+ print(f" Error: {str(e)}")
152
+ return None, None
153
+
154
+ # ============================================================================
155
+ # GRADIO INTERFACE
156
+ # ============================================================================
157
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
158
+ gr.Markdown(
159
+ """
160
+ # AI-Powered Liver Segmentation
161
+
162
+ Upload a liver CT scan in NIfTI format to get automatic liver segmentation using deep learning.
163
+
164
+ **Model**: SwinUNETR (Vision Transformer) trained on Medical Segmentation Decathlon
165
+ **Performance**: Mean Dice Score = **95.5% ยฑ 2.0%** on test set
166
+ **Input**: NIfTI files (.nii or .nii.gz)
167
+ """
168
+ )
169
+
170
+ with gr.Row():
171
+ with gr.Column():
172
+ file_input = gr.File(
173
+ label=" Upload CT Scan (NIfTI format)",
174
+ file_types=[".nii", ".nii.gz"]
175
+ )
176
+ slice_slider = gr.Slider(
177
+ minimum=0,
178
+ maximum=127,
179
+ value=64,
180
+ step=1,
181
+ label="๐Ÿ” Select Slice to Visualize"
182
+ )
183
+ segment_btn = gr.Button("๐Ÿš€ Segment Liver", variant="primary", size="lg")
184
+
185
+ with gr.Column():
186
+ output_image = gr.Image(label=" Segmentation Result")
187
+ output_file = gr.File(label=" Download Segmentation (NIfTI)")
188
+
189
+ gr.Markdown(
190
+ """
191
+ ---
192
+ ### How to Use:
193
+ 1. Upload a liver CT scan in NIfTI format (.nii.gz)
194
+ 2. Adjust the slice slider to view different cross-sections
195
+ 3. Click "Segment Liver" to run the AI model
196
+ 4. Download the segmentation mask for further analysis
197
+
198
+ ### โš ๏ธ Note:
199
+ - This is a research prototype, not for clinical use
200
+ - Processing time: ~10-30 seconds depending on GPU availability
201
+ - Best results with contrast-enhanced CT scans
202
+ """
203
+ )
204
+
205
+ # Connect button to function
206
+ segment_btn.click(
207
+ fn=segment_liver,
208
+ inputs=[file_input, slice_slider],
209
+ outputs=[output_image, output_file]
210
+ )
211
+
212
+ # Launch the app
213
+ if __name__ == "__main__":
214
+ demo.launch()
best_metric_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cb4121e118e9580eae93dc77b434dc211d45e7084f37fb8893c51ca0f7e6130
3
+ size 256345490
readme.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AI Liver Segmentation
3
+ emoji: ๐Ÿซ€
4
+ colorFrom: red
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.8.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # ๐Ÿซ€ AI-Powered Liver Segmentation
14
+
15
+ Automatic liver segmentation from CT scans using SwinUNETR (Vision Transformer).
16
+
17
+ ## ๐Ÿ“Š Model Performance
18
+ - **Architecture**: SwinUNETR with gradient checkpointing
19
+ - **Dataset**: Medical Segmentation Decathlon (Task03_Liver)
20
+ - **Dice Score**: 95.5% ยฑ 2.0% on test set
21
+
22
+ ## ๐Ÿš€ Features
23
+ - Upload CT scans in NIfTI format
24
+ - Real-time liver segmentation
25
+ - Interactive slice visualization
26
+ - Download segmentation masks
27
+
28
+ ## ๐Ÿ› ๏ธ Tech Stack
29
+ - PyTorch
30
+ - MONAI
31
+ - Gradio
32
+ - Hugging Face Spaces
33
+
34
+ ## โš ๏ธ Disclaimer
35
+ This is a research prototype and should not be used for clinical diagnosis.
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ monai==1.3.0
4
+ nibabel==5.1.0
5
+ gradio==4.8.0
6
+ matplotlib==3.7.1
7
+ numpy==1.24.3
8
+ Pillow==10.0.0
9
+ scipy==1.11.1