File size: 4,958 Bytes
f7fbe80
 
df3522f
f7fbe80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df3522f
 
f7fbe80
 
 
 
 
 
 
 
df3522f
 
 
 
 
 
f7fbe80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df3522f
f7fbe80
df3522f
 
f7fbe80
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import numpy as np
from PIL import Image
from pathlib import Path
import supervision as sv

# from ultralytics import YOLO
from doclayout_yolo import YOLOv10
from src.config import YOLO_MODEL_PATH, log
from src.utils import binarize

log.debug("loading YOLO Model")
LAYOUT_MODEL = YOLOv10(YOLO_MODEL_PATH)

CLASSES = [
    "Caption",
    "Footnote",
    "Formula",
    "List-item",
    "Page-footer",
    "Page-header",
    "Picture",
    "Section-header",
    "Table",
    "Text",
    "Title",
    "Unknown",
]

# Define a custom color palette for each class
CLASS_COLORS = [
    sv.Color(255, 0, 0),  # Red for "Caption"
    sv.Color(0, 255, 0),  # Green for "Footnote"
    sv.Color(0, 0, 255),  # Blue for "Formula"
    sv.Color(255, 255, 0),  # Yellow for "List-item"
    sv.Color(255, 0, 255),  # Magenta for "Page-footer"
    sv.Color(0, 255, 255),  # Cyan for "Page-header"
    sv.Color(128, 0, 128),  # Purple for "Picture"
    sv.Color(128, 128, 0),  # Olive for "Section-header"
    sv.Color(128, 128, 128),  # Gray for "Table"
    sv.Color(0, 128, 128),  # Teal for "Text"
    sv.Color(128, 0, 0),  # Maroon for "Title"
    sv.Color(255, 255, 255),  # Maroon for "Unknown"
]

# Initialize the BoxAnnotator with the custom color palette and increased thickness
box_annotator = sv.BoxAnnotator(
    color=sv.ColorPalette(CLASS_COLORS),
    thickness=2,  # Increased thickness for bounding boxes
)

# Initialize the LabelAnnotator with custom background and text colors
label_annotator = sv.LabelAnnotator(
    color=sv.ColorPalette(CLASS_COLORS),  # Background colors matching bounding boxes
    text_color=sv.Color(255, 255, 255),  # White text for better readability
)


def detect(img, conf=0.2, iou=0.8, labels=False, plot=False, model=LAYOUT_MODEL):
    # Object detection on image
    results = model(img, conf=conf, iou=iou, verbose=False)[0]
    # Convert results to detections
    detections = sv.Detections.from_ultralytics(results)
    # Annotate the image with bounding boxes
    annotated_image = box_annotator.annotate(scene=img, detections=detections)
    if labels:
        # Annotate the image with labels
        annotated_image = label_annotator.annotate(
            scene=annotated_image, detections=detections
        )

    if plot:
        sv.plot_image(annotated_image)

    return detections


def sort_coords(xyxy):
    return xyxy[np.argsort(xyxy[:, 0])]


def get_labels_with_confidence(detections):
    return list(
        zip(detections.data["class_name"].tolist(), detections.confidence.tolist())
    )


def rectangle_area(*coords):
    bottom_left, top_right = coords[:2], coords[2:]
    x1, y1 = bottom_left
    x2, y2 = top_right
    width = abs(x2 - x1)
    height = abs(y2 - y1)
    return width * height


def get_mask(img_h, img_w, detections, class_colors=CLASS_COLORS, inc=50):
    size = (img_h, img_w)
    areas = np.array([float(rectangle_area(*row)) for row in detections.xyxy])
    order = np.argsort(areas)[::-1]
    mask_arr = np.zeros((*size, 3), dtype=np.uint8)
    coords = np.round(detections.xyxy).astype(np.int32)
    for row, class_id in zip(coords[order], detections.class_id[order]):
        a, b, c, d = row.tolist()
        b, d = size[0] - b, size[0] - d
        rgb = class_colors[class_id].as_rgb()
        mask_arr[d : b + 1, a : c + 1, :] = rgb

    mask = Image.fromarray(mask_arr[::-1])
    return mask


def merge_masks(masks, weights=None):
    if len(masks) == 1:
        raise ValueError("more than 1 mask needed to merge")
    masks = [np.array(mask) for mask in masks]
    if weights is None:
        weights = [1 / len(masks)] * len(masks)
    final_mask = np.zeros_like(masks[0], dtype=float)
    for mask, weight in zip(masks, weights):
        final_mask += mask * weight
    return Image.fromarray(final_mask.astype("uint8"))


def get_merged_mask(
    fp: Path = None,
    img: Image = None,
    conf=0.2,
    iou=0.8,
    inc=50,
    mask_weights=None,
    resize_dim=512,
    model=LAYOUT_MODEL,
    binarized=False,
):
    if (fp is None) == (img is None):
        raise ValueError(f"only one of `fp` or `img` is required")
    if fp:
        log.debug(f"getting merged mask for file {fp}")
        img = Image.open(fp)

    images, masks = [], []
    images.extend(
        [
            img,
            img.convert("L"),
            # any other extra transformations
        ]
    )
    if mask_weights is not None:
        mask_weights /= mask_weights.sum()

    for image in images:
        dc = detect(image, conf=conf, iou=iou, plot=False, model=model)
        masks.append(get_mask(*image.size[::-1], dc, inc=inc))

    merged_masks = merge_masks(masks, weights=mask_weights)
    if binarized:
        merged_masks[merged_masks > 0] = 255
    final_mask = merged_masks.resize((resize_dim, resize_dim))
    if binarized:
        final_mask_arr = binarize(np.array(final_mask)) * 255
        final_mask = Image.fromarray(final_mask_arr)

    return final_mask