dxcanh commited on
Commit
94cfd9f
·
verified ·
1 Parent(s): 198d582

Upload test.py

Browse files
Files changed (1) hide show
  1. test.py +211 -0
test.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ def extract_all(image: np.ndarray, area_threshold: int = 100, lower_thresh: int = 100, upper_thresh: int = 200) -> dict:
7
+ if len(image.shape) == 3:
8
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
9
+ else:
10
+ gray = image.copy()
11
+
12
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
13
+ edges = cv2.Canny(blurred, lower_thresh, upper_thresh)
14
+
15
+ kernel = np.ones((3, 3), np.uint8)
16
+ closed_edges = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel, iterations=3)
17
+ kernel = np.ones((5, 5), np.uint8)
18
+ closed_edges = cv2.dilate(closed_edges, kernel, iterations=1)
19
+ kernel = np.ones((3, 3), np.uint8)
20
+ closed_edges = cv2.morphologyEx(closed_edges, cv2.MORPH_CLOSE, kernel, iterations=2)
21
+
22
+ cv2.imwrite("canny_binary.jpg", closed_edges)
23
+
24
+ contours, _ = cv2.findContours(closed_edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
25
+
26
+ real_islands = {}
27
+ contour_id = 0
28
+ for contour in contours:
29
+ if cv2.contourArea(contour) > area_threshold:
30
+ mask = np.zeros_like(gray)
31
+ cv2.drawContours(mask, [contour], -1, 255, thickness=cv2.FILLED)
32
+ pixels = list(zip(*np.where(mask == 255)))
33
+ real_islands[(pixels[0][0], pixels[0][1])] = pixels
34
+ contour_id += 1
35
+
36
+ print(f"Detected {len(real_islands)} islands from {len(contours)} contours")
37
+ return real_islands
38
+
39
+ def extract_object(image: np.ndarray, island: list[tuple]) -> np.ndarray:
40
+ coords = np.array(island)
41
+ min_y, min_x = coords.min(axis=0)
42
+ max_y, max_x = coords.max(axis=0)
43
+
44
+ height, width = max_y - min_y + 1, max_x - min_x + 1
45
+ num_channels = image.shape[2] if len(image.shape) == 3 else 1
46
+ result = np.zeros((height, width, num_channels), dtype=np.uint8)
47
+ y_coords = coords[:, 0] - min_y
48
+ x_coords = coords[:, 1] - min_x
49
+ result[y_coords, x_coords] = image[coords[:, 0], coords[:, 1]]
50
+
51
+ return result
52
+
53
+ def draw_bound(img: np.ndarray, top: int, down: int, left: int, right: int, size: int, color=(0, 255, 0)) -> np.ndarray:
54
+ img_copy = img.copy()
55
+ cv2.rectangle(img_copy, (left, top), (right, top + size), color, thickness=-1)
56
+ cv2.rectangle(img_copy, (left, down - size), (right, down), color, thickness=-1)
57
+ cv2.rectangle(img_copy, (left, top), (left + size, down), color, thickness=-1)
58
+ cv2.rectangle(img_copy, (right - size, top), (right, down), color, thickness=-1)
59
+ return img_copy
60
+
61
+ def compute_template_matching(img: np.ndarray, template: np.ndarray, method, mask: np.ndarray):
62
+ n_img = img.astype(np.uint8)
63
+ n_template = template.astype(np.uint8)
64
+
65
+ if np.std(n_template) == 0:
66
+ raise ValueError("Standard = 0")
67
+ if np.std(n_img) == 0:
68
+ raise ValueError("Standard = 0")
69
+
70
+ result = cv2.matchTemplate(n_img, n_template, method, mask=mask)
71
+ result = np.where(np.isinf(result), 0, result)
72
+
73
+ return result
74
+
75
+ def process_single_object_loop(img: np.ndarray, template: np.ndarray, method, mask: np.ndarray):
76
+ result = compute_template_matching(img, template, method, mask)
77
+ min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
78
+
79
+ top_left = max_loc
80
+ bound_image = draw_bound(
81
+ img,
82
+ top_left[1],
83
+ top_left[1] + template.shape[0],
84
+ top_left[0],
85
+ top_left[0] + template.shape[1],
86
+ 8,
87
+ (0, 255, 0)
88
+ )
89
+
90
+ return max_val, result, bound_image, (top_left[1], top_left[0])
91
+
92
+ def process_template_at_scale(source: np.ndarray, template: np.ndarray, method, scale: float):
93
+ masked_template = template.copy().astype(np.uint8)
94
+ temp = cv2.medianBlur(masked_template.copy(), 5)
95
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
96
+ temp = cv2.erode(temp, kernel, iterations=1)
97
+ _, mask = cv2.threshold(temp, 1, 255, cv2.THRESH_BINARY)
98
+
99
+ mask = cv2.resize(mask, (int(mask.shape[1] * scale), int(mask.shape[0] * scale)), interpolation=cv2.INTER_NEAREST_EXACT)
100
+ masked_template = cv2.resize(masked_template, (mask.shape[1], mask.shape[0]), interpolation=cv2.INTER_NEAREST_EXACT)
101
+
102
+ local_max, result, bound_image, pos = process_single_object_loop(source.copy(), masked_template, method, mask.astype(np.uint8))
103
+
104
+ max_template = np.zeros_like(masked_template)
105
+ max_template[mask.astype(bool)] = masked_template[mask.astype(bool)]
106
+
107
+ return local_max, result, bound_image, max_template, pos
108
+
109
+ def process_images(source_img, objects_img, confidence_threshold=0.7):
110
+ if isinstance(source_img, np.ndarray):
111
+ source = source_img
112
+ else:
113
+ source = np.array(source_img)[:, :, ::-1] # RGB -> BGR
114
+
115
+ if isinstance(objects_img, np.ndarray):
116
+ objects = objects_img
117
+ else:
118
+ objects = np.array(objects_img)[:, :, ::-1] # RGB -> BGR
119
+
120
+ object_img = cv2.medianBlur(objects.copy(), 3)
121
+ islands = extract_all(object_img, area_threshold=100, lower_thresh=100, upper_thresh=200)
122
+ objects_extracted = []
123
+ for island in islands.values():
124
+ object_image = extract_object(objects, island)
125
+ objects_extracted.append(object_image)
126
+
127
+ result_image = source.copy()
128
+ method = cv2.TM_CCOEFF_NORMED
129
+
130
+ print("\nProcessing object detection...")
131
+ print(f"Confidence threshold: {confidence_threshold}")
132
+ print(f"Total objects to detect: {len(objects_extracted)}\n")
133
+
134
+ for i, template in enumerate(objects_extracted):
135
+ print(f"\nProcessing object {i+1}/{len(objects_extracted)}")
136
+ max_val = 0
137
+ max_pos = None
138
+ max_template = None
139
+
140
+ scale_steps = np.linspace(0.25, 1.0, 20)
141
+ for scale in scale_steps:
142
+ local_max, _, temp_bound_image, local_template, pos = process_template_at_scale(
143
+ source, template, method, scale
144
+ )
145
+ print(f"Scale {scale:.2f}: Confidence = {local_max:.4f}")
146
+
147
+ if local_max > max_val:
148
+ max_val = local_max
149
+ max_template = local_template
150
+ max_pos = pos
151
+
152
+ if max_val >= confidence_threshold:
153
+ print(f"Stopping at scale {scale:.2f} as confidence {max_val:.4f} >= threshold")
154
+ break
155
+
156
+ print(f"Final confidence for object {i+1}: {max_val:.4f}")
157
+ if max_pos is not None and max_val >= confidence_threshold:
158
+ h, w = max_template.shape[:2]
159
+ result_image = draw_bound(
160
+ result_image,
161
+ max_pos[0],
162
+ max_pos[0] + h,
163
+ max_pos[1],
164
+ max_pos[1] + w,
165
+ 8,
166
+ (0, 255, 0)
167
+ )
168
+ cv2.putText(
169
+ result_image,
170
+ f"{i+1}",
171
+ (max_pos[1], max_pos[0]-10),
172
+ cv2.FONT_HERSHEY_SIMPLEX,
173
+ 0.9,
174
+ (0, 255, 0),
175
+ 2
176
+ )
177
+ print(f"Object {i+1} detected at position ({max_pos[0]}, {max_pos[1]}) with size ({h}x{w})")
178
+ else:
179
+ print(f"Object {i+1} not detected (confidence {max_val:.4f} < threshold {confidence_threshold})")
180
+
181
+ print("\nDetection completed!")
182
+ return result_image
183
+
184
+ # create a Gradio interface
185
+ with gr.Blocks(title="Object Detection in Images") as demo:
186
+ gr.Markdown("# Object Detection in Images")
187
+ gr.Markdown("Upload a source image and an objects image to detect and draw bounding boxes around matching objects.")
188
+
189
+ with gr.Row():
190
+ with gr.Column():
191
+ source_input = gr.Image(label="Source Image", type="numpy")
192
+ objects_input = gr.Image(label="Objects Image", type="numpy")
193
+ threshold_input = gr.Slider(
194
+ minimum=0.1,
195
+ maximum=1.0,
196
+ value=0.7,
197
+ step=0.01,
198
+ label="Confidence Threshold"
199
+ )
200
+ submit_btn = gr.Button("Detect Objects")
201
+
202
+ with gr.Column():
203
+ output_image = gr.Image(label="Result with Bounding Boxes", type="numpy")
204
+
205
+ submit_btn.click(
206
+ fn=process_images,
207
+ inputs=[source_input, objects_input, threshold_input],
208
+ outputs=output_image
209
+ )
210
+
211
+ demo.launch()