Spaces:
Running
Running
| from collections import defaultdict | |
| from concurrent.futures import ProcessPoolExecutor | |
| from typing import List, Optional | |
| from PIL import Image | |
| import numpy as np | |
| from surya.detection import batch_detection | |
| from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes | |
| from surya.schema import LayoutResult, LayoutBox, TextDetectionResult | |
| from surya.settings import settings | |
| def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]: | |
| logits = np.stack(heatmaps, axis=0) | |
| vertical_line_bboxes = detection_result.vertical_lines | |
| line_bboxes = detection_result.bboxes | |
| # Scale back to processor size | |
| for line in vertical_line_bboxes: | |
| line.rescale_bbox(orig_size, list(reversed(heatmaps[0].shape))) | |
| for line in line_bboxes: | |
| line.rescale(orig_size, list(reversed(heatmaps[0].shape))) | |
| for bbox in vertical_line_bboxes: | |
| # Give some width to the vertical lines | |
| vert_bbox = list(bbox.bbox) | |
| vert_bbox[2] = min(heatmaps[0].shape[0], vert_bbox[2] + vertical_line_width) | |
| logits[:, vert_bbox[1]:vert_bbox[3], vert_bbox[0]:vert_bbox[2]] = 0 # zero out where the column lines are | |
| logits[:, logits[0] >= .5] = 0 # zero out where blanks are | |
| # Zero out where other segments are | |
| for i in range(logits.shape[0]): | |
| logits[i, segment_assignment != i] = 0 | |
| detected_boxes = [] | |
| for heatmap_idx in range(1, len(id2label)): # Skip the blank class | |
| heatmap = logits[heatmap_idx] | |
| if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD: | |
| continue | |
| bboxes = get_detected_boxes(heatmap) | |
| bboxes = [bbox for bbox in bboxes if bbox.area > 25] | |
| for bb in bboxes: | |
| bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1]) | |
| for bbox in bboxes: | |
| detected_boxes.append(LayoutBox(polygon=bbox.polygon, label=id2label[heatmap_idx], confidence=1)) | |
| detected_boxes = sorted(detected_boxes, key=lambda x: x.confidence, reverse=True) | |
| # Expand bbox to cover intersecting lines | |
| box_lines = defaultdict(list) | |
| used_lines = set() | |
| # We try 2 rounds of identifying the correct lines to snap to | |
| # First round is majority intersection, second lowers the threshold | |
| for thresh in [.5, .4]: | |
| for bbox_idx, bbox in enumerate(detected_boxes): | |
| for line_idx, line_bbox in enumerate(line_bboxes): | |
| if line_bbox.intersection_pct(bbox) > thresh and line_idx not in used_lines: | |
| box_lines[bbox_idx].append(line_bbox.bbox) | |
| used_lines.add(line_idx) | |
| new_boxes = [] | |
| for bbox_idx, bbox in enumerate(detected_boxes): | |
| if bbox.label == "Picture" and bbox.area < 200: # Remove very small figures | |
| continue | |
| # Skip if we didn't find any lines to snap to, except for Pictures and Formulas | |
| if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]: | |
| continue | |
| covered_lines = box_lines[bbox_idx] | |
| # Snap non-picture layout boxes to correct text boundaries | |
| if len(covered_lines) > 0 and bbox.label not in ["Picture"]: | |
| min_x = min([line[0] for line in covered_lines]) | |
| min_y = min([line[1] for line in covered_lines]) | |
| max_x = max([line[2] for line in covered_lines]) | |
| max_y = max([line[3] for line in covered_lines]) | |
| # Tables and formulas can contain text, but text isn't the whole area | |
| if bbox.label in ["Table", "Formula"]: | |
| min_x_box = min([b[0] for b in bbox.polygon]) | |
| min_y_box = min([b[1] for b in bbox.polygon]) | |
| max_x_box = max([b[0] for b in bbox.polygon]) | |
| max_y_box = max([b[1] for b in bbox.polygon]) | |
| min_x = min(min_x, min_x_box) | |
| min_y = min(min_y, min_y_box) | |
| max_x = max(max_x, max_x_box) | |
| max_y = max(max_y, max_y_box) | |
| bbox.polygon = [ | |
| [min_x, min_y], | |
| [max_x, min_y], | |
| [max_x, max_y], | |
| [min_x, max_y] | |
| ] | |
| if bbox_idx in box_lines and bbox.label in ["Picture"]: | |
| bbox.label = "Figure" | |
| new_boxes.append(bbox) | |
| # Merge tables together (sometimes one column is detected as a separate table) | |
| mergeable_types = ["Table", "Picture", "Figure"] | |
| for ftype in mergeable_types: | |
| to_remove = set() | |
| for bbox_idx, bbox in enumerate(new_boxes): | |
| if bbox.label != ftype or bbox_idx in to_remove: | |
| continue | |
| for bbox_idx2, bbox2 in enumerate(new_boxes): | |
| if bbox2.label != ftype or bbox_idx2 in to_remove or bbox_idx == bbox_idx2: | |
| continue | |
| if bbox.intersection_pct(bbox2, x_margin=.25) > .1: | |
| bbox.merge(bbox2) | |
| to_remove.add(bbox_idx2) | |
| new_boxes = [bbox for idx, bbox in enumerate(new_boxes) if idx not in to_remove] | |
| # Ensure we account for all text lines in the layout | |
| unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines] | |
| for bbox in unused_lines: | |
| new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5)) | |
| for bbox in new_boxes: | |
| bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size) | |
| detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16] | |
| # Remove bboxes contained inside others, unless they're captions | |
| contained_bbox = [] | |
| for i, bbox in enumerate(detected_boxes): | |
| for j, bbox2 in enumerate(detected_boxes): | |
| if i == j: | |
| continue | |
| if bbox2.intersection_pct(bbox) >= .95 and bbox2.label not in ["Caption"]: | |
| contained_bbox.append(j) | |
| detected_boxes = [bbox for idx, bbox in enumerate(detected_boxes) if idx not in contained_bbox] | |
| return detected_boxes | |
| def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]: | |
| bboxes = [] | |
| for i in range(1, len(id2label)): # Skip the blank class | |
| heatmap = heatmaps[i] | |
| assert heatmap.shape == segment_assignment.shape | |
| heatmap[segment_assignment != i] = 0 # zero out where another segment is | |
| # Skip processing empty labels | |
| if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD: | |
| continue | |
| bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size) | |
| for bb in bbox: | |
| bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i])) | |
| bboxes = keep_largest_boxes(bboxes) | |
| return bboxes | |
| def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult: | |
| logits = np.stack(heatmaps, axis=0) | |
| segment_assignment = logits.argmax(axis=0) | |
| if detection_results is not None: | |
| bboxes = get_regions_from_detection_result(detection_results, heatmaps, orig_size, id2label, | |
| segment_assignment) | |
| else: | |
| bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment) | |
| segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8)) | |
| result = LayoutResult( | |
| bboxes=bboxes, | |
| segmentation_map=segmentation_img, | |
| heatmaps=heatmaps, | |
| image_bbox=[0, 0, orig_size[0], orig_size[1]] | |
| ) | |
| return result | |
| def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]: | |
| layout_generator = batch_detection(images, model, processor, batch_size=batch_size) | |
| id2label = model.config.id2label | |
| results = [] | |
| max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) | |
| parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH | |
| if parallelize: | |
| with ProcessPoolExecutor(max_workers=max_workers) as executor: | |
| img_idx = 0 | |
| for preds, orig_sizes in layout_generator: | |
| futures = [] | |
| for pred, orig_size in zip(preds, orig_sizes): | |
| future = executor.submit( | |
| parallel_get_regions, | |
| pred, | |
| orig_size, | |
| id2label, | |
| detection_results[img_idx] if detection_results else None | |
| ) | |
| futures.append(future) | |
| img_idx += 1 | |
| for future in futures: | |
| results.append(future.result()) | |
| else: | |
| img_idx = 0 | |
| for preds, orig_sizes in layout_generator: | |
| for pred, orig_size in zip(preds, orig_sizes): | |
| results.append(parallel_get_regions( | |
| pred, | |
| orig_size, | |
| id2label, | |
| detection_results[img_idx] if detection_results else None | |
| )) | |
| img_idx += 1 | |
| return results |