anomaly-detection-api / scripts /classify_single_image_opencv.py
Senum2001
Deploy Anomaly Detection API
9cf599c
import cv2
import numpy as np
import os
# Directory containing filtered images
dir_path = 'api_inference_filtered_pipeline'
output_dir = 'api_inference_labeled_boxes_pipeline'
os.makedirs(output_dir, exist_ok=True)
# IOU function for non-max suppression
def iou(boxA, boxB):
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[0]+boxA[2], boxB[0]+boxB[2])
yB = min(boxA[1]+boxA[3], boxB[1]+boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = boxA[2] * boxA[3]
boxBArea = boxB[2] * boxB[3]
iou = interArea / float(boxAArea + boxBArea - interArea + 1e-6)
return iou
# Merge close bounding boxes (same label, centers within dist_thresh)
def merge_close_boxes(boxes, labels, dist_thresh=20):
merged = []
merged_labels = []
used = [False]*len(boxes)
for i in range(len(boxes)):
if used[i]:
continue
x1, y1, w1, h1 = boxes[i]
label1 = labels[i]
x2, y2, w2, h2 = x1, y1, w1, h1
for j in range(i+1, len(boxes)):
if used[j]:
continue
bx, by, bw, bh = boxes[j]
# If boxes are close (distance between centers < dist_thresh)
cx1, cy1 = x1 + w1//2, y1 + h1//2
cx2, cy2 = bx + bw//2, by + bh//2
if abs(cx1-cx2) < dist_thresh and abs(cy1-cy2) < dist_thresh and label1 == labels[j]:
# Merge boxes
x2 = min(x2, bx)
y2 = min(y2, by)
w2 = max(x1+w1, bx+bw) - x2
h2 = max(y1+h1, by+bh) - y2
used[j] = True
merged.append((x2, y2, w2, h2))
merged_labels.append(label1)
used[i] = True
return merged, merged_labels
# Non-max suppression using IOU
def non_max_suppression_iou(boxes, labels, iou_thresh=0.4):
if len(boxes) == 0:
return [], []
idxs = np.argsort([w*h for (x, y, w, h) in boxes])[::-1]
keep = []
keep_labels = []
while len(idxs) > 0:
i = idxs[0]
keep.append(boxes[i])
keep_labels.append(labels[i])
remove = [0]
for j in range(1, len(idxs)):
if iou(boxes[i], boxes[idxs[j]]) > iou_thresh:
remove.append(j)
idxs = np.delete(idxs, remove)
return keep, keep_labels
# Filter out potential boxes that contain a faulty box inside
def filter_faulty_inside_potential(boxes, labels):
filtered_boxes = []
filtered_labels = []
for i, (box, label) in enumerate(zip(boxes, labels)):
if label == 'Point Overload (Potential)':
# Check if any faulty box is inside this potential box
keep = True
for j, (fbox, flabel) in enumerate(zip(boxes, labels)):
if flabel == 'Point Overload (Faulty)':
# Check if faulty box is inside potential box
x, y, w, h = box
fx, fy, fw, fh = fbox
if fx >= x and fy >= y and fx+fw <= x+w and fy+fh <= y+h:
keep = False
break
if keep:
filtered_boxes.append(box)
filtered_labels.append(label)
else:
filtered_boxes.append(box)
filtered_labels.append(label)
return filtered_boxes, filtered_labels
# Remove potential boxes that overlap with a faulty box (not just inside)
def filter_faulty_overlapping_potential(boxes, labels):
# Remove potential boxes that overlap at all with a faulty box (any intersection)
filtered_boxes = []
filtered_labels = []
def is_overlapping(boxA, boxB):
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[0]+boxA[2], boxB[0]+boxB[2])
yB = min(boxA[1]+boxA[3], boxB[1]+boxB[3])
return (xB > xA) and (yB > yA)
for i, (box, label) in enumerate(zip(boxes, labels)):
if label == 'Point Overload (Potential)':
keep = True
for j, (fbox, flabel) in enumerate(zip(boxes, labels)):
if flabel == 'Point Overload (Faulty)':
if is_overlapping(box, fbox):
keep = False
break
if keep:
filtered_boxes.append(box)
filtered_labels.append(label)
else:
filtered_boxes.append(box)
filtered_labels.append(label)
return filtered_boxes, filtered_labels
# Heuristic classification function
def classify_image(img_path):
img = cv2.imread(img_path)
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# Color masks
blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255)) # increased threshold
orange_mask = cv2.inRange(hsv, (10, 100, 100), (25, 255, 255))
red_mask1 = cv2.inRange(hsv, (0, 100, 100), (10, 255, 255))
red_mask2 = cv2.inRange(hsv, (160, 100, 100), (180, 255, 255))
red_mask = cv2.bitwise_or(red_mask1, red_mask2)
total = img.shape[0] * img.shape[1]
blue_count = np.sum(blue_mask > 0)
black_count = np.sum(black_mask > 0)
yellow_count = np.sum(yellow_mask > 0)
orange_count = np.sum(orange_mask > 0)
red_count = np.sum(red_mask > 0)
label = 'Unknown'
box_list = []
label_list = []
# Full image checks
if (blue_count + black_count) / total > 0.8:
label = 'Normal'
elif (red_count + orange_count) / total > 0.5:
label = 'Full Wire Overload'
elif (yellow_count) / total > 0.5:
label = 'Full Wire Overload'
# Check for full wire overload (entire image reddish or yellowish)
full_wire_thresh = 0.7 # 70% of image is reddish or yellowish
if (red_count + orange_count + yellow_count) / total > full_wire_thresh:
label = 'Full Wire Overload'
# Add a box covering the whole image
box_list.append((0, 0, img.shape[1], img.shape[0]))
label_list.append(label)
else:
# Small spot checks (improved: filter tiny spots, merge overlapping boxes)
min_area_faulty = 120 # increased min area for red/orange (faulty)
min_area_potential = 1000 # much higher min area for yellow (potential)
max_area = 0.05 * total
# Faulty (red/orange) spots
for mask, spot_label, min_a in [
(red_mask, 'Point Overload (Faulty)', min_area_faulty),
(yellow_mask, 'Point Overload (Potential)', min_area_potential)
]:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
area = cv2.contourArea(cnt)
if min_a < area < max_area:
x, y, w, h = cv2.boundingRect(cnt)
box_list.append((x, y, w, h))
label_list.append(spot_label)
# Middle area checks
h, w = img.shape[:2]
center = img[h//4:3*h//4, w//4:3*w//4]
center_hsv = cv2.cvtColor(center, cv2.COLOR_BGR2HSV)
center_yellow = cv2.inRange(center_hsv, (20, 130, 130), (35, 255, 255))
center_orange = cv2.inRange(center_hsv, (10, 100, 100), (25, 255, 255))
center_red1 = cv2.inRange(center_hsv, (0, 100, 100), (10, 255, 255))
center_red2 = cv2.inRange(center_hsv, (160, 100, 100), (180, 255, 255))
center_red = cv2.bitwise_or(center_red1, center_red2)
if np.sum(center_red > 0) + np.sum(center_orange > 0) > 0.1 * center.size:
label = 'Loose Joint (Faulty)'
box_list.append((w//4, h//4, w//2, h//2))
label_list.append(label)
elif np.sum(center_yellow > 0) > 0.1 * center.size:
label = 'Loose Joint (Potential)'
box_list.append((w//4, h//4, w//2, h//2))
label_list.append(label)
# Always check for tiny spots, even if image is labeled as Normal
min_area_tiny = 10
max_area_tiny = 30
for mask, spot_label in [
(red_mask, 'Tiny Faulty Spot'),
(yellow_mask, 'Tiny Potential Spot')
]:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
area = cv2.contourArea(cnt)
if min_area_tiny < area < max_area_tiny:
x, y, w, h = cv2.boundingRect(cnt)
box_list.append((x, y, w, h))
label_list.append(spot_label)
# Detect wire-shaped (long, thin) regions for wire overloads only
aspect_ratio_thresh = 5
min_strip_area = 0.01 * total
wire_boxes = []
wire_labels = []
for mask, strip_label in [
(red_mask, 'Wire Overload (Red Strip)'),
(yellow_mask, 'Wire Overload (Yellow Strip)'),
(orange_mask, 'Wire Overload (Orange Strip)')
]:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
area = cv2.contourArea(cnt)
if area > min_strip_area:
x, y, w, h = cv2.boundingRect(cnt)
aspect_ratio = max(w, h) / (min(w, h) + 1e-6)
if aspect_ratio > aspect_ratio_thresh:
wire_boxes.append((x, y, w, h))
wire_labels.append(strip_label)
# Add wire overloads to box_list/label_list
box_list = wire_boxes[:]
label_list = wire_labels[:]
# For point overloads, do not require wire shape
min_area_faulty = 120
min_area_potential = 1000
max_area = 0.05 * total
for mask, spot_label, min_a in [
(red_mask, 'Point Overload (Faulty)', min_area_faulty),
(yellow_mask, 'Point Overload (Potential)', min_area_potential)
]:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
area = cv2.contourArea(cnt)
if min_a < area < max_area:
x, y, w, h = cv2.boundingRect(cnt)
box_list.append((x, y, w, h))
label_list.append(spot_label)
# Remove overlapping boxes using IOU
box_list, label_list = non_max_suppression_iou(box_list, label_list, iou_thresh=0.4)
box_list, label_list = filter_faulty_inside_potential(box_list, label_list)
box_list, label_list = filter_faulty_overlapping_potential(box_list, label_list)
box_list, label_list = merge_close_boxes(box_list, label_list, dist_thresh=100)
return label, box_list, label_list, img
# Batch process all images in the directory
for fname in os.listdir(dir_path):
if not fname.lower().endswith(('.jpg', '.jpeg', '.png')):
continue
label, box_list, label_list, img = classify_image(os.path.join(dir_path, fname))
# Load the original (unfiltered) image for drawing boxes
orig_dir = 'api_inference_pred_masks' # or the directory with original images
orig_img_path = os.path.join(orig_dir, fname)
if os.path.exists(orig_img_path):
draw_img = cv2.imread(orig_img_path)
if draw_img is None:
draw_img = img.copy()
else:
draw_img = img.copy()
# Draw bounding boxes and labels on the original image
for (x, y, w, h), l in zip(box_list, label_list):
cv2.rectangle(draw_img, (x, y), (x+w, y+h), (0, 0, 255), 2)
cv2.putText(draw_img, l, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
if not box_list:
cv2.putText(draw_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
out_path = os.path.join(output_dir, fname)
cv2.imwrite(out_path, draw_img)
print(f"{fname}: {label} (saved with boxes on original image)")