Tanishq Salkar
code updated
9de8051
from __future__ import annotations
from ultralytics import YOLO
from pathlib import Path
from huggingface_hub import hf_hub_download
from rfdetr import RFDETRNano, RFDETRBase, RFDETRMedium, RFDETRLarge
from commonforms.utils import BoundingBox, Page, Widget
from commonforms.form_creator import PyPdfFormCreator
from commonforms.exceptions import EncryptedPdfError
import formalpdf
import pypdfium2
import logging
import PIL
logging.basicConfig(level=logging.INFO)
# our mapping from (model_name_upper, fast) to (repo_id, filename) for the huggingface hub.
# keeping it simple and declarative like this becuase it's not like we're adding a bunch
# of models.
models = {
("FFDNET-S", True): ("jbarrow/FFDNet-S-cpu", "FFDNet-S.onnx"),
("FFDNET-S", False): ("jbarrow/FFDNet-S", "FFDNet-S.pt"),
("FFDNET-L", True): ("jbarrow/FFDNet-L-cpu", "FFDNet-L.onnx"),
("FFDNET-L", False): ("jbarrow/FFDNet-L", "FFDNet-L.pt"),
("FFDETR", False): ("jbarrow/FFDetr", "FFDetr.pth"),
}
def batch(lst: list, n: int = 8):
l = len(lst)
for ndx in range(0, l, n):
yield lst[ndx : min(ndx + n, l)]
class FFDetrDetector:
def __init__(self, model_or_path: str, device: int | str = "cpu") -> None:
self.device = device
self.model = RFDETRMedium(pretrain_weights=self.get_model_path(model_or_path))
self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"}
def get_model_path(self, model_or_path: str) -> str:
model_upper = model_or_path.upper()
if model_upper in ["FFDETR"]:
# download the model, will just use the cached version if it already exists
repo_id, filename = models[(model_upper, False)]
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
else:
model_path = model_or_path
return model_path
def resize(
self,
image: PIL.Image.Image,
size: tuple[int, int] | int,
) -> PIL.Image.Image:
if isinstance(size, int):
size = (size, size)
return image.resize(size, PIL.Image.Resampling.LANCZOS)
def extract_widgets(
self,
pages: list[Page],
confidence: float = 0.4,
image_size: int = 1120,
batch_size: int = 3,
) -> dict[int, list[Widget]]:
image_size = 1024
results = []
for b in batch([p.image for p in pages], n=batch_size):
predictions = self.model.predict(b, threshold=confidence)
# if len(pages) == 1 or batch_size == 1:
# predictions = [predictions]
# results.extend(predictions)
if isinstance(predictions, list):
results.extend(predictions)
else:
results.append(predictions)
widgets = {}
for page_ix, detections in enumerate(results):
logging.info(f" Page {page_ix}: {len(detections)} fields detected")
detections = detections.with_nms(threshold=0.1, class_agnostic=True)
logging.info(f"\t\t{len(detections)} after nms")
widgets[page_ix] = []
for class_id, box in zip(detections.class_id, detections.xyxy):
x0, x1 = box[[0, 2]] / pages[page_ix].image.width
y0, y1 = box[[1, 3]] / pages[page_ix].image.height
widget_type = self.id_to_cls[class_id]
widgets[page_ix].append(
Widget(
widget_type=widget_type,
bounding_box=BoundingBox(x0=x0, y0=y0, x1=x1, y1=y1),
page=page_ix,
)
)
widgets[page_ix] = sort_widgets(widgets[page_ix])
return widgets
class FFDNetDetector:
def __init__(
self, model_or_path: str, device: int | str = "cpu", fast: bool = False
) -> None:
self.device = device
self.fast = fast
model_path = self.get_model_path(model_or_path, device, fast)
self.model = YOLO(model_path, task="detect")
self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"}
def get_model_path(
self, model_or_path: str, device: int | str = "cpu", fast: bool = False
) -> str:
"""
Construct the path to the model weights based on:
(a) the requested model (in the package or external path)
(b) --fast (if enabled, use ONNX, otherwise use pt)
"""
model_upper = model_or_path.upper()
if model_upper in ["FFDNET-S", "FFDNET-L"]:
# download the model, will just use the cached version if it already exists
repo_id, filename = models[(model_upper, fast)]
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
else:
model_path = model_or_path
return model_path
def extract_widgets(
self, pages: list[Page], confidence: float = 0.3, image_size: int = 1600
) -> dict[int, list[Widget]]:
if self.fast:
# overrides the image size to 1216, since that's all ONNX supports
results = [
self.model.predict(
p.image, iou=1, conf=confidence, augment=False, imgsz=1216
)
for p in pages
]
else:
results = self.model.predict(
[p.image for p in pages],
iou=0.1,
conf=confidence,
augment=True,
imgsz=image_size,
device=self.device,
)
widgets = {}
for page_ix, result in enumerate(results):
if isinstance(result, list):
result = result[0]
# no predictions, skip page
if result is None or result.boxes is None:
continue
widgets[page_ix] = []
for box in result.boxes.cpu().numpy():
x, y, w, h = box.xywhn[0]
cls_id = int(box.cls.item())
widget_type = self.id_to_cls[cls_id]
widgets[page_ix].append(
Widget(
widget_type=widget_type,
bounding_box=BoundingBox.from_yolo(cx=x, cy=y, w=w, h=h),
page=page_ix,
)
)
# do our best to sort the widgets into something resembling reading
# order; this is important for being able to Tab/Shift-Tab back and
# forth to navigate the page.
widgets[page_ix] = sort_widgets(widgets[page_ix])
return widgets
def sort_widgets(widgets: list[Widget]) -> list[Widget]:
"""
Sort widgets in approximate reading order (left-to-right/top-to-bottom)
which makes the LLMs less likely to mess up.
"""
# Sort first by y coordinate, then x coordinate for reading order
sorted_widgets = sorted(
widgets,
key=lambda w: (
round(
w.bounding_box.y0, 3
), # Round to handle minor vertical alignment differences
w.bounding_box.x0,
),
)
# Find rows of widgets by grouping those with similar y coordinates
y_threshold = 0.01 # Threshold for considering widgets on same line
lines = []
current_line = []
for widget in sorted_widgets:
if (
not current_line
or abs(widget.bounding_box.y0 - current_line[0].bounding_box.y0)
< y_threshold
):
current_line.append(widget)
else:
# Sort widgets in line by x coordinate
current_line.sort(key=lambda w: w.bounding_box.x0)
lines.append(current_line)
current_line = [widget]
if current_line:
current_line.sort(key=lambda w: w.bounding_box.x0)
lines.append(current_line)
# Flatten the lines back into single list
return [widget for line in lines for widget in line]
def render_pdf(pdf_path: str) -> list[Page]:
pages = []
doc = formalpdf.open(pdf_path)
try:
for page in doc:
image = page.render(dpi=144)
pages.append(Page(image=image, width=image.width, height=image.height))
return pages
finally:
doc.document.close()
def prepare_form(
input_path: str | Path,
output_path: str | Path,
*,
model_or_path: str = "FFDetr",
keep_existing_fields: bool = False,
use_signature_fields: bool = False,
device: int | str = "cpu",
image_size: int = 1024,
confidence: float = 0.4,
fast: bool = False,
multiline: bool = False,
batch_size: int = 4,
):
if "FFDNET" in model_or_path.upper():
detector = FFDNetDetector(model_or_path, device=device, fast=fast)
else:
detector = FFDetrDetector(model_or_path)
try:
pages = render_pdf(input_path)
except pypdfium2._helpers.misc.PdfiumError:
raise EncryptedPdfError
results = detector.extract_widgets(
pages, confidence=confidence, image_size=image_size
)
writer = PyPdfFormCreator(input_path)
if not keep_existing_fields:
writer.clear_existing_fields()
for page_ix, widgets in results.items():
for i, widget in enumerate(widgets):
name = f"{widget.widget_type.lower()}_{widget.page}_{i}"
if widget.widget_type == "TextBox":
writer.add_text_box(
name, page_ix, widget.bounding_box, multiline=multiline
)
elif widget.widget_type == "ChoiceButton":
writer.add_checkbox(name, page_ix, widget.bounding_box)
elif widget.widget_type == "Signature":
if use_signature_fields:
writer.add_signature(name, page_ix, widget.bounding_box)
else:
writer.add_text_box(name, page_ix, widget.bounding_box)
writer.save(output_path)
writer.close()