Spaces:
Paused
Paused
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import os | |
| import re | |
| from collections import Counter | |
| from copy import deepcopy | |
| import numpy as np | |
| from huggingface_hub import snapshot_download | |
| from api.utils.file_utils import get_project_base_directory | |
| from deepdoc.vision import Recognizer | |
| class LayoutRecognizer(Recognizer): | |
| labels = [ | |
| "_background_", | |
| "Text", | |
| "Title", | |
| "Figure", | |
| "Figure caption", | |
| "Table", | |
| "Table caption", | |
| "Header", | |
| "Footer", | |
| "Reference", | |
| "Equation", | |
| ] | |
| def __init__(self, domain): | |
| try: | |
| model_dir = os.path.join( | |
| get_project_base_directory(), | |
| "rag/res/deepdoc") | |
| super().__init__(self.labels, domain, model_dir) | |
| except Exception as e: | |
| model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", | |
| local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), | |
| local_dir_use_symlinks=False) | |
| super().__init__(self.labels, domain, model_dir) | |
| self.garbage_layouts = ["footer", "header", "reference"] | |
| def __call__(self, image_list, ocr_res, scale_factor=3, | |
| thr=0.2, batch_size=16, drop=True): | |
| def __is_garbage(b): | |
| patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", | |
| r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", | |
| "(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}", | |
| "\\(cid *: *[0-9]+ *\\)" | |
| ] | |
| return any([re.search(p, b["text"]) for p in patt]) | |
| layouts = super().__call__(image_list, thr, batch_size) | |
| # save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7) | |
| assert len(image_list) == len(ocr_res) | |
| # Tag layout type | |
| boxes = [] | |
| assert len(image_list) == len(layouts) | |
| garbages = {} | |
| page_layout = [] | |
| for pn, lts in enumerate(layouts): | |
| bxs = ocr_res[pn] | |
| lts = [{"type": b["type"], | |
| "score": float(b["score"]), | |
| "x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor, | |
| "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor, | |
| "page_number": pn, | |
| } for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts] | |
| lts = self.sort_Y_firstly(lts, np.mean( | |
| [l["bottom"] - l["top"] for l in lts]) / 2) | |
| lts = self.layouts_cleanup(bxs, lts) | |
| page_layout.append(lts) | |
| # Tag layout type, layouts are ready | |
| def findLayout(ty): | |
| nonlocal bxs, lts, self | |
| lts_ = [lt for lt in lts if lt["type"] == ty] | |
| i = 0 | |
| while i < len(bxs): | |
| if bxs[i].get("layout_type"): | |
| i += 1 | |
| continue | |
| if __is_garbage(bxs[i]): | |
| bxs.pop(i) | |
| continue | |
| ii = self.find_overlapped_with_threashold(bxs[i], lts_, | |
| thr=0.4) | |
| if ii is None: # belong to nothing | |
| bxs[i]["layout_type"] = "" | |
| i += 1 | |
| continue | |
| lts_[ii]["visited"] = True | |
| keep_feats = [ | |
| lts_[ | |
| ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor, | |
| lts_[ | |
| ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor, | |
| ] | |
| if drop and lts_[ | |
| ii]["type"] in self.garbage_layouts and not any(keep_feats): | |
| if lts_[ii]["type"] not in garbages: | |
| garbages[lts_[ii]["type"]] = [] | |
| garbages[lts_[ii]["type"]].append(bxs[i]["text"]) | |
| bxs.pop(i) | |
| continue | |
| bxs[i]["layoutno"] = f"{ty}-{ii}" | |
| bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ | |
| ii]["type"] != "equation" else "figure" | |
| i += 1 | |
| for lt in ["footer", "header", "reference", "figure caption", | |
| "table caption", "title", "table", "text", "figure", "equation"]: | |
| findLayout(lt) | |
| # add box to figure layouts which has not text box | |
| for i, lt in enumerate( | |
| [lt for lt in lts if lt["type"] in ["figure", "equation"]]): | |
| if lt.get("visited"): | |
| continue | |
| lt = deepcopy(lt) | |
| del lt["type"] | |
| lt["text"] = "" | |
| lt["layout_type"] = "figure" | |
| lt["layoutno"] = f"figure-{i}" | |
| bxs.append(lt) | |
| boxes.extend(bxs) | |
| ocr_res = boxes | |
| garbag_set = set() | |
| for k in garbages.keys(): | |
| garbages[k] = Counter(garbages[k]) | |
| for g, c in garbages[k].items(): | |
| if c > 1: | |
| garbag_set.add(g) | |
| ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] | |
| return ocr_res, page_layout | |