File size: 4,241 Bytes
cb90fd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from Data.config import *
import logging
class SortModel:
def __init__(self, model, classes_to_delete=classes_to_delete, conf_threshold=det_conf, expected_segments=expected_segments):
self.model = model
self.classes_to_delete = classes_to_delete
self.conf_threshold = conf_threshold
self.expected_segments = expected_segments
def get_center(self, detection):
_, x_min, y_min, x_max, y_max = detection
center_x = (x_min + x_max) / 2
center_y = (y_min + y_max) / 2
return center_x, center_y
def sort_and_group_detections(self, detections):
detections_with_centers = [(d[0], *self.get_center(d)) for d in detections]
sorted_detections = sorted(detections_with_centers, key=lambda x: (x[1], x[2]))
if not sorted_detections:
return []
threshold_x = (sorted_detections[0][1] * 0.5)
rows = []
current_row = []
current_x = sorted_detections[0][1]
for detection in sorted_detections:
class_name, center_x, center_y = detection
if abs(center_x - current_x) > threshold_x:
rows.append(sorted(current_row, key=lambda x: x[2]))
current_row = []
current_x = center_x
current_row.append(detection)
if current_row:
rows.append(sorted(current_row, key=lambda x: x[2]))
max_columns = max(len(row) for row in rows) if rows else 0
grid_matrix = []
for row in rows:
grid_row = [d[0] for d in row]
grid_row.extend([''] * (max_columns - len(row)))
grid_matrix.append(grid_row)
transposed_matrix = list(map(list, zip(*grid_matrix)))
return transposed_matrix
@staticmethod
def sequence(matrix, expected_segments):
if not expected_segments:
return False
for sequence in matrix:
segment_index = 0
for item in sequence:
if item == expected_segments[segment_index]:
continue
elif segment_index < len(expected_segments) - 1 and item == expected_segments[segment_index + 1]:
segment_index += 1
else:
return False
if segment_index != len(expected_segments) - 1:
return False
return True
def process_image(self, image_path, predicted_class):
# Run the detection on an image
results = self.model(image_path)
planogram_ghw_count = 0
planogram_blanks_count = 0
planogram_ghw = ''
planogram_valid_sequence = ''
yolo_detections = []
for result in results:
for box in result.boxes:
class_name = self.model.names[int(box.cls[0])]
if class_name in self.classes_to_delete:
planogram_ghw_count += 1
continue
if box.conf[0] >= self.conf_threshold:
x_min, y_min, x_max, y_max = box.xyxy[0]
yolo_detections.append((class_name, x_min.item(), y_min.item(), x_max.item(), y_max.item()))
planogram_blanks_count = len(yolo_detections)
# Sort and group detections
grid_matrix = self.sort_and_group_detections(yolo_detections)
print(f"ddddddddd{yolo_detections}")
# Print the matrix
print("Grid Matrix:")
for row in grid_matrix:
print(row)
print("\n")
if planogram_blanks_count == planogram_ghw_count:
planogram_ghw = "yes"
else:
planogram_ghw = "no"
planogram_valid_sequence = self.sequence(grid_matrix, self.expected_segments)
if planogram_valid_sequence:
planogram_valid_sequence = "y"
else:
planogram_valid_sequence = "n/a"
return {
"Planogram Blanks Count": planogram_blanks_count,
"Planogram GHW Count": planogram_ghw_count,
"Planogram Valid Sequence": planogram_valid_sequence,
"Planogram GHW": planogram_ghw,
"Class Name": predicted_class
}
|