invoice-processor-ml / src /ml_extraction.py
GSoumyajit2005's picture
feat: added bulk processing, html reporting, and geometric table extraction
90dbe20
# src/ml_extraction.py
import os
import torch
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from huggingface_hub import snapshot_download
from PIL import Image
from typing import List, Dict, Any, Tuple
import re
import numpy as np
from src.extraction import extract_invoice_number, extract_total, extract_address
from src.table_extraction import extract_table_items
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
# --- CONFIGURATION ---
LOCAL_MODEL_PATH = "./models/layoutlmv3-doctr-trained"
HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-doctr-invoice-processor"
# --- Load LayoutLMv3 Model ---
def load_model_and_processor(model_path, hub_id):
print("Loading processor from microsoft/layoutlmv3-base...")
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
if not os.path.exists(model_path) or not os.listdir(model_path):
print(f"Downloading model from Hub: {hub_id}...")
snapshot_download(repo_id=hub_id, local_dir=model_path, local_dir_use_symlinks=False)
try:
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
except Exception:
print(f"Fallback: Loading directly from Hub {hub_id}...")
model = LayoutLMv3ForTokenClassification.from_pretrained(hub_id)
return model, processor
# --- Load DocTR OCR Predictor ---
def load_doctr_predictor():
"""Initialize DocTR predictor and move to GPU for speed."""
print("Loading DocTR OCR predictor...")
predictor = ocr_predictor(
det_arch='db_resnet50',
reco_arch='crnn_vgg16_bn',
pretrained=True
)
if torch.cuda.is_available():
print("🚀 Moving DocTR to GPU (CUDA)...")
predictor.cuda()
else:
print("⚠️ GPU not found. Running on CPU (slow).")
print("DocTR OCR predictor is ready.")
return predictor
MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
DOCTR_PREDICTOR = load_doctr_predictor()
if MODEL and PROCESSOR:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL.to(DEVICE)
MODEL.eval()
print(f"ML Model is ready on device: {DEVICE}")
else:
DEVICE = None
print("❌ Could not load ML model.")
def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]], List[List[int]]]:
"""
Parse DocTR's hierarchical output (Page -> Block -> Line -> Word)
into flat lists of words and bounding boxes for LayoutLMv3.
DocTR returns coordinates in 0-1.0 scale (relative to image).
We convert to:
- unnormalized_boxes: pixel coordinates [x, y, width, height] for visualization
- normalized_boxes: 0-1000 scale [x0, y0, x1, y1] for LayoutLMv3
Args:
doctr_result: Output from DocTR predictor
img_width: Original image width in pixels
img_height: Original image height in pixels
Returns:
words: List of word strings
unnormalized_boxes: List of [x, y, width, height] in pixel coordinates
normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale
"""
words = []
unnormalized_boxes = []
normalized_boxes = []
# DocTR hierarchy: Document -> Page -> Block -> Line -> Word
for page in doctr_result.pages:
for block in page.blocks:
for line in block.lines:
for word in line.words:
# Skip empty words
if not word.value.strip():
continue
words.append(word.value)
# DocTR bbox format: ((x_min, y_min), (x_max, y_max)) in 0-1 scale
(x_min, y_min), (x_max, y_max) = word.geometry
# Convert to pixel coordinates for visualization
px_x0 = int(x_min * img_width)
px_y0 = int(y_min * img_height)
px_x1 = int(x_max * img_width)
px_y1 = int(y_max * img_height)
# Unnormalized box: [x, y, width, height] for visualization overlay
unnormalized_boxes.append([
px_x0,
px_y0,
px_x1 - px_x0, # width
px_y1 - px_y0 # height
])
# Normalized box: [x0, y0, x1, y1] in 0-1000 scale for LayoutLMv3
# Clamp values to ensure they stay within [0, 1000]
normalized_boxes.append([
max(0, min(1000, int(x_min * 1000))),
max(0, min(1000, int(y_min * 1000))),
max(0, min(1000, int(x_max * 1000))),
max(0, min(1000, int(y_max * 1000))),
])
return words, unnormalized_boxes, normalized_boxes
def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label):
word_ids = encoding.word_ids(batch_index=0)
word_level_preds = {}
for idx, word_id in enumerate(word_ids):
if word_id is not None:
label_id = predictions[idx]
if label_id != -100:
if word_id not in word_level_preds:
word_level_preds[word_id] = id2label[label_id]
entities = {}
for word_idx, label in word_level_preds.items():
if label == 'O': continue
entity_type = label[2:]
word = words[word_idx]
if label.startswith('B-'):
entities[entity_type] = {"text": word, "bbox": [unnormalized_boxes[word_idx]]}
elif label.startswith('I-') and entity_type in entities:
entities[entity_type]['text'] += " " + word
entities[entity_type]['bbox'].append(unnormalized_boxes[word_idx])
for entity in entities.values():
entity['text'] = entity['text'].strip()
return entities
def extract_ml_based(image_path: str) -> Dict[str, Any]:
if not MODEL or not PROCESSOR:
raise RuntimeError("ML model is not loaded.")
# 1. Load Image
image = Image.open(image_path).convert("RGB")
width, height = image.size
# 2. Run DocTR OCR
doc = DocumentFile.from_images(image_path)
doctr_result = DOCTR_PREDICTOR(doc)
# 3. Parse DocTR output to get words and boxes
words, unnormalized_boxes, normalized_boxes = parse_doctr_output(
doctr_result, width, height
)
# Reconstructs lines so regex can work line-by-line
lines = []
current_line = []
if len(unnormalized_boxes) > 0:
# Initialize with first word's Y and Height
current_y = unnormalized_boxes[0][1]
current_h = unnormalized_boxes[0][3]
for i, word in enumerate(words):
y = unnormalized_boxes[i][1]
h = unnormalized_boxes[i][3]
# If vertical gap > 50% of line height, it's a new line
if abs(y - current_y) > max(current_h, h) / 2:
lines.append(" ".join(current_line))
current_line = []
current_y = y
current_h = h
current_line.append(word)
# Append the last line
if current_line:
lines.append(" ".join(current_line))
raw_text = "\n".join(lines)
# Handle empty OCR result
if not words:
return {
"vendor": None,
"date": None,
"address": None,
"receipt_number": None,
"bill_to": None,
"total_amount": None,
"items": [],
"raw_text": "",
"raw_predictions": {}
}
# 4. Inference with LayoutLMv3
encoding = PROCESSOR(
image, text=words, boxes=normalized_boxes,
truncation=True, max_length=512, return_tensors="pt"
)
# Move tensors to device for inference, but keep original encoding for word_ids()
model_inputs = {k: v.to(DEVICE) for k, v in encoding.items()}
with torch.no_grad():
outputs = MODEL(**model_inputs)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
# 5. Construct Output
final_output = {
"vendor": extracted_entities.get("COMPANY", {}).get("text"),
"date": extracted_entities.get("DATE", {}).get("text"),
"address": extracted_entities.get("ADDRESS", {}).get("text"),
"receipt_number": extracted_entities.get("INVOICE_NO", {}).get("text"),
"bill_to": extracted_entities.get("BILL_TO", {}).get("text"),
"total_amount": None,
"items": [],
"raw_text": raw_text,
"raw_predictions": extracted_entities # Contains text and bbox data for each entity
}
# 6. Vendor Fallback (Spatial Heuristic)
# If ML failed to find a vendor, assume the largest text at the top is the vendor
if not final_output["vendor"] and unnormalized_boxes:
# Filter for words in the top 20% of the image
top_words_indices = [
i for i, box in enumerate(unnormalized_boxes)
if box[1] < height * 0.2
]
if top_words_indices:
# Find the word with the largest height (font size)
largest_idx = max(top_words_indices, key=lambda i: unnormalized_boxes[i][3])
final_output["vendor"] = words[largest_idx]
# --- ADDRESS FALLBACK ---
if not final_output["address"]:
# We pass the extracted (or fallback) Vendor Name to help anchor the search
# Use the raw text and the known vendor to find the address spatially
fallback_address = extract_address(raw_text, vendor_name=final_output["vendor"])
if fallback_address:
final_output["address"] = fallback_address
# Backfill Bounding Boxes for Address Fallback
# If Regex found the address but ML didn't, find its boxes in the OCR data
if final_output["address"] and "ADDRESS" not in final_output["raw_predictions"]:
address_text = final_output["address"]
address_boxes = []
# The address may span multiple words, so we search for each word
# Split by comma first (since extract_address joins lines with ", ")
address_parts = [part.strip() for part in address_text.split(",")]
for part in address_parts:
part_words = part.split()
for target_word in part_words:
for i, word in enumerate(words):
# Case-insensitive match
if target_word.lower() == word.lower() or target_word.lower() in word.lower():
address_boxes.append(unnormalized_boxes[i])
break # Only match once per target word
# If we found any boxes, inject into raw_predictions
if address_boxes:
final_output["raw_predictions"]["ADDRESS"] = {
"text": address_text,
"bbox": address_boxes
}
# Fallbacks
ml_total = extracted_entities.get("TOTAL", {}).get("text")
if ml_total:
try:
cleaned = re.sub(r'[^\d.,]', '', ml_total).replace(',', '.')
final_output["total_amount"] = float(cleaned)
except (ValueError, TypeError):
pass
if final_output["total_amount"] is None:
final_output["total_amount"] = extract_total(raw_text)
if not final_output["receipt_number"]:
final_output["receipt_number"] = extract_invoice_number(raw_text)
# Backfill Bounding Boxes for Regex Results
# If Regex found the number but ML didn't, we must find its box
# in the OCR data so the UI can draw it.
if final_output["receipt_number"] and "INVOICE_NO" not in final_output["raw_predictions"]:
target_val = final_output["receipt_number"].strip()
found_box = None
# 1. Try finding the exact word in the OCR list
# 'words' and 'unnormalized_boxes' are available from step 3
for i, word in enumerate(words):
# Check for exact match or if the word contains the target (e.g. "Inv#123")
if target_val == word or (len(target_val) > 3 and target_val in word):
found_box = unnormalized_boxes[i]
break
# 2. If found, inject it into raw_predictions
if found_box:
# The UI expects a list of boxes
final_output["raw_predictions"]["INVOICE_NO"] = {
"text": target_val,
"bbox": [found_box]
}
# --- TABLE EXTRACTION (Geometric Heuristic) ---
# Use the geometric fallback to extract line items from table region
if words and unnormalized_boxes:
extracted_items = extract_table_items(words, unnormalized_boxes)
if extracted_items:
final_output["items"] = extracted_items
return final_output