Raghu commited on
Commit
eb79113
·
1 Parent(s): 6fad358

Re-enable LayoutLMv3 field extractor with cached weights

Browse files
Files changed (1) hide show
  1. app.py +141 -7
app.py CHANGED
@@ -15,7 +15,12 @@ import re
15
  from PIL import Image, ImageDraw
16
  from datetime import datetime
17
  from torchvision import transforms, models
18
- from transformers import ViTForImageClassification, ViTImageProcessor
 
 
 
 
 
19
  from sklearn.ensemble import IsolationForest
20
  import warnings
21
  warnings.filterwarnings('ignore')
@@ -421,6 +426,126 @@ class ReceiptOCR:
421
  return match.group() if match else None
422
 
423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  # ============================================================================
425
  # Anomaly Detection
426
  # ============================================================================
@@ -512,6 +637,13 @@ except Exception as e:
512
  print(f"Warning: Could not load OCR: {e}")
513
  receipt_ocr = None
514
 
 
 
 
 
 
 
 
515
  anomaly_detector = AnomalyDetector()
516
 
517
  print("\n" + "="*50)
@@ -608,19 +740,21 @@ def process_receipt(image):
608
  fields = {}
609
  fields_html = ""
610
  try:
611
- if receipt_ocr and ocr_results:
 
 
612
  fields = receipt_ocr.postprocess_receipt(ocr_results)
613
 
614
  fields_html = "<div style='padding: 16px; background: #f8f9fa; border-radius: 12px;'><h4>Extracted Fields</h4>"
615
  for name, value in [('Vendor', fields.get('vendor')), ('Date', fields.get('date')),
616
- ('Total', f"${fields.get('total')}" if fields.get('total') else None),
617
  ('Time', fields.get('time'))]:
618
- display = value or '<span style="color: #adb5bd;">Not found</span>'
619
- fields_html += f"<div style='padding: 8px; background: white; border-radius: 6px; margin: 4px 0;'><b>{name}:</b> {display}</div>"
620
- fields_html += "</div>"
621
  results['fields'] = fields
622
  except Exception as e:
623
- fields_html = f"<div style='color: red;'>Extraction error: {e}</div>"
624
 
625
  # 4. Anomaly Detection
626
  anomaly_html = ""
 
15
  from PIL import Image, ImageDraw
16
  from datetime import datetime
17
  from torchvision import transforms, models
18
+ from transformers import (
19
+ ViTForImageClassification,
20
+ ViTImageProcessor,
21
+ LayoutLMv3ForTokenClassification,
22
+ LayoutLMv3Processor,
23
+ )
24
  from sklearn.ensemble import IsolationForest
25
  import warnings
26
  warnings.filterwarnings('ignore')
 
426
  return match.group() if match else None
427
 
428
 
429
+ # ============================================================================
430
+ # LayoutLMv3 Field Extractor
431
+ # ============================================================================
432
+
433
+ class LayoutLMFieldExtractor:
434
+ """LayoutLMv3-based field extractor using fine-tuned weights if available."""
435
+
436
+ def __init__(self, model_path=None):
437
+ self.model_path = model_path or os.path.join(MODELS_DIR, 'layoutlm_extractor.pt')
438
+ self.id2label = {
439
+ 0: 'O',
440
+ 1: 'B-VENDOR', 2: 'I-VENDOR',
441
+ 3: 'B-DATE', 4: 'I-DATE',
442
+ 5: 'B-TOTAL', 6: 'I-TOTAL',
443
+ 7: 'B-TIME', 8: 'I-TIME'
444
+ }
445
+ self.label2id = {v: k for k, v in self.id2label.items()}
446
+ self.processor = None
447
+ self.model = None
448
+
449
+ def load(self):
450
+ print("Loading LayoutLMv3 extractor...")
451
+ self.processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
452
+ self.model = LayoutLMv3ForTokenClassification.from_pretrained(
453
+ "microsoft/layoutlmv3-base",
454
+ num_labels=len(self.id2label),
455
+ id2label=self.id2label,
456
+ label2id=self.label2id,
457
+ )
458
+ if os.path.exists(self.model_path):
459
+ checkpoint = torch.load(self.model_path, map_location=DEVICE)
460
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
461
+ checkpoint = checkpoint['model_state_dict']
462
+ if isinstance(checkpoint, dict):
463
+ missing, unexpected = self.model.load_state_dict(checkpoint, strict=False)
464
+ print(f"Loaded LayoutLM weights; missing={len(missing)}, unexpected={len(unexpected)}")
465
+ self.model = self.model.to(DEVICE)
466
+ self.model.eval()
467
+ print("LayoutLMv3 ready")
468
+ return self
469
+
470
+ def _prepare_boxes(self, ocr_results, image_size):
471
+ """Convert absolute pixel boxes to LayoutLM 0-1000 format."""
472
+ width, height = image_size
473
+ boxes = []
474
+ words = []
475
+ for r in ocr_results:
476
+ bbox = r.get("bbox", [0, 0, width, height])
477
+ x0, y0, x1, y1 = bbox
478
+ boxes.append([
479
+ int(1000 * x0 / width),
480
+ int(1000 * y0 / height),
481
+ int(1000 * x1 / width),
482
+ int(1000 * y1 / height),
483
+ ])
484
+ words.append(r.get("text", ""))
485
+ return words, boxes
486
+
487
+ def predict_fields(self, image, ocr_results=None):
488
+ if self.model is None:
489
+ self.load()
490
+
491
+ if not isinstance(image, Image.Image):
492
+ image = Image.fromarray(image)
493
+ image = image.convert("RGB")
494
+
495
+ if ocr_results:
496
+ words, boxes = self._prepare_boxes(ocr_results, image.size)
497
+ encoding = self.processor(
498
+ image,
499
+ words=words,
500
+ boxes=boxes,
501
+ return_tensors="pt",
502
+ truncation=True,
503
+ padding="max_length",
504
+ max_length=512,
505
+ )
506
+ else:
507
+ encoding = self.processor(image, return_tensors="pt")
508
+
509
+ encoding = {k: v.to(DEVICE) for k, v in encoding.items()}
510
+ with torch.no_grad():
511
+ outputs = self.model(**encoding)
512
+ logits = outputs.logits[0]
513
+ preds = logits.argmax(-1).cpu().tolist()
514
+ tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu())
515
+
516
+ entities = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
517
+ current = {"label": None, "tokens": []}
518
+
519
+ for token, pred in zip(tokens, preds):
520
+ label = self.id2label.get(pred, "O")
521
+ if token in ["[PAD]", "[CLS]", "[SEP]"]:
522
+ continue
523
+ if label.startswith("B-"):
524
+ # flush previous
525
+ if current["label"] and current["tokens"]:
526
+ entities[current["label"]].append(" ".join(current["tokens"]))
527
+ current = {"label": label[2:], "tokens": [token]}
528
+ elif label.startswith("I-") and current["label"] == label[2:]:
529
+ current["tokens"].append(token)
530
+ else:
531
+ if current["label"] and current["tokens"]:
532
+ entities[current["label"]].append(" ".join(current["tokens"]))
533
+ current = {"label": None, "tokens": []}
534
+ if current["label"] and current["tokens"]:
535
+ entities[current["label"]].append(" ".join(current["tokens"]))
536
+
537
+ def pick_first(key):
538
+ vals = entities.get(key, [])
539
+ return vals[0].replace("▁", " ").strip() if vals else None
540
+
541
+ return {
542
+ "vendor": pick_first("VENDOR"),
543
+ "date": pick_first("DATE"),
544
+ "total": pick_first("TOTAL"),
545
+ "time": pick_first("TIME"),
546
+ }
547
+
548
+
549
  # ============================================================================
550
  # Anomaly Detection
551
  # ============================================================================
 
637
  print(f"Warning: Could not load OCR: {e}")
638
  receipt_ocr = None
639
 
640
+ try:
641
+ layoutlm_extractor = LayoutLMFieldExtractor()
642
+ layoutlm_extractor.load()
643
+ except Exception as e:
644
+ print(f"Warning: Could not load LayoutLMv3 extractor: {e}")
645
+ layoutlm_extractor = None
646
+
647
  anomaly_detector = AnomalyDetector()
648
 
649
  print("\n" + "="*50)
 
740
  fields = {}
741
  fields_html = ""
742
  try:
743
+ if layoutlm_extractor:
744
+ fields = layoutlm_extractor.predict_fields(image, ocr_results)
745
+ elif receipt_ocr and ocr_results:
746
  fields = receipt_ocr.postprocess_receipt(ocr_results)
747
 
748
  fields_html = "<div style='padding: 16px; background: #f8f9fa; border-radius: 12px;'><h4>Extracted Fields</h4>"
749
  for name, value in [('Vendor', fields.get('vendor')), ('Date', fields.get('date')),
750
+ ('Total', f\"${fields.get('total')}\" if fields.get('total') else None),
751
  ('Time', fields.get('time'))]:
752
+ display = value or '<span style=\"color: #adb5bd;\">Not found</span>'
753
+ fields_html += f\"<div style='padding: 8px; background: white; border-radius: 6px; margin: 4px 0;'><b>{name}:</b> {display}</div>\"
754
+ fields_html += \"</div>\"
755
  results['fields'] = fields
756
  except Exception as e:
757
+ fields_html = f\"<div style='color: red;'>Extraction error: {e}</div>\"
758
 
759
  # 4. Anomaly Detection
760
  anomaly_html = ""