sunjuice's picture
initial commit
a88f841
def iou(box1, box2):
x1_min = box1[0] - box1[2] / 2
y1_min = box1[1] - box1[3] / 2
x1_max = box1[0] + box1[2] / 2
y1_max = box1[1] + box1[3] / 2
x2_min = box2[0] - box2[2] / 2
y2_min = box2[1] - box2[3] / 2
x2_max = box2[0] + box2[2] / 2
y2_max = box2[1] + box2[3] / 2
inter_x_min = max(x1_min, x2_min)
inter_y_min = max(y1_min, y2_min)
inter_x_max = min(x1_max, x2_max)
inter_y_max = min(y1_max, y2_max)
inter_w = max(0.0, inter_x_max - inter_x_min)
inter_h = max(0.0, inter_y_max - inter_y_min)
inter_area = inter_w * inter_h
area1 = box1[2] * box1[3]
area2 = box2[2] * box2[3]
union_area = area1 + area2 - inter_area
return inter_area / union_area
def non_max_suppression(boxes, scores, iou_threshold):
indices = list(range(len(boxes)))
indices.sort(key=lambda i: scores[i], reverse=True)
keep = []
while indices:
current = indices.pop(0)
keep.append(current)
indices = [
i for i in indices
if iou(boxes[current], boxes[i]) < iou_threshold
]
return keep