Senum2001 commited on
Commit
f301e7a
·
1 Parent(s): de87a47

Implement complete classification with all detection rules - Point/Tiny/Wire/Loose Joint detection

Browse files
Files changed (1) hide show
  1. inference_core.py +200 -10
inference_core.py CHANGED
@@ -111,8 +111,124 @@ def infer_single_image_with_patchcore(image_path: str):
111
  }
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def classify_filtered_image(filtered_img_path: str):
115
- """OpenCV heuristic classification on filtered image"""
 
 
 
 
 
 
 
116
  img = cv2.imread(filtered_img_path)
117
  if img is None:
118
  raise FileNotFoundError(f"Could not read filtered image: {filtered_img_path}")
@@ -123,7 +239,7 @@ def classify_filtered_image(filtered_img_path: str):
123
 
124
  hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
125
 
126
- # Color masks (with slight tolerance adjustments for consistency)
127
  blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
128
  black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
129
  yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255))
@@ -147,26 +263,100 @@ def classify_filtered_image(filtered_img_path: str):
147
  label = "Unknown"
148
  box_list, label_list = [], []
149
 
150
- # Simplified classification logic (keeping only essential parts)
151
  if (blue_count + black_count) / total > 0.8:
152
  label = "Normal"
153
- elif (red_count + orange_count + yellow_count) / total > 0.7:
 
 
 
 
 
 
 
154
  label = "Full Wire Overload"
155
  box_list.append((0, 0, img.shape[1], img.shape[0]))
156
  label_list.append(label)
157
  else:
158
- # Point overloads detection (simplified)
159
  min_area_faulty = 120
 
160
  max_area = 0.05 * total
161
-
162
- contours, _ = cv2.findContours(red_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  for cnt in contours:
164
  area = cv2.contourArea(cnt)
165
- if min_area_faulty < area < max_area:
166
  x, y, w, h = cv2.boundingRect(cnt)
167
  box_list.append((x, y, w, h))
168
- label_list.append("Point Overload (Faulty)")
169
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  print(f"[Classification] Final label: {label}, Boxes found: {len(box_list)}")
171
  return label, box_list, label_list, img
172
 
 
111
  }
112
 
113
 
114
+ # Helper functions for classification
115
+ def _iou(boxA, boxB):
116
+ """Calculate Intersection over Union"""
117
+ xA = max(boxA[0], boxB[0])
118
+ yA = max(boxA[1], boxB[1])
119
+ xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
120
+ yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])
121
+ interW = max(0, xB - xA)
122
+ interH = max(0, yB - yA)
123
+ interArea = interW * interH
124
+ boxAArea = boxA[2] * boxA[3]
125
+ boxBArea = boxB[2] * boxB[3]
126
+ return interArea / float(boxAArea + boxBArea - interArea + 1e-6)
127
+
128
+
129
+ def _merge_close_boxes(boxes, labels, dist_thresh=20):
130
+ """Merge boxes that are close to each other"""
131
+ merged, merged_labels = [], []
132
+ used = [False] * len(boxes)
133
+ for i in range(len(boxes)):
134
+ if used[i]:
135
+ continue
136
+ x1, y1, w1, h1 = boxes[i]
137
+ label1 = labels[i]
138
+ x2, y2, w2, h2 = x1, y1, w1, h1
139
+ for j in range(i + 1, len(boxes)):
140
+ if used[j]:
141
+ continue
142
+ bx, by, bw, bh = boxes[j]
143
+ cx1, cy1 = x1 + w1 // 2, y1 + h1 // 2
144
+ cx2, cy2 = bx + bw // 2, by + bh // 2
145
+ if abs(cx1 - cx2) < dist_thresh and abs(cy1 - cy2) < dist_thresh and label1 == labels[j]:
146
+ x2 = min(x2, bx)
147
+ y2 = min(y2, by)
148
+ w2 = max(x1 + w1, bx + bw) - x2
149
+ h2 = max(y1 + h1, by + bh) - y2
150
+ used[j] = True
151
+ merged.append((x2, y2, w2, h2))
152
+ merged_labels.append(label1)
153
+ used[i] = True
154
+ return merged, merged_labels
155
+
156
+
157
+ def _nms_iou(boxes, labels, iou_thresh=0.4):
158
+ """Non-Maximum Suppression based on IOU"""
159
+ if len(boxes) == 0:
160
+ return [], []
161
+ idxs = np.argsort([w * h for (x, y, w, h) in boxes])[::-1]
162
+ keep, keep_labels = [], []
163
+ while len(idxs) > 0:
164
+ i = idxs[0]
165
+ keep.append(boxes[i])
166
+ keep_labels.append(labels[i])
167
+ remove = [0]
168
+ for j in range(1, len(idxs)):
169
+ if _iou(boxes[i], boxes[idxs[j]]) > iou_thresh:
170
+ remove.append(j)
171
+ idxs = np.delete(idxs, remove)
172
+ return keep, keep_labels
173
+
174
+
175
+ def _filter_faulty_inside_potential(boxes, labels):
176
+ """Remove potential boxes that contain faulty boxes"""
177
+ filtered_boxes, filtered_labels = [], []
178
+ for (box, label) in zip(boxes, labels):
179
+ if label == "Point Overload (Potential)":
180
+ keep = True
181
+ x, y, w, h = box
182
+ for (fbox, flabel) in zip(boxes, labels):
183
+ if flabel == "Point Overload (Faulty)":
184
+ fx, fy, fw, fh = fbox
185
+ if fx >= x and fy >= y and fx + fw <= x + w and fy + fh <= y + h:
186
+ keep = False
187
+ break
188
+ if keep:
189
+ filtered_boxes.append(box)
190
+ filtered_labels.append(label)
191
+ else:
192
+ filtered_boxes.append(box)
193
+ filtered_labels.append(label)
194
+ return filtered_boxes, filtered_labels
195
+
196
+
197
+ def _filter_faulty_overlapping_potential(boxes, labels):
198
+ """Remove potential boxes that overlap with faulty boxes"""
199
+ def is_overlapping(boxA, boxB):
200
+ xA = max(boxA[0], boxB[0])
201
+ yA = max(boxA[1], boxB[1])
202
+ xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
203
+ yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])
204
+ return (xB > xA) and (yB > yA)
205
+
206
+ filtered_boxes, filtered_labels = [], []
207
+ for (box, label) in zip(boxes, labels):
208
+ if label == "Point Overload (Potential)":
209
+ keep = True
210
+ for (fbox, flabel) in zip(boxes, labels):
211
+ if flabel == "Point Overload (Faulty)" and is_overlapping(box, fbox):
212
+ keep = False
213
+ break
214
+ if keep:
215
+ filtered_boxes.append(box)
216
+ filtered_labels.append(label)
217
+ else:
218
+ filtered_boxes.append(box)
219
+ filtered_labels.append(label)
220
+ return filtered_boxes, filtered_labels
221
+
222
+
223
  def classify_filtered_image(filtered_img_path: str):
224
+ """
225
+ Runs the heuristic color-based classification on the FILTERED image.
226
+ Returns:
227
+ label: str
228
+ box_list: [(x, y, w, h), ...]
229
+ label_list: [str, ...]
230
+ img_bgr: the filtered image as BGR
231
+ """
232
  img = cv2.imread(filtered_img_path)
233
  if img is None:
234
  raise FileNotFoundError(f"Could not read filtered image: {filtered_img_path}")
 
239
 
240
  hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
241
 
242
+ # Color masks
243
  blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
244
  black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
245
  yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255))
 
263
  label = "Unknown"
264
  box_list, label_list = [], []
265
 
266
+ # Full image checks
267
  if (blue_count + black_count) / total > 0.8:
268
  label = "Normal"
269
+ elif (red_count + orange_count) / total > 0.5:
270
+ label = "Full Wire Overload"
271
+ elif (yellow_count) / total > 0.5:
272
+ label = "Full Wire Overload"
273
+
274
+ # Check for full wire overload (dominant warm colors)
275
+ full_wire_thresh = 0.7
276
+ if (red_count + orange_count + yellow_count) / total > full_wire_thresh:
277
  label = "Full Wire Overload"
278
  box_list.append((0, 0, img.shape[1], img.shape[0]))
279
  label_list.append(label)
280
  else:
281
+ # Point overloads (areas + thresholds)
282
  min_area_faulty = 120
283
+ min_area_potential = 1000
284
  max_area = 0.05 * total
285
+
286
+ for mask, spot_label, min_a in [
287
+ (red_mask, "Point Overload (Faulty)", min_area_faulty),
288
+ (yellow_mask, "Point Overload (Potential)", min_area_potential),
289
+ ]:
290
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
291
+ for cnt in contours:
292
+ area = cv2.contourArea(cnt)
293
+ if min_a < area < max_area:
294
+ x, y, w, h = cv2.boundingRect(cnt)
295
+ box_list.append((x, y, w, h))
296
+ label_list.append(spot_label)
297
+
298
+ # Middle area checks (Loose Joint detection)
299
+ h, w = img.shape[:2]
300
+ center = img[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
301
+ center_hsv = cv2.cvtColor(center, cv2.COLOR_BGR2HSV)
302
+ center_yellow = cv2.inRange(center_hsv, (20, 130, 130), (35, 255, 255))
303
+ center_orange = cv2.inRange(center_hsv, (10, 100, 100), (25, 255, 255))
304
+ center_red1 = cv2.inRange(center_hsv, (0, 100, 100), (10, 255, 255))
305
+ center_red2 = cv2.inRange(center_hsv, (160, 100, 100), (180, 255, 255))
306
+ center_red = cv2.bitwise_or(center_red1, center_red2)
307
+
308
+ if np.sum(center_red > 0) + np.sum(center_orange > 0) > 0.1 * center.size:
309
+ label = "Loose Joint (Faulty)"
310
+ box_list.append((w // 4, h // 4, w // 2, h // 2))
311
+ label_list.append(label)
312
+ elif np.sum(center_yellow > 0) > 0.1 * center.size:
313
+ label = "Loose Joint (Potential)"
314
+ box_list.append((w // 4, h // 4, w // 2, h // 2))
315
+ label_list.append(label)
316
+
317
+ # Tiny spots (always check)
318
+ min_area_tiny, max_area_tiny = 10, 30
319
+ for mask, spot_label in [
320
+ (red_mask, "Tiny Faulty Spot"),
321
+ (yellow_mask, "Tiny Potential Spot"),
322
+ ]:
323
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
324
  for cnt in contours:
325
  area = cv2.contourArea(cnt)
326
+ if min_area_tiny < area < max_area_tiny:
327
  x, y, w, h = cv2.boundingRect(cnt)
328
  box_list.append((x, y, w, h))
329
+ label_list.append(spot_label)
330
+
331
+ # Detect wire-shaped (long/thin) warm regions
332
+ aspect_ratio_thresh = 5
333
+ min_strip_area = 0.01 * total
334
+ wire_boxes, wire_labels = [], []
335
+ for mask, strip_label in [
336
+ (red_mask, "Wire Overload (Red Strip)"),
337
+ (yellow_mask, "Wire Overload (Yellow Strip)"),
338
+ (orange_mask, "Wire Overload (Orange Strip)"),
339
+ ]:
340
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
341
+ for cnt in contours:
342
+ area = cv2.contourArea(cnt)
343
+ if area > min_strip_area:
344
+ x, y, w, h = cv2.boundingRect(cnt)
345
+ aspect_ratio = max(w, h) / (min(w, h) + 1e-6)
346
+ if aspect_ratio > aspect_ratio_thresh:
347
+ wire_boxes.append((x, y, w, h))
348
+ wire_labels.append(strip_label)
349
+
350
+ # Prioritize wire boxes first
351
+ box_list = wire_boxes[:] + box_list
352
+ label_list = wire_labels[:] + label_list
353
+
354
+ # Final pruning/merging
355
+ box_list, label_list = _nms_iou(box_list, label_list, iou_thresh=0.4)
356
+ box_list, label_list = _filter_faulty_inside_potential(box_list, label_list)
357
+ box_list, label_list = _filter_faulty_overlapping_potential(box_list, label_list)
358
+ box_list, label_list = _merge_close_boxes(box_list, label_list, dist_thresh=100)
359
+
360
  print(f"[Classification] Final label: {label}, Boxes found: {len(box_list)}")
361
  return label, box_list, label_list, img
362