Spaces:
Sleeping
Sleeping
Raghu
commited on
Commit
·
eb79113
1
Parent(s):
6fad358
Re-enable LayoutLMv3 field extractor with cached weights
Browse files
app.py
CHANGED
|
@@ -15,7 +15,12 @@ import re
|
|
| 15 |
from PIL import Image, ImageDraw
|
| 16 |
from datetime import datetime
|
| 17 |
from torchvision import transforms, models
|
| 18 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from sklearn.ensemble import IsolationForest
|
| 20 |
import warnings
|
| 21 |
warnings.filterwarnings('ignore')
|
|
@@ -421,6 +426,126 @@ class ReceiptOCR:
|
|
| 421 |
return match.group() if match else None
|
| 422 |
|
| 423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
# ============================================================================
|
| 425 |
# Anomaly Detection
|
| 426 |
# ============================================================================
|
|
@@ -512,6 +637,13 @@ except Exception as e:
|
|
| 512 |
print(f"Warning: Could not load OCR: {e}")
|
| 513 |
receipt_ocr = None
|
| 514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
anomaly_detector = AnomalyDetector()
|
| 516 |
|
| 517 |
print("\n" + "="*50)
|
|
@@ -608,19 +740,21 @@ def process_receipt(image):
|
|
| 608 |
fields = {}
|
| 609 |
fields_html = ""
|
| 610 |
try:
|
| 611 |
-
if
|
|
|
|
|
|
|
| 612 |
fields = receipt_ocr.postprocess_receipt(ocr_results)
|
| 613 |
|
| 614 |
fields_html = "<div style='padding: 16px; background: #f8f9fa; border-radius: 12px;'><h4>Extracted Fields</h4>"
|
| 615 |
for name, value in [('Vendor', fields.get('vendor')), ('Date', fields.get('date')),
|
| 616 |
-
('Total', f"${fields.get('total')}" if fields.get('total') else None),
|
| 617 |
('Time', fields.get('time'))]:
|
| 618 |
-
display = value or '<span style
|
| 619 |
-
fields_html += f"<div style='padding: 8px; background: white; border-radius: 6px; margin: 4px 0;'><b>{name}:</b> {display}</div
|
| 620 |
-
fields_html += "</div
|
| 621 |
results['fields'] = fields
|
| 622 |
except Exception as e:
|
| 623 |
-
fields_html = f"<div style='color: red;'>Extraction error: {e}</div
|
| 624 |
|
| 625 |
# 4. Anomaly Detection
|
| 626 |
anomaly_html = ""
|
|
|
|
| 15 |
from PIL import Image, ImageDraw
|
| 16 |
from datetime import datetime
|
| 17 |
from torchvision import transforms, models
|
| 18 |
+
from transformers import (
|
| 19 |
+
ViTForImageClassification,
|
| 20 |
+
ViTImageProcessor,
|
| 21 |
+
LayoutLMv3ForTokenClassification,
|
| 22 |
+
LayoutLMv3Processor,
|
| 23 |
+
)
|
| 24 |
from sklearn.ensemble import IsolationForest
|
| 25 |
import warnings
|
| 26 |
warnings.filterwarnings('ignore')
|
|
|
|
| 426 |
return match.group() if match else None
|
| 427 |
|
| 428 |
|
| 429 |
+
# ============================================================================
|
| 430 |
+
# LayoutLMv3 Field Extractor
|
| 431 |
+
# ============================================================================
|
| 432 |
+
|
| 433 |
+
class LayoutLMFieldExtractor:
|
| 434 |
+
"""LayoutLMv3-based field extractor using fine-tuned weights if available."""
|
| 435 |
+
|
| 436 |
+
def __init__(self, model_path=None):
|
| 437 |
+
self.model_path = model_path or os.path.join(MODELS_DIR, 'layoutlm_extractor.pt')
|
| 438 |
+
self.id2label = {
|
| 439 |
+
0: 'O',
|
| 440 |
+
1: 'B-VENDOR', 2: 'I-VENDOR',
|
| 441 |
+
3: 'B-DATE', 4: 'I-DATE',
|
| 442 |
+
5: 'B-TOTAL', 6: 'I-TOTAL',
|
| 443 |
+
7: 'B-TIME', 8: 'I-TIME'
|
| 444 |
+
}
|
| 445 |
+
self.label2id = {v: k for k, v in self.id2label.items()}
|
| 446 |
+
self.processor = None
|
| 447 |
+
self.model = None
|
| 448 |
+
|
| 449 |
+
def load(self):
|
| 450 |
+
print("Loading LayoutLMv3 extractor...")
|
| 451 |
+
self.processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
|
| 452 |
+
self.model = LayoutLMv3ForTokenClassification.from_pretrained(
|
| 453 |
+
"microsoft/layoutlmv3-base",
|
| 454 |
+
num_labels=len(self.id2label),
|
| 455 |
+
id2label=self.id2label,
|
| 456 |
+
label2id=self.label2id,
|
| 457 |
+
)
|
| 458 |
+
if os.path.exists(self.model_path):
|
| 459 |
+
checkpoint = torch.load(self.model_path, map_location=DEVICE)
|
| 460 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 461 |
+
checkpoint = checkpoint['model_state_dict']
|
| 462 |
+
if isinstance(checkpoint, dict):
|
| 463 |
+
missing, unexpected = self.model.load_state_dict(checkpoint, strict=False)
|
| 464 |
+
print(f"Loaded LayoutLM weights; missing={len(missing)}, unexpected={len(unexpected)}")
|
| 465 |
+
self.model = self.model.to(DEVICE)
|
| 466 |
+
self.model.eval()
|
| 467 |
+
print("LayoutLMv3 ready")
|
| 468 |
+
return self
|
| 469 |
+
|
| 470 |
+
def _prepare_boxes(self, ocr_results, image_size):
|
| 471 |
+
"""Convert absolute pixel boxes to LayoutLM 0-1000 format."""
|
| 472 |
+
width, height = image_size
|
| 473 |
+
boxes = []
|
| 474 |
+
words = []
|
| 475 |
+
for r in ocr_results:
|
| 476 |
+
bbox = r.get("bbox", [0, 0, width, height])
|
| 477 |
+
x0, y0, x1, y1 = bbox
|
| 478 |
+
boxes.append([
|
| 479 |
+
int(1000 * x0 / width),
|
| 480 |
+
int(1000 * y0 / height),
|
| 481 |
+
int(1000 * x1 / width),
|
| 482 |
+
int(1000 * y1 / height),
|
| 483 |
+
])
|
| 484 |
+
words.append(r.get("text", ""))
|
| 485 |
+
return words, boxes
|
| 486 |
+
|
| 487 |
+
def predict_fields(self, image, ocr_results=None):
|
| 488 |
+
if self.model is None:
|
| 489 |
+
self.load()
|
| 490 |
+
|
| 491 |
+
if not isinstance(image, Image.Image):
|
| 492 |
+
image = Image.fromarray(image)
|
| 493 |
+
image = image.convert("RGB")
|
| 494 |
+
|
| 495 |
+
if ocr_results:
|
| 496 |
+
words, boxes = self._prepare_boxes(ocr_results, image.size)
|
| 497 |
+
encoding = self.processor(
|
| 498 |
+
image,
|
| 499 |
+
words=words,
|
| 500 |
+
boxes=boxes,
|
| 501 |
+
return_tensors="pt",
|
| 502 |
+
truncation=True,
|
| 503 |
+
padding="max_length",
|
| 504 |
+
max_length=512,
|
| 505 |
+
)
|
| 506 |
+
else:
|
| 507 |
+
encoding = self.processor(image, return_tensors="pt")
|
| 508 |
+
|
| 509 |
+
encoding = {k: v.to(DEVICE) for k, v in encoding.items()}
|
| 510 |
+
with torch.no_grad():
|
| 511 |
+
outputs = self.model(**encoding)
|
| 512 |
+
logits = outputs.logits[0]
|
| 513 |
+
preds = logits.argmax(-1).cpu().tolist()
|
| 514 |
+
tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu())
|
| 515 |
+
|
| 516 |
+
entities = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
|
| 517 |
+
current = {"label": None, "tokens": []}
|
| 518 |
+
|
| 519 |
+
for token, pred in zip(tokens, preds):
|
| 520 |
+
label = self.id2label.get(pred, "O")
|
| 521 |
+
if token in ["[PAD]", "[CLS]", "[SEP]"]:
|
| 522 |
+
continue
|
| 523 |
+
if label.startswith("B-"):
|
| 524 |
+
# flush previous
|
| 525 |
+
if current["label"] and current["tokens"]:
|
| 526 |
+
entities[current["label"]].append(" ".join(current["tokens"]))
|
| 527 |
+
current = {"label": label[2:], "tokens": [token]}
|
| 528 |
+
elif label.startswith("I-") and current["label"] == label[2:]:
|
| 529 |
+
current["tokens"].append(token)
|
| 530 |
+
else:
|
| 531 |
+
if current["label"] and current["tokens"]:
|
| 532 |
+
entities[current["label"]].append(" ".join(current["tokens"]))
|
| 533 |
+
current = {"label": None, "tokens": []}
|
| 534 |
+
if current["label"] and current["tokens"]:
|
| 535 |
+
entities[current["label"]].append(" ".join(current["tokens"]))
|
| 536 |
+
|
| 537 |
+
def pick_first(key):
|
| 538 |
+
vals = entities.get(key, [])
|
| 539 |
+
return vals[0].replace("▁", " ").strip() if vals else None
|
| 540 |
+
|
| 541 |
+
return {
|
| 542 |
+
"vendor": pick_first("VENDOR"),
|
| 543 |
+
"date": pick_first("DATE"),
|
| 544 |
+
"total": pick_first("TOTAL"),
|
| 545 |
+
"time": pick_first("TIME"),
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
|
| 549 |
# ============================================================================
|
| 550 |
# Anomaly Detection
|
| 551 |
# ============================================================================
|
|
|
|
| 637 |
print(f"Warning: Could not load OCR: {e}")
|
| 638 |
receipt_ocr = None
|
| 639 |
|
| 640 |
+
try:
|
| 641 |
+
layoutlm_extractor = LayoutLMFieldExtractor()
|
| 642 |
+
layoutlm_extractor.load()
|
| 643 |
+
except Exception as e:
|
| 644 |
+
print(f"Warning: Could not load LayoutLMv3 extractor: {e}")
|
| 645 |
+
layoutlm_extractor = None
|
| 646 |
+
|
| 647 |
anomaly_detector = AnomalyDetector()
|
| 648 |
|
| 649 |
print("\n" + "="*50)
|
|
|
|
| 740 |
fields = {}
|
| 741 |
fields_html = ""
|
| 742 |
try:
|
| 743 |
+
if layoutlm_extractor:
|
| 744 |
+
fields = layoutlm_extractor.predict_fields(image, ocr_results)
|
| 745 |
+
elif receipt_ocr and ocr_results:
|
| 746 |
fields = receipt_ocr.postprocess_receipt(ocr_results)
|
| 747 |
|
| 748 |
fields_html = "<div style='padding: 16px; background: #f8f9fa; border-radius: 12px;'><h4>Extracted Fields</h4>"
|
| 749 |
for name, value in [('Vendor', fields.get('vendor')), ('Date', fields.get('date')),
|
| 750 |
+
('Total', f\"${fields.get('total')}\" if fields.get('total') else None),
|
| 751 |
('Time', fields.get('time'))]:
|
| 752 |
+
display = value or '<span style=\"color: #adb5bd;\">Not found</span>'
|
| 753 |
+
fields_html += f\"<div style='padding: 8px; background: white; border-radius: 6px; margin: 4px 0;'><b>{name}:</b> {display}</div>\"
|
| 754 |
+
fields_html += \"</div>\"
|
| 755 |
results['fields'] = fields
|
| 756 |
except Exception as e:
|
| 757 |
+
fields_html = f\"<div style='color: red;'>Extraction error: {e}</div>\"
|
| 758 |
|
| 759 |
# 4. Anomaly Detection
|
| 760 |
anomaly_html = ""
|