File size: 4,241 Bytes
cb90fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from Data.config import *
import logging


class SortModel:
    def __init__(self, model, classes_to_delete=classes_to_delete, conf_threshold=det_conf, expected_segments=expected_segments):
        self.model = model
        self.classes_to_delete = classes_to_delete
        self.conf_threshold = conf_threshold
        self.expected_segments = expected_segments

    def get_center(self, detection):
        _, x_min, y_min, x_max, y_max = detection
        center_x = (x_min + x_max) / 2
        center_y = (y_min + y_max) / 2
        return center_x, center_y

    def sort_and_group_detections(self, detections):
        detections_with_centers = [(d[0], *self.get_center(d)) for d in detections]
        sorted_detections = sorted(detections_with_centers, key=lambda x: (x[1], x[2]))

        if not sorted_detections:
            return []

        threshold_x = (sorted_detections[0][1] * 0.5)
        rows = []
        current_row = []
        current_x = sorted_detections[0][1]

        for detection in sorted_detections:
            class_name, center_x, center_y = detection

            if abs(center_x - current_x) > threshold_x:
                rows.append(sorted(current_row, key=lambda x: x[2]))
                current_row = []
                current_x = center_x

            current_row.append(detection)

        if current_row:
            rows.append(sorted(current_row, key=lambda x: x[2]))

        max_columns = max(len(row) for row in rows) if rows else 0
        grid_matrix = []
        for row in rows:
            grid_row = [d[0] for d in row]
            grid_row.extend([''] * (max_columns - len(row)))
            grid_matrix.append(grid_row)

        transposed_matrix = list(map(list, zip(*grid_matrix)))
        return transposed_matrix

    @staticmethod
    def sequence(matrix, expected_segments):
        if not expected_segments:
            return False

        for sequence in matrix:
            segment_index = 0
            for item in sequence:
                if item == expected_segments[segment_index]:
                    continue
                elif segment_index < len(expected_segments) - 1 and item == expected_segments[segment_index + 1]:
                    segment_index += 1
                else:
                    return False
            if segment_index != len(expected_segments) - 1:
                return False
        return True

    def process_image(self, image_path, predicted_class):
        # Run the detection on an image
        results = self.model(image_path)
    
        planogram_ghw_count = 0
        planogram_blanks_count = 0
        planogram_ghw = ''
        planogram_valid_sequence = ''
        yolo_detections = []

        for result in results:
            for box in result.boxes:
                class_name = self.model.names[int(box.cls[0])]
                
                if class_name in self.classes_to_delete:
                    planogram_ghw_count += 1
                    continue

                if box.conf[0] >= self.conf_threshold:
                    x_min, y_min, x_max, y_max = box.xyxy[0]
                    yolo_detections.append((class_name, x_min.item(), y_min.item(), x_max.item(), y_max.item()))
        
        planogram_blanks_count = len(yolo_detections)

        # Sort and group detections
        grid_matrix = self.sort_and_group_detections(yolo_detections)
        print(f"ddddddddd{yolo_detections}")
        
        # Print the matrix
        print("Grid Matrix:")
        for row in grid_matrix:
            print(row)
        print("\n")

        if planogram_blanks_count == planogram_ghw_count:
            planogram_ghw = "yes"
        else:
            planogram_ghw = "no"

        planogram_valid_sequence = self.sequence(grid_matrix, self.expected_segments)
        if planogram_valid_sequence:
            planogram_valid_sequence = "y"
        else:
            planogram_valid_sequence = "n/a"

        return {
            "Planogram Blanks Count": planogram_blanks_count,
            "Planogram GHW Count": planogram_ghw_count,
            "Planogram Valid Sequence": planogram_valid_sequence,
            "Planogram GHW": planogram_ghw,
            "Class Name": predicted_class
        }