| import cv2 | |
| import matplotlib.pyplot as plt | |
| from PIL import ImageColor | |
| from pathlib import Path | |
| import os | |
| def annotate_image_prediction(image_path, yolo_boxes, class_dic, saving_folder, hex_class_colors=None, show=False, true_count=False, saving_image_name=None, put_title=True, box_thickness=3, font_scale=1, font_thickness=5): | |
| """ | |
| Fonction to label individual images with YOLO predictions | |
| Args: | |
| image_path (str): path to the image to label | |
| yolo_boxes (str): YOLO predicted boxes | |
| class_dic (dict): dictionary with predicted class as key and corresponding label as value | |
| saving_folder (str): folder where to save the annotated image | |
| hex_class_colors (dict, optional): HEX color code dict of the class to plot. Defaults to None. | |
| show (bool, optional): If you want a window of the annotated image to pop up. Defaults to False. | |
| true_count (bool, optional): If you want to display the true total count of cherries. Defaults to None. | |
| saving_image_name (str, optional): Name of the annotated image to save. Defaults to None. | |
| put_title (bool, optional): If you want a title to show in the plot. Defaults to True. | |
| box_thickness (int, optional): Thickness of the bounding boxes to plot. Defaults to 3. | |
| font_scale (int, optional): Font scale of the text of counts to be displayed. Defaults to 1. | |
| font_thickness (int, optional): Font thickness of the text of counts to be displayed. Defaults to 5. | |
| Returns: | |
| string: saving path of the annotated image | |
| """ | |
| if os.path.isfile(image_path): | |
| Path(saving_folder).mkdir(parents=True, exist_ok=True) | |
| image_file = image_path.split('/')[-1] | |
| if not hex_class_colors: | |
| hex_class_colors = {class_name: (255, 0, 0) for class_name in class_dic.values()} | |
| color_map = {key: ImageColor.getcolor(hex_class_colors[class_dic[key]], 'RGB') for key in [*class_dic]} | |
| img = cv2.imread(image_path) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| dh, dw, _ = img.shape | |
| for yolo_box in yolo_boxes: | |
| x1, y1, x2, y2 = yolo_box.xyxy[0] | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| c = int(yolo_box.cls[0]) | |
| cv2.rectangle(img, (x1, y1), (x2, y2), color_map[c], box_thickness) | |
| if show: | |
| plt.imshow(img) | |
| plt.show() | |
| img_copy = img.copy() | |
| if put_title: | |
| if true_count: | |
| title = f'Predicted count: {len(yolo_boxes)}, true count: {true_count}, delta: {len(yolo_boxes) - true_count}' | |
| else: | |
| title = f'Predicted count: {len(yolo_boxes)}' | |
| cv2.putText( | |
| img=img_copy, | |
| text=title, | |
| org=(int(0.1 * dw), int(0.1 * dh)), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
| fontScale=font_scale, | |
| thickness=font_thickness, | |
| color=(255,251,5), | |
| ) | |
| if not saving_image_name: | |
| saving_image_name = f'annotated_{image_file}' | |
| Path(saving_folder).mkdir(parents=True, exist_ok=True) | |
| full_saving_path = os.path.join(saving_folder, saving_image_name) | |
| plt.imsave(full_saving_path, img_copy) | |
| else: | |
| full_saving_path = None | |
| print(f'WARNING: {image_path} does not exists') | |
| return full_saving_path | |