fastSAM / sam_segment.py
robot2no1's picture
Update sam_segment.py
f37cd33 verified
import numpy as np
import cv2
from ultralytics import FastSAM
import torch
import gc
# 定义可用的模型
MODELS = {
"small": "./models/FastSAM-s.pt",
"large": "./models/FastSAM-x.pt"
}
def clear_gpu_memory():
"""
清理GPU显存
"""
gc.collect() # 清理Python的垃圾收集器
if torch.cuda.is_available():
torch.cuda.empty_cache() # 清空PyTorch的CUDA缓存
torch.cuda.ipc_collect() # 收集CUDA IPC内存
def get_model(model_size: str = "large"):
"""
获取指定大小的模型
"""
if model_size not in MODELS:
raise ValueError(f"Invalid model size. Available sizes: {list(MODELS.keys())}")
try:
return FastSAM(MODELS[model_size])
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
def mask_to_points(mask: np.ndarray) -> list:
"""
Convert mask to a list of contour points
Args:
mask: 2D numpy array representing the mask
Returns:
list: Flattened list of points [x1, y1, x2, y2, ...]
"""
# Convert mask to uint8 type
mask_uint8 = mask.astype(np.uint8) * 255
# Find contours
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return []
# Get the largest contour
contour = max(contours, key=cv2.contourArea)
# Convert contour points to flattened list
points = []
for point in contour:
points.extend([float(point[0][0]), float(point[0][1])])
return points
def segment_image_with_prompt(
image: np.ndarray,
model_size: str = "large",
conf: float = 0.4,
iou: float = 0.9,
bboxes: list = None,
points: list = None,
labels: list = None,
texts: str = None
):
"""
带提示的图像分割函数
Args:
image: 输入图像 (numpy.ndarray)
model_size: 模型大小 ("small" 或 "large")
conf: 置信度阈值
iou: IoU 阈值
bboxes: 边界框列表 [x1, y1, x2, y2, ...]
points: 点列表 [[x1, y1], [x2, y2], ...]
labels: 点对应的标签列表 [0, 1, ...]
texts: 文本提示
"""
try:
if image is None:
raise ValueError("Invalid image input")
# 获取模型并执行分割
model = get_model(model_size)
# 准备模型参数
model_args = {
"device": "cpu",
"retina_masks": True,
"conf": conf,
"iou": iou
}
# 添加提示参数
if bboxes is not None:
model_args["bboxes"] = bboxes
if points is not None and labels is not None:
model_args["points"] = points
model_args["labels"] = labels
if texts is not None:
model_args["texts"] = texts
# 执行分割
everything_results = model(image, **model_args)
# 处理分割结果
segments = []
if everything_results and len(everything_results) > 0:
result = everything_results[0]
if hasattr(result, 'masks') and result.masks is not None:
masks = result.masks.data.cpu().numpy()
for mask in masks:
points = mask_to_points(mask)
if points:
segments.append(points)
# 清理模型和GPU内存
del model
del everything_results
if hasattr(result, 'masks'):
del result.masks
del result
# clear_gpu_memory()
return {
"total_segments": len(segments),
"segments": segments
}
except Exception as e:
# 确保发生错误时也清理内存
clear_gpu_memory()
raise RuntimeError(f"Error processing image: {str(e)}")