import numpy as np from PIL import Image from pathlib import Path import supervision as sv # from ultralytics import YOLO from doclayout_yolo import YOLOv10 from src.config import YOLO_MODEL_PATH, log from src.utils import binarize log.debug("loading YOLO Model") LAYOUT_MODEL = YOLOv10(YOLO_MODEL_PATH) CLASSES = [ "Caption", "Footnote", "Formula", "List-item", "Page-footer", "Page-header", "Picture", "Section-header", "Table", "Text", "Title", "Unknown", ] # Define a custom color palette for each class CLASS_COLORS = [ sv.Color(255, 0, 0), # Red for "Caption" sv.Color(0, 255, 0), # Green for "Footnote" sv.Color(0, 0, 255), # Blue for "Formula" sv.Color(255, 255, 0), # Yellow for "List-item" sv.Color(255, 0, 255), # Magenta for "Page-footer" sv.Color(0, 255, 255), # Cyan for "Page-header" sv.Color(128, 0, 128), # Purple for "Picture" sv.Color(128, 128, 0), # Olive for "Section-header" sv.Color(128, 128, 128), # Gray for "Table" sv.Color(0, 128, 128), # Teal for "Text" sv.Color(128, 0, 0), # Maroon for "Title" sv.Color(255, 255, 255), # Maroon for "Unknown" ] # Initialize the BoxAnnotator with the custom color palette and increased thickness box_annotator = sv.BoxAnnotator( color=sv.ColorPalette(CLASS_COLORS), thickness=2, # Increased thickness for bounding boxes ) # Initialize the LabelAnnotator with custom background and text colors label_annotator = sv.LabelAnnotator( color=sv.ColorPalette(CLASS_COLORS), # Background colors matching bounding boxes text_color=sv.Color(255, 255, 255), # White text for better readability ) def detect(img, conf=0.2, iou=0.8, labels=False, plot=False, model=LAYOUT_MODEL): # Object detection on image results = model(img, conf=conf, iou=iou, verbose=False)[0] # Convert results to detections detections = sv.Detections.from_ultralytics(results) # Annotate the image with bounding boxes annotated_image = box_annotator.annotate(scene=img, detections=detections) if labels: # Annotate the image with labels annotated_image = label_annotator.annotate( scene=annotated_image, detections=detections ) if plot: sv.plot_image(annotated_image) return detections def sort_coords(xyxy): return xyxy[np.argsort(xyxy[:, 0])] def get_labels_with_confidence(detections): return list( zip(detections.data["class_name"].tolist(), detections.confidence.tolist()) ) def rectangle_area(*coords): bottom_left, top_right = coords[:2], coords[2:] x1, y1 = bottom_left x2, y2 = top_right width = abs(x2 - x1) height = abs(y2 - y1) return width * height def get_mask(img_h, img_w, detections, class_colors=CLASS_COLORS, inc=50): size = (img_h, img_w) areas = np.array([float(rectangle_area(*row)) for row in detections.xyxy]) order = np.argsort(areas)[::-1] mask_arr = np.zeros((*size, 3), dtype=np.uint8) coords = np.round(detections.xyxy).astype(np.int32) for row, class_id in zip(coords[order], detections.class_id[order]): a, b, c, d = row.tolist() b, d = size[0] - b, size[0] - d rgb = class_colors[class_id].as_rgb() mask_arr[d : b + 1, a : c + 1, :] = rgb mask = Image.fromarray(mask_arr[::-1]) return mask def merge_masks(masks, weights=None): if len(masks) == 1: raise ValueError("more than 1 mask needed to merge") masks = [np.array(mask) for mask in masks] if weights is None: weights = [1 / len(masks)] * len(masks) final_mask = np.zeros_like(masks[0], dtype=float) for mask, weight in zip(masks, weights): final_mask += mask * weight return Image.fromarray(final_mask.astype("uint8")) def get_merged_mask( fp: Path = None, img: Image = None, conf=0.2, iou=0.8, inc=50, mask_weights=None, resize_dim=512, model=LAYOUT_MODEL, binarized=False, ): if (fp is None) == (img is None): raise ValueError(f"only one of `fp` or `img` is required") if fp: log.debug(f"getting merged mask for file {fp}") img = Image.open(fp) images, masks = [], [] images.extend( [ img, img.convert("L"), # any other extra transformations ] ) if mask_weights is not None: mask_weights /= mask_weights.sum() for image in images: dc = detect(image, conf=conf, iou=iou, plot=False, model=model) masks.append(get_mask(*image.size[::-1], dc, inc=inc)) merged_masks = merge_masks(masks, weights=mask_weights) if binarized: merged_masks[merged_masks > 0] = 255 final_mask = merged_masks.resize((resize_dim, resize_dim)) if binarized: final_mask_arr = binarize(np.array(final_mask)) * 255 final_mask = Image.fromarray(final_mask_arr) return final_mask