#!/usr/bin/env python3 """ Parse vendor invoices (LayoutLMv3 FUNSD) or retail receipts (Donut CORD v2). Usage: python3 scripts/parse_vendor_document.py --image /path/to.png [--type invoice|receipt|auto] Prints a single JSON object to stdout matching ParsedVendorInvoice. """ from __future__ import annotations import argparse import json import re import sys from dataclasses import dataclass from pathlib import Path from typing import Any RECEIPT_MODEL = "naver-clova-ix/donut-base-finetuned-cord-v2" INVOICE_MODEL = "nielsr/layoutlmv3-finetuned-funsd" INVOICE_HINTS = ( "invoice", "inv #", "inv no", "bill to", "ship to", "purchase order", "po #", "remit to", "net 30", "del weight", "unit price", "vendor", "food service", ) RECEIPT_HINTS = ( "receipt", "thank you", "subtotal", "sub total", "change due", "cashier", "register", "visa", "mastercard", "debit", "loyalty", "store #", ) @dataclass class OcrWord: text: str left: int top: int width: int height: int @property def box(self) -> list[int]: return [self.left, self.top, self.left + self.width, self.top + self.height] def eprint(*args: object) -> None: print(*args, file=sys.stderr) def load_image(path: Path): from PIL import Image image = Image.open(path).convert("RGB") return image def ocr_words(image) -> list[OcrWord]: import pytesseract data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT) words: list[OcrWord] = [] count = len(data["text"]) for i in range(count): text = (data["text"][i] or "").strip() if not text: continue conf = int(float(data["conf"][i])) if data["conf"][i] not in ("-1", "") else -1 if conf >= 0 and conf < 35: continue words.append( OcrWord( text=text, left=int(data["left"][i]), top=int(data["top"][i]), width=int(data["width"][i]), height=int(data["height"][i]), ) ) return words def normalize_boxes(words: list[OcrWord], width: int, height: int) -> list[list[int]]: boxes: list[list[int]] = [] for word in words: x0, y0, x1, y1 = word.box boxes.append( [ min(1000, max(0, int(1000 * x0 / width))), min(1000, max(0, int(1000 * y0 / height))), min(1000, max(0, int(1000 * x1 / width))), min(1000, max(0, int(1000 * y1 / height))), ] ) return boxes def classify_document_type(words: list[OcrWord], forced: str | None) -> str: if forced in ("invoice", "receipt"): return forced text = " ".join(word.text for word in words).lower() invoice_score = sum(1 for hint in INVOICE_HINTS if hint in text) receipt_score = sum(1 for hint in RECEIPT_HINTS if hint in text) if "invoice" in text or "inv " in text: invoice_score += 2 if "receipt" in text: receipt_score += 2 if invoice_score > receipt_score + 1: return "invoice" if receipt_score > invoice_score: return "receipt" return "invoice" def parse_loose_number(value: Any) -> float | None: if isinstance(value, (int, float)): return float(value) if not isinstance(value, str): return None cleaned = re.sub(r"[^0-9.,-]", "", value).replace(",", ".") if not cleaned: return None try: return float(cleaned) except ValueError: return None def normalize_date(value: str | None) -> str | None: if not value: return None value = value.strip() if re.match(r"^\d{4}-\d{2}-\d{2}$", value): return value match = re.match(r"^(\d{1,2})/(\d{1,2})/(\d{2,4})$", value) if not match: return value month, day, year = match.groups() if len(year) == 2: year = f"20{year}" return f"{year}-{month.zfill(2)}-{day.zfill(2)}" def map_cord_json(cord: dict[str, Any]) -> dict[str, Any]: line_items: list[dict[str, Any]] = [] menu = cord.get("menu") menus = menu if isinstance(menu, list) else [menu] if isinstance(menu, dict) else [] for entry in menus: if not isinstance(entry, dict): continue description = ( entry.get("nm") or entry.get("item") or entry.get("name") or entry.get("menu.nm") ) if not description or not str(description).strip(): continue line_items.append( { "description": str(description).strip(), "vendorItemNumber": None, "quantity": parse_loose_number(entry.get("cnt") or entry.get("num")), "unit": str(entry.get("unit") or entry.get("itemsubtotal") or "").strip() or None, "unitPrice": parse_loose_number( entry.get("unitprice") or entry.get("price") or entry.get("itemprice") ), "lineTotal": parse_loose_number( entry.get("price") or entry.get("cntprice") or entry.get("itemprice") ), } ) sub_total = cord.get("sub_total") or cord.get("subtotal") tax = cord.get("tax") or cord.get("tax_price") total = cord.get("total") or cord.get("total_price") or cord.get("total_etc") def price_field(block: Any, *keys: str) -> float | None: if isinstance(block, dict): for key in keys: if key in block: return parse_loose_number(block[key]) return parse_loose_number(block) return { "vendorName": str(cord.get("store") or cord.get("company") or cord.get("brand") or "").strip() or None, "invoiceNumber": str(cord.get("receipt_no") or cord.get("order_no") or "").strip() or None, "invoiceDate": normalize_date( str(cord.get("date") or cord.get("receipt_date") or "").strip() or None ), "subtotal": price_field(sub_total, "price", "subtotal_price", "sub_total_price"), "tax": price_field(tax, "price", "tax_price"), "total": price_field(total, "total_price", "price", "total"), "currency": None, "confidence": "medium" if line_items else "low", "rawNotes": json.dumps(cord)[:4000] if cord else None, "lineItems": line_items, } def parse_receipt(image) -> dict[str, Any]: import torch from transformers import DonutProcessor, VisionEncoderDecoderModel processor = DonutProcessor.from_pretrained(RECEIPT_MODEL) model = VisionEncoderDecoderModel.from_pretrained(RECEIPT_MODEL) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() pixel_values = processor(image, return_tensors="pt").pixel_values.to(device) task_prompt = "" decoder_input_ids = processor.tokenizer( task_prompt, add_special_tokens=False, return_tensors="pt" ).input_ids.to(device) outputs = model.generate( pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) sequence = processor.batch_decode(outputs.sequences)[0] sequence = ( sequence.replace(processor.tokenizer.eos_token, "") .replace(processor.tokenizer.pad_token, "") .strip() ) sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() cord = processor.token2json(sequence) return map_cord_json(cord) def align_word_labels(word_texts: list[str], word_ids: list[int | None], predictions: list[int], id2label: dict) -> list[str]: labels = ["O"] * len(word_texts) for word_id, pred in zip(word_ids, predictions): if word_id is None: continue label = id2label.get(pred, id2label.get(str(pred), "O")) labels[word_id] = label return labels def group_entities(words: list[str], labels: list[str]) -> list[tuple[str, str]]: groups: list[tuple[str, str]] = [] current_label: str | None = None current_tokens: list[str] = [] def flush() -> None: nonlocal current_label, current_tokens if current_tokens and current_label: groups.append((current_label, " ".join(current_tokens).strip())) current_label = None current_tokens = [] for word, label in zip(words, labels): if label == "O": flush() continue prefix = label[:2] base = label[2:] if prefix in ("B-", "I-") else label if prefix == "B-" or current_label != base: flush() current_label = base current_tokens = [word] else: current_tokens.append(word) flush() return groups def extract_qa_pairs(groups: list[tuple[str, str]]) -> list[tuple[str, str]]: pairs: list[tuple[str, str]] = [] pending_question: str | None = None for label, text in groups: if label.endswith("QUESTION"): pending_question = text elif label.endswith("ANSWER") and pending_question: pairs.append((pending_question, text)) pending_question = None elif label.endswith("HEADER"): pairs.append(("HEADER", text)) return pairs def extract_line_items_from_ocr(words: list[OcrWord]) -> list[dict[str, Any]]: if not words: return [] rows: dict[int, list[OcrWord]] = {} for word in words: bucket = round(word.top / 12) * 12 rows.setdefault(bucket, []).append(word) line_items: list[dict[str, Any]] = [] for _, row_words in sorted(rows.items()): row_words = sorted(row_words, key=lambda w: w.left) text = " ".join(word.text for word in row_words) if len(text) < 4: continue lower = text.lower() if any( skip in lower for skip in ( "subtotal", "sub total", "total", "tax", "balance", "thank you", "page ", "invoice", "bill to", "ship to", ) ): continue numbers = [ parse_loose_number(match.group()) for match in re.finditer(r"\d[\d,]*\.?\d*", text) ] numbers = [n for n in numbers if n is not None] if len(numbers) < 2: continue quantity = numbers[-2] if len(numbers) >= 2 else None line_total = numbers[-1] description = re.sub(r"\s+\d[\d,]*\.?\d*.*$", "", text).strip() if len(description) < 3: continue line_items.append( { "description": description, "vendorItemNumber": None, "quantity": quantity, "unit": None, "unitPrice": round(line_total / quantity, 4) if quantity and quantity > 0 else None, "lineTotal": line_total, } ) return line_items[:40] def parse_invoice(image, words: list[OcrWord]) -> dict[str, Any]: import torch from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification if not words: return { "vendorName": None, "invoiceNumber": None, "invoiceDate": None, "subtotal": None, "tax": None, "total": None, "currency": None, "confidence": "low", "rawNotes": None, "lineItems": [], } processor = LayoutLMv3Processor.from_pretrained(INVOICE_MODEL, apply_ocr=False) model = LayoutLMv3ForTokenClassification.from_pretrained(INVOICE_MODEL) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() width, height = image.size word_texts = [word.text for word in words] boxes = normalize_boxes(words, width, height) encoding = processor( image, word_texts, boxes=boxes, return_tensors="pt", truncation=True, padding="max_length", max_length=512, ) encoding = {key: value.to(device) for key, value in encoding.items()} with torch.no_grad(): outputs = model(**encoding) predictions = outputs.logits.argmax(-1).squeeze().tolist() if isinstance(predictions, int): predictions = [predictions] id2label = model.config.id2label word_ids = encoding.word_ids(batch_index=0) labels = align_word_labels(word_texts, word_ids, predictions, id2label) groups = group_entities(word_texts, labels) qa_pairs = extract_qa_pairs(groups) vendor_name = None invoice_number = None invoice_date = None total = None tax = None subtotal = None for question, answer in qa_pairs: q = question.lower() if question == "HEADER" and not vendor_name: vendor_name = answer continue if any(token in q for token in ("invoice", "inv", "bill")) and "date" in q: invoice_date = normalize_date(answer) elif any(token in q for token in ("invoice", "inv")) and "no" in q: invoice_number = answer elif "date" in q: invoice_date = normalize_date(answer) elif "total" in q and "sub" not in q: total = parse_loose_number(answer) elif "tax" in q: tax = parse_loose_number(answer) elif "subtotal" in q or "sub total" in q: subtotal = parse_loose_number(answer) elif any(token in q for token in ("vendor", "supplier", "seller", "remit", "from")): vendor_name = answer line_items = extract_line_items_from_ocr(words) confidence = "high" if line_items and (invoice_number or vendor_name) else "medium" if line_items else "low" return { "vendorName": vendor_name, "invoiceNumber": invoice_number, "invoiceDate": invoice_date, "subtotal": subtotal, "tax": tax, "total": total, "currency": None, "confidence": confidence, "rawNotes": None, "lineItems": line_items, } def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--image", required=True, help="Path to a PNG/JPG/WebP image") parser.add_argument( "--type", default="auto", choices=("auto", "invoice", "receipt"), help="Document type routing", ) args = parser.parse_args() image_path = Path(args.image) if not image_path.exists(): eprint(f"Image not found: {image_path}") return 1 try: image = load_image(image_path) words = ocr_words(image) doc_type = classify_document_type(words, None if args.type == "auto" else args.type) result = parse_receipt(image) if doc_type == "receipt" else parse_invoice(image, words) payload = {"documentType": doc_type, **result} print(json.dumps(payload)) return 0 except Exception as error: # noqa: BLE001 eprint(f"Document parse failed: {error}") return 1 if __name__ == "__main__": raise SystemExit(main())