|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
import SimpleITK as sitk |
|
|
import torch.nn.functional as F |
|
|
import cv2 |
|
|
from PIL import Image, ImageDraw, ImageOps |
|
|
import tempfile |
|
|
import gradio as gr |
|
|
from segment_anything.build_sam3D import sam_model_registry3D |
|
|
from utils.click_method import get_next_click3D_torch_ritm, get_next_click3D_torch_2 |
|
|
|
|
|
|
|
|
def build_model(): |
|
|
checkpoint_path = 'ckpt\\BoSAM.pth' |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False) |
|
|
|
|
|
state_dict = checkpoint['model_state_dict'] |
|
|
|
|
|
sam_model = sam_model_registry3D['vit_b_ori'](checkpoint=None).to('cuda') |
|
|
sam_model.load_state_dict(state_dict) |
|
|
|
|
|
return sam_model |
|
|
|
|
|
|
|
|
def center_crop_or_pad(image_array, target_shape=(128, 128, 128)): |
|
|
"""中心裁剪或填充图像到目标尺寸""" |
|
|
current_shape = image_array.shape |
|
|
|
|
|
start = [(c - t) // 2 if c > t else 0 for c, t in zip(current_shape, target_shape)] |
|
|
end = [s + t if c > t else c for s, t, c in zip(start, target_shape, current_shape)] |
|
|
|
|
|
result = np.zeros(target_shape, dtype=image_array.dtype) |
|
|
|
|
|
target_start = [0 if c > t else (t - c) // 2 for c, t in zip(current_shape, target_shape)] |
|
|
target_end = [t if c > t else ts + c for ts, c, t in zip(target_start, current_shape, target_shape)] |
|
|
|
|
|
if all(c >= t for c, t in zip(current_shape, target_shape)): |
|
|
cropped = image_array[ |
|
|
start[0]:start[0]+target_shape[0], |
|
|
start[1]:start[1]+target_shape[1], |
|
|
start[2]:start[2]+target_shape[2] |
|
|
] |
|
|
return cropped |
|
|
else: |
|
|
source_slices = tuple(slice(0 if c <= t else s, c if c <= t else e) |
|
|
for s, e, c, t in zip(start, end, current_shape, target_shape)) |
|
|
target_slices = tuple(slice(ts, te) |
|
|
for ts, te in zip(target_start, target_end)) |
|
|
|
|
|
result[target_slices] = image_array[source_slices] |
|
|
return result |
|
|
|
|
|
|
|
|
def preprocess_image(image_path): |
|
|
"""预处理图像为128x128x128""" |
|
|
image = sitk.ReadImage(image_path) |
|
|
image_array = sitk.GetArrayFromImage(image) |
|
|
|
|
|
processed_array = center_crop_or_pad(image_array, (128, 128, 128)) |
|
|
|
|
|
image_tensor = torch.tensor(processed_array).float().unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
return image_tensor.to('cuda') |
|
|
|
|
|
|
|
|
def load_gt3d(image3d_path): |
|
|
"""加载并预处理GT标签为128x128x128""" |
|
|
gt3d_path = r'examples\labels\1.3.6.1.4.1.9328.50.4.0357.nii.gz' |
|
|
if not os.path.exists(gt3d_path): |
|
|
raise FileNotFoundError(f"The file {gt3d_path} does not exist.") |
|
|
|
|
|
image = sitk.ReadImage(gt3d_path) |
|
|
image_array = sitk.GetArrayFromImage(image) |
|
|
|
|
|
processed_array = center_crop_or_pad(image_array, (128, 128, 128)) |
|
|
|
|
|
gt_tensor = torch.tensor(processed_array).float().unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
return gt_tensor.to('cuda') |
|
|
|
|
|
|
|
|
def overlay_mask_on_image(image_slice, mask_slice, alpha=0.6): |
|
|
"""在图像切片上叠加掩码,增强视觉效果""" |
|
|
|
|
|
p2, p98 = np.percentile(image_slice, (2, 98)) |
|
|
image_contrast = np.clip((image_slice - p2) / (p98 - p2), 0, 1) |
|
|
image_contrast = (image_contrast * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
image_rgb = Image.fromarray(image_contrast).convert("RGB") |
|
|
|
|
|
|
|
|
enhancer = ImageOps.autocontrast(image_rgb, cutoff=0.5) |
|
|
image_rgba = enhancer.convert("RGBA") |
|
|
|
|
|
|
|
|
mask_image = Image.new('RGBA', image_rgba.size, (0, 0, 0, 0)) |
|
|
mask_draw = ImageDraw.Draw(mask_image) |
|
|
|
|
|
mask = (mask_slice > 0).astype(np.uint8) * 255 |
|
|
mask_pil = Image.fromarray(mask, mode='L') |
|
|
|
|
|
|
|
|
mask_draw.bitmap((0, 0), mask_pil, fill=(41, 128, 255, int(255 * alpha))) |
|
|
|
|
|
|
|
|
combined_image = Image.alpha_composite(image_rgba, mask_image) |
|
|
|
|
|
return combined_image |
|
|
|
|
|
|
|
|
def predict(image3D, sam_model, points=None, prev_masks=None, num_clicks=5): |
|
|
"""使用SAM模型预测掩码""" |
|
|
sam_model.eval() |
|
|
|
|
|
image3D = (image3D - image3D.mean()) / image3D.std() |
|
|
|
|
|
gt3D = load_gt3d(None) |
|
|
|
|
|
if prev_masks is None: |
|
|
prev_masks = torch.zeros_like(image3D).to('cuda') |
|
|
|
|
|
low_res_masks = F.interpolate(prev_masks.float(), size=(32, 32, 32)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
image_embedding = sam_model.image_encoder(image3D) |
|
|
|
|
|
for num_click in range(num_clicks): |
|
|
with torch.no_grad(): |
|
|
batch_points, batch_labels = get_next_click3D_torch_2(prev_masks.to('cuda'), gt3D.to('cuda')) |
|
|
|
|
|
points_co = torch.cat(batch_points, dim=0).to('cuda') |
|
|
points_la = torch.cat(batch_labels, dim=0).to('cuda') |
|
|
|
|
|
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( |
|
|
points=[points_co, points_la], |
|
|
boxes=None, |
|
|
masks=low_res_masks.to('cuda'), |
|
|
) |
|
|
|
|
|
low_res_masks, iou_predictions = sam_model.mask_decoder( |
|
|
image_embeddings=image_embedding.to('cuda'), |
|
|
image_pe=sam_model.prompt_encoder.get_dense_pe(), |
|
|
sparse_prompt_embeddings=sparse_embeddings, |
|
|
dense_prompt_embeddings=dense_embeddings, |
|
|
multimask_output=False, |
|
|
) |
|
|
|
|
|
prev_masks = F.interpolate(low_res_masks, size=[128, 128, 128], mode='trilinear', align_corners=False) |
|
|
|
|
|
medsam_seg_prob = torch.sigmoid(prev_masks) |
|
|
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() |
|
|
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8) |
|
|
|
|
|
return medsam_seg, medsam_seg_prob |
|
|
|
|
|
|
|
|
def normalize_image(image): |
|
|
"""增强图像对比度""" |
|
|
|
|
|
p2, p98 = np.percentile(image, (2, 98)) |
|
|
if p98 - p2 != 0: |
|
|
image = np.clip((image - p2) / (p98 - p2), 0, 1) |
|
|
else: |
|
|
image = np.zeros_like(image) |
|
|
image = (image * 255).astype(np.uint8) |
|
|
return image |
|
|
|
|
|
|
|
|
def predicts(img_path, sam_model): |
|
|
"""预处理图像并预测""" |
|
|
img = preprocess_image(img_path) |
|
|
prediction, prediction_prob = predict(img, sam_model) |
|
|
return prediction, prediction_prob |
|
|
|
|
|
|
|
|
def save_nifti(prediction, original_image_path): |
|
|
"""保存预测结果为NIFTI文件""" |
|
|
original_image = sitk.ReadImage(original_image_path) |
|
|
|
|
|
output_image = sitk.GetImageFromArray(prediction.astype(np.uint8)) |
|
|
|
|
|
output_image.SetDirection(original_image.GetDirection()) |
|
|
output_image.SetOrigin(original_image.GetOrigin()) |
|
|
|
|
|
original_size = original_image.GetSize() |
|
|
original_spacing = original_image.GetSpacing() |
|
|
|
|
|
new_spacing = [ |
|
|
original_spacing[0] * (original_size[0] / 128), |
|
|
original_spacing[1] * (original_size[1] / 128), |
|
|
original_spacing[2] * (original_size[2] / 128) |
|
|
] |
|
|
output_image.SetSpacing(new_spacing) |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz") |
|
|
temp_filename = temp_file.name |
|
|
|
|
|
sitk.WriteImage(output_image, temp_filename) |
|
|
|
|
|
return temp_filename |
|
|
|
|
|
|
|
|
def gr_interface(img_path, sam_model=None): |
|
|
"""增强的Gradio界面函数""" |
|
|
if sam_model is None: |
|
|
sam_model = build_model() |
|
|
|
|
|
|
|
|
yield None, gr.update(value="正在加载数据..."), None, None, None |
|
|
|
|
|
processed_img = preprocess_image(img_path) |
|
|
|
|
|
yield None, gr.update(value="正在分割..."), None, None, None |
|
|
|
|
|
prediction, prediction_prob = predicts(img_path, sam_model) |
|
|
|
|
|
yield None, gr.update(value="正在生成可视化..."), None, None, None |
|
|
|
|
|
processed_slices = [] |
|
|
combined_slices = [] |
|
|
predicted_slices = [] |
|
|
|
|
|
nifti_file_path = save_nifti(prediction, img_path) |
|
|
|
|
|
|
|
|
start_idx = (128 - 32) // 2 |
|
|
end_idx = start_idx + 32 |
|
|
|
|
|
for i in range(start_idx, end_idx): |
|
|
|
|
|
processed_slice = processed_img[0, 0, i].cpu().numpy() |
|
|
processed_slices.append(normalize_image(processed_slice)) |
|
|
|
|
|
|
|
|
mask_slice = prediction[i] |
|
|
normalized_mask = normalize_image(mask_slice) |
|
|
|
|
|
|
|
|
combined_image = overlay_mask_on_image(processed_slices[-1], mask_slice) |
|
|
combined_slices.append(combined_image) |
|
|
|
|
|
|
|
|
predicted_slices.append(normalized_mask) |
|
|
|
|
|
yield processed_slices, gr.update(value="分割完成!"), combined_slices, predicted_slices, nifti_file_path |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_EXAMPLE = "examples\\1.3.6.1.4.1.9328.50.4.0327.nii.gz" |
|
|
EXAMPLES = [ |
|
|
["examples\\1.3.6.1.4.1.9328.50.4.0327.nii.gz"], |
|
|
["examples\\1.3.6.1.4.1.9328.50.4.0357.nii.gz"], |
|
|
["examples\\1.3.6.1.4.1.9328.50.4.0477.nii.gz"], |
|
|
["examples\\1.3.6.1.4.1.9328.50.4.0491.nii.gz"], |
|
|
["examples\\1.3.6.1.4.1.9328.50.4.0708.nii.gz"], |
|
|
["examples\\1.3.6.1.4.1.9328.50.4.0719.nii.gz"] |
|
|
] |
|
|
|
|
|
|
|
|
css = """ |
|
|
body { |
|
|
background-color: #f8fafc; |
|
|
} |
|
|
|
|
|
.container { |
|
|
max-width: 1200px; |
|
|
margin: 0 auto; |
|
|
} |
|
|
|
|
|
.main-title { |
|
|
text-align: center; |
|
|
color: #2563eb; |
|
|
font-size: 2.5rem; |
|
|
margin-bottom: 1rem; |
|
|
font-weight: bold; |
|
|
animation: fadeIn 1.5s ease-in-out; |
|
|
} |
|
|
|
|
|
.sub-title { |
|
|
text-align: center; |
|
|
color: #1e293b; |
|
|
margin-bottom: 2rem; |
|
|
animation: fadeIn 2s ease-in-out; |
|
|
} |
|
|
|
|
|
.custom-button { |
|
|
background-color: #2563eb !important; |
|
|
color: white !important; |
|
|
transition: transform 0.3s, box-shadow 0.3s; |
|
|
} |
|
|
|
|
|
.custom-button:hover { |
|
|
transform: translateY(-2px); |
|
|
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); |
|
|
} |
|
|
|
|
|
.gallery-item { |
|
|
border-radius: 8px; |
|
|
overflow: hidden; |
|
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); |
|
|
transition: transform 0.3s; |
|
|
} |
|
|
|
|
|
.gallery-item:hover { |
|
|
transform: scale(1.02); |
|
|
box-shadow: 0 6px 12px rgba(0, 0, 0, 0.15); |
|
|
} |
|
|
|
|
|
@keyframes fadeIn { |
|
|
from { opacity: 0; transform: translateY(20px); } |
|
|
to { opacity: 1; transform: translateY(0); } |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
sam_model = build_model() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="3D医学影像智能分割系统", css=css) as demo: |
|
|
gr.HTML("<h1 class='main-title'>3D医学影像智能分割系统</h1>") |
|
|
gr.HTML("<p class='sub-title'>基于BoSAM的前沿人工智能自动分割技术,为医学影像分析提供高精度解决方案</p>") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
gr.Markdown("### 上传/选择影像") |
|
|
input_file = gr.File(label="上传NIfTI文件", value=DEFAULT_EXAMPLE) |
|
|
|
|
|
status = gr.Textbox(label="处理状态", value="准备就绪") |
|
|
process_btn = gr.Button("开始智能分割", elem_classes=["custom-button"]) |
|
|
|
|
|
|
|
|
gr.Markdown("### 示例数据") |
|
|
examples = gr.Examples( |
|
|
examples=EXAMPLES, |
|
|
inputs=[input_file] |
|
|
) |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="margin-top: 2rem; padding: 1rem; background-color: rgba(16, 185, 129, 0.1); border-radius: 8px;"> |
|
|
<h3 style="color: #10b981; margin-bottom: 0.5rem;">技术亮点</h3> |
|
|
<ul style="margin-left: 1.5rem;"> |
|
|
<li>基于最新的Segment Anything Model (SAM) 技术</li> |
|
|
<li>专为3D医学影像优化的深度学习模型</li> |
|
|
<li>智能识别解剖结构,无需手动绘制边界</li> |
|
|
<li>高精度分割结果,提升诊断效率</li> |
|
|
</ul> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("## 原始医学影像") |
|
|
output_original = gr.Gallery(label="", show_label=False, columns=4, rows=8, height="600px", elem_classes=["gallery-item"]) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("## 分割叠加结果") |
|
|
output_combined = gr.Gallery(label="", show_label=False, columns=4, rows=4, height="400px", elem_classes=["gallery-item"]) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("## 分割掩码") |
|
|
output_mask = gr.Gallery(label="", show_label=False, columns=4, rows=4, height="400px", elem_classes=["gallery-item"]) |
|
|
|
|
|
gr.Markdown("## 分割结果下载") |
|
|
output_file = gr.File(label="下载完整3D分割结果 (NIFTI格式)") |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid rgba(0, 0, 0, 0.1);"> |
|
|
<p>© 2025 3D医学影像智能分割系统 | 人工智能辅助医学影像分析平台</p> |
|
|
<p>基于最新的BoaSAM模型,为医疗影像分析提供高精度自动分割解决方案</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
fn=gr_interface, |
|
|
inputs=[input_file], |
|
|
outputs=[output_original, status, output_combined, output_mask, output_file] |
|
|
) |
|
|
|
|
|
examples.dataset.click( |
|
|
fn=gr_interface, |
|
|
inputs=[input_file], |
|
|
outputs=[output_original, status, output_combined, output_mask, output_file] |
|
|
) |
|
|
|
|
|
demo.load( |
|
|
fn=gr_interface, |
|
|
inputs=[input_file], |
|
|
outputs=[output_original, status, output_combined, output_mask, output_file] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(debug=True, share = True) |