Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| import supervision as sv | |
| # from ultralytics import YOLO | |
| from doclayout_yolo import YOLOv10 | |
| from src.config import YOLO_MODEL_PATH, log | |
| from src.utils import binarize | |
| log.debug("loading YOLO Model") | |
| LAYOUT_MODEL = YOLOv10(YOLO_MODEL_PATH) | |
| CLASSES = [ | |
| "Caption", | |
| "Footnote", | |
| "Formula", | |
| "List-item", | |
| "Page-footer", | |
| "Page-header", | |
| "Picture", | |
| "Section-header", | |
| "Table", | |
| "Text", | |
| "Title", | |
| "Unknown", | |
| ] | |
| # Define a custom color palette for each class | |
| CLASS_COLORS = [ | |
| sv.Color(255, 0, 0), # Red for "Caption" | |
| sv.Color(0, 255, 0), # Green for "Footnote" | |
| sv.Color(0, 0, 255), # Blue for "Formula" | |
| sv.Color(255, 255, 0), # Yellow for "List-item" | |
| sv.Color(255, 0, 255), # Magenta for "Page-footer" | |
| sv.Color(0, 255, 255), # Cyan for "Page-header" | |
| sv.Color(128, 0, 128), # Purple for "Picture" | |
| sv.Color(128, 128, 0), # Olive for "Section-header" | |
| sv.Color(128, 128, 128), # Gray for "Table" | |
| sv.Color(0, 128, 128), # Teal for "Text" | |
| sv.Color(128, 0, 0), # Maroon for "Title" | |
| sv.Color(255, 255, 255), # Maroon for "Unknown" | |
| ] | |
| # Initialize the BoxAnnotator with the custom color palette and increased thickness | |
| box_annotator = sv.BoxAnnotator( | |
| color=sv.ColorPalette(CLASS_COLORS), | |
| thickness=2, # Increased thickness for bounding boxes | |
| ) | |
| # Initialize the LabelAnnotator with custom background and text colors | |
| label_annotator = sv.LabelAnnotator( | |
| color=sv.ColorPalette(CLASS_COLORS), # Background colors matching bounding boxes | |
| text_color=sv.Color(255, 255, 255), # White text for better readability | |
| ) | |
| def detect(img, conf=0.2, iou=0.8, labels=False, plot=False, model=LAYOUT_MODEL): | |
| # Object detection on image | |
| results = model(img, conf=conf, iou=iou, verbose=False)[0] | |
| # Convert results to detections | |
| detections = sv.Detections.from_ultralytics(results) | |
| # Annotate the image with bounding boxes | |
| annotated_image = box_annotator.annotate(scene=img, detections=detections) | |
| if labels: | |
| # Annotate the image with labels | |
| annotated_image = label_annotator.annotate( | |
| scene=annotated_image, detections=detections | |
| ) | |
| if plot: | |
| sv.plot_image(annotated_image) | |
| return detections | |
| def sort_coords(xyxy): | |
| return xyxy[np.argsort(xyxy[:, 0])] | |
| def get_labels_with_confidence(detections): | |
| return list( | |
| zip(detections.data["class_name"].tolist(), detections.confidence.tolist()) | |
| ) | |
| def rectangle_area(*coords): | |
| bottom_left, top_right = coords[:2], coords[2:] | |
| x1, y1 = bottom_left | |
| x2, y2 = top_right | |
| width = abs(x2 - x1) | |
| height = abs(y2 - y1) | |
| return width * height | |
| def get_mask(img_h, img_w, detections, class_colors=CLASS_COLORS, inc=50): | |
| size = (img_h, img_w) | |
| areas = np.array([float(rectangle_area(*row)) for row in detections.xyxy]) | |
| order = np.argsort(areas)[::-1] | |
| mask_arr = np.zeros((*size, 3), dtype=np.uint8) | |
| coords = np.round(detections.xyxy).astype(np.int32) | |
| for row, class_id in zip(coords[order], detections.class_id[order]): | |
| a, b, c, d = row.tolist() | |
| b, d = size[0] - b, size[0] - d | |
| rgb = class_colors[class_id].as_rgb() | |
| mask_arr[d : b + 1, a : c + 1, :] = rgb | |
| mask = Image.fromarray(mask_arr[::-1]) | |
| return mask | |
| def merge_masks(masks, weights=None): | |
| if len(masks) == 1: | |
| raise ValueError("more than 1 mask needed to merge") | |
| masks = [np.array(mask) for mask in masks] | |
| if weights is None: | |
| weights = [1 / len(masks)] * len(masks) | |
| final_mask = np.zeros_like(masks[0], dtype=float) | |
| for mask, weight in zip(masks, weights): | |
| final_mask += mask * weight | |
| return Image.fromarray(final_mask.astype("uint8")) | |
| def get_merged_mask( | |
| fp: Path = None, | |
| img: Image = None, | |
| conf=0.2, | |
| iou=0.8, | |
| inc=50, | |
| mask_weights=None, | |
| resize_dim=512, | |
| model=LAYOUT_MODEL, | |
| binarized=False, | |
| ): | |
| if (fp is None) == (img is None): | |
| raise ValueError(f"only one of `fp` or `img` is required") | |
| if fp: | |
| log.debug(f"getting merged mask for file {fp}") | |
| img = Image.open(fp) | |
| images, masks = [], [] | |
| images.extend( | |
| [ | |
| img, | |
| img.convert("L"), | |
| # any other extra transformations | |
| ] | |
| ) | |
| if mask_weights is not None: | |
| mask_weights /= mask_weights.sum() | |
| for image in images: | |
| dc = detect(image, conf=conf, iou=iou, plot=False, model=model) | |
| masks.append(get_mask(*image.size[::-1], dc, inc=inc)) | |
| merged_masks = merge_masks(masks, weights=mask_weights) | |
| if binarized: | |
| merged_masks[merged_masks > 0] = 255 | |
| final_mask = merged_masks.resize((resize_dim, resize_dim)) | |
| if binarized: | |
| final_mask_arr = binarize(np.array(final_mask)) * 255 | |
| final_mask = Image.fromarray(final_mask_arr) | |
| return final_mask | |