Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Parse vendor invoices (LayoutLMv3 FUNSD) or retail receipts (Donut CORD v2). | |
| Usage: | |
| python3 scripts/parse_vendor_document.py --image /path/to.png [--type invoice|receipt|auto] | |
| Prints a single JSON object to stdout matching ParsedVendorInvoice. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import re | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| RECEIPT_MODEL = "naver-clova-ix/donut-base-finetuned-cord-v2" | |
| INVOICE_MODEL = "nielsr/layoutlmv3-finetuned-funsd" | |
| INVOICE_HINTS = ( | |
| "invoice", | |
| "inv #", | |
| "inv no", | |
| "bill to", | |
| "ship to", | |
| "purchase order", | |
| "po #", | |
| "remit to", | |
| "net 30", | |
| "del weight", | |
| "unit price", | |
| "vendor", | |
| "food service", | |
| ) | |
| RECEIPT_HINTS = ( | |
| "receipt", | |
| "thank you", | |
| "subtotal", | |
| "sub total", | |
| "change due", | |
| "cashier", | |
| "register", | |
| "visa", | |
| "mastercard", | |
| "debit", | |
| "loyalty", | |
| "store #", | |
| ) | |
| class OcrWord: | |
| text: str | |
| left: int | |
| top: int | |
| width: int | |
| height: int | |
| def box(self) -> list[int]: | |
| return [self.left, self.top, self.left + self.width, self.top + self.height] | |
| def eprint(*args: object) -> None: | |
| print(*args, file=sys.stderr) | |
| def load_image(path: Path): | |
| from PIL import Image | |
| image = Image.open(path).convert("RGB") | |
| return image | |
| def ocr_words(image) -> list[OcrWord]: | |
| import pytesseract | |
| data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT) | |
| words: list[OcrWord] = [] | |
| count = len(data["text"]) | |
| for i in range(count): | |
| text = (data["text"][i] or "").strip() | |
| if not text: | |
| continue | |
| conf = int(float(data["conf"][i])) if data["conf"][i] not in ("-1", "") else -1 | |
| if conf >= 0 and conf < 35: | |
| continue | |
| words.append( | |
| OcrWord( | |
| text=text, | |
| left=int(data["left"][i]), | |
| top=int(data["top"][i]), | |
| width=int(data["width"][i]), | |
| height=int(data["height"][i]), | |
| ) | |
| ) | |
| return words | |
| def normalize_boxes(words: list[OcrWord], width: int, height: int) -> list[list[int]]: | |
| boxes: list[list[int]] = [] | |
| for word in words: | |
| x0, y0, x1, y1 = word.box | |
| boxes.append( | |
| [ | |
| min(1000, max(0, int(1000 * x0 / width))), | |
| min(1000, max(0, int(1000 * y0 / height))), | |
| min(1000, max(0, int(1000 * x1 / width))), | |
| min(1000, max(0, int(1000 * y1 / height))), | |
| ] | |
| ) | |
| return boxes | |
| def classify_document_type(words: list[OcrWord], forced: str | None) -> str: | |
| if forced in ("invoice", "receipt"): | |
| return forced | |
| text = " ".join(word.text for word in words).lower() | |
| invoice_score = sum(1 for hint in INVOICE_HINTS if hint in text) | |
| receipt_score = sum(1 for hint in RECEIPT_HINTS if hint in text) | |
| if "invoice" in text or "inv " in text: | |
| invoice_score += 2 | |
| if "receipt" in text: | |
| receipt_score += 2 | |
| if invoice_score > receipt_score + 1: | |
| return "invoice" | |
| if receipt_score > invoice_score: | |
| return "receipt" | |
| return "invoice" | |
| def parse_loose_number(value: Any) -> float | None: | |
| if isinstance(value, (int, float)): | |
| return float(value) | |
| if not isinstance(value, str): | |
| return None | |
| cleaned = re.sub(r"[^0-9.,-]", "", value).replace(",", ".") | |
| if not cleaned: | |
| return None | |
| try: | |
| return float(cleaned) | |
| except ValueError: | |
| return None | |
| def normalize_date(value: str | None) -> str | None: | |
| if not value: | |
| return None | |
| value = value.strip() | |
| if re.match(r"^\d{4}-\d{2}-\d{2}$", value): | |
| return value | |
| match = re.match(r"^(\d{1,2})/(\d{1,2})/(\d{2,4})$", value) | |
| if not match: | |
| return value | |
| month, day, year = match.groups() | |
| if len(year) == 2: | |
| year = f"20{year}" | |
| return f"{year}-{month.zfill(2)}-{day.zfill(2)}" | |
| def map_cord_json(cord: dict[str, Any]) -> dict[str, Any]: | |
| line_items: list[dict[str, Any]] = [] | |
| menu = cord.get("menu") | |
| menus = menu if isinstance(menu, list) else [menu] if isinstance(menu, dict) else [] | |
| for entry in menus: | |
| if not isinstance(entry, dict): | |
| continue | |
| description = ( | |
| entry.get("nm") | |
| or entry.get("item") | |
| or entry.get("name") | |
| or entry.get("menu.nm") | |
| ) | |
| if not description or not str(description).strip(): | |
| continue | |
| line_items.append( | |
| { | |
| "description": str(description).strip(), | |
| "vendorItemNumber": None, | |
| "quantity": parse_loose_number(entry.get("cnt") or entry.get("num")), | |
| "unit": str(entry.get("unit") or entry.get("itemsubtotal") or "").strip() or None, | |
| "unitPrice": parse_loose_number( | |
| entry.get("unitprice") or entry.get("price") or entry.get("itemprice") | |
| ), | |
| "lineTotal": parse_loose_number( | |
| entry.get("price") or entry.get("cntprice") or entry.get("itemprice") | |
| ), | |
| } | |
| ) | |
| sub_total = cord.get("sub_total") or cord.get("subtotal") | |
| tax = cord.get("tax") or cord.get("tax_price") | |
| total = cord.get("total") or cord.get("total_price") or cord.get("total_etc") | |
| def price_field(block: Any, *keys: str) -> float | None: | |
| if isinstance(block, dict): | |
| for key in keys: | |
| if key in block: | |
| return parse_loose_number(block[key]) | |
| return parse_loose_number(block) | |
| return { | |
| "vendorName": str(cord.get("store") or cord.get("company") or cord.get("brand") or "").strip() | |
| or None, | |
| "invoiceNumber": str(cord.get("receipt_no") or cord.get("order_no") or "").strip() or None, | |
| "invoiceDate": normalize_date( | |
| str(cord.get("date") or cord.get("receipt_date") or "").strip() or None | |
| ), | |
| "subtotal": price_field(sub_total, "price", "subtotal_price", "sub_total_price"), | |
| "tax": price_field(tax, "price", "tax_price"), | |
| "total": price_field(total, "total_price", "price", "total"), | |
| "currency": None, | |
| "confidence": "medium" if line_items else "low", | |
| "rawNotes": json.dumps(cord)[:4000] if cord else None, | |
| "lineItems": line_items, | |
| } | |
| def parse_receipt(image) -> dict[str, Any]: | |
| import torch | |
| from transformers import DonutProcessor, VisionEncoderDecoderModel | |
| processor = DonutProcessor.from_pretrained(RECEIPT_MODEL) | |
| model = VisionEncoderDecoderModel.from_pretrained(RECEIPT_MODEL) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| pixel_values = processor(image, return_tensors="pt").pixel_values.to(device) | |
| task_prompt = "<s_cord-v2>" | |
| decoder_input_ids = processor.tokenizer( | |
| task_prompt, add_special_tokens=False, return_tensors="pt" | |
| ).input_ids.to(device) | |
| outputs = model.generate( | |
| pixel_values, | |
| decoder_input_ids=decoder_input_ids, | |
| max_length=model.decoder.config.max_position_embeddings, | |
| early_stopping=True, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| num_beams=1, | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
| return_dict_in_generate=True, | |
| ) | |
| sequence = processor.batch_decode(outputs.sequences)[0] | |
| sequence = ( | |
| sequence.replace(processor.tokenizer.eos_token, "") | |
| .replace(processor.tokenizer.pad_token, "") | |
| .strip() | |
| ) | |
| sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() | |
| cord = processor.token2json(sequence) | |
| return map_cord_json(cord) | |
| def align_word_labels(word_texts: list[str], word_ids: list[int | None], predictions: list[int], id2label: dict) -> list[str]: | |
| labels = ["O"] * len(word_texts) | |
| for word_id, pred in zip(word_ids, predictions): | |
| if word_id is None: | |
| continue | |
| label = id2label.get(pred, id2label.get(str(pred), "O")) | |
| labels[word_id] = label | |
| return labels | |
| def group_entities(words: list[str], labels: list[str]) -> list[tuple[str, str]]: | |
| groups: list[tuple[str, str]] = [] | |
| current_label: str | None = None | |
| current_tokens: list[str] = [] | |
| def flush() -> None: | |
| nonlocal current_label, current_tokens | |
| if current_tokens and current_label: | |
| groups.append((current_label, " ".join(current_tokens).strip())) | |
| current_label = None | |
| current_tokens = [] | |
| for word, label in zip(words, labels): | |
| if label == "O": | |
| flush() | |
| continue | |
| prefix = label[:2] | |
| base = label[2:] if prefix in ("B-", "I-") else label | |
| if prefix == "B-" or current_label != base: | |
| flush() | |
| current_label = base | |
| current_tokens = [word] | |
| else: | |
| current_tokens.append(word) | |
| flush() | |
| return groups | |
| def extract_qa_pairs(groups: list[tuple[str, str]]) -> list[tuple[str, str]]: | |
| pairs: list[tuple[str, str]] = [] | |
| pending_question: str | None = None | |
| for label, text in groups: | |
| if label.endswith("QUESTION"): | |
| pending_question = text | |
| elif label.endswith("ANSWER") and pending_question: | |
| pairs.append((pending_question, text)) | |
| pending_question = None | |
| elif label.endswith("HEADER"): | |
| pairs.append(("HEADER", text)) | |
| return pairs | |
| def extract_line_items_from_ocr(words: list[OcrWord]) -> list[dict[str, Any]]: | |
| if not words: | |
| return [] | |
| rows: dict[int, list[OcrWord]] = {} | |
| for word in words: | |
| bucket = round(word.top / 12) * 12 | |
| rows.setdefault(bucket, []).append(word) | |
| line_items: list[dict[str, Any]] = [] | |
| for _, row_words in sorted(rows.items()): | |
| row_words = sorted(row_words, key=lambda w: w.left) | |
| text = " ".join(word.text for word in row_words) | |
| if len(text) < 4: | |
| continue | |
| lower = text.lower() | |
| if any( | |
| skip in lower | |
| for skip in ( | |
| "subtotal", | |
| "sub total", | |
| "total", | |
| "tax", | |
| "balance", | |
| "thank you", | |
| "page ", | |
| "invoice", | |
| "bill to", | |
| "ship to", | |
| ) | |
| ): | |
| continue | |
| numbers = [ | |
| parse_loose_number(match.group()) | |
| for match in re.finditer(r"\d[\d,]*\.?\d*", text) | |
| ] | |
| numbers = [n for n in numbers if n is not None] | |
| if len(numbers) < 2: | |
| continue | |
| quantity = numbers[-2] if len(numbers) >= 2 else None | |
| line_total = numbers[-1] | |
| description = re.sub(r"\s+\d[\d,]*\.?\d*.*$", "", text).strip() | |
| if len(description) < 3: | |
| continue | |
| line_items.append( | |
| { | |
| "description": description, | |
| "vendorItemNumber": None, | |
| "quantity": quantity, | |
| "unit": None, | |
| "unitPrice": round(line_total / quantity, 4) if quantity and quantity > 0 else None, | |
| "lineTotal": line_total, | |
| } | |
| ) | |
| return line_items[:40] | |
| def parse_invoice(image, words: list[OcrWord]) -> dict[str, Any]: | |
| import torch | |
| from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification | |
| if not words: | |
| return { | |
| "vendorName": None, | |
| "invoiceNumber": None, | |
| "invoiceDate": None, | |
| "subtotal": None, | |
| "tax": None, | |
| "total": None, | |
| "currency": None, | |
| "confidence": "low", | |
| "rawNotes": None, | |
| "lineItems": [], | |
| } | |
| processor = LayoutLMv3Processor.from_pretrained(INVOICE_MODEL, apply_ocr=False) | |
| model = LayoutLMv3ForTokenClassification.from_pretrained(INVOICE_MODEL) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| width, height = image.size | |
| word_texts = [word.text for word in words] | |
| boxes = normalize_boxes(words, width, height) | |
| encoding = processor( | |
| image, | |
| word_texts, | |
| boxes=boxes, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=512, | |
| ) | |
| encoding = {key: value.to(device) for key, value in encoding.items()} | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| predictions = outputs.logits.argmax(-1).squeeze().tolist() | |
| if isinstance(predictions, int): | |
| predictions = [predictions] | |
| id2label = model.config.id2label | |
| word_ids = encoding.word_ids(batch_index=0) | |
| labels = align_word_labels(word_texts, word_ids, predictions, id2label) | |
| groups = group_entities(word_texts, labels) | |
| qa_pairs = extract_qa_pairs(groups) | |
| vendor_name = None | |
| invoice_number = None | |
| invoice_date = None | |
| total = None | |
| tax = None | |
| subtotal = None | |
| for question, answer in qa_pairs: | |
| q = question.lower() | |
| if question == "HEADER" and not vendor_name: | |
| vendor_name = answer | |
| continue | |
| if any(token in q for token in ("invoice", "inv", "bill")) and "date" in q: | |
| invoice_date = normalize_date(answer) | |
| elif any(token in q for token in ("invoice", "inv")) and "no" in q: | |
| invoice_number = answer | |
| elif "date" in q: | |
| invoice_date = normalize_date(answer) | |
| elif "total" in q and "sub" not in q: | |
| total = parse_loose_number(answer) | |
| elif "tax" in q: | |
| tax = parse_loose_number(answer) | |
| elif "subtotal" in q or "sub total" in q: | |
| subtotal = parse_loose_number(answer) | |
| elif any(token in q for token in ("vendor", "supplier", "seller", "remit", "from")): | |
| vendor_name = answer | |
| line_items = extract_line_items_from_ocr(words) | |
| confidence = "high" if line_items and (invoice_number or vendor_name) else "medium" if line_items else "low" | |
| return { | |
| "vendorName": vendor_name, | |
| "invoiceNumber": invoice_number, | |
| "invoiceDate": invoice_date, | |
| "subtotal": subtotal, | |
| "tax": tax, | |
| "total": total, | |
| "currency": None, | |
| "confidence": confidence, | |
| "rawNotes": None, | |
| "lineItems": line_items, | |
| } | |
| def main() -> int: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--image", required=True, help="Path to a PNG/JPG/WebP image") | |
| parser.add_argument( | |
| "--type", | |
| default="auto", | |
| choices=("auto", "invoice", "receipt"), | |
| help="Document type routing", | |
| ) | |
| args = parser.parse_args() | |
| image_path = Path(args.image) | |
| if not image_path.exists(): | |
| eprint(f"Image not found: {image_path}") | |
| return 1 | |
| try: | |
| image = load_image(image_path) | |
| words = ocr_words(image) | |
| doc_type = classify_document_type(words, None if args.type == "auto" else args.type) | |
| result = parse_receipt(image) if doc_type == "receipt" else parse_invoice(image, words) | |
| payload = {"documentType": doc_type, **result} | |
| print(json.dumps(payload)) | |
| return 0 | |
| except Exception as error: # noqa: BLE001 | |
| eprint(f"Document parse failed: {error}") | |
| return 1 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |