Spaces:
Build error
Build error
| import os | |
| import re | |
| import gc | |
| import sys | |
| import time | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from torchvision import transforms | |
| from torch.utils.data import DataLoader | |
| sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..')) | |
| from pdf_extract_kit.utils.data_preprocess import load_pdf | |
| from pdf_extract_kit.tasks.ocr.task import OCRTask | |
| from pdf_extract_kit.dataset.dataset import MathDataset | |
| from pdf_extract_kit.registry.registry import TASK_REGISTRY | |
| from pdf_extract_kit.utils.merge_blocks_and_spans import ( | |
| fill_spans_in_blocks, | |
| fix_block_spans, | |
| merge_para_with_text | |
| ) | |
| def latex_rm_whitespace(s: str): | |
| """Remove unnecessary whitespace from LaTeX code. | |
| """ | |
| text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})' | |
| letter = '[a-zA-Z]' | |
| noletter = '[\W_^\d]' | |
| names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)] | |
| s = re.sub(text_reg, lambda match: str(names.pop(0)), s) | |
| news = s | |
| while True: | |
| s = news | |
| news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s) | |
| news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news) | |
| news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news) | |
| if news == s: | |
| break | |
| return s | |
| def crop_img(input_res, input_pil_img, padding_x=0, padding_y=0): | |
| crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1]) | |
| crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5]) | |
| # Create a white background with an additional width and height of 50 | |
| crop_new_width = crop_xmax - crop_xmin + padding_x * 2 | |
| crop_new_height = crop_ymax - crop_ymin + padding_y * 2 | |
| return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white') | |
| # Crop image | |
| crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax) | |
| cropped_img = input_pil_img.crop(crop_box) | |
| return_image.paste(cropped_img, (padding_x, padding_y)) | |
| return_list = [padding_x, padding_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height] | |
| return return_image, return_list | |
| class PDF2MARKDOWN(OCRTask): | |
| def __init__(self, layout_model, mfd_model, mfr_model, ocr_model): | |
| self.layout_model = layout_model | |
| self.mfd_model = mfd_model | |
| self.mfr_model = mfr_model | |
| self.ocr_model = ocr_model | |
| if self.mfr_model is not None: | |
| assert self.mfd_model is not None, "formula recognition based on formula detection, mfd_model can not be None." | |
| self.mfr_transform = transforms.Compose([self.mfr_model.vis_processor, ]) | |
| self.color_palette = { | |
| 'title': (255, 64, 255), | |
| 'plain text': (255, 255, 0), | |
| 'abandon': (0, 255, 255), | |
| 'figure': (255, 215, 135), | |
| 'figure_caption': (215, 0, 95), | |
| 'table': (100, 0, 48), | |
| 'table_caption': (0, 175, 0), | |
| 'table_footnote': (95, 0, 95), | |
| 'isolate_formula': (175, 95, 0), | |
| 'formula_caption': (95, 95, 0), | |
| 'inline': (0, 0, 255), | |
| 'isolated': (0, 255, 0), | |
| 'text': (255, 0, 0) | |
| } | |
| def convert_format(self, yolo_res, id_to_names, ): | |
| """ | |
| convert yolo format to pdf-extract format. | |
| """ | |
| res_list = [] | |
| for xyxy, conf, cla in zip(yolo_res.boxes.xyxy.cpu(), yolo_res.boxes.conf.cpu(), yolo_res.boxes.cls.cpu()): | |
| xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] | |
| new_item = { | |
| 'category_type': id_to_names[int(cla.item())], | |
| 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], | |
| 'score': round(float(conf.item()), 2), | |
| } | |
| res_list.append(new_item) | |
| return res_list | |
| def process_single_pdf(self, image_list): | |
| """predict on one image, reture text detection and recognition results. | |
| Args: | |
| image_list: List[PIL.Image.Image] | |
| Returns: | |
| List[dict]: list of PDF extract results | |
| Return example: | |
| [ | |
| { | |
| "layout_dets": [ | |
| { | |
| "category_type": "text", | |
| "poly": [ | |
| 380.6792698635707, | |
| 159.85058512958923, | |
| 765.1419999999998, | |
| 159.85058512958923, | |
| 765.1419999999998, | |
| 192.51073013642917, | |
| 380.6792698635707, | |
| 192.51073013642917 | |
| ], | |
| "text": "this is an example text", | |
| "score": 0.97 | |
| }, | |
| ... | |
| ], | |
| "page_info": { | |
| "page_no": 0, | |
| "height": 2339, | |
| "width": 1654, | |
| } | |
| }, | |
| ... | |
| ] | |
| """ | |
| pdf_extract_res = [] | |
| mf_image_list = [] | |
| latex_filling_list = [] | |
| for idx, image in enumerate(image_list): | |
| img_W, img_H = image.size | |
| if self.layout_model is not None: | |
| ori_layout_res = self.layout_model.predict([image], "")[0] | |
| layout_res = self.convert_format(ori_layout_res, self.layout_model.id_to_names) | |
| else: | |
| layout_res = [] | |
| single_page_res = {'layout_dets': layout_res} | |
| single_page_res['page_info'] = dict( | |
| page_no = idx, | |
| height = img_H, | |
| width = img_W | |
| ) | |
| if self.mfd_model is not None: | |
| mfd_res = self.mfd_model.predict([image], "")[0] | |
| for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): | |
| xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] | |
| new_item = { | |
| 'category_type': self.mfd_model.id_to_names[int(cla.item())], | |
| 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], | |
| 'score': round(float(conf.item()), 2), | |
| 'latex': '', | |
| } | |
| single_page_res['layout_dets'].append(new_item) | |
| if self.mfr_model is not None: | |
| latex_filling_list.append(new_item) | |
| bbox_img = image.crop((xmin, ymin, xmax, ymax)) | |
| mf_image_list.append(bbox_img) | |
| pdf_extract_res.append(single_page_res) | |
| del mfd_res | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Formula recognition, collect all formula images in whole pdf file, then batch infer them. | |
| if self.mfr_model is not None: | |
| a = time.time() | |
| dataset = MathDataset(mf_image_list, transform=self.mfr_transform) | |
| dataloader = DataLoader(dataset, batch_size=self.mfr_model.batch_size, num_workers=0) | |
| mfr_res = [] | |
| for imgs in dataloader: | |
| imgs = imgs.to(self.mfr_model.device) | |
| output = self.mfr_model.model.generate({'image': imgs}) | |
| mfr_res.extend(output['pred_str']) | |
| for res, latex in zip(latex_filling_list, mfr_res): | |
| res['latex'] = latex_rm_whitespace(latex) | |
| b = time.time() | |
| print("formula nums:", len(mf_image_list), "mfr time:", round(b-a, 2)) | |
| # ocr_res = self.ocr_model.predict(image) | |
| # ocr and table recognition | |
| for idx, image in enumerate(image_list): | |
| layout_res = pdf_extract_res[idx]['layout_dets'] | |
| pil_img = image.copy() | |
| ocr_res_list = [] | |
| table_res_list = [] | |
| single_page_mfdetrec_res = [] | |
| for res in layout_res: | |
| if res['category_type'] in self.mfd_model.id_to_names.values(): | |
| single_page_mfdetrec_res.append({ | |
| "bbox": [int(res['poly'][0]), int(res['poly'][1]), | |
| int(res['poly'][4]), int(res['poly'][5])], | |
| }) | |
| elif res['category_type'] in [self.layout_model.id_to_names[cid] for cid in [0, 1, 2, 4, 6, 7]]: | |
| ocr_res_list.append(res) | |
| elif res['category_type'] in [self.layout_model.id_to_names[5]]: | |
| table_res_list.append(res) | |
| ocr_start = time.time() | |
| # Process each area that requires OCR processing | |
| for res in ocr_res_list: | |
| new_image, useful_list = crop_img(res, pil_img, padding_x=25, padding_y=25) | |
| paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list | |
| # Adjust the coordinates of the formula area | |
| adjusted_mfdetrec_res = [] | |
| for mf_res in single_page_mfdetrec_res: | |
| mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"] | |
| # Adjust the coordinates of the formula area to the coordinates relative to the cropping area | |
| x0 = mf_xmin - xmin + paste_x | |
| y0 = mf_ymin - ymin + paste_y | |
| x1 = mf_xmax - xmin + paste_x | |
| y1 = mf_ymax - ymin + paste_y | |
| # Filter formula blocks outside the graph | |
| if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]): | |
| continue | |
| else: | |
| adjusted_mfdetrec_res.append({ | |
| "bbox": [x0, y0, x1, y1], | |
| }) | |
| # OCR recognition | |
| ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] | |
| # Integration results | |
| if ocr_res: | |
| for box_ocr_res in ocr_res: | |
| p1, p2, p3, p4 = box_ocr_res[0] | |
| text, score = box_ocr_res[1] | |
| # Convert the coordinates back to the original coordinate system | |
| p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin] | |
| p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin] | |
| p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] | |
| p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin] | |
| layout_res.append({ | |
| 'category_type': 'text', | |
| 'poly': p1 + p2 + p3 + p4, | |
| 'score': round(score, 2), | |
| 'text': text, | |
| }) | |
| ocr_cost = round(time.time() - ocr_start, 2) | |
| print(f"ocr cost: {ocr_cost}") | |
| return pdf_extract_res | |
| def order_blocks(self, blocks): | |
| def calculate_oder(poly): | |
| xmin, ymin, _, _, xmax, ymax, _, _ = poly | |
| return ymin*3000 + xmin | |
| return sorted(blocks, key=lambda item: calculate_oder(item['poly'])) | |
| def convert2md(self, extract_res): | |
| blocks = [] | |
| spans = [] | |
| for item in extract_res['layout_dets']: | |
| if item['category_type'] in ['inline', 'text', 'isolated']: | |
| text_key = 'text' if item['category_type'] == 'text' else 'latex' | |
| xmin, ymin, _, _, xmax, ymax, _, _ = item['poly'] | |
| spans.append( | |
| { | |
| "type": item['category_type'], | |
| "bbox": [xmin, ymin, xmax, ymax], | |
| "content": item[text_key] | |
| } | |
| ) | |
| if item['category_type'] == "isolated": | |
| item['category_type'] = "isolate_formula" | |
| blocks.append(item) | |
| else: | |
| blocks.append(item) | |
| blocks_types = ["title", "plain text", "figure_caption", "table_caption", "table_footnote", "isolate_formula", "formula_caption"] | |
| need_fix_bbox = [] | |
| final_block = [] | |
| for block in blocks: | |
| block_type = block["category_type"] | |
| if block_type in blocks_types: | |
| need_fix_bbox.append(block) | |
| else: | |
| final_block.append(block) | |
| block_with_spans, spans = fill_spans_in_blocks(need_fix_bbox, spans, 0.6) | |
| fix_blocks = fix_block_spans(block_with_spans) | |
| for para_block in fix_blocks: | |
| result = merge_para_with_text(para_block) | |
| if para_block['type'] == "isolate_formula": | |
| para_block['saved_info']['latex'] = result | |
| else: | |
| para_block['saved_info']['text'] = result | |
| final_block.append(para_block['saved_info']) | |
| final_block = self.order_blocks(final_block) | |
| md_text = "" | |
| for block in final_block: | |
| if block['category_type'] == "title": | |
| md_text += "\n# "+block['text'] +"\n" | |
| elif block['category_type'] in ["isolate_formula"]: | |
| md_text += "\n"+block['latex']+"\n" | |
| elif block['category_type'] in ["plain text", "figure_caption", "table_caption"]: | |
| md_text += " "+block['text']+" " | |
| elif block['category_type'] in ["figure", "table"]: | |
| continue | |
| else: | |
| continue | |
| return md_text | |
| def process(self, input_path, save_dir=None, visualize=False, merge2markdown=False): | |
| file_list = self.prepare_input_files(input_path) | |
| res_list = [] | |
| for fpath in file_list: | |
| basename = os.path.basename(fpath)[:-4] | |
| if fpath.endswith(".pdf") or fpath.endswith(".PDF"): | |
| images = load_pdf(fpath) | |
| else: | |
| images = [Image.open(fpath)] | |
| pdf_extract_res = self.process_single_pdf(images) | |
| res_list.append(pdf_extract_res) | |
| if save_dir: | |
| os.makedirs(save_dir, exist_ok=True) | |
| self.save_json_result(pdf_extract_res, os.path.join(save_dir, f"{basename}.json")) | |
| if merge2markdown: | |
| md_content = [] | |
| for extract_res in pdf_extract_res: | |
| md_text = self.convert2md(extract_res) | |
| md_content.append(md_text) | |
| with open(os.path.join(save_dir, f"{basename}.md"), "w") as f: | |
| f.write("\n\n".join(md_content)) | |
| if visualize: | |
| for image, page_res in zip(images, pdf_extract_res): | |
| self.visualize_image(image, page_res['layout_dets'], cate2color=self.color_palette) | |
| if fpath.endswith(".pdf") or fpath.endswith(".PDF"): | |
| first_page = images.pop(0) | |
| first_page.save(os.path.join(save_dir, f'{basename}.pdf'), 'PDF', resolution=100, save_all=True, append_images=images) | |
| else: | |
| images[0].save(os.path.join(save_dir, f"{basename}.png")) | |
| return res_list | |