|
|
|
|
|
|
|
|
import os
|
|
|
import torch
|
|
|
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
|
|
from huggingface_hub import snapshot_download
|
|
|
from PIL import Image
|
|
|
from typing import List, Dict, Any, Tuple
|
|
|
import re
|
|
|
import numpy as np
|
|
|
from src.extraction import extract_invoice_number, extract_total, extract_address
|
|
|
from src.table_extraction import extract_table_items
|
|
|
from doctr.io import DocumentFile
|
|
|
from doctr.models import ocr_predictor
|
|
|
|
|
|
|
|
|
LOCAL_MODEL_PATH = "./models/layoutlmv3-doctr-trained"
|
|
|
HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-doctr-invoice-processor"
|
|
|
|
|
|
|
|
|
def load_model_and_processor(model_path, hub_id):
|
|
|
print("Loading processor from microsoft/layoutlmv3-base...")
|
|
|
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
|
|
|
|
|
if not os.path.exists(model_path) or not os.listdir(model_path):
|
|
|
print(f"Downloading model from Hub: {hub_id}...")
|
|
|
snapshot_download(repo_id=hub_id, local_dir=model_path, local_dir_use_symlinks=False)
|
|
|
|
|
|
try:
|
|
|
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
|
|
|
except Exception:
|
|
|
print(f"Fallback: Loading directly from Hub {hub_id}...")
|
|
|
model = LayoutLMv3ForTokenClassification.from_pretrained(hub_id)
|
|
|
|
|
|
return model, processor
|
|
|
|
|
|
|
|
|
def load_doctr_predictor():
|
|
|
"""Initialize DocTR predictor and move to GPU for speed."""
|
|
|
print("Loading DocTR OCR predictor...")
|
|
|
predictor = ocr_predictor(
|
|
|
det_arch='db_resnet50',
|
|
|
reco_arch='crnn_vgg16_bn',
|
|
|
pretrained=True
|
|
|
)
|
|
|
if torch.cuda.is_available():
|
|
|
print("🚀 Moving DocTR to GPU (CUDA)...")
|
|
|
predictor.cuda()
|
|
|
else:
|
|
|
print("⚠️ GPU not found. Running on CPU (slow).")
|
|
|
|
|
|
print("DocTR OCR predictor is ready.")
|
|
|
return predictor
|
|
|
|
|
|
MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
|
|
|
DOCTR_PREDICTOR = load_doctr_predictor()
|
|
|
|
|
|
if MODEL and PROCESSOR:
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
MODEL.to(DEVICE)
|
|
|
MODEL.eval()
|
|
|
print(f"ML Model is ready on device: {DEVICE}")
|
|
|
else:
|
|
|
DEVICE = None
|
|
|
print("❌ Could not load ML model.")
|
|
|
|
|
|
|
|
|
def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]], List[List[int]]]:
|
|
|
"""
|
|
|
Parse DocTR's hierarchical output (Page -> Block -> Line -> Word)
|
|
|
into flat lists of words and bounding boxes for LayoutLMv3.
|
|
|
|
|
|
DocTR returns coordinates in 0-1.0 scale (relative to image).
|
|
|
We convert to:
|
|
|
- unnormalized_boxes: pixel coordinates [x, y, width, height] for visualization
|
|
|
- normalized_boxes: 0-1000 scale [x0, y0, x1, y1] for LayoutLMv3
|
|
|
|
|
|
Args:
|
|
|
doctr_result: Output from DocTR predictor
|
|
|
img_width: Original image width in pixels
|
|
|
img_height: Original image height in pixels
|
|
|
|
|
|
Returns:
|
|
|
words: List of word strings
|
|
|
unnormalized_boxes: List of [x, y, width, height] in pixel coordinates
|
|
|
normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale
|
|
|
"""
|
|
|
words = []
|
|
|
unnormalized_boxes = []
|
|
|
normalized_boxes = []
|
|
|
|
|
|
|
|
|
for page in doctr_result.pages:
|
|
|
for block in page.blocks:
|
|
|
for line in block.lines:
|
|
|
for word in line.words:
|
|
|
|
|
|
if not word.value.strip():
|
|
|
continue
|
|
|
|
|
|
words.append(word.value)
|
|
|
|
|
|
|
|
|
(x_min, y_min), (x_max, y_max) = word.geometry
|
|
|
|
|
|
|
|
|
px_x0 = int(x_min * img_width)
|
|
|
px_y0 = int(y_min * img_height)
|
|
|
px_x1 = int(x_max * img_width)
|
|
|
px_y1 = int(y_max * img_height)
|
|
|
|
|
|
|
|
|
unnormalized_boxes.append([
|
|
|
px_x0,
|
|
|
px_y0,
|
|
|
px_x1 - px_x0,
|
|
|
px_y1 - px_y0
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
normalized_boxes.append([
|
|
|
max(0, min(1000, int(x_min * 1000))),
|
|
|
max(0, min(1000, int(y_min * 1000))),
|
|
|
max(0, min(1000, int(x_max * 1000))),
|
|
|
max(0, min(1000, int(y_max * 1000))),
|
|
|
])
|
|
|
|
|
|
return words, unnormalized_boxes, normalized_boxes
|
|
|
|
|
|
|
|
|
def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label):
|
|
|
word_ids = encoding.word_ids(batch_index=0)
|
|
|
word_level_preds = {}
|
|
|
for idx, word_id in enumerate(word_ids):
|
|
|
if word_id is not None:
|
|
|
label_id = predictions[idx]
|
|
|
if label_id != -100:
|
|
|
if word_id not in word_level_preds:
|
|
|
word_level_preds[word_id] = id2label[label_id]
|
|
|
|
|
|
entities = {}
|
|
|
for word_idx, label in word_level_preds.items():
|
|
|
if label == 'O': continue
|
|
|
entity_type = label[2:]
|
|
|
word = words[word_idx]
|
|
|
|
|
|
if label.startswith('B-'):
|
|
|
entities[entity_type] = {"text": word, "bbox": [unnormalized_boxes[word_idx]]}
|
|
|
elif label.startswith('I-') and entity_type in entities:
|
|
|
entities[entity_type]['text'] += " " + word
|
|
|
entities[entity_type]['bbox'].append(unnormalized_boxes[word_idx])
|
|
|
|
|
|
for entity in entities.values():
|
|
|
entity['text'] = entity['text'].strip()
|
|
|
|
|
|
return entities
|
|
|
|
|
|
def extract_ml_based(image_path: str) -> Dict[str, Any]:
|
|
|
if not MODEL or not PROCESSOR:
|
|
|
raise RuntimeError("ML model is not loaded.")
|
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
width, height = image.size
|
|
|
|
|
|
|
|
|
doc = DocumentFile.from_images(image_path)
|
|
|
doctr_result = DOCTR_PREDICTOR(doc)
|
|
|
|
|
|
|
|
|
words, unnormalized_boxes, normalized_boxes = parse_doctr_output(
|
|
|
doctr_result, width, height
|
|
|
)
|
|
|
|
|
|
|
|
|
lines = []
|
|
|
current_line = []
|
|
|
if len(unnormalized_boxes) > 0:
|
|
|
|
|
|
current_y = unnormalized_boxes[0][1]
|
|
|
current_h = unnormalized_boxes[0][3]
|
|
|
|
|
|
for i, word in enumerate(words):
|
|
|
y = unnormalized_boxes[i][1]
|
|
|
h = unnormalized_boxes[i][3]
|
|
|
|
|
|
|
|
|
if abs(y - current_y) > max(current_h, h) / 2:
|
|
|
lines.append(" ".join(current_line))
|
|
|
current_line = []
|
|
|
current_y = y
|
|
|
current_h = h
|
|
|
|
|
|
current_line.append(word)
|
|
|
|
|
|
|
|
|
if current_line:
|
|
|
lines.append(" ".join(current_line))
|
|
|
|
|
|
raw_text = "\n".join(lines)
|
|
|
|
|
|
|
|
|
if not words:
|
|
|
return {
|
|
|
"vendor": None,
|
|
|
"date": None,
|
|
|
"address": None,
|
|
|
"receipt_number": None,
|
|
|
"bill_to": None,
|
|
|
"total_amount": None,
|
|
|
"items": [],
|
|
|
"raw_text": "",
|
|
|
"raw_predictions": {}
|
|
|
}
|
|
|
|
|
|
|
|
|
encoding = PROCESSOR(
|
|
|
image, text=words, boxes=normalized_boxes,
|
|
|
truncation=True, max_length=512, return_tensors="pt"
|
|
|
)
|
|
|
|
|
|
model_inputs = {k: v.to(DEVICE) for k, v in encoding.items()}
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = MODEL(**model_inputs)
|
|
|
|
|
|
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
|
|
extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
|
|
|
|
|
|
|
|
|
final_output = {
|
|
|
"vendor": extracted_entities.get("COMPANY", {}).get("text"),
|
|
|
"date": extracted_entities.get("DATE", {}).get("text"),
|
|
|
"address": extracted_entities.get("ADDRESS", {}).get("text"),
|
|
|
"receipt_number": extracted_entities.get("INVOICE_NO", {}).get("text"),
|
|
|
"bill_to": extracted_entities.get("BILL_TO", {}).get("text"),
|
|
|
"total_amount": None,
|
|
|
"items": [],
|
|
|
"raw_text": raw_text,
|
|
|
"raw_predictions": extracted_entities
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if not final_output["vendor"] and unnormalized_boxes:
|
|
|
|
|
|
top_words_indices = [
|
|
|
i for i, box in enumerate(unnormalized_boxes)
|
|
|
if box[1] < height * 0.2
|
|
|
]
|
|
|
|
|
|
if top_words_indices:
|
|
|
|
|
|
largest_idx = max(top_words_indices, key=lambda i: unnormalized_boxes[i][3])
|
|
|
final_output["vendor"] = words[largest_idx]
|
|
|
|
|
|
|
|
|
if not final_output["address"]:
|
|
|
|
|
|
|
|
|
fallback_address = extract_address(raw_text, vendor_name=final_output["vendor"])
|
|
|
|
|
|
if fallback_address:
|
|
|
final_output["address"] = fallback_address
|
|
|
|
|
|
|
|
|
|
|
|
if final_output["address"] and "ADDRESS" not in final_output["raw_predictions"]:
|
|
|
address_text = final_output["address"]
|
|
|
address_boxes = []
|
|
|
|
|
|
|
|
|
|
|
|
address_parts = [part.strip() for part in address_text.split(",")]
|
|
|
|
|
|
for part in address_parts:
|
|
|
part_words = part.split()
|
|
|
for target_word in part_words:
|
|
|
for i, word in enumerate(words):
|
|
|
|
|
|
if target_word.lower() == word.lower() or target_word.lower() in word.lower():
|
|
|
address_boxes.append(unnormalized_boxes[i])
|
|
|
break
|
|
|
|
|
|
|
|
|
if address_boxes:
|
|
|
final_output["raw_predictions"]["ADDRESS"] = {
|
|
|
"text": address_text,
|
|
|
"bbox": address_boxes
|
|
|
}
|
|
|
|
|
|
|
|
|
ml_total = extracted_entities.get("TOTAL", {}).get("text")
|
|
|
if ml_total:
|
|
|
try:
|
|
|
cleaned = re.sub(r'[^\d.,]', '', ml_total).replace(',', '.')
|
|
|
final_output["total_amount"] = float(cleaned)
|
|
|
except (ValueError, TypeError):
|
|
|
pass
|
|
|
|
|
|
if final_output["total_amount"] is None:
|
|
|
final_output["total_amount"] = extract_total(raw_text)
|
|
|
|
|
|
if not final_output["receipt_number"]:
|
|
|
final_output["receipt_number"] = extract_invoice_number(raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if final_output["receipt_number"] and "INVOICE_NO" not in final_output["raw_predictions"]:
|
|
|
target_val = final_output["receipt_number"].strip()
|
|
|
found_box = None
|
|
|
|
|
|
|
|
|
|
|
|
for i, word in enumerate(words):
|
|
|
|
|
|
if target_val == word or (len(target_val) > 3 and target_val in word):
|
|
|
found_box = unnormalized_boxes[i]
|
|
|
break
|
|
|
|
|
|
|
|
|
if found_box:
|
|
|
|
|
|
final_output["raw_predictions"]["INVOICE_NO"] = {
|
|
|
"text": target_val,
|
|
|
"bbox": [found_box]
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if words and unnormalized_boxes:
|
|
|
extracted_items = extract_table_items(words, unnormalized_boxes)
|
|
|
if extracted_items:
|
|
|
final_output["items"] = extracted_items
|
|
|
|
|
|
return final_output |