Spaces:
Paused
Paused
| """ | |
| Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py | |
| But accepts preloaded model to avoid slowness in use and CUDA forking issues | |
| Loader that uses H2O DocTR OCR models to extract text from images | |
| """ | |
| from typing import List, Union, Any, Tuple, Optional | |
| import requests | |
| from langchain.docstore.document import Document | |
| from langchain.document_loaders import ImageCaptionLoader | |
| import numpy as np | |
| from utils import get_device, clear_torch_cache | |
| from doctr.utils.common_types import AbstractFile | |
| class H2OOCRLoader(ImageCaptionLoader): | |
| """Loader that extracts text from images""" | |
| def __init__(self, path_images: Union[str, List[str]] = None, layout_aware=False): | |
| super().__init__(path_images) | |
| self._ocr_model = None | |
| self.layout_aware = layout_aware | |
| def set_context(self): | |
| if get_device() == 'cuda': | |
| import torch | |
| n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 | |
| if n_gpus > 0: | |
| self.context_class = torch.device | |
| self.device = 'cuda' | |
| else: | |
| self.device = 'cpu' | |
| else: | |
| self.device = 'cpu' | |
| def load_model(self): | |
| try: | |
| from weasyprint import HTML # to avoid warning | |
| from doctr.models.zoo import ocr_predictor | |
| except ImportError: | |
| raise ValueError( | |
| "`doctr` package not found, please install with " | |
| "`pip install git+https://github.com/h2oai/doctr.git`." | |
| ) | |
| if self._ocr_model: | |
| self._ocr_model = self._ocr_model.to(self.device) | |
| return self | |
| self.set_context() | |
| self._ocr_model = ocr_predictor(det_arch="db_resnet50", reco_arch="crnn_efficientnetv2_mV2", | |
| pretrained=True).to(self.device) | |
| return self | |
| def unload_model(self): | |
| if hasattr(self._ocr_model.det_predictor.model, 'cpu'): | |
| self._ocr_model.det_predictor.model.cpu() | |
| clear_torch_cache() | |
| if hasattr(self._ocr_model.reco_predictor.model, 'cpu'): | |
| self._ocr_model.reco_predictor.model.cpu() | |
| clear_torch_cache() | |
| if hasattr(self._ocr_model, 'cpu'): | |
| self._ocr_model.cpu() | |
| clear_torch_cache() | |
| def set_document_paths(self, document_paths: Union[str, List[str]]): | |
| """ | |
| Load from a list of image files | |
| """ | |
| if isinstance(document_paths, str): | |
| self.document_paths = [document_paths] | |
| else: | |
| self.document_paths = document_paths | |
| def load(self, prompt=None) -> List[Document]: | |
| if self._ocr_model is None: | |
| self.load_model() | |
| results = [] | |
| for document_path in self.document_paths: | |
| caption, metadata = self._get_captions_and_metadata( | |
| model=self._ocr_model, document_path=document_path | |
| ) | |
| doc = Document(page_content=" \n".join(caption), metadata=metadata) | |
| results.append(doc) | |
| return results | |
| def _get_captions_and_metadata( | |
| self, model: Any, document_path: str) -> Tuple[str, dict]: | |
| """ | |
| Helper function for getting the captions and metadata of an image | |
| """ | |
| try: | |
| from doctr.io import DocumentFile | |
| except ImportError: | |
| raise ValueError( | |
| "`doctr` package not found, please install with " | |
| "`pip install git+https://github.com/h2oai/doctr.git`." | |
| ) | |
| try: | |
| if document_path.lower().endswith(".pdf"): | |
| # load at roughly 300 dpi | |
| images = read_pdf(document_path) | |
| else: | |
| images = DocumentFile.from_images(document_path) | |
| except Exception: | |
| raise ValueError(f"Could not get image data for {document_path}") | |
| document_words = [] | |
| for image in images: | |
| ocr_output = model([image]) | |
| page_words = [] | |
| page_boxes = [] | |
| for block_num, block in enumerate(ocr_output.pages[0].blocks): | |
| for line_num, line in enumerate(block.lines): | |
| for word_num, word in enumerate(line.words): | |
| if not (word.value or "").strip(): | |
| continue | |
| page_words.append(word.value) | |
| page_boxes.append( | |
| [word.geometry[0][0], word.geometry[0][1], word.geometry[1][0], word.geometry[1][1]]) | |
| if self.layout_aware: | |
| ids = boxes_sort(page_boxes) | |
| texts = [page_words[i] for i in ids] | |
| text_boxes = [page_boxes[i] for i in ids] | |
| page_words = space_layout(texts=texts, boxes=text_boxes) | |
| else: | |
| page_words = " ".join(page_words) | |
| document_words.append(page_words) | |
| metadata: dict = {"image_path": document_path} | |
| return document_words, metadata | |
| def boxes_sort(boxes): | |
| """ From left top to right bottom | |
| Params: | |
| boxes: [[x1, y1, x2, y2], [x1, y1, x2, y2], ...] | |
| """ | |
| sorted_id = sorted(range(len(boxes)), key=lambda x: (boxes[x][1])) | |
| # sorted_boxes = [boxes[id] for id in sorted_id] | |
| return sorted_id | |
| def is_same_line(box1, box2): | |
| """ | |
| Params: | |
| box1: [x1, y1, x2, y2] | |
| box2: [x1, y1, x2, y2] | |
| """ | |
| box1_midy = (box1[1] + box1[3]) / 2 | |
| box2_midy = (box2[1] + box2[3]) / 2 | |
| if box1_midy < box2[3] and box1_midy > box2[1] and box2_midy < box1[3] and box2_midy > box1[1]: | |
| return True | |
| else: | |
| return False | |
| def union_box(box1, box2): | |
| """ | |
| Params: | |
| box1: [x1, y1, x2, y2] | |
| box2: [x1, y1, x2, y2] | |
| """ | |
| x1 = min(box1[0], box2[0]) | |
| y1 = min(box1[1], box2[1]) | |
| x2 = max(box1[2], box2[2]) | |
| y2 = max(box1[3], box2[3]) | |
| return [x1, y1, x2, y2] | |
| def space_layout(texts, boxes): | |
| line_boxes = [] | |
| line_texts = [] | |
| max_line_char_num = 0 | |
| line_width = 0 | |
| # print(f"len_boxes: {len(boxes)}") | |
| boxes = np.array(boxes) | |
| texts = np.array(texts) | |
| while len(boxes) > 0: | |
| box = boxes[0] | |
| mid = (boxes[:, 3] + boxes[:, 1]) / 2 | |
| inline_boxes = np.logical_and(mid > box[1], mid < box[3]) | |
| sorted_xs = np.argsort(boxes[inline_boxes][:, 0], axis=0) | |
| line_box = boxes[inline_boxes][sorted_xs] | |
| line_text = texts[inline_boxes][sorted_xs] | |
| boxes = boxes[~inline_boxes] | |
| texts = texts[~inline_boxes] | |
| line_boxes.append(line_box.tolist()) | |
| line_texts.append(line_text.tolist()) | |
| if len(" ".join(line_texts[-1])) > max_line_char_num: | |
| max_line_char_num = len(" ".join(line_texts[-1])) | |
| line_width = np.array(line_boxes[-1]) | |
| line_width = line_width[:, 2].max() - line_width[:, 0].min() | |
| char_width = (line_width / max_line_char_num) | |
| if char_width == 0: | |
| char_width = 1 | |
| space_line_texts = [] | |
| for i, line_box in enumerate(line_boxes): | |
| space_line_text = "" | |
| for j, box in enumerate(line_box): | |
| left_char_num = int(box[0] / char_width) | |
| left_char_num = max((left_char_num - len(space_line_text)), 1) | |
| # verbose layout | |
| # space_line_text += " " * left_char_num | |
| # minified layout | |
| if left_char_num > 1: | |
| space_line_text += f" <{left_char_num}> " | |
| else: | |
| space_line_text += " " | |
| space_line_text += line_texts[i][j] | |
| space_line_texts.append(space_line_text + "\n") | |
| return "".join(space_line_texts) | |
| def read_pdf( | |
| file: AbstractFile, | |
| scale: float = 300 / 72, | |
| rgb_mode: bool = True, | |
| password: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> List[np.ndarray]: | |
| """Read a PDF file and convert it into an image in numpy format | |
| >>> from doctr.documents import read_pdf | |
| >>> doc = read_pdf("path/to/your/doc.pdf") | |
| Args: | |
| file: the path to the PDF file | |
| scale: rendering scale (1 corresponds to 72dpi) | |
| rgb_mode: if True, the output will be RGB, otherwise BGR | |
| password: a password to unlock the document, if encrypted | |
| kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` | |
| Returns: | |
| the list of pages decoded as numpy ndarray of shape H x W x C | |
| """ | |
| # Rasterise pages to numpy ndarrays with pypdfium2 | |
| import pypdfium2 as pdfium | |
| pdf = pdfium.PdfDocument(file, password=password, autoclose=True) | |
| return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf] | |