YoungEWBOK commited on
Commit
54e2c3f
·
verified ·
1 Parent(s): 84a11f2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+
5
+ class ImageProcessor:
6
+ @staticmethod
7
+ def process_image(image):
8
+ if image is None:
9
+ print("错误:输入图像为空。")
10
+ return None, 0
11
+
12
+ try:
13
+ # 确保图像格式正确
14
+ if len(image.shape) == 2: # 如果是灰度图
15
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
16
+ elif image.shape[2] == 4: # 如果是RGBA
17
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
18
+
19
+ # 边缘保留滤波EPF 去噪
20
+ blur = cv2.pyrMeanShiftFiltering(image, sp=21, sr=55)
21
+
22
+ # 转成灰度图像
23
+ gray = cv2.cvtColor(blur, cv2.COLOR_BGR2GRAY)
24
+
25
+ # 得到二值图像区间阈值
26
+ ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
27
+
28
+ # 距离变换
29
+ dist = cv2.distanceTransform(binary, cv2.DIST_L2, 3)
30
+ dist_output = cv2.normalize(dist, None, 0, 1.0, cv2.NORM_MINMAX)
31
+ ret, surface = cv2.threshold(dist_output, 0.5*dist_output.max(), 255, cv2.THRESH_BINARY)
32
+
33
+ # 标记连通区域
34
+ ret, markers = cv2.connectedComponents(np.uint8(surface))
35
+ markers = markers + 1
36
+
37
+ # 未知区域标记
38
+ kernel = np.ones((3, 3), np.uint8)
39
+ unknown = cv2.subtract(cv2.dilate(binary, kernel, iterations=1), np.uint8(surface))
40
+ markers[unknown == 255] = 0
41
+
42
+ # 分水岭算法分割
43
+ markers = cv2.watershed(image, markers=markers)
44
+ markers_8u = np.uint8(markers)
45
+
46
+ colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0),
47
+ (255,0,255), (0,255,255), (255,128,0), (255,0,128),
48
+ (128,255,0), (128,0,255), (255,128,128), (128,255,255)]
49
+
50
+ areas = []
51
+ for i in range(2, np.max(markers) + 1):
52
+ mask = cv2.inRange(markers_8u, i, i)
53
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
54
+ if contours:
55
+ areas.append(cv2.contourArea(contours[0]))
56
+
57
+ if not areas:
58
+ print("警告:未检测到任何对象。")
59
+ return image, 0
60
+
61
+ hist, bin_edges = np.histogram(areas, bins=20)
62
+ most_common_bin = np.argmax(hist)
63
+ standard_area = (bin_edges[most_common_bin] + bin_edges[most_common_bin + 1]) / 2
64
+ area_threshold_low = standard_area * 0.7
65
+ area_threshold_high = standard_area * 1.3
66
+
67
+ object_count = 0
68
+ for i in range(2, np.max(markers) + 1):
69
+ mask = cv2.inRange(markers_8u, i, i)
70
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
71
+ if contours:
72
+ area = cv2.contourArea(contours[0])
73
+ if area_threshold_low <= area <= area_threshold_high:
74
+ object_count += 1
75
+ elif area > area_threshold_high:
76
+ num_objects = round(area / standard_area)
77
+ object_count += num_objects
78
+
79
+ color = colors[(i-2)%len(colors)]
80
+ cv2.drawContours(image, contours, -1, color, -1)
81
+
82
+ M = cv2.moments(contours[0])
83
+ if M['m00'] != 0:
84
+ cx = int(M['m10']/M['m00'])
85
+ cy = int(M['m01']/M['m00'])
86
+ cv2.drawMarker(image, (cx,cy), (0,0,255), cv2.MARKER_CROSS, 10, 2)
87
+
88
+ cv2.putText(image, f"数量={object_count}", (20,50), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0,255,0), 3)
89
+
90
+ return image, object_count
91
+
92
+ except Exception as e:
93
+ print(f"图像处理过程中发生错误: {e}")
94
+ return None, 0
95
+
96
+ def process_and_count(input_image):
97
+ if input_image is None:
98
+ return None, "未上传图像"
99
+
100
+ # 转换图像格式
101
+ input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
102
+
103
+ processed_image, count = ImageProcessor.process_image(input_image)
104
+
105
+ if processed_image is None:
106
+ return None, "图像处理错误"
107
+
108
+ # 转换回RGB格式以供Gradio显示
109
+ processed_image = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
110
+
111
+ return processed_image, f"检测到的对象数量: {count}"
112
+
113
+ iface = gr.Interface(
114
+ fn=process_and_count,
115
+ inputs=gr.Image(),
116
+ outputs=[
117
+ gr.Image(label="处理后的图像"),
118
+ gr.Textbox(label="数量")
119
+ ],
120
+ title="螺帽管计数器",
121
+ description="上传一张图像或拍照以统计螺帽管个数。程序将处理图像并返回检测到的螺帽管数量。"
122
+ )
123
+
124
+ if __name__ == "__main__":
125
+ iface.launch(share=True)