Spaces:
Build error
Build error
| # Copyright (c) 2022, Lawrence Livermore National Security, LLC. | |
| # All rights reserved. | |
| # See the top-level LICENSE and NOTICE files for details. | |
| # LLNL-CODE-838964 | |
| # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception | |
| import cv2 | |
| from pathlib import Path | |
| import torch | |
| import json | |
| from detectron2.config import CfgNode as CN | |
| from detectron2.config import get_cfg | |
| from detectron2.utils.visualizer import ColorMode, Visualizer | |
| from detectron2.data import MetadataCatalog | |
| from detectron2.engine import DefaultPredictor | |
| from pdf2image import convert_from_path | |
| from PIL import Image | |
| import numpy as np | |
| from dit_object_detection.ditod import add_vit_config | |
| import base_utils | |
| from pdfminer.layout import LTTextLineHorizontal, LTTextBoxHorizontal, LTAnno, LTChar | |
| from tokenizers.pre_tokenizers import Whitespace | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| dit_path = Path('DiT_Extractor/dit_object_detection') | |
| cfg = get_cfg() | |
| add_vit_config(cfg) | |
| cfg.merge_from_file(dit_path / "publaynet_configs/cascade/cascade_dit_base.yaml") | |
| cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth" | |
| cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| predictor = DefaultPredictor(cfg) | |
| thing_classes = ["text","title","list","table","figure"] | |
| thing_map = dict(map(reversed, enumerate(thing_classes))) | |
| md = MetadataCatalog.get(cfg.DATASETS.TEST[0]) | |
| md.set(thing_classes=thing_classes) | |
| def get_pdf_image(pdf_file, page): | |
| image = convert_from_path(pdf_file, dpi=200, first_page=page, last_page=page) | |
| return image | |
| def get_characters(subelement): | |
| all_chars = [] | |
| if isinstance(subelement, LTTextLineHorizontal): | |
| for char in subelement: | |
| if isinstance(char, LTChar): | |
| all_chars.append((char.bbox, char.get_text())) | |
| if isinstance(char, LTAnno): | |
| # No bbox, just a space, so make a thin slice after previous text | |
| bbox = all_chars[-1][0] | |
| bbox = (bbox[2],bbox[1],bbox[2],bbox[3]) | |
| all_chars.append((bbox, char.get_text())) | |
| return all_chars | |
| def get_dit_preds(pdf, score_threshold=0.5): | |
| page_count = base_utils.get_pdf_page_count(pdf) | |
| # Input is numpy array of PIL image | |
| page_sizes = base_utils.get_page_sizes(pdf) | |
| sections = {} | |
| viz_images = [] | |
| page_words = base_utils.get_pdf_words(pdf) | |
| for page in range(1, page_count+1): #range(2, page_count + 1): | |
| image = get_pdf_image(pdf, page) | |
| image = np.array(image[0]) | |
| # Get prediction | |
| output = predictor(image)["instances"] | |
| output = output.to('cpu') | |
| # Visualize predictions | |
| v = Visualizer(image[:, :, ::-1], | |
| md, | |
| scale=1.0, | |
| instance_mode=ColorMode.SEGMENTATION) | |
| result = v.draw_instance_predictions(output) | |
| result_image = result.get_image()[:, :, ::-1] | |
| viz_img = Image.fromarray(result_image) | |
| viz_images.append(viz_img) | |
| words = page_words[page-1] | |
| # Convert from image_size to page size | |
| pdf_dimensions = page_sizes[page-1][2:] | |
| # Swap height/width | |
| pdf_image_size = (output.image_size[1], output.image_size[0]) | |
| scale = np.array(pdf_dimensions) / np.array(pdf_image_size) | |
| scale_box = np.hstack((scale,scale)) | |
| # Words are in page coordinates | |
| id = 0 | |
| sections[page-1] = [] | |
| draw = image.copy() | |
| for box_t, clazz, score in zip(output.get('pred_boxes'), output.get('pred_classes'), output.get('scores')): | |
| if score < score_threshold: | |
| continue | |
| box = box_t.numpy() | |
| # Flip along Y axis | |
| box[1] = pdf_image_size[1] - box[1] | |
| box[3] = pdf_image_size[1] - box[3] | |
| # Scale | |
| scaled = box * scale_box | |
| # This is the correct order | |
| scaled = [scaled[0], scaled[3], scaled[2], scaled[1]] | |
| if clazz != thing_map['text']: | |
| continue | |
| start = box[0:2].tolist() | |
| end = box[2:4].tolist() | |
| start = [int(x) for x in start] | |
| end = [int(x) for x in end] | |
| out = {} | |
| for word in words.copy(): | |
| if base_utils.partial_overlaps(word[0:4], scaled): | |
| if out == {}: | |
| id += 1 | |
| out['coord'] = word[0:4] | |
| out['subelements'] = [] | |
| out['type'] = 'content_block' | |
| out['id']= id | |
| out['text'] = '' | |
| out['coord'] = base_utils.union(out['coord'], word[0:4]) | |
| out['text'] = out['text'] + word[4].get_text() | |
| characters = get_characters(word[4]) | |
| out['subelements'].append(characters) | |
| words.remove(word) | |
| if len(out) != 0: | |
| sections[page-1].append(out) | |
| # Write final annotation | |
| out_name = Path(pdf).name[:-4] + ".json" | |
| with open(out_name, 'w', encoding='utf8') as json_out: | |
| json.dump(sections, json_out, ensure_ascii=False, indent=4) | |
| return viz_images | |