Spaces:
Running
on
Zero
Running
on
Zero
| # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import numpy as np | |
| import numpy.typing as npt | |
| from typing import Dict, List, Tuple, Literal | |
| def get_overlaps( | |
| boxes: npt.NDArray[np.float64], | |
| other_boxes: npt.NDArray[np.float64], | |
| normalize: Literal["box_only", "all"] = "box_only", | |
| ) -> npt.NDArray[np.float64]: | |
| """ | |
| Checks if a box overlaps with any other box. | |
| Boxes are expeceted in format (x0, y0, x1, y1) | |
| Args: | |
| boxes (np array [4] or [n x 4]): Boxes. | |
| other_boxes (np array [m x 4]): Other boxes. | |
| Returns: | |
| np array [n x m]: Overlaps. | |
| """ | |
| if boxes.ndim == 1: | |
| boxes = boxes[None, :] | |
| x0, y0, x1, y1 = ( | |
| boxes[:, 0][:, None], | |
| boxes[:, 1][:, None], | |
| boxes[:, 2][:, None], | |
| boxes[:, 3][:, None], | |
| ) | |
| areas = (y1 - y0) * (x1 - x0) | |
| x0_other, y0_other, x1_other, y1_other = ( | |
| other_boxes[:, 0][None, :], | |
| other_boxes[:, 1][None, :], | |
| other_boxes[:, 2][None, :], | |
| other_boxes[:, 3][None, :], | |
| ) | |
| areas_other = (y1_other - y0_other) * (x1_other - x0_other) | |
| # Intersection | |
| inter_y0 = np.maximum(y0, y0_other) | |
| inter_y1 = np.minimum(y1, y1_other) | |
| inter_x0 = np.maximum(x0, x0_other) | |
| inter_x1 = np.minimum(x1, x1_other) | |
| inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0) | |
| # Overlap | |
| if normalize == "box_only": # Only consider box included in other box | |
| overlaps = inter_area / areas | |
| elif ( | |
| normalize == "all" | |
| ): # Consider box included in other box and other box included in box | |
| overlaps = inter_area / np.minimum(areas, areas_other[:, None]) | |
| else: | |
| raise ValueError(f"Invalid normalization: {normalize}") | |
| return overlaps | |
| def get_distances( | |
| title_boxes: npt.NDArray[np.float64], other_boxes: npt.NDArray[np.float64] | |
| ) -> npt.NDArray[np.float64]: | |
| """ | |
| Computes the distances between title and table/chart boxes. | |
| Distance is computed as the sum of the vertical and horizontal distances. | |
| Horizontal distance uses min(boxes center dist, boxes left dist). | |
| Vertical distance uses min(top_title to bottom_other dists, bottom_title to top_other dists). | |
| Args: | |
| title_boxes (np array [n_titles x 4]): Title boxes. | |
| other_boxes (np array [n_other x 4]): Other boxes. | |
| Returns: | |
| np array [n_titles x n_other]: Distances between titles and other boxes. | |
| """ | |
| x0_title, xc_title, y0_title, y1_title = ( | |
| title_boxes[:, 0], | |
| (title_boxes[:, 0] + title_boxes[:, 2]) / 2, | |
| title_boxes[:, 1], | |
| title_boxes[:, 3], | |
| ) | |
| x0_other, xc_other, y0_other, y1_other = ( | |
| other_boxes[:, 0], | |
| (other_boxes[:, 0] + other_boxes[:, 2]) / 2, | |
| other_boxes[:, 1], | |
| other_boxes[:, 3], | |
| ) | |
| x_dists = np.min( | |
| [ | |
| np.abs( | |
| xc_title[:, None] - xc_other[None, :] | |
| ), # Title center to other center | |
| np.abs(x0_title[:, None] - x0_other[None, :]), # Title left to other left | |
| ], | |
| axis=0, | |
| ) | |
| y_dists = np.min( | |
| [ | |
| np.abs(y1_title[:, None] - y0_other[None, :]), # Title above other | |
| np.abs(y0_title[:, None] - y1_other[None, :]), # Title below other | |
| ], | |
| axis=0, | |
| ) | |
| dists = y_dists + x_dists / 2 | |
| return dists | |
| def find_titles( | |
| title_boxes: npt.NDArray[np.float64], | |
| table_boxes: npt.NDArray[np.float64], | |
| chart_boxes: npt.NDArray[np.float64], | |
| max_dist: float = 0.1, | |
| ) -> Dict[int, Tuple[str, int]]: | |
| """ | |
| Associates titles to tables and charts. | |
| Args: | |
| title_boxes (np array [n_titles x 4]): Title boxes. | |
| table_boxes (np array [n_tables x 4]): Table boxes. | |
| chart_boxes (np array [n_charts x 4]): Chart boxes. | |
| max_dist (float, optional): Maximum distance between title and table/chart. Defaults to 0.1. | |
| Returns: | |
| dict: Dictionary of assigned titles. | |
| - Keys are the indices of the titles, | |
| - Values are tuples of: | |
| - str: Whether the title is assigned to a "chart" or "table" | |
| - int: index of the assigned table/chart | |
| """ | |
| if not len(title_boxes) or not (len(table_boxes) or len(chart_boxes)): | |
| return {} | |
| # print(title_boxes.shape, table_boxes.shape, chart_boxes.shape) | |
| # Get distances | |
| chart_distances = np.ones((len(title_boxes), 0)) | |
| if len(chart_boxes): | |
| chart_distances = get_distances(title_boxes, chart_boxes) | |
| chart_overlaps = get_overlaps(title_boxes, chart_boxes, normalize="box_only") | |
| # print(chart_overlaps, "chart_overlaps", chart_overlaps.shape) | |
| # print(chart_distances, "chart_distances", chart_distances.shape) | |
| chart_distances = np.where(chart_overlaps > 0.25, 0, chart_distances) | |
| # print(chart_distances) | |
| table_distances = np.ones((len(title_boxes), 0)) | |
| if len(table_boxes): | |
| table_distances = get_distances(title_boxes, table_boxes) | |
| if len(chart_boxes): # Penalize table titles that are inside charts | |
| table_distances = np.where( | |
| chart_overlaps.max(1, keepdims=True) > 0.25, | |
| table_distances * 10, | |
| table_distances, | |
| ) | |
| # print(table_distances, "table_distances") | |
| # Assign to tables | |
| assigned_titles = {} | |
| for i, table in enumerate(table_boxes): | |
| best_match = np.argmin(table_distances[:, i]) | |
| if table_distances[best_match, i] < max_dist: | |
| assigned_titles[best_match] = ("table", i) | |
| table_distances[best_match] = np.inf | |
| chart_distances[best_match] = np.inf | |
| # Assign to charts | |
| for i, chart in enumerate(chart_boxes): | |
| best_match = np.argmin(chart_distances[:, i]) | |
| if chart_distances[best_match, i] < max_dist: | |
| assigned_titles[best_match] = ("chart", i) | |
| chart_distances[best_match] = np.inf | |
| return assigned_titles | |
| def postprocess_included( | |
| boxes: npt.NDArray[np.float64], | |
| labels: npt.NDArray[np.int_], | |
| confs: npt.NDArray[np.float64], | |
| class_: str = "title", | |
| classes: List[str] = ["table", "chart", "title", "infographic"], | |
| ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]]: | |
| """ | |
| Post process title predictions. | |
| - Remove titles that are included in other boxes | |
| Args: | |
| boxes (numpy.ndarray [N, 4]): Array of bounding boxes. | |
| labels (numpy.ndarray [N]): Array of labels. | |
| confs (numpy.ndarray [N]): Array of confidences. | |
| class_ (str, optional): Class to postprocess. Defaults to "title". | |
| classes (list, optional): Classes. Defaults to ["table", "chart", "title", "infographic"]. | |
| Returns: | |
| boxes (numpy.ndarray): Array of bounding boxes. | |
| labels (numpy.ndarray): Array of labels. | |
| confs (numpy.ndarray): Array of confidences. | |
| """ | |
| boxes_to_pp = boxes[labels == classes.index(class_)] | |
| confs_to_pp = confs[labels == classes.index(class_)] | |
| order = np.argsort(confs_to_pp) # least to most confident for NMS | |
| boxes_to_pp, confs_to_pp = boxes_to_pp[order], confs_to_pp[order] | |
| if len(boxes_to_pp) == 0: | |
| return boxes, labels, confs | |
| # other_boxes = boxes[labels != classes.index("title")] | |
| inclusion_classes = ["table", "infographic", "chart"] | |
| if class_ in ["header_footer", "title"]: | |
| inclusion_classes.append("text") | |
| other_boxes = boxes[np.isin(labels, [classes.index(c) for c in inclusion_classes])] | |
| # Remove boxes included in other_boxes | |
| kept_boxes, kept_confs = [], [] | |
| for i, b in enumerate(boxes_to_pp): | |
| # # Inclusion NMS | |
| # if i < len(titles) - 1: | |
| # overlaps_titles = get_overlaps(t, titles[i + 1:], normalize="all") | |
| # if overlaps_titles.max() > 0.9: | |
| # continue | |
| # print(t) | |
| # print(other_boxes) | |
| if len(other_boxes) > 0: | |
| overlaps = get_overlaps(b, other_boxes, normalize="box_only") | |
| if overlaps.max() > 0.9: | |
| continue | |
| kept_boxes.append(b) | |
| kept_confs.append(confs_to_pp[i]) | |
| # Aggregate | |
| kept_boxes = np.stack(kept_boxes) if len(kept_boxes) else np.empty((0, 4)) | |
| kept_confs = np.stack(kept_confs) if len(kept_confs) else np.empty(0) | |
| boxes_pp = np.concatenate([boxes[labels != classes.index(class_)], kept_boxes]) | |
| confs_pp = np.concatenate([confs[labels != classes.index(class_)], kept_confs]) | |
| labels_pp = np.concatenate( | |
| [ | |
| labels[labels != classes.index(class_)], | |
| np.ones(len(kept_boxes)) * classes.index(class_), | |
| ] | |
| ) | |
| return boxes_pp, labels_pp, confs_pp | |