BoSAM / app.py
ziyanlu's picture
Upload folder using huggingface_hub
9859ea2 verified
import os
import numpy as np
import torch
import SimpleITK as sitk
import torch.nn.functional as F
import cv2
from PIL import Image, ImageDraw, ImageOps
import tempfile
import gradio as gr
from segment_anything.build_sam3D import sam_model_registry3D
from utils.click_method import get_next_click3D_torch_ritm, get_next_click3D_torch_2
def build_model():
checkpoint_path = 'ckpt\\BoSAM.pth'
checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False)
state_dict = checkpoint['model_state_dict']
sam_model = sam_model_registry3D['vit_b_ori'](checkpoint=None).to('cuda')
sam_model.load_state_dict(state_dict)
return sam_model
def center_crop_or_pad(image_array, target_shape=(128, 128, 128)):
"""中心裁剪或填充图像到目标尺寸"""
current_shape = image_array.shape
start = [(c - t) // 2 if c > t else 0 for c, t in zip(current_shape, target_shape)]
end = [s + t if c > t else c for s, t, c in zip(start, target_shape, current_shape)]
result = np.zeros(target_shape, dtype=image_array.dtype)
target_start = [0 if c > t else (t - c) // 2 for c, t in zip(current_shape, target_shape)]
target_end = [t if c > t else ts + c for ts, c, t in zip(target_start, current_shape, target_shape)]
if all(c >= t for c, t in zip(current_shape, target_shape)):
cropped = image_array[
start[0]:start[0]+target_shape[0],
start[1]:start[1]+target_shape[1],
start[2]:start[2]+target_shape[2]
]
return cropped
else:
source_slices = tuple(slice(0 if c <= t else s, c if c <= t else e)
for s, e, c, t in zip(start, end, current_shape, target_shape))
target_slices = tuple(slice(ts, te)
for ts, te in zip(target_start, target_end))
result[target_slices] = image_array[source_slices]
return result
def preprocess_image(image_path):
"""预处理图像为128x128x128"""
image = sitk.ReadImage(image_path)
image_array = sitk.GetArrayFromImage(image)
processed_array = center_crop_or_pad(image_array, (128, 128, 128))
image_tensor = torch.tensor(processed_array).float().unsqueeze(0).unsqueeze(0)
return image_tensor.to('cuda')
def load_gt3d(image3d_path):
"""加载并预处理GT标签为128x128x128"""
gt3d_path = r'examples\labels\1.3.6.1.4.1.9328.50.4.0357.nii.gz' # 使用固定的GT
if not os.path.exists(gt3d_path):
raise FileNotFoundError(f"The file {gt3d_path} does not exist.")
image = sitk.ReadImage(gt3d_path)
image_array = sitk.GetArrayFromImage(image)
processed_array = center_crop_or_pad(image_array, (128, 128, 128))
gt_tensor = torch.tensor(processed_array).float().unsqueeze(0).unsqueeze(0)
return gt_tensor.to('cuda')
def overlay_mask_on_image(image_slice, mask_slice, alpha=0.6):
"""在图像切片上叠加掩码,增强视觉效果"""
# 增强对比度
p2, p98 = np.percentile(image_slice, (2, 98))
image_contrast = np.clip((image_slice - p2) / (p98 - p2), 0, 1)
image_contrast = (image_contrast * 255).astype(np.uint8)
# 创建彩色图像
image_rgb = Image.fromarray(image_contrast).convert("RGB")
# 应用轻微的锐化和增强
enhancer = ImageOps.autocontrast(image_rgb, cutoff=0.5)
image_rgba = enhancer.convert("RGBA")
# 创建更鲜明的掩码颜色
mask_image = Image.new('RGBA', image_rgba.size, (0, 0, 0, 0))
mask_draw = ImageDraw.Draw(mask_image)
mask = (mask_slice > 0).astype(np.uint8) * 255
mask_pil = Image.fromarray(mask, mode='L')
# 使用高饱和度的蓝色
mask_draw.bitmap((0, 0), mask_pil, fill=(41, 128, 255, int(255 * alpha)))
# 叠加并添加轻微的发光效果
combined_image = Image.alpha_composite(image_rgba, mask_image)
return combined_image
def predict(image3D, sam_model, points=None, prev_masks=None, num_clicks=5):
"""使用SAM模型预测掩码"""
sam_model.eval()
image3D = (image3D - image3D.mean()) / image3D.std()
gt3D = load_gt3d(None)
if prev_masks is None:
prev_masks = torch.zeros_like(image3D).to('cuda')
low_res_masks = F.interpolate(prev_masks.float(), size=(32, 32, 32))
with torch.no_grad():
image_embedding = sam_model.image_encoder(image3D)
for num_click in range(num_clicks):
with torch.no_grad():
batch_points, batch_labels = get_next_click3D_torch_2(prev_masks.to('cuda'), gt3D.to('cuda'))
points_co = torch.cat(batch_points, dim=0).to('cuda')
points_la = torch.cat(batch_labels, dim=0).to('cuda')
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
points=[points_co, points_la],
boxes=None,
masks=low_res_masks.to('cuda'),
)
low_res_masks, iou_predictions = sam_model.mask_decoder(
image_embeddings=image_embedding.to('cuda'),
image_pe=sam_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
prev_masks = F.interpolate(low_res_masks, size=[128, 128, 128], mode='trilinear', align_corners=False)
medsam_seg_prob = torch.sigmoid(prev_masks)
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
return medsam_seg, medsam_seg_prob
def normalize_image(image):
"""增强图像对比度"""
# 使用百分位数来增强对比度
p2, p98 = np.percentile(image, (2, 98))
if p98 - p2 != 0:
image = np.clip((image - p2) / (p98 - p2), 0, 1)
else:
image = np.zeros_like(image)
image = (image * 255).astype(np.uint8)
return image
def predicts(img_path, sam_model):
"""预处理图像并预测"""
img = preprocess_image(img_path)
prediction, prediction_prob = predict(img, sam_model)
return prediction, prediction_prob
def save_nifti(prediction, original_image_path):
"""保存预测结果为NIFTI文件"""
original_image = sitk.ReadImage(original_image_path)
output_image = sitk.GetImageFromArray(prediction.astype(np.uint8))
output_image.SetDirection(original_image.GetDirection())
output_image.SetOrigin(original_image.GetOrigin())
original_size = original_image.GetSize()
original_spacing = original_image.GetSpacing()
new_spacing = [
original_spacing[0] * (original_size[0] / 128),
original_spacing[1] * (original_size[1] / 128),
original_spacing[2] * (original_size[2] / 128)
]
output_image.SetSpacing(new_spacing)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz")
temp_filename = temp_file.name
sitk.WriteImage(output_image, temp_filename)
return temp_filename
def gr_interface(img_path, sam_model=None):
"""增强的Gradio界面函数"""
if sam_model is None:
sam_model = build_model()
# 显示进度信息
yield None, gr.update(value="正在加载数据..."), None, None, None
processed_img = preprocess_image(img_path)
yield None, gr.update(value="正在分割..."), None, None, None
prediction, prediction_prob = predicts(img_path, sam_model)
yield None, gr.update(value="正在生成可视化..."), None, None, None
processed_slices = []
combined_slices = []
predicted_slices = []
nifti_file_path = save_nifti(prediction, img_path)
# 计算中心32张切片的索引
start_idx = (128 - 32) // 2 # 48
end_idx = start_idx + 32 # 80
for i in range(start_idx, end_idx):
# 处理原始图像切片
processed_slice = processed_img[0, 0, i].cpu().numpy()
processed_slices.append(normalize_image(processed_slice))
# 处理预测掩码切片
mask_slice = prediction[i]
normalized_mask = normalize_image(mask_slice)
# 叠加掩码到图像上 - 使用更醒目的视觉效果
combined_image = overlay_mask_on_image(processed_slices[-1], mask_slice)
combined_slices.append(combined_image)
# 添加预测切片
predicted_slices.append(normalized_mask)
yield processed_slices, gr.update(value="分割完成!"), combined_slices, predicted_slices, nifti_file_path
# 使用示例文件路径作为常量
DEFAULT_EXAMPLE = "examples\\1.3.6.1.4.1.9328.50.4.0327.nii.gz"
EXAMPLES = [
["examples\\1.3.6.1.4.1.9328.50.4.0327.nii.gz"],
["examples\\1.3.6.1.4.1.9328.50.4.0357.nii.gz"],
["examples\\1.3.6.1.4.1.9328.50.4.0477.nii.gz"],
["examples\\1.3.6.1.4.1.9328.50.4.0491.nii.gz"],
["examples\\1.3.6.1.4.1.9328.50.4.0708.nii.gz"],
["examples\\1.3.6.1.4.1.9328.50.4.0719.nii.gz"]
]
# 自定义CSS样式以美化界面
css = """
body {
background-color: #f8fafc;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
.main-title {
text-align: center;
color: #2563eb;
font-size: 2.5rem;
margin-bottom: 1rem;
font-weight: bold;
animation: fadeIn 1.5s ease-in-out;
}
.sub-title {
text-align: center;
color: #1e293b;
margin-bottom: 2rem;
animation: fadeIn 2s ease-in-out;
}
.custom-button {
background-color: #2563eb !important;
color: white !important;
transition: transform 0.3s, box-shadow 0.3s;
}
.custom-button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
}
.gallery-item {
border-radius: 8px;
overflow: hidden;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
transition: transform 0.3s;
}
.gallery-item:hover {
transform: scale(1.02);
box-shadow: 0 6px 12px rgba(0, 0, 0, 0.15);
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
"""
# 加载模型(全局变量)
sam_model = build_model()
# 创建更美观的Gradio界面,使用兼容的组件
with gr.Blocks(title="3D医学影像智能分割系统", css=css) as demo:
gr.HTML("<h1 class='main-title'>3D医学影像智能分割系统</h1>")
gr.HTML("<p class='sub-title'>基于BoSAM的前沿人工智能自动分割技术,为医学影像分析提供高精度解决方案</p>")
with gr.Row():
with gr.Column(scale=1):
# 输入区域
gr.Markdown("### 上传/选择影像")
input_file = gr.File(label="上传NIfTI文件", value=DEFAULT_EXAMPLE)
status = gr.Textbox(label="处理状态", value="准备就绪")
process_btn = gr.Button("开始智能分割", elem_classes=["custom-button"])
# 示例区域
gr.Markdown("### 示例数据")
examples = gr.Examples(
examples=EXAMPLES,
inputs=[input_file]
)
gr.HTML("""
<div style="margin-top: 2rem; padding: 1rem; background-color: rgba(16, 185, 129, 0.1); border-radius: 8px;">
<h3 style="color: #10b981; margin-bottom: 0.5rem;">技术亮点</h3>
<ul style="margin-left: 1.5rem;">
<li>基于最新的Segment Anything Model (SAM) 技术</li>
<li>专为3D医学影像优化的深度学习模型</li>
<li>智能识别解剖结构,无需手动绘制边界</li>
<li>高精度分割结果,提升诊断效率</li>
</ul>
</div>
""")
with gr.Column(scale=2):
# 输出区域
with gr.Row():
gr.Markdown("## 原始医学影像")
output_original = gr.Gallery(label="", show_label=False, columns=4, rows=8, height="600px", elem_classes=["gallery-item"])
with gr.Row():
with gr.Column():
gr.Markdown("## 分割叠加结果")
output_combined = gr.Gallery(label="", show_label=False, columns=4, rows=4, height="400px", elem_classes=["gallery-item"])
with gr.Column():
gr.Markdown("## 分割掩码")
output_mask = gr.Gallery(label="", show_label=False, columns=4, rows=4, height="400px", elem_classes=["gallery-item"])
gr.Markdown("## 分割结果下载")
output_file = gr.File(label="下载完整3D分割结果 (NIFTI格式)")
gr.HTML("""
<div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid rgba(0, 0, 0, 0.1);">
<p>© 2025 3D医学影像智能分割系统 | 人工智能辅助医学影像分析平台</p>
<p>基于最新的BoaSAM模型,为医疗影像分析提供高精度自动分割解决方案</p>
</div>
""")
# 处理事件
process_btn.click(
fn=gr_interface,
inputs=[input_file],
outputs=[output_original, status, output_combined, output_mask, output_file]
)
examples.dataset.click(
fn=gr_interface,
inputs=[input_file],
outputs=[output_original, status, output_combined, output_mask, output_file]
)
demo.load(
fn=gr_interface,
inputs=[input_file],
outputs=[output_original, status, output_combined, output_mask, output_file]
)
if __name__ == "__main__":
demo.launch(debug=True, share = True)