File size: 13,584 Bytes
d79b7f7 7630bcd d79b7f7 7630bcd d79b7f7 ec0b507 d79b7f7 2a944a5 90dbe20 ec0b507 d79b7f7 ec0b507 d79b7f7 ec0b507 d79b7f7 7630bcd d79b7f7 7630bcd d79b7f7 7630bcd d79b7f7 ec0b507 d79b7f7 ec0b507 d79b7f7 ec0b507 d79b7f7 ec0b507 d79b7f7 ec0b507 d79b7f7 ec0b507 d79b7f7 2a944a5 d79b7f7 2a944a5 d79b7f7 ec0b507 d79b7f7 343b0c3 d79b7f7 ec0b507 097a95c d79b7f7 ec0b507 d79b7f7 90dbe20 42e1c04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
# src/ml_extraction.py
import os
import torch
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from huggingface_hub import snapshot_download
from PIL import Image
from typing import List, Dict, Any, Tuple
import re
import numpy as np
from src.extraction import extract_invoice_number, extract_total, extract_address
from src.table_extraction import extract_table_items
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
# --- CONFIGURATION ---
LOCAL_MODEL_PATH = "./models/layoutlmv3-doctr-trained"
HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-doctr-invoice-processor"
# --- Load LayoutLMv3 Model ---
def load_model_and_processor(model_path, hub_id):
print("Loading processor from microsoft/layoutlmv3-base...")
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
if not os.path.exists(model_path) or not os.listdir(model_path):
print(f"Downloading model from Hub: {hub_id}...")
snapshot_download(repo_id=hub_id, local_dir=model_path, local_dir_use_symlinks=False)
try:
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
except Exception:
print(f"Fallback: Loading directly from Hub {hub_id}...")
model = LayoutLMv3ForTokenClassification.from_pretrained(hub_id)
return model, processor
# --- Load DocTR OCR Predictor ---
def load_doctr_predictor():
"""Initialize DocTR predictor and move to GPU for speed."""
print("Loading DocTR OCR predictor...")
predictor = ocr_predictor(
det_arch='db_resnet50',
reco_arch='crnn_vgg16_bn',
pretrained=True
)
if torch.cuda.is_available():
print("🚀 Moving DocTR to GPU (CUDA)...")
predictor.cuda()
else:
print("⚠️ GPU not found. Running on CPU (slow).")
print("DocTR OCR predictor is ready.")
return predictor
MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
DOCTR_PREDICTOR = load_doctr_predictor()
if MODEL and PROCESSOR:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL.to(DEVICE)
MODEL.eval()
print(f"ML Model is ready on device: {DEVICE}")
else:
DEVICE = None
print("❌ Could not load ML model.")
def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]], List[List[int]]]:
"""
Parse DocTR's hierarchical output (Page -> Block -> Line -> Word)
into flat lists of words and bounding boxes for LayoutLMv3.
DocTR returns coordinates in 0-1.0 scale (relative to image).
We convert to:
- unnormalized_boxes: pixel coordinates [x, y, width, height] for visualization
- normalized_boxes: 0-1000 scale [x0, y0, x1, y1] for LayoutLMv3
Args:
doctr_result: Output from DocTR predictor
img_width: Original image width in pixels
img_height: Original image height in pixels
Returns:
words: List of word strings
unnormalized_boxes: List of [x, y, width, height] in pixel coordinates
normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale
"""
words = []
unnormalized_boxes = []
normalized_boxes = []
# DocTR hierarchy: Document -> Page -> Block -> Line -> Word
for page in doctr_result.pages:
for block in page.blocks:
for line in block.lines:
for word in line.words:
# Skip empty words
if not word.value.strip():
continue
words.append(word.value)
# DocTR bbox format: ((x_min, y_min), (x_max, y_max)) in 0-1 scale
(x_min, y_min), (x_max, y_max) = word.geometry
# Convert to pixel coordinates for visualization
px_x0 = int(x_min * img_width)
px_y0 = int(y_min * img_height)
px_x1 = int(x_max * img_width)
px_y1 = int(y_max * img_height)
# Unnormalized box: [x, y, width, height] for visualization overlay
unnormalized_boxes.append([
px_x0,
px_y0,
px_x1 - px_x0, # width
px_y1 - px_y0 # height
])
# Normalized box: [x0, y0, x1, y1] in 0-1000 scale for LayoutLMv3
# Clamp values to ensure they stay within [0, 1000]
normalized_boxes.append([
max(0, min(1000, int(x_min * 1000))),
max(0, min(1000, int(y_min * 1000))),
max(0, min(1000, int(x_max * 1000))),
max(0, min(1000, int(y_max * 1000))),
])
return words, unnormalized_boxes, normalized_boxes
def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label):
word_ids = encoding.word_ids(batch_index=0)
word_level_preds = {}
for idx, word_id in enumerate(word_ids):
if word_id is not None:
label_id = predictions[idx]
if label_id != -100:
if word_id not in word_level_preds:
word_level_preds[word_id] = id2label[label_id]
entities = {}
for word_idx, label in word_level_preds.items():
if label == 'O': continue
entity_type = label[2:]
word = words[word_idx]
if label.startswith('B-'):
entities[entity_type] = {"text": word, "bbox": [unnormalized_boxes[word_idx]]}
elif label.startswith('I-') and entity_type in entities:
entities[entity_type]['text'] += " " + word
entities[entity_type]['bbox'].append(unnormalized_boxes[word_idx])
for entity in entities.values():
entity['text'] = entity['text'].strip()
return entities
def extract_ml_based(image_path: str) -> Dict[str, Any]:
if not MODEL or not PROCESSOR:
raise RuntimeError("ML model is not loaded.")
# 1. Load Image
image = Image.open(image_path).convert("RGB")
width, height = image.size
# 2. Run DocTR OCR
doc = DocumentFile.from_images(image_path)
doctr_result = DOCTR_PREDICTOR(doc)
# 3. Parse DocTR output to get words and boxes
words, unnormalized_boxes, normalized_boxes = parse_doctr_output(
doctr_result, width, height
)
# Reconstructs lines so regex can work line-by-line
lines = []
current_line = []
if len(unnormalized_boxes) > 0:
# Initialize with first word's Y and Height
current_y = unnormalized_boxes[0][1]
current_h = unnormalized_boxes[0][3]
for i, word in enumerate(words):
y = unnormalized_boxes[i][1]
h = unnormalized_boxes[i][3]
# If vertical gap > 50% of line height, it's a new line
if abs(y - current_y) > max(current_h, h) / 2:
lines.append(" ".join(current_line))
current_line = []
current_y = y
current_h = h
current_line.append(word)
# Append the last line
if current_line:
lines.append(" ".join(current_line))
raw_text = "\n".join(lines)
# Handle empty OCR result
if not words:
return {
"vendor": None,
"date": None,
"address": None,
"receipt_number": None,
"bill_to": None,
"total_amount": None,
"items": [],
"raw_text": "",
"raw_predictions": {}
}
# 4. Inference with LayoutLMv3
encoding = PROCESSOR(
image, text=words, boxes=normalized_boxes,
truncation=True, max_length=512, return_tensors="pt"
)
# Move tensors to device for inference, but keep original encoding for word_ids()
model_inputs = {k: v.to(DEVICE) for k, v in encoding.items()}
with torch.no_grad():
outputs = MODEL(**model_inputs)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
# 5. Construct Output
final_output = {
"vendor": extracted_entities.get("COMPANY", {}).get("text"),
"date": extracted_entities.get("DATE", {}).get("text"),
"address": extracted_entities.get("ADDRESS", {}).get("text"),
"receipt_number": extracted_entities.get("INVOICE_NO", {}).get("text"),
"bill_to": extracted_entities.get("BILL_TO", {}).get("text"),
"total_amount": None,
"items": [],
"raw_text": raw_text,
"raw_predictions": extracted_entities # Contains text and bbox data for each entity
}
# 6. Vendor Fallback (Spatial Heuristic)
# If ML failed to find a vendor, assume the largest text at the top is the vendor
if not final_output["vendor"] and unnormalized_boxes:
# Filter for words in the top 20% of the image
top_words_indices = [
i for i, box in enumerate(unnormalized_boxes)
if box[1] < height * 0.2
]
if top_words_indices:
# Find the word with the largest height (font size)
largest_idx = max(top_words_indices, key=lambda i: unnormalized_boxes[i][3])
final_output["vendor"] = words[largest_idx]
# --- ADDRESS FALLBACK ---
if not final_output["address"]:
# We pass the extracted (or fallback) Vendor Name to help anchor the search
# Use the raw text and the known vendor to find the address spatially
fallback_address = extract_address(raw_text, vendor_name=final_output["vendor"])
if fallback_address:
final_output["address"] = fallback_address
# Backfill Bounding Boxes for Address Fallback
# If Regex found the address but ML didn't, find its boxes in the OCR data
if final_output["address"] and "ADDRESS" not in final_output["raw_predictions"]:
address_text = final_output["address"]
address_boxes = []
# The address may span multiple words, so we search for each word
# Split by comma first (since extract_address joins lines with ", ")
address_parts = [part.strip() for part in address_text.split(",")]
for part in address_parts:
part_words = part.split()
for target_word in part_words:
for i, word in enumerate(words):
# Case-insensitive match
if target_word.lower() == word.lower() or target_word.lower() in word.lower():
address_boxes.append(unnormalized_boxes[i])
break # Only match once per target word
# If we found any boxes, inject into raw_predictions
if address_boxes:
final_output["raw_predictions"]["ADDRESS"] = {
"text": address_text,
"bbox": address_boxes
}
# Fallbacks
ml_total = extracted_entities.get("TOTAL", {}).get("text")
if ml_total:
try:
cleaned = re.sub(r'[^\d.,]', '', ml_total).replace(',', '.')
final_output["total_amount"] = float(cleaned)
except (ValueError, TypeError):
pass
if final_output["total_amount"] is None:
final_output["total_amount"] = extract_total(raw_text)
if not final_output["receipt_number"]:
final_output["receipt_number"] = extract_invoice_number(raw_text)
# Backfill Bounding Boxes for Regex Results
# If Regex found the number but ML didn't, we must find its box
# in the OCR data so the UI can draw it.
if final_output["receipt_number"] and "INVOICE_NO" not in final_output["raw_predictions"]:
target_val = final_output["receipt_number"].strip()
found_box = None
# 1. Try finding the exact word in the OCR list
# 'words' and 'unnormalized_boxes' are available from step 3
for i, word in enumerate(words):
# Check for exact match or if the word contains the target (e.g. "Inv#123")
if target_val == word or (len(target_val) > 3 and target_val in word):
found_box = unnormalized_boxes[i]
break
# 2. If found, inject it into raw_predictions
if found_box:
# The UI expects a list of boxes
final_output["raw_predictions"]["INVOICE_NO"] = {
"text": target_val,
"bbox": [found_box]
}
# --- TABLE EXTRACTION (Geometric Heuristic) ---
# Use the geometric fallback to extract line items from table region
if words and unnormalized_boxes:
extracted_items = extract_table_items(words, unnormalized_boxes)
if extracted_items:
final_output["items"] = extracted_items
return final_output |