import os import numpy as np import torch import SimpleITK as sitk import torch.nn.functional as F import cv2 from PIL import Image, ImageDraw, ImageOps import tempfile import gradio as gr from segment_anything.build_sam3D import sam_model_registry3D from utils.click_method import get_next_click3D_torch_ritm, get_next_click3D_torch_2 def build_model(): checkpoint_path = 'ckpt\\BoSAM.pth' checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False) state_dict = checkpoint['model_state_dict'] sam_model = sam_model_registry3D['vit_b_ori'](checkpoint=None).to('cuda') sam_model.load_state_dict(state_dict) return sam_model def center_crop_or_pad(image_array, target_shape=(128, 128, 128)): """中心裁剪或填充图像到目标尺寸""" current_shape = image_array.shape start = [(c - t) // 2 if c > t else 0 for c, t in zip(current_shape, target_shape)] end = [s + t if c > t else c for s, t, c in zip(start, target_shape, current_shape)] result = np.zeros(target_shape, dtype=image_array.dtype) target_start = [0 if c > t else (t - c) // 2 for c, t in zip(current_shape, target_shape)] target_end = [t if c > t else ts + c for ts, c, t in zip(target_start, current_shape, target_shape)] if all(c >= t for c, t in zip(current_shape, target_shape)): cropped = image_array[ start[0]:start[0]+target_shape[0], start[1]:start[1]+target_shape[1], start[2]:start[2]+target_shape[2] ] return cropped else: source_slices = tuple(slice(0 if c <= t else s, c if c <= t else e) for s, e, c, t in zip(start, end, current_shape, target_shape)) target_slices = tuple(slice(ts, te) for ts, te in zip(target_start, target_end)) result[target_slices] = image_array[source_slices] return result def preprocess_image(image_path): """预处理图像为128x128x128""" image = sitk.ReadImage(image_path) image_array = sitk.GetArrayFromImage(image) processed_array = center_crop_or_pad(image_array, (128, 128, 128)) image_tensor = torch.tensor(processed_array).float().unsqueeze(0).unsqueeze(0) return image_tensor.to('cuda') def load_gt3d(image3d_path): """加载并预处理GT标签为128x128x128""" gt3d_path = r'examples\labels\1.3.6.1.4.1.9328.50.4.0357.nii.gz' # 使用固定的GT if not os.path.exists(gt3d_path): raise FileNotFoundError(f"The file {gt3d_path} does not exist.") image = sitk.ReadImage(gt3d_path) image_array = sitk.GetArrayFromImage(image) processed_array = center_crop_or_pad(image_array, (128, 128, 128)) gt_tensor = torch.tensor(processed_array).float().unsqueeze(0).unsqueeze(0) return gt_tensor.to('cuda') def overlay_mask_on_image(image_slice, mask_slice, alpha=0.6): """在图像切片上叠加掩码,增强视觉效果""" # 增强对比度 p2, p98 = np.percentile(image_slice, (2, 98)) image_contrast = np.clip((image_slice - p2) / (p98 - p2), 0, 1) image_contrast = (image_contrast * 255).astype(np.uint8) # 创建彩色图像 image_rgb = Image.fromarray(image_contrast).convert("RGB") # 应用轻微的锐化和增强 enhancer = ImageOps.autocontrast(image_rgb, cutoff=0.5) image_rgba = enhancer.convert("RGBA") # 创建更鲜明的掩码颜色 mask_image = Image.new('RGBA', image_rgba.size, (0, 0, 0, 0)) mask_draw = ImageDraw.Draw(mask_image) mask = (mask_slice > 0).astype(np.uint8) * 255 mask_pil = Image.fromarray(mask, mode='L') # 使用高饱和度的蓝色 mask_draw.bitmap((0, 0), mask_pil, fill=(41, 128, 255, int(255 * alpha))) # 叠加并添加轻微的发光效果 combined_image = Image.alpha_composite(image_rgba, mask_image) return combined_image def predict(image3D, sam_model, points=None, prev_masks=None, num_clicks=5): """使用SAM模型预测掩码""" sam_model.eval() image3D = (image3D - image3D.mean()) / image3D.std() gt3D = load_gt3d(None) if prev_masks is None: prev_masks = torch.zeros_like(image3D).to('cuda') low_res_masks = F.interpolate(prev_masks.float(), size=(32, 32, 32)) with torch.no_grad(): image_embedding = sam_model.image_encoder(image3D) for num_click in range(num_clicks): with torch.no_grad(): batch_points, batch_labels = get_next_click3D_torch_2(prev_masks.to('cuda'), gt3D.to('cuda')) points_co = torch.cat(batch_points, dim=0).to('cuda') points_la = torch.cat(batch_labels, dim=0).to('cuda') sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( points=[points_co, points_la], boxes=None, masks=low_res_masks.to('cuda'), ) low_res_masks, iou_predictions = sam_model.mask_decoder( image_embeddings=image_embedding.to('cuda'), image_pe=sam_model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) prev_masks = F.interpolate(low_res_masks, size=[128, 128, 128], mode='trilinear', align_corners=False) medsam_seg_prob = torch.sigmoid(prev_masks) medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8) return medsam_seg, medsam_seg_prob def normalize_image(image): """增强图像对比度""" # 使用百分位数来增强对比度 p2, p98 = np.percentile(image, (2, 98)) if p98 - p2 != 0: image = np.clip((image - p2) / (p98 - p2), 0, 1) else: image = np.zeros_like(image) image = (image * 255).astype(np.uint8) return image def predicts(img_path, sam_model): """预处理图像并预测""" img = preprocess_image(img_path) prediction, prediction_prob = predict(img, sam_model) return prediction, prediction_prob def save_nifti(prediction, original_image_path): """保存预测结果为NIFTI文件""" original_image = sitk.ReadImage(original_image_path) output_image = sitk.GetImageFromArray(prediction.astype(np.uint8)) output_image.SetDirection(original_image.GetDirection()) output_image.SetOrigin(original_image.GetOrigin()) original_size = original_image.GetSize() original_spacing = original_image.GetSpacing() new_spacing = [ original_spacing[0] * (original_size[0] / 128), original_spacing[1] * (original_size[1] / 128), original_spacing[2] * (original_size[2] / 128) ] output_image.SetSpacing(new_spacing) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz") temp_filename = temp_file.name sitk.WriteImage(output_image, temp_filename) return temp_filename def gr_interface(img_path, sam_model=None): """增强的Gradio界面函数""" if sam_model is None: sam_model = build_model() # 显示进度信息 yield None, gr.update(value="正在加载数据..."), None, None, None processed_img = preprocess_image(img_path) yield None, gr.update(value="正在分割..."), None, None, None prediction, prediction_prob = predicts(img_path, sam_model) yield None, gr.update(value="正在生成可视化..."), None, None, None processed_slices = [] combined_slices = [] predicted_slices = [] nifti_file_path = save_nifti(prediction, img_path) # 计算中心32张切片的索引 start_idx = (128 - 32) // 2 # 48 end_idx = start_idx + 32 # 80 for i in range(start_idx, end_idx): # 处理原始图像切片 processed_slice = processed_img[0, 0, i].cpu().numpy() processed_slices.append(normalize_image(processed_slice)) # 处理预测掩码切片 mask_slice = prediction[i] normalized_mask = normalize_image(mask_slice) # 叠加掩码到图像上 - 使用更醒目的视觉效果 combined_image = overlay_mask_on_image(processed_slices[-1], mask_slice) combined_slices.append(combined_image) # 添加预测切片 predicted_slices.append(normalized_mask) yield processed_slices, gr.update(value="分割完成!"), combined_slices, predicted_slices, nifti_file_path # 使用示例文件路径作为常量 DEFAULT_EXAMPLE = "examples\\1.3.6.1.4.1.9328.50.4.0327.nii.gz" EXAMPLES = [ ["examples\\1.3.6.1.4.1.9328.50.4.0327.nii.gz"], ["examples\\1.3.6.1.4.1.9328.50.4.0357.nii.gz"], ["examples\\1.3.6.1.4.1.9328.50.4.0477.nii.gz"], ["examples\\1.3.6.1.4.1.9328.50.4.0491.nii.gz"], ["examples\\1.3.6.1.4.1.9328.50.4.0708.nii.gz"], ["examples\\1.3.6.1.4.1.9328.50.4.0719.nii.gz"] ] # 自定义CSS样式以美化界面 css = """ body { background-color: #f8fafc; } .container { max-width: 1200px; margin: 0 auto; } .main-title { text-align: center; color: #2563eb; font-size: 2.5rem; margin-bottom: 1rem; font-weight: bold; animation: fadeIn 1.5s ease-in-out; } .sub-title { text-align: center; color: #1e293b; margin-bottom: 2rem; animation: fadeIn 2s ease-in-out; } .custom-button { background-color: #2563eb !important; color: white !important; transition: transform 0.3s, box-shadow 0.3s; } .custom-button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); } .gallery-item { border-radius: 8px; overflow: hidden; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); transition: transform 0.3s; } .gallery-item:hover { transform: scale(1.02); box-shadow: 0 6px 12px rgba(0, 0, 0, 0.15); } @keyframes fadeIn { from { opacity: 0; transform: translateY(20px); } to { opacity: 1; transform: translateY(0); } } """ # 加载模型(全局变量) sam_model = build_model() # 创建更美观的Gradio界面,使用兼容的组件 with gr.Blocks(title="3D医学影像智能分割系统", css=css) as demo: gr.HTML("
基于BoSAM的前沿人工智能自动分割技术,为医学影像分析提供高精度解决方案
") with gr.Row(): with gr.Column(scale=1): # 输入区域 gr.Markdown("### 上传/选择影像") input_file = gr.File(label="上传NIfTI文件", value=DEFAULT_EXAMPLE) status = gr.Textbox(label="处理状态", value="准备就绪") process_btn = gr.Button("开始智能分割", elem_classes=["custom-button"]) # 示例区域 gr.Markdown("### 示例数据") examples = gr.Examples( examples=EXAMPLES, inputs=[input_file] ) gr.HTML("""© 2025 3D医学影像智能分割系统 | 人工智能辅助医学影像分析平台
基于最新的BoaSAM模型,为医疗影像分析提供高精度自动分割解决方案