robot2no1 commited on
Commit
faeea0b
·
verified ·
1 Parent(s): 3ce1368

Create sam_segment.py

Browse files
Files changed (1) hide show
  1. sam_segment.py +136 -0
sam_segment.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from ultralytics import FastSAM
4
+ import torch
5
+ import gc
6
+
7
+ # 定义可用的模型
8
+ MODELS = {
9
+ "small": "/disk2/models/FastSAM-s.pt",
10
+ "large": "/disk2/models/FastSAM-x.pt"
11
+ }
12
+
13
+ def clear_gpu_memory():
14
+ """
15
+ 清理GPU显存
16
+ """
17
+ gc.collect() # 清理Python的垃圾收集器
18
+ if torch.cuda.is_available():
19
+ torch.cuda.empty_cache() # 清空PyTorch的CUDA缓存
20
+ torch.cuda.ipc_collect() # 收集CUDA IPC内存
21
+
22
+ def get_model(model_size: str = "large"):
23
+ """
24
+ 获取指定大小的模型
25
+ """
26
+ if model_size not in MODELS:
27
+ raise ValueError(f"Invalid model size. Available sizes: {list(MODELS.keys())}")
28
+
29
+ try:
30
+ return FastSAM(MODELS[model_size])
31
+ except Exception as e:
32
+ raise RuntimeError(f"Failed to load model: {str(e)}")
33
+
34
+ def mask_to_points(mask: np.ndarray) -> list:
35
+ """
36
+ Convert mask to a list of contour points
37
+
38
+ Args:
39
+ mask: 2D numpy array representing the mask
40
+
41
+ Returns:
42
+ list: Flattened list of points [x1, y1, x2, y2, ...]
43
+ """
44
+ # Convert mask to uint8 type
45
+ mask_uint8 = mask.astype(np.uint8) * 255
46
+ # Find contours
47
+ contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
48
+
49
+ if not contours:
50
+ return []
51
+
52
+ # Get the largest contour
53
+ contour = max(contours, key=cv2.contourArea)
54
+ # Convert contour points to flattened list
55
+ points = []
56
+ for point in contour:
57
+ points.extend([float(point[0][0]), float(point[0][1])])
58
+ return points
59
+
60
+ def segment_image_with_prompt(
61
+ image: np.ndarray,
62
+ model_size: str = "large",
63
+ conf: float = 0.4,
64
+ iou: float = 0.9,
65
+ bboxes: list = None,
66
+ points: list = None,
67
+ labels: list = None,
68
+ texts: str = None
69
+ ):
70
+ """
71
+ 带提示的图像分割函数
72
+
73
+ Args:
74
+ image: 输入图像 (numpy.ndarray)
75
+ model_size: 模型大小 ("small" 或 "large")
76
+ conf: 置信度阈值
77
+ iou: IoU 阈值
78
+ bboxes: 边界框列表 [x1, y1, x2, y2, ...]
79
+ points: 点列表 [[x1, y1], [x2, y2], ...]
80
+ labels: 点对应的标签列表 [0, 1, ...]
81
+ texts: 文本提示
82
+ """
83
+ try:
84
+ if image is None:
85
+ raise ValueError("Invalid image input")
86
+
87
+ # 获取模型并执行分割
88
+ model = get_model(model_size)
89
+
90
+ # 准备模型参数
91
+ model_args = {
92
+ "retina_masks": True,
93
+ "conf": conf,
94
+ "iou": iou
95
+ }
96
+
97
+ # 添加提示参数
98
+ if bboxes is not None:
99
+ model_args["bboxes"] = bboxes
100
+ if points is not None and labels is not None:
101
+ model_args["points"] = points
102
+ model_args["labels"] = labels
103
+ if texts is not None:
104
+ model_args["texts"] = texts
105
+
106
+ # 执行分割
107
+ everything_results = model(image, **model_args)
108
+
109
+ # 处理分割结果
110
+ segments = []
111
+ if everything_results and len(everything_results) > 0:
112
+ result = everything_results[0]
113
+ if hasattr(result, 'masks') and result.masks is not None:
114
+ masks = result.masks.data.cpu().numpy()
115
+
116
+ for mask in masks:
117
+ points = mask_to_points(mask)
118
+ if points:
119
+ segments.append(points)
120
+
121
+ # 清理模型和GPU内存
122
+ del model
123
+ del everything_results
124
+ if hasattr(result, 'masks'):
125
+ del result.masks
126
+ del result
127
+ clear_gpu_memory()
128
+
129
+ return {
130
+ "total_segments": len(segments),
131
+ "segments": segments
132
+ }
133
+ except Exception as e:
134
+ # 确保发生错误时也清理内存
135
+ clear_gpu_memory()
136
+ raise RuntimeError(f"Error processing image: {str(e)}")