Spaces:
Sleeping
Sleeping
Senum2001 commited on
Commit ·
f301e7a
1
Parent(s): de87a47
Implement complete classification with all detection rules - Point/Tiny/Wire/Loose Joint detection
Browse files- inference_core.py +200 -10
inference_core.py
CHANGED
|
@@ -111,8 +111,124 @@ def infer_single_image_with_patchcore(image_path: str):
|
|
| 111 |
}
|
| 112 |
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
def classify_filtered_image(filtered_img_path: str):
|
| 115 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
img = cv2.imread(filtered_img_path)
|
| 117 |
if img is None:
|
| 118 |
raise FileNotFoundError(f"Could not read filtered image: {filtered_img_path}")
|
|
@@ -123,7 +239,7 @@ def classify_filtered_image(filtered_img_path: str):
|
|
| 123 |
|
| 124 |
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
| 125 |
|
| 126 |
-
# Color masks
|
| 127 |
blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
|
| 128 |
black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
|
| 129 |
yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255))
|
|
@@ -147,26 +263,100 @@ def classify_filtered_image(filtered_img_path: str):
|
|
| 147 |
label = "Unknown"
|
| 148 |
box_list, label_list = [], []
|
| 149 |
|
| 150 |
-
#
|
| 151 |
if (blue_count + black_count) / total > 0.8:
|
| 152 |
label = "Normal"
|
| 153 |
-
elif (red_count + orange_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
label = "Full Wire Overload"
|
| 155 |
box_list.append((0, 0, img.shape[1], img.shape[0]))
|
| 156 |
label_list.append(label)
|
| 157 |
else:
|
| 158 |
-
# Point overloads
|
| 159 |
min_area_faulty = 120
|
|
|
|
| 160 |
max_area = 0.05 * total
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
for cnt in contours:
|
| 164 |
area = cv2.contourArea(cnt)
|
| 165 |
-
if
|
| 166 |
x, y, w, h = cv2.boundingRect(cnt)
|
| 167 |
box_list.append((x, y, w, h))
|
| 168 |
-
label_list.append(
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
print(f"[Classification] Final label: {label}, Boxes found: {len(box_list)}")
|
| 171 |
return label, box_list, label_list, img
|
| 172 |
|
|
|
|
| 111 |
}
|
| 112 |
|
| 113 |
|
| 114 |
+
# Helper functions for classification
|
| 115 |
+
def _iou(boxA, boxB):
|
| 116 |
+
"""Calculate Intersection over Union"""
|
| 117 |
+
xA = max(boxA[0], boxB[0])
|
| 118 |
+
yA = max(boxA[1], boxB[1])
|
| 119 |
+
xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
|
| 120 |
+
yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])
|
| 121 |
+
interW = max(0, xB - xA)
|
| 122 |
+
interH = max(0, yB - yA)
|
| 123 |
+
interArea = interW * interH
|
| 124 |
+
boxAArea = boxA[2] * boxA[3]
|
| 125 |
+
boxBArea = boxB[2] * boxB[3]
|
| 126 |
+
return interArea / float(boxAArea + boxBArea - interArea + 1e-6)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _merge_close_boxes(boxes, labels, dist_thresh=20):
|
| 130 |
+
"""Merge boxes that are close to each other"""
|
| 131 |
+
merged, merged_labels = [], []
|
| 132 |
+
used = [False] * len(boxes)
|
| 133 |
+
for i in range(len(boxes)):
|
| 134 |
+
if used[i]:
|
| 135 |
+
continue
|
| 136 |
+
x1, y1, w1, h1 = boxes[i]
|
| 137 |
+
label1 = labels[i]
|
| 138 |
+
x2, y2, w2, h2 = x1, y1, w1, h1
|
| 139 |
+
for j in range(i + 1, len(boxes)):
|
| 140 |
+
if used[j]:
|
| 141 |
+
continue
|
| 142 |
+
bx, by, bw, bh = boxes[j]
|
| 143 |
+
cx1, cy1 = x1 + w1 // 2, y1 + h1 // 2
|
| 144 |
+
cx2, cy2 = bx + bw // 2, by + bh // 2
|
| 145 |
+
if abs(cx1 - cx2) < dist_thresh and abs(cy1 - cy2) < dist_thresh and label1 == labels[j]:
|
| 146 |
+
x2 = min(x2, bx)
|
| 147 |
+
y2 = min(y2, by)
|
| 148 |
+
w2 = max(x1 + w1, bx + bw) - x2
|
| 149 |
+
h2 = max(y1 + h1, by + bh) - y2
|
| 150 |
+
used[j] = True
|
| 151 |
+
merged.append((x2, y2, w2, h2))
|
| 152 |
+
merged_labels.append(label1)
|
| 153 |
+
used[i] = True
|
| 154 |
+
return merged, merged_labels
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _nms_iou(boxes, labels, iou_thresh=0.4):
|
| 158 |
+
"""Non-Maximum Suppression based on IOU"""
|
| 159 |
+
if len(boxes) == 0:
|
| 160 |
+
return [], []
|
| 161 |
+
idxs = np.argsort([w * h for (x, y, w, h) in boxes])[::-1]
|
| 162 |
+
keep, keep_labels = [], []
|
| 163 |
+
while len(idxs) > 0:
|
| 164 |
+
i = idxs[0]
|
| 165 |
+
keep.append(boxes[i])
|
| 166 |
+
keep_labels.append(labels[i])
|
| 167 |
+
remove = [0]
|
| 168 |
+
for j in range(1, len(idxs)):
|
| 169 |
+
if _iou(boxes[i], boxes[idxs[j]]) > iou_thresh:
|
| 170 |
+
remove.append(j)
|
| 171 |
+
idxs = np.delete(idxs, remove)
|
| 172 |
+
return keep, keep_labels
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _filter_faulty_inside_potential(boxes, labels):
|
| 176 |
+
"""Remove potential boxes that contain faulty boxes"""
|
| 177 |
+
filtered_boxes, filtered_labels = [], []
|
| 178 |
+
for (box, label) in zip(boxes, labels):
|
| 179 |
+
if label == "Point Overload (Potential)":
|
| 180 |
+
keep = True
|
| 181 |
+
x, y, w, h = box
|
| 182 |
+
for (fbox, flabel) in zip(boxes, labels):
|
| 183 |
+
if flabel == "Point Overload (Faulty)":
|
| 184 |
+
fx, fy, fw, fh = fbox
|
| 185 |
+
if fx >= x and fy >= y and fx + fw <= x + w and fy + fh <= y + h:
|
| 186 |
+
keep = False
|
| 187 |
+
break
|
| 188 |
+
if keep:
|
| 189 |
+
filtered_boxes.append(box)
|
| 190 |
+
filtered_labels.append(label)
|
| 191 |
+
else:
|
| 192 |
+
filtered_boxes.append(box)
|
| 193 |
+
filtered_labels.append(label)
|
| 194 |
+
return filtered_boxes, filtered_labels
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _filter_faulty_overlapping_potential(boxes, labels):
|
| 198 |
+
"""Remove potential boxes that overlap with faulty boxes"""
|
| 199 |
+
def is_overlapping(boxA, boxB):
|
| 200 |
+
xA = max(boxA[0], boxB[0])
|
| 201 |
+
yA = max(boxA[1], boxB[1])
|
| 202 |
+
xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
|
| 203 |
+
yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])
|
| 204 |
+
return (xB > xA) and (yB > yA)
|
| 205 |
+
|
| 206 |
+
filtered_boxes, filtered_labels = [], []
|
| 207 |
+
for (box, label) in zip(boxes, labels):
|
| 208 |
+
if label == "Point Overload (Potential)":
|
| 209 |
+
keep = True
|
| 210 |
+
for (fbox, flabel) in zip(boxes, labels):
|
| 211 |
+
if flabel == "Point Overload (Faulty)" and is_overlapping(box, fbox):
|
| 212 |
+
keep = False
|
| 213 |
+
break
|
| 214 |
+
if keep:
|
| 215 |
+
filtered_boxes.append(box)
|
| 216 |
+
filtered_labels.append(label)
|
| 217 |
+
else:
|
| 218 |
+
filtered_boxes.append(box)
|
| 219 |
+
filtered_labels.append(label)
|
| 220 |
+
return filtered_boxes, filtered_labels
|
| 221 |
+
|
| 222 |
+
|
| 223 |
def classify_filtered_image(filtered_img_path: str):
|
| 224 |
+
"""
|
| 225 |
+
Runs the heuristic color-based classification on the FILTERED image.
|
| 226 |
+
Returns:
|
| 227 |
+
label: str
|
| 228 |
+
box_list: [(x, y, w, h), ...]
|
| 229 |
+
label_list: [str, ...]
|
| 230 |
+
img_bgr: the filtered image as BGR
|
| 231 |
+
"""
|
| 232 |
img = cv2.imread(filtered_img_path)
|
| 233 |
if img is None:
|
| 234 |
raise FileNotFoundError(f"Could not read filtered image: {filtered_img_path}")
|
|
|
|
| 239 |
|
| 240 |
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
| 241 |
|
| 242 |
+
# Color masks
|
| 243 |
blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
|
| 244 |
black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
|
| 245 |
yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255))
|
|
|
|
| 263 |
label = "Unknown"
|
| 264 |
box_list, label_list = [], []
|
| 265 |
|
| 266 |
+
# Full image checks
|
| 267 |
if (blue_count + black_count) / total > 0.8:
|
| 268 |
label = "Normal"
|
| 269 |
+
elif (red_count + orange_count) / total > 0.5:
|
| 270 |
+
label = "Full Wire Overload"
|
| 271 |
+
elif (yellow_count) / total > 0.5:
|
| 272 |
+
label = "Full Wire Overload"
|
| 273 |
+
|
| 274 |
+
# Check for full wire overload (dominant warm colors)
|
| 275 |
+
full_wire_thresh = 0.7
|
| 276 |
+
if (red_count + orange_count + yellow_count) / total > full_wire_thresh:
|
| 277 |
label = "Full Wire Overload"
|
| 278 |
box_list.append((0, 0, img.shape[1], img.shape[0]))
|
| 279 |
label_list.append(label)
|
| 280 |
else:
|
| 281 |
+
# Point overloads (areas + thresholds)
|
| 282 |
min_area_faulty = 120
|
| 283 |
+
min_area_potential = 1000
|
| 284 |
max_area = 0.05 * total
|
| 285 |
+
|
| 286 |
+
for mask, spot_label, min_a in [
|
| 287 |
+
(red_mask, "Point Overload (Faulty)", min_area_faulty),
|
| 288 |
+
(yellow_mask, "Point Overload (Potential)", min_area_potential),
|
| 289 |
+
]:
|
| 290 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 291 |
+
for cnt in contours:
|
| 292 |
+
area = cv2.contourArea(cnt)
|
| 293 |
+
if min_a < area < max_area:
|
| 294 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
| 295 |
+
box_list.append((x, y, w, h))
|
| 296 |
+
label_list.append(spot_label)
|
| 297 |
+
|
| 298 |
+
# Middle area checks (Loose Joint detection)
|
| 299 |
+
h, w = img.shape[:2]
|
| 300 |
+
center = img[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
|
| 301 |
+
center_hsv = cv2.cvtColor(center, cv2.COLOR_BGR2HSV)
|
| 302 |
+
center_yellow = cv2.inRange(center_hsv, (20, 130, 130), (35, 255, 255))
|
| 303 |
+
center_orange = cv2.inRange(center_hsv, (10, 100, 100), (25, 255, 255))
|
| 304 |
+
center_red1 = cv2.inRange(center_hsv, (0, 100, 100), (10, 255, 255))
|
| 305 |
+
center_red2 = cv2.inRange(center_hsv, (160, 100, 100), (180, 255, 255))
|
| 306 |
+
center_red = cv2.bitwise_or(center_red1, center_red2)
|
| 307 |
+
|
| 308 |
+
if np.sum(center_red > 0) + np.sum(center_orange > 0) > 0.1 * center.size:
|
| 309 |
+
label = "Loose Joint (Faulty)"
|
| 310 |
+
box_list.append((w // 4, h // 4, w // 2, h // 2))
|
| 311 |
+
label_list.append(label)
|
| 312 |
+
elif np.sum(center_yellow > 0) > 0.1 * center.size:
|
| 313 |
+
label = "Loose Joint (Potential)"
|
| 314 |
+
box_list.append((w // 4, h // 4, w // 2, h // 2))
|
| 315 |
+
label_list.append(label)
|
| 316 |
+
|
| 317 |
+
# Tiny spots (always check)
|
| 318 |
+
min_area_tiny, max_area_tiny = 10, 30
|
| 319 |
+
for mask, spot_label in [
|
| 320 |
+
(red_mask, "Tiny Faulty Spot"),
|
| 321 |
+
(yellow_mask, "Tiny Potential Spot"),
|
| 322 |
+
]:
|
| 323 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 324 |
for cnt in contours:
|
| 325 |
area = cv2.contourArea(cnt)
|
| 326 |
+
if min_area_tiny < area < max_area_tiny:
|
| 327 |
x, y, w, h = cv2.boundingRect(cnt)
|
| 328 |
box_list.append((x, y, w, h))
|
| 329 |
+
label_list.append(spot_label)
|
| 330 |
+
|
| 331 |
+
# Detect wire-shaped (long/thin) warm regions
|
| 332 |
+
aspect_ratio_thresh = 5
|
| 333 |
+
min_strip_area = 0.01 * total
|
| 334 |
+
wire_boxes, wire_labels = [], []
|
| 335 |
+
for mask, strip_label in [
|
| 336 |
+
(red_mask, "Wire Overload (Red Strip)"),
|
| 337 |
+
(yellow_mask, "Wire Overload (Yellow Strip)"),
|
| 338 |
+
(orange_mask, "Wire Overload (Orange Strip)"),
|
| 339 |
+
]:
|
| 340 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 341 |
+
for cnt in contours:
|
| 342 |
+
area = cv2.contourArea(cnt)
|
| 343 |
+
if area > min_strip_area:
|
| 344 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
| 345 |
+
aspect_ratio = max(w, h) / (min(w, h) + 1e-6)
|
| 346 |
+
if aspect_ratio > aspect_ratio_thresh:
|
| 347 |
+
wire_boxes.append((x, y, w, h))
|
| 348 |
+
wire_labels.append(strip_label)
|
| 349 |
+
|
| 350 |
+
# Prioritize wire boxes first
|
| 351 |
+
box_list = wire_boxes[:] + box_list
|
| 352 |
+
label_list = wire_labels[:] + label_list
|
| 353 |
+
|
| 354 |
+
# Final pruning/merging
|
| 355 |
+
box_list, label_list = _nms_iou(box_list, label_list, iou_thresh=0.4)
|
| 356 |
+
box_list, label_list = _filter_faulty_inside_potential(box_list, label_list)
|
| 357 |
+
box_list, label_list = _filter_faulty_overlapping_potential(box_list, label_list)
|
| 358 |
+
box_list, label_list = _merge_close_boxes(box_list, label_list, dist_thresh=100)
|
| 359 |
+
|
| 360 |
print(f"[Classification] Final label: {label}, Boxes found: {len(box_list)}")
|
| 361 |
return label, box_list, label_list, img
|
| 362 |
|