File size: 13,834 Bytes
9859ea2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
import os
import numpy as np
import torch
import SimpleITK as sitk
import torch.nn.functional as F
import cv2
from PIL import Image, ImageDraw, ImageOps
import tempfile
import gradio as gr
from segment_anything.build_sam3D import sam_model_registry3D
from utils.click_method import get_next_click3D_torch_ritm, get_next_click3D_torch_2


def build_model():
    checkpoint_path = 'ckpt\\BoSAM.pth'
    
    checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False)
    
    state_dict = checkpoint['model_state_dict']
    
    sam_model = sam_model_registry3D['vit_b_ori'](checkpoint=None).to('cuda')
    sam_model.load_state_dict(state_dict)
    
    return sam_model


def center_crop_or_pad(image_array, target_shape=(128, 128, 128)):
    """中心裁剪或填充图像到目标尺寸"""
    current_shape = image_array.shape
    
    start = [(c - t) // 2 if c > t else 0 for c, t in zip(current_shape, target_shape)]
    end = [s + t if c > t else c for s, t, c in zip(start, target_shape, current_shape)]
    
    result = np.zeros(target_shape, dtype=image_array.dtype)
    
    target_start = [0 if c > t else (t - c) // 2 for c, t in zip(current_shape, target_shape)]
    target_end = [t if c > t else ts + c for ts, c, t in zip(target_start, current_shape, target_shape)]
    
    if all(c >= t for c, t in zip(current_shape, target_shape)):
        cropped = image_array[
            start[0]:start[0]+target_shape[0],
            start[1]:start[1]+target_shape[1],
            start[2]:start[2]+target_shape[2]
        ]
        return cropped
    else:
        source_slices = tuple(slice(0 if c <= t else s, c if c <= t else e) 
                            for s, e, c, t in zip(start, end, current_shape, target_shape))
        target_slices = tuple(slice(ts, te) 
                             for ts, te in zip(target_start, target_end))
        
        result[target_slices] = image_array[source_slices]
        return result


def preprocess_image(image_path):
    """预处理图像为128x128x128"""
    image = sitk.ReadImage(image_path)
    image_array = sitk.GetArrayFromImage(image)
    
    processed_array = center_crop_or_pad(image_array, (128, 128, 128))
    
    image_tensor = torch.tensor(processed_array).float().unsqueeze(0).unsqueeze(0)
    
    return image_tensor.to('cuda')


def load_gt3d(image3d_path):
    """加载并预处理GT标签为128x128x128"""
    gt3d_path = r'examples\labels\1.3.6.1.4.1.9328.50.4.0357.nii.gz'  # 使用固定的GT
    if not os.path.exists(gt3d_path):
        raise FileNotFoundError(f"The file {gt3d_path} does not exist.")
    
    image = sitk.ReadImage(gt3d_path)
    image_array = sitk.GetArrayFromImage(image)
    
    processed_array = center_crop_or_pad(image_array, (128, 128, 128))
    
    gt_tensor = torch.tensor(processed_array).float().unsqueeze(0).unsqueeze(0)
    
    return gt_tensor.to('cuda')


def overlay_mask_on_image(image_slice, mask_slice, alpha=0.6):
    """在图像切片上叠加掩码,增强视觉效果"""
    # 增强对比度
    p2, p98 = np.percentile(image_slice, (2, 98))
    image_contrast = np.clip((image_slice - p2) / (p98 - p2), 0, 1)
    image_contrast = (image_contrast * 255).astype(np.uint8)
    
    # 创建彩色图像
    image_rgb = Image.fromarray(image_contrast).convert("RGB")
    
    # 应用轻微的锐化和增强
    enhancer = ImageOps.autocontrast(image_rgb, cutoff=0.5)
    image_rgba = enhancer.convert("RGBA")

    # 创建更鲜明的掩码颜色
    mask_image = Image.new('RGBA', image_rgba.size, (0, 0, 0, 0))
    mask_draw = ImageDraw.Draw(mask_image)

    mask = (mask_slice > 0).astype(np.uint8) * 255
    mask_pil = Image.fromarray(mask, mode='L')
    
    # 使用高饱和度的蓝色
    mask_draw.bitmap((0, 0), mask_pil, fill=(41, 128, 255, int(255 * alpha)))

    # 叠加并添加轻微的发光效果
    combined_image = Image.alpha_composite(image_rgba, mask_image)
    
    return combined_image


def predict(image3D, sam_model, points=None, prev_masks=None, num_clicks=5):
    """使用SAM模型预测掩码"""
    sam_model.eval()
    
    image3D = (image3D - image3D.mean()) / image3D.std()
    
    gt3D = load_gt3d(None)
    
    if prev_masks is None:
        prev_masks = torch.zeros_like(image3D).to('cuda')
    
    low_res_masks = F.interpolate(prev_masks.float(), size=(32, 32, 32))
    
    with torch.no_grad():
        image_embedding = sam_model.image_encoder(image3D)
    
    for num_click in range(num_clicks):
        with torch.no_grad():
            batch_points, batch_labels = get_next_click3D_torch_2(prev_masks.to('cuda'), gt3D.to('cuda'))
            
            points_co = torch.cat(batch_points, dim=0).to('cuda')  
            points_la = torch.cat(batch_labels, dim=0).to('cuda')  
            
            sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                points=[points_co, points_la],
                boxes=None,
                masks=low_res_masks.to('cuda'),
            )
            
            low_res_masks, iou_predictions = sam_model.mask_decoder(
                image_embeddings=image_embedding.to('cuda'),
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            
            prev_masks = F.interpolate(low_res_masks, size=[128, 128, 128], mode='trilinear', align_corners=False)
            
            medsam_seg_prob = torch.sigmoid(prev_masks)
            medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
            medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
    
    return medsam_seg, medsam_seg_prob


def normalize_image(image):
    """增强图像对比度"""
    # 使用百分位数来增强对比度
    p2, p98 = np.percentile(image, (2, 98))
    if p98 - p2 != 0:
        image = np.clip((image - p2) / (p98 - p2), 0, 1)
    else:
        image = np.zeros_like(image)
    image = (image * 255).astype(np.uint8)
    return image


def predicts(img_path, sam_model):
    """预处理图像并预测"""
    img = preprocess_image(img_path)
    prediction, prediction_prob = predict(img, sam_model)
    return prediction, prediction_prob


def save_nifti(prediction, original_image_path):
    """保存预测结果为NIFTI文件"""
    original_image = sitk.ReadImage(original_image_path)
    
    output_image = sitk.GetImageFromArray(prediction.astype(np.uint8))
    
    output_image.SetDirection(original_image.GetDirection())
    output_image.SetOrigin(original_image.GetOrigin())
    
    original_size = original_image.GetSize()
    original_spacing = original_image.GetSpacing()
    
    new_spacing = [
        original_spacing[0] * (original_size[0] / 128),
        original_spacing[1] * (original_size[1] / 128),
        original_spacing[2] * (original_size[2] / 128)
    ]
    output_image.SetSpacing(new_spacing)
    
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz")
    temp_filename = temp_file.name
    
    sitk.WriteImage(output_image, temp_filename)
    
    return temp_filename


def gr_interface(img_path, sam_model=None):
    """增强的Gradio界面函数"""
    if sam_model is None:
        sam_model = build_model()
    
    # 显示进度信息    
    yield None, gr.update(value="正在加载数据..."), None, None, None
    
    processed_img = preprocess_image(img_path)
    
    yield None, gr.update(value="正在分割..."), None, None, None
    
    prediction, prediction_prob = predicts(img_path, sam_model)
    
    yield None, gr.update(value="正在生成可视化..."), None, None, None
    
    processed_slices = []
    combined_slices = []
    predicted_slices = []
    
    nifti_file_path = save_nifti(prediction, img_path)
    
    # 计算中心32张切片的索引
    start_idx = (128 - 32) // 2  # 48
    end_idx = start_idx + 32     # 80
    
    for i in range(start_idx, end_idx):
        # 处理原始图像切片
        processed_slice = processed_img[0, 0, i].cpu().numpy()
        processed_slices.append(normalize_image(processed_slice))
        
        # 处理预测掩码切片
        mask_slice = prediction[i]
        normalized_mask = normalize_image(mask_slice)
        
        # 叠加掩码到图像上 - 使用更醒目的视觉效果
        combined_image = overlay_mask_on_image(processed_slices[-1], mask_slice)
        combined_slices.append(combined_image)
        
        # 添加预测切片
        predicted_slices.append(normalized_mask)
    
    yield processed_slices, gr.update(value="分割完成!"), combined_slices, predicted_slices, nifti_file_path


# 使用示例文件路径作为常量
DEFAULT_EXAMPLE = "examples\\1.3.6.1.4.1.9328.50.4.0327.nii.gz"
EXAMPLES = [
    ["examples\\1.3.6.1.4.1.9328.50.4.0327.nii.gz"],
    ["examples\\1.3.6.1.4.1.9328.50.4.0357.nii.gz"],
    ["examples\\1.3.6.1.4.1.9328.50.4.0477.nii.gz"],
    ["examples\\1.3.6.1.4.1.9328.50.4.0491.nii.gz"],
    ["examples\\1.3.6.1.4.1.9328.50.4.0708.nii.gz"],
    ["examples\\1.3.6.1.4.1.9328.50.4.0719.nii.gz"]
]

# 自定义CSS样式以美化界面
css = """
body {
    background-color: #f8fafc;
}

.container {
    max-width: 1200px;
    margin: 0 auto;
}

.main-title {
    text-align: center;
    color: #2563eb;
    font-size: 2.5rem;
    margin-bottom: 1rem;
    font-weight: bold;
    animation: fadeIn 1.5s ease-in-out;
}

.sub-title {
    text-align: center;
    color: #1e293b;
    margin-bottom: 2rem;
    animation: fadeIn 2s ease-in-out;
}

.custom-button {
    background-color: #2563eb !important;
    color: white !important;
    transition: transform 0.3s, box-shadow 0.3s;
}

.custom-button:hover {
    transform: translateY(-2px);
    box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
}

.gallery-item {
    border-radius: 8px;
    overflow: hidden;
    box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
    transition: transform 0.3s;
}

.gallery-item:hover {
    transform: scale(1.02);
    box-shadow: 0 6px 12px rgba(0, 0, 0, 0.15);
}

@keyframes fadeIn {
    from { opacity: 0; transform: translateY(20px); }
    to { opacity: 1; transform: translateY(0); }
}
"""

# 加载模型(全局变量)
sam_model = build_model()

# 创建更美观的Gradio界面,使用兼容的组件
with gr.Blocks(title="3D医学影像智能分割系统", css=css) as demo:
    gr.HTML("<h1 class='main-title'>3D医学影像智能分割系统</h1>")
    gr.HTML("<p class='sub-title'>基于BoSAM的前沿人工智能自动分割技术,为医学影像分析提供高精度解决方案</p>")
    
    with gr.Row():
        with gr.Column(scale=1):
            # 输入区域
            gr.Markdown("### 上传/选择影像")
            input_file = gr.File(label="上传NIfTI文件", value=DEFAULT_EXAMPLE)
            
            status = gr.Textbox(label="处理状态", value="准备就绪")
            process_btn = gr.Button("开始智能分割", elem_classes=["custom-button"])
            
            # 示例区域
            gr.Markdown("### 示例数据")
            examples = gr.Examples(
                examples=EXAMPLES,
                inputs=[input_file]
            )
            
            gr.HTML("""
            <div style="margin-top: 2rem; padding: 1rem; background-color: rgba(16, 185, 129, 0.1); border-radius: 8px;">
                <h3 style="color: #10b981; margin-bottom: 0.5rem;">技术亮点</h3>
                <ul style="margin-left: 1.5rem;">
                    <li>基于最新的Segment Anything Model (SAM) 技术</li>
                    <li>专为3D医学影像优化的深度学习模型</li>
                    <li>智能识别解剖结构,无需手动绘制边界</li>
                    <li>高精度分割结果,提升诊断效率</li>
                </ul>
            </div>
            """)

        with gr.Column(scale=2):
            # 输出区域
            with gr.Row():
                gr.Markdown("## 原始医学影像")
                output_original = gr.Gallery(label="", show_label=False, columns=4, rows=8, height="600px", elem_classes=["gallery-item"])
            
            with gr.Row():
                with gr.Column():
                    gr.Markdown("## 分割叠加结果")
                    output_combined = gr.Gallery(label="", show_label=False, columns=4, rows=4, height="400px", elem_classes=["gallery-item"])
                
                with gr.Column():
                    gr.Markdown("## 分割掩码")
                    output_mask = gr.Gallery(label="", show_label=False, columns=4, rows=4, height="400px", elem_classes=["gallery-item"])
            
            gr.Markdown("## 分割结果下载")
            output_file = gr.File(label="下载完整3D分割结果 (NIFTI格式)")
    
    gr.HTML("""
    <div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid rgba(0, 0, 0, 0.1);">
        <p>© 2025 3D医学影像智能分割系统 | 人工智能辅助医学影像分析平台</p>
        <p>基于最新的BoaSAM模型,为医疗影像分析提供高精度自动分割解决方案</p>
    </div>
    """)
    
    # 处理事件
    process_btn.click(
        fn=gr_interface,
        inputs=[input_file],
        outputs=[output_original, status, output_combined, output_mask, output_file]
    )
    
    examples.dataset.click(
        fn=gr_interface,
        inputs=[input_file],
        outputs=[output_original, status, output_combined, output_mask, output_file]
    )
    
    demo.load(
        fn=gr_interface,
        inputs=[input_file],
        outputs=[output_original, status, output_combined, output_mask, output_file]
    )

if __name__ == "__main__":
    demo.launch(debug=True, share = True)