File size: 3,926 Bytes
faeea0b
 
 
 
 
 
 
 
f37cd33
 
faeea0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f37cd33
faeea0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f37cd33
faeea0b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
import cv2
from ultralytics import FastSAM
import torch
import gc

# 定义可用的模型
MODELS = {
    "small": "./models/FastSAM-s.pt",
    "large": "./models/FastSAM-x.pt"
}

def clear_gpu_memory():
    """
    清理GPU显存
    """
    gc.collect()  # 清理Python的垃圾收集器
    if torch.cuda.is_available():
        torch.cuda.empty_cache()  # 清空PyTorch的CUDA缓存
        torch.cuda.ipc_collect()  # 收集CUDA IPC内存

def get_model(model_size: str = "large"):
    """
    获取指定大小的模型
    """
    if model_size not in MODELS:
        raise ValueError(f"Invalid model size. Available sizes: {list(MODELS.keys())}")
    
    try:
        return FastSAM(MODELS[model_size])
    except Exception as e:
        raise RuntimeError(f"Failed to load model: {str(e)}")

def mask_to_points(mask: np.ndarray) -> list:
    """
    Convert mask to a list of contour points
    
    Args:
        mask: 2D numpy array representing the mask
    
    Returns:
        list: Flattened list of points [x1, y1, x2, y2, ...]
    """
    # Convert mask to uint8 type
    mask_uint8 = mask.astype(np.uint8) * 255
    # Find contours
    contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if not contours:
        return []
    
    # Get the largest contour
    contour = max(contours, key=cv2.contourArea)
    # Convert contour points to flattened list
    points = []
    for point in contour:
        points.extend([float(point[0][0]), float(point[0][1])])
    return points

def segment_image_with_prompt(
    image: np.ndarray,
    model_size: str = "large",
    conf: float = 0.4,
    iou: float = 0.9,
    bboxes: list = None,
    points: list = None,
    labels: list = None,
    texts: str = None
):
    """
    带提示的图像分割函数
    
    Args:
        image: 输入图像 (numpy.ndarray)
        model_size: 模型大小 ("small" 或 "large")
        conf: 置信度阈值
        iou: IoU 阈值
        bboxes: 边界框列表 [x1, y1, x2, y2, ...]
        points: 点列表 [[x1, y1], [x2, y2], ...]
        labels: 点对应的标签列表 [0, 1, ...]
        texts: 文本提示
    """
    try:
        if image is None:
            raise ValueError("Invalid image input")
        
        # 获取模型并执行分割
        model = get_model(model_size)
        
        # 准备模型参数
        model_args = {
            "device": "cpu",
            "retina_masks": True,
            "conf": conf,
            "iou": iou
        }
        
        # 添加提示参数
        if bboxes is not None:
            model_args["bboxes"] = bboxes
        if points is not None and labels is not None:
            model_args["points"] = points
            model_args["labels"] = labels
        if texts is not None:
            model_args["texts"] = texts
            
        # 执行分割
        everything_results = model(image, **model_args)
        
        # 处理分割结果
        segments = []
        if everything_results and len(everything_results) > 0:
            result = everything_results[0]
            if hasattr(result, 'masks') and result.masks is not None:
                masks = result.masks.data.cpu().numpy()
                
                for mask in masks:
                    points = mask_to_points(mask)
                    if points:
                        segments.append(points)
        
        # 清理模型和GPU内存
        del model
        del everything_results
        if hasattr(result, 'masks'):
            del result.masks
        del result
        # clear_gpu_memory()
        
        return {
            "total_segments": len(segments),
            "segments": segments
        }
    except Exception as e:
        # 确保发生错误时也清理内存
        clear_gpu_memory()
        raise RuntimeError(f"Error processing image: {str(e)}")