Spaces:
Sleeping
Sleeping
File size: 7,086 Bytes
cbb1b1a b861cd9 cbb1b1a b861cd9 cbb1b1a b861cd9 cbb1b1a b861cd9 cbb1b1a b861cd9 cbb1b1a b861cd9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | """
DocumentViT β Vision Transformer for structured evidence extraction.
Architecture: google/vit-base-patch16-224 (encoder, no classification head).
The [CLS] token embedding and patch embedding statistics are used
as a "document quality score" that weights entity confidences.
Actual entity spans are extracted from OCR text via regex and
then confidence-adjusted by the ViT's image assessment.
Why ViT alongside OCR?
Tesseract excels at clean printed text but degrades on blurry screenshots
and poorly-scanned documents. The ViT path applies OCR independently
with its own quality signal, and the DocumentProcessor takes the
higher-confidence span when both paths find the same field.
Input: PIL Image (any size; auto-resized to 224Γ224 by the ViT processor).
Output: List of Entity spans β same schema as EvidenceNER (src.ner.model.Entity).
Fine-tuning note:
The current implementation uses the pre-trained ViT encoder as a feature
extractor for confidence scoring. Swap in a fine-tuned checkpoint by
passing the local model directory to __init__(model_name=...).
"""
from __future__ import annotations
import logging
import re
import pytesseract
import torch
from PIL import Image
from transformers import AutoImageProcessor, ViTModel
from src.document_processor.ocr import (
_clean_text,
_extract_entities_from_text,
_preprocess_pil,
)
from src.ner.model import Entity
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "google/vit-base-patch16-224"
# ---------------------------------------------------------------------------
# Document quality scoring from ViT patch embeddings
# ---------------------------------------------------------------------------
def _document_score(last_hidden_state: torch.Tensor) -> float:
"""
Estimate how "document-like" an image is from its ViT patch embeddings.
Structured documents (receipts, bills, screenshots) have high spatial
variation across their 196 patches (text, logos, lines, whitespace).
Natural photos tend to have smoother patch distributions.
Returns a score in [0, 1] used to scale entity confidences.
"""
# Patch embeddings: shape (1, 196, 768)
patches = last_hidden_state[:, 1:, :] # exclude [CLS]
# Mean variance across the embedding dimension, averaged over patches
# Higher β more information diversity β more document-like
patch_var = patches.var(dim=-1).mean().item() # scalar
# Empirically, document images yield patch_var ~ 1.0β3.0
# Natural photos tend to be lower or at different scales
score = min(1.0, max(0.3, patch_var / 2.5))
return round(score, 3)
# ---------------------------------------------------------------------------
# DocumentViT
# ---------------------------------------------------------------------------
class DocumentViT:
"""
ViT-based document evidence extractor.
Combines:
1. ViT image quality scoring (using patch embedding variance)
2. Pytesseract OCR for text recovery from the image
3. Regex entity extraction on the OCR text
4. Confidence values scaled by the ViT document score
This design means the ViT pass adds value even on blurry images where
OCR quality is low: the document score automatically lowers confidence,
signalling to the DocumentProcessor that these entities need human review.
"""
def __init__(self, model_name: str = _DEFAULT_MODEL) -> None:
"""Load the ViT encoder and image processor from *model_name*."""
logger.info("Loading DocumentViT from %s β¦", model_name)
self._processor = AutoImageProcessor.from_pretrained(model_name)
self._model = ViTModel.from_pretrained(model_name)
self._model.eval()
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._model.to(self._device)
logger.info("DocumentViT ready on %s.", self._device)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _score(self, image: Image.Image) -> float:
"""Run the ViT encoder on *image* and return a document quality score."""
inputs = self._processor(images=image.convert("RGB"), return_tensors="pt")
inputs = {k: v.to(self._device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self._model(**inputs)
return _document_score(outputs.last_hidden_state.cpu())
@staticmethod
def _ocr_from_image(image: Image.Image) -> str:
"""
Pre-process *image* and run Tesseract.
Uses the same preprocessing pipeline as ocr.py so results are
consistent and can be merged without offset conflicts.
"""
preprocessed = _preprocess_pil(image.convert("RGB"))
raw = pytesseract.image_to_string(preprocessed, lang="eng")
return _clean_text(raw)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def extract(
self,
image: Image.Image,
ocr_text: str = "",
) -> list[Entity]:
"""
Extract evidence entities from *image*.
Args:
image: Source document image (PIL Image, any mode/size).
ocr_text: Pre-computed OCR text (e.g. from ocr.py). If empty,
OCR is run internally so the ViT path is self-contained.
Returns:
List of Entity objects (text/label/start/end/confidence) in the
same schema as EvidenceNER. Confidence values are in [0, 1] and
are scaled by the ViT document quality score so the DocumentProcessor
can prefer higher-confidence spans during merging.
"""
# 1. Document quality score from ViT
try:
doc_score = self._score(image)
except Exception:
logger.warning("ViT scoring failed β using default confidence.", exc_info=True)
doc_score = 0.6
# 2. OCR text (use provided or run internally)
text = ocr_text.strip() if ocr_text.strip() else self._ocr_from_image(image)
if not text:
logger.debug("DocumentViT: no OCR text recovered β returning empty entity list.")
return []
# 3. Regex entity extraction
raw_entities = _extract_entities_from_text(text, base_confidence=1.0)
# 4. Scale confidences by ViT document score
scaled: list[Entity] = [
Entity(
text=e.text,
label=e.label,
start=e.start,
end=e.end,
confidence=round(e.confidence * doc_score, 4),
)
for e in raw_entities
]
logger.debug(
"DocumentViT: doc_score=%.3f, %d entities extracted.",
doc_score, len(scaled),
)
return scaled
|