BAT / note.py
aiyubali's picture
planogram updated
cb90fd0
from Data.config import *
import logging
class SortModel:
def __init__(self, model, classes_to_delete=classes_to_delete, conf_threshold=det_conf, expected_segments=expected_segments):
self.model = model
self.classes_to_delete = classes_to_delete
self.conf_threshold = conf_threshold
self.expected_segments = expected_segments
def get_center(self, detection):
_, x_min, y_min, x_max, y_max = detection
center_x = (x_min + x_max) / 2
center_y = (y_min + y_max) / 2
return center_x, center_y
def sort_and_group_detections(self, detections):
detections_with_centers = [(d[0], *self.get_center(d)) for d in detections]
sorted_detections = sorted(detections_with_centers, key=lambda x: (x[1], x[2]))
if not sorted_detections:
return []
threshold_x = (sorted_detections[0][1] * 0.5)
rows = []
current_row = []
current_x = sorted_detections[0][1]
for detection in sorted_detections:
class_name, center_x, center_y = detection
if abs(center_x - current_x) > threshold_x:
rows.append(sorted(current_row, key=lambda x: x[2]))
current_row = []
current_x = center_x
current_row.append(detection)
if current_row:
rows.append(sorted(current_row, key=lambda x: x[2]))
max_columns = max(len(row) for row in rows) if rows else 0
grid_matrix = []
for row in rows:
grid_row = [d[0] for d in row]
grid_row.extend([''] * (max_columns - len(row)))
grid_matrix.append(grid_row)
transposed_matrix = list(map(list, zip(*grid_matrix)))
return transposed_matrix
@staticmethod
def sequence(matrix, expected_segments):
if not expected_segments:
return False
for sequence in matrix:
segment_index = 0
for item in sequence:
if item == expected_segments[segment_index]:
continue
elif segment_index < len(expected_segments) - 1 and item == expected_segments[segment_index + 1]:
segment_index += 1
else:
return False
if segment_index != len(expected_segments) - 1:
return False
return True
def process_image(self, image_path, predicted_class):
# Run the detection on an image
results = self.model(image_path)
planogram_ghw_count = 0
planogram_blanks_count = 0
planogram_ghw = ''
planogram_valid_sequence = ''
yolo_detections = []
for result in results:
for box in result.boxes:
class_name = self.model.names[int(box.cls[0])]
if class_name in self.classes_to_delete:
planogram_ghw_count += 1
continue
if box.conf[0] >= self.conf_threshold:
x_min, y_min, x_max, y_max = box.xyxy[0]
yolo_detections.append((class_name, x_min.item(), y_min.item(), x_max.item(), y_max.item()))
planogram_blanks_count = len(yolo_detections)
# Sort and group detections
grid_matrix = self.sort_and_group_detections(yolo_detections)
print(f"ddddddddd{yolo_detections}")
# Print the matrix
print("Grid Matrix:")
for row in grid_matrix:
print(row)
print("\n")
if planogram_blanks_count == planogram_ghw_count:
planogram_ghw = "yes"
else:
planogram_ghw = "no"
planogram_valid_sequence = self.sequence(grid_matrix, self.expected_segments)
if planogram_valid_sequence:
planogram_valid_sequence = "y"
else:
planogram_valid_sequence = "n/a"
return {
"Planogram Blanks Count": planogram_blanks_count,
"Planogram GHW Count": planogram_ghw_count,
"Planogram Valid Sequence": planogram_valid_sequence,
"Planogram GHW": planogram_ghw,
"Class Name": predicted_class
}