robot2no1 commited on
Commit
840f4e2
·
verified ·
1 Parent(s): 43425a5

sam_segment.py

Browse files

import gradio as gr
import cv2
import numpy as np
from sam_segment import segment_image_with_prompt

# 预定义分割颜色组
SEGMENT_COLORS = [
((255, 99, 71), (255, 99, 71)), # 红橙色
((65, 105, 225), (65, 105, 225)), # 皇家蓝
((50, 205, 50), (50, 205, 50)), # 酸橙绿
((255, 215, 0), (255, 215, 0)), # 金色
((238, 130, 238), (238, 130, 238)), # 紫罗兰
((0, 191, 255), (0, 191, 255)), # 深天蓝
((255, 165, 0), (255, 165, 0)), # 橙色
((106, 90, 205), (106, 90, 205)), # 石板蓝
]

def segment_image(input_image, model_size, conf_threshold, iou_threshold):
"""
使用FastSAM模型对输入图片进行分割
"""
try:
# 进行预测
results = segment_image_with_prompt(
image=input_image,
model_size=model_size,
conf=conf_threshold,
iou=iou_threshold,
)

# 创建输出图像的副本
output_image = input_image.copy()

# 获取图像尺寸
h, w = output_image.shape[:2]

# 创建一个总的遮罩层和一个累积掩码
final_mask = np.zeros_like(output_image)
accumulated_mask = np.zeros((h, w), dtype=np.uint8)

# 为每个分割结果创建掩码
for idx, points in enumerate(results["segments"]):
# 将点列表转换为轮廓格式
contour_points = np.array(points).reshape(-1, 2).astype(np.int32)

# 创建空白掩码
mask = np.zeros((h, w), dtype=np.uint8)

# 填充轮廓
cv2.fillPoly(mask, [contour_points], 1)

# 更新累积掩码(避免重叠区域重复计算)
mask = cv2.bitwise_and(mask, cv2.bitwise_not(accumulated_mask))
accumulated_mask = cv2.bitwise_or(accumulated_mask, mask)

# 使用预定义的颜色(循环使用)
color_idx = idx % len(SEGMENT_COLORS)
fill_color, stroke_color = SEGMENT_COLORS[color_idx]

# 创建填充区域(半透明)
fill_mask = np.zeros_like(output_image)
fill_mask[mask > 0] = fill_color
final_mask = cv2.addWeighted(final_mask, 1.0, fill_mask, 0.3, 0)

# 绘制轮廓线
cv2.drawContours(final_mask, [contour_points], -1, stroke_color, 2)

# 混合原图和掩码
output_image = cv2.addWeighted(output_image, 1.0, final_mask, 0.5, 0)

return output_image

except Exception as e:
print(f"分割过程中出错: {str(e)}")
return input_image

# 创建Gradio界面
demo = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(label="输入图片"),
gr.Radio(
choices=["small", "large"],
value="large",
label="模型大小",
info="small: 更快但精度较低, large: 更慢但精度更高"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.4,
step=0.1,
label="置信度阈值",
info="值越高,检测越严格"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.3, # 降低默认值,使其能显示更多区域
step=0.1,
label="IoU阈值",
info="值越低则保留更多重叠区域,值越高则保留更少重叠区域"
)
],
outputs=gr.Image(label="分割结果"),
title="FastSAM图像分割演示",
description="上传一张图片,调整参数,模型将对图片中的对象进行分割。",
examples=[
[
"https://3vj-render.3vjia.com//UpFile_Render/C00006070/PMC/DesignSchemeRenderFile/20240831/592351213526335564/43f8d835b3a54869a34167ed7f2a27aa.jpg?x-oss-process=image/resize,m_fill,h_730,w_1220", # 图片路径
"large", # 模型大小
0.4, # 置信度阈值
0.3 # IoU阈值,降低默认值
]
]
)

# 启动应用
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")

Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from sam_segment import segment_image_with_prompt
5
+
6
+ # 预定义分割颜色组
7
+ SEGMENT_COLORS = [
8
+ ((255, 99, 71), (255, 99, 71)), # 红橙色
9
+ ((65, 105, 225), (65, 105, 225)), # 皇家蓝
10
+ ((50, 205, 50), (50, 205, 50)), # 酸橙绿
11
+ ((255, 215, 0), (255, 215, 0)), # 金色
12
+ ((238, 130, 238), (238, 130, 238)), # 紫罗兰
13
+ ((0, 191, 255), (0, 191, 255)), # 深天蓝
14
+ ((255, 165, 0), (255, 165, 0)), # 橙色
15
+ ((106, 90, 205), (106, 90, 205)), # 石板蓝
16
+ ]
17
+
18
+ def segment_image(input_image, model_size, conf_threshold, iou_threshold):
19
+ """
20
+ 使用FastSAM模型对输入图片进行分割
21
+ """
22
+ try:
23
+ # 进行预测
24
+ results = segment_image_with_prompt(
25
+ image=input_image,
26
+ model_size=model_size,
27
+ conf=conf_threshold,
28
+ iou=iou_threshold,
29
+ )
30
+
31
+ # 创建输出图像的副本
32
+ output_image = input_image.copy()
33
+
34
+ # 获取图像尺寸
35
+ h, w = output_image.shape[:2]
36
+
37
+ # 创建一个总的遮罩层和一个累积掩码
38
+ final_mask = np.zeros_like(output_image)
39
+ accumulated_mask = np.zeros((h, w), dtype=np.uint8)
40
+
41
+ # 为每个分割结果创建掩码
42
+ for idx, points in enumerate(results["segments"]):
43
+ # 将点列表转换为轮廓格式
44
+ contour_points = np.array(points).reshape(-1, 2).astype(np.int32)
45
+
46
+ # 创建空白掩码
47
+ mask = np.zeros((h, w), dtype=np.uint8)
48
+
49
+ # 填充轮廓
50
+ cv2.fillPoly(mask, [contour_points], 1)
51
+
52
+ # 更新累积掩码(避免重叠区域重复计算)
53
+ mask = cv2.bitwise_and(mask, cv2.bitwise_not(accumulated_mask))
54
+ accumulated_mask = cv2.bitwise_or(accumulated_mask, mask)
55
+
56
+ # 使用预定义的颜色(循环使用)
57
+ color_idx = idx % len(SEGMENT_COLORS)
58
+ fill_color, stroke_color = SEGMENT_COLORS[color_idx]
59
+
60
+ # 创建填充区域(半透明)
61
+ fill_mask = np.zeros_like(output_image)
62
+ fill_mask[mask > 0] = fill_color
63
+ final_mask = cv2.addWeighted(final_mask, 1.0, fill_mask, 0.3, 0)
64
+
65
+ # 绘制轮廓线
66
+ cv2.drawContours(final_mask, [contour_points], -1, stroke_color, 2)
67
+
68
+ # 混合原图和掩码
69
+ output_image = cv2.addWeighted(output_image, 1.0, final_mask, 0.5, 0)
70
+
71
+ return output_image
72
+
73
+ except Exception as e:
74
+ print(f"分割过程中出错: {str(e)}")
75
+ return input_image
76
+
77
+ # 创建Gradio界面
78
+ demo = gr.Interface(
79
+ fn=segment_image,
80
+ inputs=[
81
+ gr.Image(label="输入图片"),
82
+ gr.Radio(
83
+ choices=["small", "large"],
84
+ value="large",
85
+ label="模型大小",
86
+ info="small: 更快但精度较低, large: 更慢但精度更高"
87
+ ),
88
+ gr.Slider(
89
+ minimum=0.1,
90
+ maximum=1.0,
91
+ value=0.4,
92
+ step=0.1,
93
+ label="置信度阈值",
94
+ info="值越高,检测越严格"
95
+ ),
96
+ gr.Slider(
97
+ minimum=0.1,
98
+ maximum=1.0,
99
+ value=0.3, # 降低默认值,使其能显示更多区域
100
+ step=0.1,
101
+ label="IoU阈值",
102
+ info="值越低则保留更多重叠区域,值越高则保留更少重叠区域"
103
+ )
104
+ ],
105
+ outputs=gr.Image(label="分割结果"),
106
+ title="FastSAM图像分割演示",
107
+ description="上传一张图片,调整参数,模型将对图片中的对象进行分割。",
108
+ examples=[
109
+ [
110
+ "https://3vj-render.3vjia.com//UpFile_Render/C00006070/PMC/DesignSchemeRenderFile/20240831/592351213526335564/43f8d835b3a54869a34167ed7f2a27aa.jpg?x-oss-process=image/resize,m_fill,h_730,w_1220", # 图片路径
111
+ "large", # 模型大小
112
+ 0.4, # 置信度阈值
113
+ 0.3 # IoU阈值,降低默认值
114
+ ]
115
+ ]
116
+ )
117
+
118
+ # 启动应用
119
+ if __name__ == "__main__":
120
+ demo.launch(server_name="0.0.0.0")