File size: 13,584 Bytes
d79b7f7
 
7630bcd
d79b7f7
 
7630bcd
d79b7f7
ec0b507
d79b7f7
 
2a944a5
90dbe20
ec0b507
 
d79b7f7
 
ec0b507
 
d79b7f7
ec0b507
d79b7f7
7630bcd
 
 
 
 
d79b7f7
7630bcd
 
d79b7f7
7630bcd
 
 
 
d79b7f7
 
ec0b507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d79b7f7
ec0b507
d79b7f7
 
 
 
 
 
 
 
 
 
ec0b507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d79b7f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec0b507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d79b7f7
ec0b507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d79b7f7
ec0b507
d79b7f7
 
 
2a944a5
 
 
d79b7f7
 
2a944a5
d79b7f7
 
 
 
ec0b507
d79b7f7
 
 
 
 
 
 
 
343b0c3
 
d79b7f7
 
ec0b507
 
 
 
 
 
 
 
 
 
 
 
 
 
097a95c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d79b7f7
 
 
 
 
 
 
 
 
 
 
 
 
 
ec0b507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d79b7f7
90dbe20
 
 
 
 
 
 
42e1c04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
# src/ml_extraction.py

import os
import torch
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from huggingface_hub import snapshot_download
from PIL import Image
from typing import List, Dict, Any, Tuple
import re
import numpy as np
from src.extraction import extract_invoice_number, extract_total, extract_address
from src.table_extraction import extract_table_items
from doctr.io import DocumentFile
from doctr.models import ocr_predictor

# --- CONFIGURATION ---
LOCAL_MODEL_PATH = "./models/layoutlmv3-doctr-trained"
HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-doctr-invoice-processor" 

# --- Load LayoutLMv3 Model ---
def load_model_and_processor(model_path, hub_id):
    print("Loading processor from microsoft/layoutlmv3-base...")
    processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

    if not os.path.exists(model_path) or not os.listdir(model_path):
        print(f"Downloading model from Hub: {hub_id}...")
        snapshot_download(repo_id=hub_id, local_dir=model_path, local_dir_use_symlinks=False)

    try:
        model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
    except Exception:
        print(f"Fallback: Loading directly from Hub {hub_id}...")
        model = LayoutLMv3ForTokenClassification.from_pretrained(hub_id)

    return model, processor

# --- Load DocTR OCR Predictor ---
def load_doctr_predictor():
    """Initialize DocTR predictor and move to GPU for speed."""
    print("Loading DocTR OCR predictor...")
    predictor = ocr_predictor(
        det_arch='db_resnet50',
        reco_arch='crnn_vgg16_bn',
        pretrained=True
    )
    if torch.cuda.is_available():
        print("🚀 Moving DocTR to GPU (CUDA)...")
        predictor.cuda()
    else:
        print("⚠️ GPU not found. Running on CPU (slow).")
        
    print("DocTR OCR predictor is ready.")
    return predictor

MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
DOCTR_PREDICTOR = load_doctr_predictor()

if MODEL and PROCESSOR:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    MODEL.to(DEVICE)
    MODEL.eval()
    print(f"ML Model is ready on device: {DEVICE}")
else:
    DEVICE = None
    print("❌ Could not load ML model.")


def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]], List[List[int]]]:
    """

    Parse DocTR's hierarchical output (Page -> Block -> Line -> Word) 

    into flat lists of words and bounding boxes for LayoutLMv3.

    

    DocTR returns coordinates in 0-1.0 scale (relative to image).

    We convert to:

    - unnormalized_boxes: pixel coordinates [x, y, width, height] for visualization

    - normalized_boxes: 0-1000 scale [x0, y0, x1, y1] for LayoutLMv3

    

    Args:

        doctr_result: Output from DocTR predictor

        img_width: Original image width in pixels

        img_height: Original image height in pixels

        

    Returns:

        words: List of word strings

        unnormalized_boxes: List of [x, y, width, height] in pixel coordinates

        normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale

    """
    words = []
    unnormalized_boxes = []
    normalized_boxes = []
    
    # DocTR hierarchy: Document -> Page -> Block -> Line -> Word
    for page in doctr_result.pages:
        for block in page.blocks:
            for line in block.lines:
                for word in line.words:
                    # Skip empty words
                    if not word.value.strip():
                        continue
                    
                    words.append(word.value)
                    
                    # DocTR bbox format: ((x_min, y_min), (x_max, y_max)) in 0-1 scale
                    (x_min, y_min), (x_max, y_max) = word.geometry
                    
                    # Convert to pixel coordinates for visualization
                    px_x0 = int(x_min * img_width)
                    px_y0 = int(y_min * img_height)
                    px_x1 = int(x_max * img_width)
                    px_y1 = int(y_max * img_height)
                    
                    # Unnormalized box: [x, y, width, height] for visualization overlay
                    unnormalized_boxes.append([
                        px_x0, 
                        px_y0, 
                        px_x1 - px_x0,  # width
                        px_y1 - px_y0   # height
                    ])
                    
                    # Normalized box: [x0, y0, x1, y1] in 0-1000 scale for LayoutLMv3
                    # Clamp values to ensure they stay within [0, 1000]
                    normalized_boxes.append([
                        max(0, min(1000, int(x_min * 1000))),
                        max(0, min(1000, int(y_min * 1000))),
                        max(0, min(1000, int(x_max * 1000))),
                        max(0, min(1000, int(y_max * 1000))),
                    ])
    
    return words, unnormalized_boxes, normalized_boxes


def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label):
    word_ids = encoding.word_ids(batch_index=0)
    word_level_preds = {} 
    for idx, word_id in enumerate(word_ids):
        if word_id is not None:
            label_id = predictions[idx]
            if label_id != -100:
                if word_id not in word_level_preds:
                    word_level_preds[word_id] = id2label[label_id]

    entities = {}
    for word_idx, label in word_level_preds.items():
        if label == 'O': continue
        entity_type = label[2:] 
        word = words[word_idx]
        
        if label.startswith('B-'):
            entities[entity_type] = {"text": word, "bbox": [unnormalized_boxes[word_idx]]}
        elif label.startswith('I-') and entity_type in entities:
            entities[entity_type]['text'] += " " + word
            entities[entity_type]['bbox'].append(unnormalized_boxes[word_idx])
    
    for entity in entities.values():
        entity['text'] = entity['text'].strip()

    return entities

def extract_ml_based(image_path: str) -> Dict[str, Any]:
    if not MODEL or not PROCESSOR:
        raise RuntimeError("ML model is not loaded.")

    # 1. Load Image
    image = Image.open(image_path).convert("RGB")
    width, height = image.size
    
    # 2. Run DocTR OCR
    doc = DocumentFile.from_images(image_path)
    doctr_result = DOCTR_PREDICTOR(doc)
    
    # 3. Parse DocTR output to get words and boxes
    words, unnormalized_boxes, normalized_boxes = parse_doctr_output(
        doctr_result, width, height
    )
    
    # Reconstructs lines so regex can work line-by-line
    lines = []
    current_line = []
    if len(unnormalized_boxes) > 0:
        # Initialize with first word's Y and Height
        current_y = unnormalized_boxes[0][1]
        current_h = unnormalized_boxes[0][3]
        
        for i, word in enumerate(words):
            y = unnormalized_boxes[i][1]
            h = unnormalized_boxes[i][3]
            
            # If vertical gap > 50% of line height, it's a new line
            if abs(y - current_y) > max(current_h, h) / 2:
                lines.append(" ".join(current_line))
                current_line = []
                current_y = y
                current_h = h
            
            current_line.append(word)
            
        # Append the last line
        if current_line:
            lines.append(" ".join(current_line))
            
    raw_text = "\n".join(lines)
    
    # Handle empty OCR result
    if not words:
        return {
            "vendor": None,
            "date": None,
            "address": None,
            "receipt_number": None,
            "bill_to": None,
            "total_amount": None,
            "items": [],
            "raw_text": "",
            "raw_predictions": {}
        }

    # 4. Inference with LayoutLMv3
    encoding = PROCESSOR(
        image, text=words, boxes=normalized_boxes, 
        truncation=True, max_length=512, return_tensors="pt"
    )
    # Move tensors to device for inference, but keep original encoding for word_ids()
    model_inputs = {k: v.to(DEVICE) for k, v in encoding.items()}

    with torch.no_grad():
        outputs = MODEL(**model_inputs)

    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)

    # 5. Construct Output
    final_output = {
        "vendor": extracted_entities.get("COMPANY", {}).get("text"),
        "date": extracted_entities.get("DATE", {}).get("text"),
        "address": extracted_entities.get("ADDRESS", {}).get("text"),
        "receipt_number": extracted_entities.get("INVOICE_NO", {}).get("text"),
        "bill_to": extracted_entities.get("BILL_TO", {}).get("text"),
        "total_amount": None, 
        "items": [],
        "raw_text": raw_text,
        "raw_predictions": extracted_entities  # Contains text and bbox data for each entity
    }

    # 6. Vendor Fallback (Spatial Heuristic)
    # If ML failed to find a vendor, assume the largest text at the top is the vendor
    if not final_output["vendor"] and unnormalized_boxes:
        # Filter for words in the top 20% of the image
        top_words_indices = [
            i for i, box in enumerate(unnormalized_boxes) 
            if box[1] < height * 0.2
        ]
        
        if top_words_indices:
            # Find the word with the largest height (font size)
            largest_idx = max(top_words_indices, key=lambda i: unnormalized_boxes[i][3])
            final_output["vendor"] = words[largest_idx]
    
    # --- ADDRESS FALLBACK ---
    if not final_output["address"]:
        # We pass the extracted (or fallback) Vendor Name to help anchor the search
        # Use the raw text and the known vendor to find the address spatially
        fallback_address = extract_address(raw_text, vendor_name=final_output["vendor"])
        
        if fallback_address:
            final_output["address"] = fallback_address
    
    # Backfill Bounding Boxes for Address Fallback
    # If Regex found the address but ML didn't, find its boxes in the OCR data
    if final_output["address"] and "ADDRESS" not in final_output["raw_predictions"]:
        address_text = final_output["address"]
        address_boxes = []
        
        # The address may span multiple words, so we search for each word
        # Split by comma first (since extract_address joins lines with ", ")
        address_parts = [part.strip() for part in address_text.split(",")]
        
        for part in address_parts:
            part_words = part.split()
            for target_word in part_words:                    
                for i, word in enumerate(words):
                    # Case-insensitive match
                    if target_word.lower() == word.lower() or target_word.lower() in word.lower():
                        address_boxes.append(unnormalized_boxes[i])
                        break  # Only match once per target word
        
        # If we found any boxes, inject into raw_predictions
        if address_boxes:
            final_output["raw_predictions"]["ADDRESS"] = {
                "text": address_text,
                "bbox": address_boxes
            }

    # Fallbacks
    ml_total = extracted_entities.get("TOTAL", {}).get("text")
    if ml_total:
        try:
            cleaned = re.sub(r'[^\d.,]', '', ml_total).replace(',', '.')
            final_output["total_amount"] = float(cleaned)
        except (ValueError, TypeError):
            pass
            
    if final_output["total_amount"] is None:
        final_output["total_amount"] = extract_total(raw_text)

    if not final_output["receipt_number"]:
        final_output["receipt_number"] = extract_invoice_number(raw_text)

    # Backfill Bounding Boxes for Regex Results
    # If Regex found the number but ML didn't, we must find its box 
    # in the OCR data so the UI can draw it.

    if final_output["receipt_number"] and "INVOICE_NO" not in final_output["raw_predictions"]:
        target_val = final_output["receipt_number"].strip()
        found_box = None
        
        # 1. Try finding the exact word in the OCR list
        # 'words' and 'unnormalized_boxes' are available from step 3
        for i, word in enumerate(words):
            # Check for exact match or if the word contains the target (e.g. "Inv#123")
            if target_val == word or (len(target_val) > 3 and target_val in word):
                found_box = unnormalized_boxes[i]
                break
        
        # 2. If found, inject it into raw_predictions
        if found_box:
            # The UI expects a list of boxes
            final_output["raw_predictions"]["INVOICE_NO"] = {
                "text": target_val,
                "bbox": [found_box] 
            }
    
    # --- TABLE EXTRACTION (Geometric Heuristic) ---
    # Use the geometric fallback to extract line items from table region
    if words and unnormalized_boxes:
        extracted_items = extract_table_items(words, unnormalized_boxes)
        if extracted_items:
            final_output["items"] = extracted_items
    
    return final_output