| from __future__ import annotations |
| from ultralytics import YOLO |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
| from rfdetr import RFDETRNano, RFDETRBase, RFDETRMedium, RFDETRLarge |
|
|
| from commonforms.utils import BoundingBox, Page, Widget |
| from commonforms.form_creator import PyPdfFormCreator |
| from commonforms.exceptions import EncryptedPdfError |
|
|
| import formalpdf |
| import pypdfium2 |
| import logging |
| import PIL |
|
|
|
|
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| |
| |
| |
| models = { |
| ("FFDNET-S", True): ("jbarrow/FFDNet-S-cpu", "FFDNet-S.onnx"), |
| ("FFDNET-S", False): ("jbarrow/FFDNet-S", "FFDNet-S.pt"), |
| ("FFDNET-L", True): ("jbarrow/FFDNet-L-cpu", "FFDNet-L.onnx"), |
| ("FFDNET-L", False): ("jbarrow/FFDNet-L", "FFDNet-L.pt"), |
| ("FFDETR", False): ("jbarrow/FFDetr", "FFDetr.pth"), |
| } |
|
|
|
|
| def batch(lst: list, n: int = 8): |
| l = len(lst) |
| for ndx in range(0, l, n): |
| yield lst[ndx : min(ndx + n, l)] |
|
|
|
|
| class FFDetrDetector: |
| def __init__(self, model_or_path: str, device: int | str = "cpu") -> None: |
| self.device = device |
| self.model = RFDETRMedium(pretrain_weights=self.get_model_path(model_or_path)) |
|
|
| self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"} |
|
|
| def get_model_path(self, model_or_path: str) -> str: |
| model_upper = model_or_path.upper() |
| if model_upper in ["FFDETR"]: |
| |
| repo_id, filename = models[(model_upper, False)] |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename) |
| else: |
| model_path = model_or_path |
|
|
| return model_path |
|
|
| def resize( |
| self, |
| image: PIL.Image.Image, |
| size: tuple[int, int] | int, |
| ) -> PIL.Image.Image: |
| if isinstance(size, int): |
| size = (size, size) |
|
|
| return image.resize(size, PIL.Image.Resampling.LANCZOS) |
|
|
| def extract_widgets( |
| self, |
| pages: list[Page], |
| confidence: float = 0.4, |
| image_size: int = 1120, |
| batch_size: int = 3, |
| ) -> dict[int, list[Widget]]: |
| image_size = 1024 |
| results = [] |
| for b in batch([p.image for p in pages], n=batch_size): |
| predictions = self.model.predict(b, threshold=confidence) |
| |
| |
| |
| if isinstance(predictions, list): |
| results.extend(predictions) |
| else: |
| results.append(predictions) |
|
|
| widgets = {} |
|
|
| for page_ix, detections in enumerate(results): |
| logging.info(f" Page {page_ix}: {len(detections)} fields detected") |
| detections = detections.with_nms(threshold=0.1, class_agnostic=True) |
| logging.info(f"\t\t{len(detections)} after nms") |
| widgets[page_ix] = [] |
|
|
| for class_id, box in zip(detections.class_id, detections.xyxy): |
| x0, x1 = box[[0, 2]] / pages[page_ix].image.width |
| y0, y1 = box[[1, 3]] / pages[page_ix].image.height |
|
|
| widget_type = self.id_to_cls[class_id] |
|
|
| widgets[page_ix].append( |
| Widget( |
| widget_type=widget_type, |
| bounding_box=BoundingBox(x0=x0, y0=y0, x1=x1, y1=y1), |
| page=page_ix, |
| ) |
| ) |
|
|
| widgets[page_ix] = sort_widgets(widgets[page_ix]) |
|
|
| return widgets |
|
|
|
|
| class FFDNetDetector: |
| def __init__( |
| self, model_or_path: str, device: int | str = "cpu", fast: bool = False |
| ) -> None: |
| self.device = device |
| self.fast = fast |
|
|
| model_path = self.get_model_path(model_or_path, device, fast) |
| self.model = YOLO(model_path, task="detect") |
|
|
| self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"} |
|
|
| def get_model_path( |
| self, model_or_path: str, device: int | str = "cpu", fast: bool = False |
| ) -> str: |
| """ |
| Construct the path to the model weights based on: |
| (a) the requested model (in the package or external path) |
| (b) --fast (if enabled, use ONNX, otherwise use pt) |
| """ |
| model_upper = model_or_path.upper() |
| if model_upper in ["FFDNET-S", "FFDNET-L"]: |
| |
| repo_id, filename = models[(model_upper, fast)] |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename) |
| else: |
| model_path = model_or_path |
|
|
| return model_path |
|
|
| def extract_widgets( |
| self, pages: list[Page], confidence: float = 0.3, image_size: int = 1600 |
| ) -> dict[int, list[Widget]]: |
| if self.fast: |
| |
| results = [ |
| self.model.predict( |
| p.image, iou=1, conf=confidence, augment=False, imgsz=1216 |
| ) |
| for p in pages |
| ] |
| else: |
| results = self.model.predict( |
| [p.image for p in pages], |
| iou=0.1, |
| conf=confidence, |
| augment=True, |
| imgsz=image_size, |
| device=self.device, |
| ) |
|
|
| widgets = {} |
| for page_ix, result in enumerate(results): |
| if isinstance(result, list): |
| result = result[0] |
| |
| if result is None or result.boxes is None: |
| continue |
|
|
| widgets[page_ix] = [] |
| for box in result.boxes.cpu().numpy(): |
| x, y, w, h = box.xywhn[0] |
| cls_id = int(box.cls.item()) |
| widget_type = self.id_to_cls[cls_id] |
|
|
| widgets[page_ix].append( |
| Widget( |
| widget_type=widget_type, |
| bounding_box=BoundingBox.from_yolo(cx=x, cy=y, w=w, h=h), |
| page=page_ix, |
| ) |
| ) |
|
|
| |
| |
| |
| widgets[page_ix] = sort_widgets(widgets[page_ix]) |
|
|
| return widgets |
|
|
|
|
| def sort_widgets(widgets: list[Widget]) -> list[Widget]: |
| """ |
| Sort widgets in approximate reading order (left-to-right/top-to-bottom) |
| which makes the LLMs less likely to mess up. |
| """ |
| |
| sorted_widgets = sorted( |
| widgets, |
| key=lambda w: ( |
| round( |
| w.bounding_box.y0, 3 |
| ), |
| w.bounding_box.x0, |
| ), |
| ) |
|
|
| |
| y_threshold = 0.01 |
| lines = [] |
| current_line = [] |
|
|
| for widget in sorted_widgets: |
| if ( |
| not current_line |
| or abs(widget.bounding_box.y0 - current_line[0].bounding_box.y0) |
| < y_threshold |
| ): |
| current_line.append(widget) |
| else: |
| |
| current_line.sort(key=lambda w: w.bounding_box.x0) |
| lines.append(current_line) |
| current_line = [widget] |
|
|
| if current_line: |
| current_line.sort(key=lambda w: w.bounding_box.x0) |
| lines.append(current_line) |
|
|
| |
| return [widget for line in lines for widget in line] |
|
|
|
|
| def render_pdf(pdf_path: str) -> list[Page]: |
| pages = [] |
| doc = formalpdf.open(pdf_path) |
| try: |
| for page in doc: |
| image = page.render(dpi=144) |
| pages.append(Page(image=image, width=image.width, height=image.height)) |
| return pages |
| finally: |
| doc.document.close() |
|
|
|
|
| def prepare_form( |
| input_path: str | Path, |
| output_path: str | Path, |
| *, |
| model_or_path: str = "FFDetr", |
| keep_existing_fields: bool = False, |
| use_signature_fields: bool = False, |
| device: int | str = "cpu", |
| image_size: int = 1024, |
| confidence: float = 0.4, |
| fast: bool = False, |
| multiline: bool = False, |
| batch_size: int = 4, |
| ): |
| if "FFDNET" in model_or_path.upper(): |
| detector = FFDNetDetector(model_or_path, device=device, fast=fast) |
| else: |
| detector = FFDetrDetector(model_or_path) |
|
|
| try: |
| pages = render_pdf(input_path) |
| except pypdfium2._helpers.misc.PdfiumError: |
| raise EncryptedPdfError |
|
|
| results = detector.extract_widgets( |
| pages, confidence=confidence, image_size=image_size |
| ) |
|
|
| writer = PyPdfFormCreator(input_path) |
| if not keep_existing_fields: |
| writer.clear_existing_fields() |
|
|
| for page_ix, widgets in results.items(): |
| for i, widget in enumerate(widgets): |
| name = f"{widget.widget_type.lower()}_{widget.page}_{i}" |
|
|
| if widget.widget_type == "TextBox": |
| writer.add_text_box( |
| name, page_ix, widget.bounding_box, multiline=multiline |
| ) |
| elif widget.widget_type == "ChoiceButton": |
| writer.add_checkbox(name, page_ix, widget.bounding_box) |
| elif widget.widget_type == "Signature": |
| if use_signature_fields: |
| writer.add_signature(name, page_ix, widget.bounding_box) |
| else: |
| writer.add_text_box(name, page_ix, widget.bounding_box) |
|
|
| writer.save(output_path) |
| writer.close() |
|
|