from Data.config import * import logging class SortModel: def __init__(self, model, classes_to_delete=classes_to_delete, conf_threshold=det_conf, expected_segments=expected_segments): self.model = model self.classes_to_delete = classes_to_delete self.conf_threshold = conf_threshold self.expected_segments = expected_segments def get_center(self, detection): _, x_min, y_min, x_max, y_max = detection center_x = (x_min + x_max) / 2 center_y = (y_min + y_max) / 2 return center_x, center_y def sort_and_group_detections(self, detections): detections_with_centers = [(d[0], *self.get_center(d)) for d in detections] sorted_detections = sorted(detections_with_centers, key=lambda x: (x[1], x[2])) if not sorted_detections: return [] threshold_x = (sorted_detections[0][1] * 0.5) rows = [] current_row = [] current_x = sorted_detections[0][1] for detection in sorted_detections: class_name, center_x, center_y = detection if abs(center_x - current_x) > threshold_x: rows.append(sorted(current_row, key=lambda x: x[2])) current_row = [] current_x = center_x current_row.append(detection) if current_row: rows.append(sorted(current_row, key=lambda x: x[2])) max_columns = max(len(row) for row in rows) if rows else 0 grid_matrix = [] for row in rows: grid_row = [d[0] for d in row] grid_row.extend([''] * (max_columns - len(row))) grid_matrix.append(grid_row) transposed_matrix = list(map(list, zip(*grid_matrix))) return transposed_matrix @staticmethod def sequence(matrix, expected_segments): if not expected_segments: return False for sequence in matrix: segment_index = 0 for item in sequence: if item == expected_segments[segment_index]: continue elif segment_index < len(expected_segments) - 1 and item == expected_segments[segment_index + 1]: segment_index += 1 else: return False if segment_index != len(expected_segments) - 1: return False return True def process_image(self, image_path, predicted_class): # Run the detection on an image results = self.model(image_path) planogram_ghw_count = 0 planogram_blanks_count = 0 planogram_ghw = '' planogram_valid_sequence = '' yolo_detections = [] for result in results: for box in result.boxes: class_name = self.model.names[int(box.cls[0])] if class_name in self.classes_to_delete: planogram_ghw_count += 1 continue if box.conf[0] >= self.conf_threshold: x_min, y_min, x_max, y_max = box.xyxy[0] yolo_detections.append((class_name, x_min.item(), y_min.item(), x_max.item(), y_max.item())) planogram_blanks_count = len(yolo_detections) # Sort and group detections grid_matrix = self.sort_and_group_detections(yolo_detections) print(f"ddddddddd{yolo_detections}") # Print the matrix print("Grid Matrix:") for row in grid_matrix: print(row) print("\n") if planogram_blanks_count == planogram_ghw_count: planogram_ghw = "yes" else: planogram_ghw = "no" planogram_valid_sequence = self.sequence(grid_matrix, self.expected_segments) if planogram_valid_sequence: planogram_valid_sequence = "y" else: planogram_valid_sequence = "n/a" return { "Planogram Blanks Count": planogram_blanks_count, "Planogram GHW Count": planogram_ghw_count, "Planogram Valid Sequence": planogram_valid_sequence, "Planogram GHW": planogram_ghw, "Class Name": predicted_class }