File size: 8,845 Bytes
5facae9
6c80703
da3baa6
9e99f59
 
da3baa6
 
9e99f59
 
 
 
 
da3baa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e99f59
 
 
 
da3baa6
9e99f59
da3baa6
 
 
 
 
9e99f59
da3baa6
9e99f59
da3baa6
 
 
 
 
 
 
 
 
 
 
9e99f59
 
 
da3baa6
 
 
 
 
 
9e99f59
 
 
da3baa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e99f59
da3baa6
 
 
 
 
9e99f59
da3baa6
 
9e99f59
 
 
 
 
 
 
 
 
da3baa6
9e99f59
 
 
 
 
 
 
da3baa6
 
 
 
 
9e99f59
 
 
 
 
 
da3baa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e99f59
 
 
da3baa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e99f59
 
 
 
 
 
da3baa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e99f59
da3baa6
9e99f59
da3baa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e99f59
 
 
 
 
 
 
da3baa6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import numpy.typing as npt
from typing import Dict, List, Tuple, Literal


def get_overlaps(
    boxes: npt.NDArray[np.float64],
    other_boxes: npt.NDArray[np.float64],
    normalize: Literal["box_only", "all"] = "box_only",
) -> npt.NDArray[np.float64]:
    """
    Checks if a box overlaps with any other box.
    Boxes are expeceted in format (x0, y0, x1, y1)

    Args:
        boxes (np array [4] or [n x 4]): Boxes.
        other_boxes (np array [m x 4]): Other boxes.

    Returns:
        np array [n x m]: Overlaps.
    """
    if boxes.ndim == 1:
        boxes = boxes[None, :]

    x0, y0, x1, y1 = (
        boxes[:, 0][:, None],
        boxes[:, 1][:, None],
        boxes[:, 2][:, None],
        boxes[:, 3][:, None],
    )
    areas = (y1 - y0) * (x1 - x0)

    x0_other, y0_other, x1_other, y1_other = (
        other_boxes[:, 0][None, :],
        other_boxes[:, 1][None, :],
        other_boxes[:, 2][None, :],
        other_boxes[:, 3][None, :],
    )
    areas_other = (y1_other - y0_other) * (x1_other - x0_other)

    # Intersection
    inter_y0 = np.maximum(y0, y0_other)
    inter_y1 = np.minimum(y1, y1_other)
    inter_x0 = np.maximum(x0, x0_other)
    inter_x1 = np.minimum(x1, x1_other)
    inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0)

    # Overlap
    if normalize == "box_only":  # Only consider box included in other box
        overlaps = inter_area / areas
    elif (
        normalize == "all"
    ):  # Consider box included in other box and other box included in box
        overlaps = inter_area / np.minimum(areas, areas_other[:, None])
    else:
        raise ValueError(f"Invalid normalization: {normalize}")
    return overlaps


def get_distances(
    title_boxes: npt.NDArray[np.float64], other_boxes: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
    """
    Computes the distances between title and table/chart boxes.
    Distance is computed as the sum of the vertical and horizontal distances.
    Horizontal distance uses min(boxes center dist, boxes left dist).
    Vertical distance uses min(top_title to bottom_other dists, bottom_title to top_other dists).

    Args:
        title_boxes (np array [n_titles x 4]): Title boxes.
        other_boxes (np array [n_other x 4]): Other boxes.

    Returns:
        np array [n_titles x n_other]: Distances between titles and other boxes.
    """
    x0_title, xc_title, y0_title, y1_title = (
        title_boxes[:, 0],
        (title_boxes[:, 0] + title_boxes[:, 2]) / 2,
        title_boxes[:, 1],
        title_boxes[:, 3],
    )
    x0_other, xc_other, y0_other, y1_other = (
        other_boxes[:, 0],
        (other_boxes[:, 0] + other_boxes[:, 2]) / 2,
        other_boxes[:, 1],
        other_boxes[:, 3],
    )

    x_dists = np.min(
        [
            np.abs(
                xc_title[:, None] - xc_other[None, :]
            ),  # Title center to other center
            np.abs(x0_title[:, None] - x0_other[None, :]),  # Title left to other left
        ],
        axis=0,
    )

    y_dists = np.min(
        [
            np.abs(y1_title[:, None] - y0_other[None, :]),  # Title above other
            np.abs(y0_title[:, None] - y1_other[None, :]),  # Title below other
        ],
        axis=0,
    )

    dists = y_dists + x_dists / 2
    return dists


def find_titles(
    title_boxes: npt.NDArray[np.float64],
    table_boxes: npt.NDArray[np.float64],
    chart_boxes: npt.NDArray[np.float64],
    max_dist: float = 0.1,
) -> Dict[int, Tuple[str, int]]:
    """
    Associates titles to tables and charts.

    Args:
        title_boxes (np array [n_titles x 4]): Title boxes.
        table_boxes (np array [n_tables x 4]): Table boxes.
        chart_boxes (np array [n_charts x 4]): Chart boxes.
        max_dist (float, optional): Maximum distance between title and table/chart. Defaults to 0.1.

    Returns:
        dict: Dictionary of assigned titles.
            - Keys are the indices of the titles,
            - Values are tuples of:
                - str: Whether the title is assigned to a "chart" or "table"
                - int: index of the assigned table/chart
    """
    if not len(title_boxes) or not (len(table_boxes) or len(chart_boxes)):
        return {}

    # print(title_boxes.shape, table_boxes.shape, chart_boxes.shape)

    # Get distances
    chart_distances = np.ones((len(title_boxes), 0))
    if len(chart_boxes):
        chart_distances = get_distances(title_boxes, chart_boxes)
        chart_overlaps = get_overlaps(title_boxes, chart_boxes, normalize="box_only")
        # print(chart_overlaps, "chart_overlaps", chart_overlaps.shape)
        # print(chart_distances, "chart_distances", chart_distances.shape)
        chart_distances = np.where(chart_overlaps > 0.25, 0, chart_distances)

    # print(chart_distances)

    table_distances = np.ones((len(title_boxes), 0))
    if len(table_boxes):
        table_distances = get_distances(title_boxes, table_boxes)
        if len(chart_boxes):  # Penalize table titles that are inside charts
            table_distances = np.where(
                chart_overlaps.max(1, keepdims=True) > 0.25,
                table_distances * 10,
                table_distances,
            )

    # print(table_distances, "table_distances")

    # Assign to tables
    assigned_titles = {}
    for i, table in enumerate(table_boxes):
        best_match = np.argmin(table_distances[:, i])
        if table_distances[best_match, i] < max_dist:
            assigned_titles[best_match] = ("table", i)
            table_distances[best_match] = np.inf
            chart_distances[best_match] = np.inf

    # Assign to charts
    for i, chart in enumerate(chart_boxes):
        best_match = np.argmin(chart_distances[:, i])
        if chart_distances[best_match, i] < max_dist:
            assigned_titles[best_match] = ("chart", i)
            chart_distances[best_match] = np.inf

    return assigned_titles


def postprocess_included(
    boxes: npt.NDArray[np.float64],
    labels: npt.NDArray[np.int_],
    confs: npt.NDArray[np.float64],
    class_: str = "title",
    classes: List[str] = ["table", "chart", "title", "infographic"],
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]]:
    """
    Post process title predictions.
    - Remove titles that are included in other boxes

    Args:
        boxes (numpy.ndarray [N, 4]): Array of bounding boxes.
        labels (numpy.ndarray [N]): Array of labels.
        confs (numpy.ndarray [N]): Array of confidences.
        class_ (str, optional): Class to postprocess. Defaults to "title".
        classes (list, optional): Classes. Defaults to ["table", "chart", "title", "infographic"].

    Returns:
        boxes (numpy.ndarray): Array of bounding boxes.
        labels (numpy.ndarray): Array of labels.
        confs (numpy.ndarray): Array of confidences.
    """
    boxes_to_pp = boxes[labels == classes.index(class_)]
    confs_to_pp = confs[labels == classes.index(class_)]

    order = np.argsort(confs_to_pp)  # least to most confident for NMS
    boxes_to_pp, confs_to_pp = boxes_to_pp[order], confs_to_pp[order]

    if len(boxes_to_pp) == 0:
        return boxes, labels, confs

    # other_boxes = boxes[labels != classes.index("title")]

    inclusion_classes = ["table", "infographic", "chart"]
    if class_ in ["header_footer", "title"]:
        inclusion_classes.append("text")

    other_boxes = boxes[np.isin(labels, [classes.index(c) for c in inclusion_classes])]

    # Remove boxes included in other_boxes
    kept_boxes, kept_confs = [], []
    for i, b in enumerate(boxes_to_pp):
        # # Inclusion NMS
        # if i < len(titles) - 1:
        #     overlaps_titles = get_overlaps(t, titles[i + 1:], normalize="all")
        #     if overlaps_titles.max() > 0.9:
        #         continue

        # print(t)
        # print(other_boxes)
        if len(other_boxes) > 0:
            overlaps = get_overlaps(b, other_boxes, normalize="box_only")
            if overlaps.max() > 0.9:
                continue

        kept_boxes.append(b)
        kept_confs.append(confs_to_pp[i])

    # Aggregate
    kept_boxes = np.stack(kept_boxes) if len(kept_boxes) else np.empty((0, 4))
    kept_confs = np.stack(kept_confs) if len(kept_confs) else np.empty(0)

    boxes_pp = np.concatenate([boxes[labels != classes.index(class_)], kept_boxes])
    confs_pp = np.concatenate([confs[labels != classes.index(class_)], kept_confs])
    labels_pp = np.concatenate(
        [
            labels[labels != classes.index(class_)],
            np.ones(len(kept_boxes)) * classes.index(class_),
        ]
    )

    return boxes_pp, labels_pp, confs_pp