|
|
import numpy as np |
|
|
import cv2 |
|
|
from ultralytics import FastSAM |
|
|
import torch |
|
|
import gc |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"small": "./models/FastSAM-s.pt", |
|
|
"large": "./models/FastSAM-x.pt" |
|
|
} |
|
|
|
|
|
def clear_gpu_memory(): |
|
|
""" |
|
|
清理GPU显存 |
|
|
""" |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.ipc_collect() |
|
|
|
|
|
def get_model(model_size: str = "large"): |
|
|
""" |
|
|
获取指定大小的模型 |
|
|
""" |
|
|
if model_size not in MODELS: |
|
|
raise ValueError(f"Invalid model size. Available sizes: {list(MODELS.keys())}") |
|
|
|
|
|
try: |
|
|
return FastSAM(MODELS[model_size]) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load model: {str(e)}") |
|
|
|
|
|
def mask_to_points(mask: np.ndarray) -> list: |
|
|
""" |
|
|
Convert mask to a list of contour points |
|
|
|
|
|
Args: |
|
|
mask: 2D numpy array representing the mask |
|
|
|
|
|
Returns: |
|
|
list: Flattened list of points [x1, y1, x2, y2, ...] |
|
|
""" |
|
|
|
|
|
mask_uint8 = mask.astype(np.uint8) * 255 |
|
|
|
|
|
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
if not contours: |
|
|
return [] |
|
|
|
|
|
|
|
|
contour = max(contours, key=cv2.contourArea) |
|
|
|
|
|
points = [] |
|
|
for point in contour: |
|
|
points.extend([float(point[0][0]), float(point[0][1])]) |
|
|
return points |
|
|
|
|
|
def segment_image_with_prompt( |
|
|
image: np.ndarray, |
|
|
model_size: str = "large", |
|
|
conf: float = 0.4, |
|
|
iou: float = 0.9, |
|
|
bboxes: list = None, |
|
|
points: list = None, |
|
|
labels: list = None, |
|
|
texts: str = None |
|
|
): |
|
|
""" |
|
|
带提示的图像分割函数 |
|
|
|
|
|
Args: |
|
|
image: 输入图像 (numpy.ndarray) |
|
|
model_size: 模型大小 ("small" 或 "large") |
|
|
conf: 置信度阈值 |
|
|
iou: IoU 阈值 |
|
|
bboxes: 边界框列表 [x1, y1, x2, y2, ...] |
|
|
points: 点列表 [[x1, y1], [x2, y2], ...] |
|
|
labels: 点对应的标签列表 [0, 1, ...] |
|
|
texts: 文本提示 |
|
|
""" |
|
|
try: |
|
|
if image is None: |
|
|
raise ValueError("Invalid image input") |
|
|
|
|
|
|
|
|
model = get_model(model_size) |
|
|
|
|
|
|
|
|
model_args = { |
|
|
"device": "cpu", |
|
|
"retina_masks": True, |
|
|
"conf": conf, |
|
|
"iou": iou |
|
|
} |
|
|
|
|
|
|
|
|
if bboxes is not None: |
|
|
model_args["bboxes"] = bboxes |
|
|
if points is not None and labels is not None: |
|
|
model_args["points"] = points |
|
|
model_args["labels"] = labels |
|
|
if texts is not None: |
|
|
model_args["texts"] = texts |
|
|
|
|
|
|
|
|
everything_results = model(image, **model_args) |
|
|
|
|
|
|
|
|
segments = [] |
|
|
if everything_results and len(everything_results) > 0: |
|
|
result = everything_results[0] |
|
|
if hasattr(result, 'masks') and result.masks is not None: |
|
|
masks = result.masks.data.cpu().numpy() |
|
|
|
|
|
for mask in masks: |
|
|
points = mask_to_points(mask) |
|
|
if points: |
|
|
segments.append(points) |
|
|
|
|
|
|
|
|
del model |
|
|
del everything_results |
|
|
if hasattr(result, 'masks'): |
|
|
del result.masks |
|
|
del result |
|
|
|
|
|
|
|
|
return { |
|
|
"total_segments": len(segments), |
|
|
"segments": segments |
|
|
} |
|
|
except Exception as e: |
|
|
|
|
|
clear_gpu_memory() |
|
|
raise RuntimeError(f"Error processing image: {str(e)}") |
|
|
|