fresh-catch-parser / scripts /parse_vendor_document.py
stubdude's picture
Add document parser Docker service
fbba60e
#!/usr/bin/env python3
"""
Parse vendor invoices (LayoutLMv3 FUNSD) or retail receipts (Donut CORD v2).
Usage:
python3 scripts/parse_vendor_document.py --image /path/to.png [--type invoice|receipt|auto]
Prints a single JSON object to stdout matching ParsedVendorInvoice.
"""
from __future__ import annotations
import argparse
import json
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
RECEIPT_MODEL = "naver-clova-ix/donut-base-finetuned-cord-v2"
INVOICE_MODEL = "nielsr/layoutlmv3-finetuned-funsd"
INVOICE_HINTS = (
"invoice",
"inv #",
"inv no",
"bill to",
"ship to",
"purchase order",
"po #",
"remit to",
"net 30",
"del weight",
"unit price",
"vendor",
"food service",
)
RECEIPT_HINTS = (
"receipt",
"thank you",
"subtotal",
"sub total",
"change due",
"cashier",
"register",
"visa",
"mastercard",
"debit",
"loyalty",
"store #",
)
@dataclass
class OcrWord:
text: str
left: int
top: int
width: int
height: int
@property
def box(self) -> list[int]:
return [self.left, self.top, self.left + self.width, self.top + self.height]
def eprint(*args: object) -> None:
print(*args, file=sys.stderr)
def load_image(path: Path):
from PIL import Image
image = Image.open(path).convert("RGB")
return image
def ocr_words(image) -> list[OcrWord]:
import pytesseract
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
words: list[OcrWord] = []
count = len(data["text"])
for i in range(count):
text = (data["text"][i] or "").strip()
if not text:
continue
conf = int(float(data["conf"][i])) if data["conf"][i] not in ("-1", "") else -1
if conf >= 0 and conf < 35:
continue
words.append(
OcrWord(
text=text,
left=int(data["left"][i]),
top=int(data["top"][i]),
width=int(data["width"][i]),
height=int(data["height"][i]),
)
)
return words
def normalize_boxes(words: list[OcrWord], width: int, height: int) -> list[list[int]]:
boxes: list[list[int]] = []
for word in words:
x0, y0, x1, y1 = word.box
boxes.append(
[
min(1000, max(0, int(1000 * x0 / width))),
min(1000, max(0, int(1000 * y0 / height))),
min(1000, max(0, int(1000 * x1 / width))),
min(1000, max(0, int(1000 * y1 / height))),
]
)
return boxes
def classify_document_type(words: list[OcrWord], forced: str | None) -> str:
if forced in ("invoice", "receipt"):
return forced
text = " ".join(word.text for word in words).lower()
invoice_score = sum(1 for hint in INVOICE_HINTS if hint in text)
receipt_score = sum(1 for hint in RECEIPT_HINTS if hint in text)
if "invoice" in text or "inv " in text:
invoice_score += 2
if "receipt" in text:
receipt_score += 2
if invoice_score > receipt_score + 1:
return "invoice"
if receipt_score > invoice_score:
return "receipt"
return "invoice"
def parse_loose_number(value: Any) -> float | None:
if isinstance(value, (int, float)):
return float(value)
if not isinstance(value, str):
return None
cleaned = re.sub(r"[^0-9.,-]", "", value).replace(",", ".")
if not cleaned:
return None
try:
return float(cleaned)
except ValueError:
return None
def normalize_date(value: str | None) -> str | None:
if not value:
return None
value = value.strip()
if re.match(r"^\d{4}-\d{2}-\d{2}$", value):
return value
match = re.match(r"^(\d{1,2})/(\d{1,2})/(\d{2,4})$", value)
if not match:
return value
month, day, year = match.groups()
if len(year) == 2:
year = f"20{year}"
return f"{year}-{month.zfill(2)}-{day.zfill(2)}"
def map_cord_json(cord: dict[str, Any]) -> dict[str, Any]:
line_items: list[dict[str, Any]] = []
menu = cord.get("menu")
menus = menu if isinstance(menu, list) else [menu] if isinstance(menu, dict) else []
for entry in menus:
if not isinstance(entry, dict):
continue
description = (
entry.get("nm")
or entry.get("item")
or entry.get("name")
or entry.get("menu.nm")
)
if not description or not str(description).strip():
continue
line_items.append(
{
"description": str(description).strip(),
"vendorItemNumber": None,
"quantity": parse_loose_number(entry.get("cnt") or entry.get("num")),
"unit": str(entry.get("unit") or entry.get("itemsubtotal") or "").strip() or None,
"unitPrice": parse_loose_number(
entry.get("unitprice") or entry.get("price") or entry.get("itemprice")
),
"lineTotal": parse_loose_number(
entry.get("price") or entry.get("cntprice") or entry.get("itemprice")
),
}
)
sub_total = cord.get("sub_total") or cord.get("subtotal")
tax = cord.get("tax") or cord.get("tax_price")
total = cord.get("total") or cord.get("total_price") or cord.get("total_etc")
def price_field(block: Any, *keys: str) -> float | None:
if isinstance(block, dict):
for key in keys:
if key in block:
return parse_loose_number(block[key])
return parse_loose_number(block)
return {
"vendorName": str(cord.get("store") or cord.get("company") or cord.get("brand") or "").strip()
or None,
"invoiceNumber": str(cord.get("receipt_no") or cord.get("order_no") or "").strip() or None,
"invoiceDate": normalize_date(
str(cord.get("date") or cord.get("receipt_date") or "").strip() or None
),
"subtotal": price_field(sub_total, "price", "subtotal_price", "sub_total_price"),
"tax": price_field(tax, "price", "tax_price"),
"total": price_field(total, "total_price", "price", "total"),
"currency": None,
"confidence": "medium" if line_items else "low",
"rawNotes": json.dumps(cord)[:4000] if cord else None,
"lineItems": line_items,
}
def parse_receipt(image) -> dict[str, Any]:
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
processor = DonutProcessor.from_pretrained(RECEIPT_MODEL)
model = VisionEncoderDecoderModel.from_pretrained(RECEIPT_MODEL)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
).input_ids.to(device)
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = (
sequence.replace(processor.tokenizer.eos_token, "")
.replace(processor.tokenizer.pad_token, "")
.strip()
)
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
cord = processor.token2json(sequence)
return map_cord_json(cord)
def align_word_labels(word_texts: list[str], word_ids: list[int | None], predictions: list[int], id2label: dict) -> list[str]:
labels = ["O"] * len(word_texts)
for word_id, pred in zip(word_ids, predictions):
if word_id is None:
continue
label = id2label.get(pred, id2label.get(str(pred), "O"))
labels[word_id] = label
return labels
def group_entities(words: list[str], labels: list[str]) -> list[tuple[str, str]]:
groups: list[tuple[str, str]] = []
current_label: str | None = None
current_tokens: list[str] = []
def flush() -> None:
nonlocal current_label, current_tokens
if current_tokens and current_label:
groups.append((current_label, " ".join(current_tokens).strip()))
current_label = None
current_tokens = []
for word, label in zip(words, labels):
if label == "O":
flush()
continue
prefix = label[:2]
base = label[2:] if prefix in ("B-", "I-") else label
if prefix == "B-" or current_label != base:
flush()
current_label = base
current_tokens = [word]
else:
current_tokens.append(word)
flush()
return groups
def extract_qa_pairs(groups: list[tuple[str, str]]) -> list[tuple[str, str]]:
pairs: list[tuple[str, str]] = []
pending_question: str | None = None
for label, text in groups:
if label.endswith("QUESTION"):
pending_question = text
elif label.endswith("ANSWER") and pending_question:
pairs.append((pending_question, text))
pending_question = None
elif label.endswith("HEADER"):
pairs.append(("HEADER", text))
return pairs
def extract_line_items_from_ocr(words: list[OcrWord]) -> list[dict[str, Any]]:
if not words:
return []
rows: dict[int, list[OcrWord]] = {}
for word in words:
bucket = round(word.top / 12) * 12
rows.setdefault(bucket, []).append(word)
line_items: list[dict[str, Any]] = []
for _, row_words in sorted(rows.items()):
row_words = sorted(row_words, key=lambda w: w.left)
text = " ".join(word.text for word in row_words)
if len(text) < 4:
continue
lower = text.lower()
if any(
skip in lower
for skip in (
"subtotal",
"sub total",
"total",
"tax",
"balance",
"thank you",
"page ",
"invoice",
"bill to",
"ship to",
)
):
continue
numbers = [
parse_loose_number(match.group())
for match in re.finditer(r"\d[\d,]*\.?\d*", text)
]
numbers = [n for n in numbers if n is not None]
if len(numbers) < 2:
continue
quantity = numbers[-2] if len(numbers) >= 2 else None
line_total = numbers[-1]
description = re.sub(r"\s+\d[\d,]*\.?\d*.*$", "", text).strip()
if len(description) < 3:
continue
line_items.append(
{
"description": description,
"vendorItemNumber": None,
"quantity": quantity,
"unit": None,
"unitPrice": round(line_total / quantity, 4) if quantity and quantity > 0 else None,
"lineTotal": line_total,
}
)
return line_items[:40]
def parse_invoice(image, words: list[OcrWord]) -> dict[str, Any]:
import torch
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
if not words:
return {
"vendorName": None,
"invoiceNumber": None,
"invoiceDate": None,
"subtotal": None,
"tax": None,
"total": None,
"currency": None,
"confidence": "low",
"rawNotes": None,
"lineItems": [],
}
processor = LayoutLMv3Processor.from_pretrained(INVOICE_MODEL, apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained(INVOICE_MODEL)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
width, height = image.size
word_texts = [word.text for word in words]
boxes = normalize_boxes(words, width, height)
encoding = processor(
image,
word_texts,
boxes=boxes,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=512,
)
encoding = {key: value.to(device) for key, value in encoding.items()}
with torch.no_grad():
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
if isinstance(predictions, int):
predictions = [predictions]
id2label = model.config.id2label
word_ids = encoding.word_ids(batch_index=0)
labels = align_word_labels(word_texts, word_ids, predictions, id2label)
groups = group_entities(word_texts, labels)
qa_pairs = extract_qa_pairs(groups)
vendor_name = None
invoice_number = None
invoice_date = None
total = None
tax = None
subtotal = None
for question, answer in qa_pairs:
q = question.lower()
if question == "HEADER" and not vendor_name:
vendor_name = answer
continue
if any(token in q for token in ("invoice", "inv", "bill")) and "date" in q:
invoice_date = normalize_date(answer)
elif any(token in q for token in ("invoice", "inv")) and "no" in q:
invoice_number = answer
elif "date" in q:
invoice_date = normalize_date(answer)
elif "total" in q and "sub" not in q:
total = parse_loose_number(answer)
elif "tax" in q:
tax = parse_loose_number(answer)
elif "subtotal" in q or "sub total" in q:
subtotal = parse_loose_number(answer)
elif any(token in q for token in ("vendor", "supplier", "seller", "remit", "from")):
vendor_name = answer
line_items = extract_line_items_from_ocr(words)
confidence = "high" if line_items and (invoice_number or vendor_name) else "medium" if line_items else "low"
return {
"vendorName": vendor_name,
"invoiceNumber": invoice_number,
"invoiceDate": invoice_date,
"subtotal": subtotal,
"tax": tax,
"total": total,
"currency": None,
"confidence": confidence,
"rawNotes": None,
"lineItems": line_items,
}
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--image", required=True, help="Path to a PNG/JPG/WebP image")
parser.add_argument(
"--type",
default="auto",
choices=("auto", "invoice", "receipt"),
help="Document type routing",
)
args = parser.parse_args()
image_path = Path(args.image)
if not image_path.exists():
eprint(f"Image not found: {image_path}")
return 1
try:
image = load_image(image_path)
words = ocr_words(image)
doc_type = classify_document_type(words, None if args.type == "auto" else args.type)
result = parse_receipt(image) if doc_type == "receipt" else parse_invoice(image, words)
payload = {"documentType": doc_type, **result}
print(json.dumps(payload))
return 0
except Exception as error: # noqa: BLE001
eprint(f"Document parse failed: {error}")
return 1
if __name__ == "__main__":
raise SystemExit(main())