File size: 3,926 Bytes
faeea0b f37cd33 faeea0b f37cd33 faeea0b f37cd33 faeea0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import numpy as np
import cv2
from ultralytics import FastSAM
import torch
import gc
# 定义可用的模型
MODELS = {
"small": "./models/FastSAM-s.pt",
"large": "./models/FastSAM-x.pt"
}
def clear_gpu_memory():
"""
清理GPU显存
"""
gc.collect() # 清理Python的垃圾收集器
if torch.cuda.is_available():
torch.cuda.empty_cache() # 清空PyTorch的CUDA缓存
torch.cuda.ipc_collect() # 收集CUDA IPC内存
def get_model(model_size: str = "large"):
"""
获取指定大小的模型
"""
if model_size not in MODELS:
raise ValueError(f"Invalid model size. Available sizes: {list(MODELS.keys())}")
try:
return FastSAM(MODELS[model_size])
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
def mask_to_points(mask: np.ndarray) -> list:
"""
Convert mask to a list of contour points
Args:
mask: 2D numpy array representing the mask
Returns:
list: Flattened list of points [x1, y1, x2, y2, ...]
"""
# Convert mask to uint8 type
mask_uint8 = mask.astype(np.uint8) * 255
# Find contours
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return []
# Get the largest contour
contour = max(contours, key=cv2.contourArea)
# Convert contour points to flattened list
points = []
for point in contour:
points.extend([float(point[0][0]), float(point[0][1])])
return points
def segment_image_with_prompt(
image: np.ndarray,
model_size: str = "large",
conf: float = 0.4,
iou: float = 0.9,
bboxes: list = None,
points: list = None,
labels: list = None,
texts: str = None
):
"""
带提示的图像分割函数
Args:
image: 输入图像 (numpy.ndarray)
model_size: 模型大小 ("small" 或 "large")
conf: 置信度阈值
iou: IoU 阈值
bboxes: 边界框列表 [x1, y1, x2, y2, ...]
points: 点列表 [[x1, y1], [x2, y2], ...]
labels: 点对应的标签列表 [0, 1, ...]
texts: 文本提示
"""
try:
if image is None:
raise ValueError("Invalid image input")
# 获取模型并执行分割
model = get_model(model_size)
# 准备模型参数
model_args = {
"device": "cpu",
"retina_masks": True,
"conf": conf,
"iou": iou
}
# 添加提示参数
if bboxes is not None:
model_args["bboxes"] = bboxes
if points is not None and labels is not None:
model_args["points"] = points
model_args["labels"] = labels
if texts is not None:
model_args["texts"] = texts
# 执行分割
everything_results = model(image, **model_args)
# 处理分割结果
segments = []
if everything_results and len(everything_results) > 0:
result = everything_results[0]
if hasattr(result, 'masks') and result.masks is not None:
masks = result.masks.data.cpu().numpy()
for mask in masks:
points = mask_to_points(mask)
if points:
segments.append(points)
# 清理模型和GPU内存
del model
del everything_results
if hasattr(result, 'masks'):
del result.masks
del result
# clear_gpu_memory()
return {
"total_segments": len(segments),
"segments": segments
}
except Exception as e:
# 确保发生错误时也清理内存
clear_gpu_memory()
raise RuntimeError(f"Error processing image: {str(e)}")
|