File size: 7,086 Bytes
cbb1b1a
 
 
b861cd9
 
 
 
 
cbb1b1a
b861cd9
 
 
 
 
 
 
 
 
 
 
 
 
cbb1b1a
 
b861cd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb1b1a
 
b861cd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb1b1a
b861cd9
 
 
 
 
 
 
 
 
 
 
cbb1b1a
b861cd9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
DocumentViT β€” Vision Transformer for structured evidence extraction.

Architecture: google/vit-base-patch16-224 (encoder, no classification head).
              The [CLS] token embedding and patch embedding statistics are used
              as a "document quality score" that weights entity confidences.
              Actual entity spans are extracted from OCR text via regex and
              then confidence-adjusted by the ViT's image assessment.

Why ViT alongside OCR?
    Tesseract excels at clean printed text but degrades on blurry screenshots
    and poorly-scanned documents.  The ViT path applies OCR independently
    with its own quality signal, and the DocumentProcessor takes the
    higher-confidence span when both paths find the same field.

Input:  PIL Image (any size; auto-resized to 224Γ—224 by the ViT processor).
Output: List of Entity spans β€” same schema as EvidenceNER (src.ner.model.Entity).

Fine-tuning note:
    The current implementation uses the pre-trained ViT encoder as a feature
    extractor for confidence scoring.  Swap in a fine-tuned checkpoint by
    passing the local model directory to __init__(model_name=...).
"""

from __future__ import annotations

import logging
import re

import pytesseract
import torch
from PIL import Image
from transformers import AutoImageProcessor, ViTModel

from src.document_processor.ocr import (
    _clean_text,
    _extract_entities_from_text,
    _preprocess_pil,
)
from src.ner.model import Entity

logger = logging.getLogger(__name__)

_DEFAULT_MODEL = "google/vit-base-patch16-224"

# ---------------------------------------------------------------------------
# Document quality scoring from ViT patch embeddings
# ---------------------------------------------------------------------------

def _document_score(last_hidden_state: torch.Tensor) -> float:
    """
    Estimate how "document-like" an image is from its ViT patch embeddings.

    Structured documents (receipts, bills, screenshots) have high spatial
    variation across their 196 patches (text, logos, lines, whitespace).
    Natural photos tend to have smoother patch distributions.

    Returns a score in [0, 1] used to scale entity confidences.
    """
    # Patch embeddings: shape (1, 196, 768)
    patches = last_hidden_state[:, 1:, :]   # exclude [CLS]

    # Mean variance across the embedding dimension, averaged over patches
    # Higher β†’ more information diversity β†’ more document-like
    patch_var = patches.var(dim=-1).mean().item()   # scalar

    # Empirically, document images yield patch_var ~ 1.0–3.0
    # Natural photos tend to be lower or at different scales
    score = min(1.0, max(0.3, patch_var / 2.5))
    return round(score, 3)


# ---------------------------------------------------------------------------
# DocumentViT
# ---------------------------------------------------------------------------

class DocumentViT:
    """
    ViT-based document evidence extractor.

    Combines:
      1. ViT image quality scoring (using patch embedding variance)
      2. Pytesseract OCR for text recovery from the image
      3. Regex entity extraction on the OCR text
      4. Confidence values scaled by the ViT document score

    This design means the ViT pass adds value even on blurry images where
    OCR quality is low: the document score automatically lowers confidence,
    signalling to the DocumentProcessor that these entities need human review.
    """

    def __init__(self, model_name: str = _DEFAULT_MODEL) -> None:
        """Load the ViT encoder and image processor from *model_name*."""
        logger.info("Loading DocumentViT from %s …", model_name)
        self._processor = AutoImageProcessor.from_pretrained(model_name)
        self._model = ViTModel.from_pretrained(model_name)
        self._model.eval()
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._model.to(self._device)
        logger.info("DocumentViT ready on %s.", self._device)

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _score(self, image: Image.Image) -> float:
        """Run the ViT encoder on *image* and return a document quality score."""
        inputs = self._processor(images=image.convert("RGB"), return_tensors="pt")
        inputs = {k: v.to(self._device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self._model(**inputs)
        return _document_score(outputs.last_hidden_state.cpu())

    @staticmethod
    def _ocr_from_image(image: Image.Image) -> str:
        """
        Pre-process *image* and run Tesseract.

        Uses the same preprocessing pipeline as ocr.py so results are
        consistent and can be merged without offset conflicts.
        """
        preprocessed = _preprocess_pil(image.convert("RGB"))
        raw = pytesseract.image_to_string(preprocessed, lang="eng")
        return _clean_text(raw)

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def extract(
        self,
        image: Image.Image,
        ocr_text: str = "",
    ) -> list[Entity]:
        """
        Extract evidence entities from *image*.

        Args:
            image:    Source document image (PIL Image, any mode/size).
            ocr_text: Pre-computed OCR text (e.g. from ocr.py).  If empty,
                      OCR is run internally so the ViT path is self-contained.

        Returns:
            List of Entity objects (text/label/start/end/confidence) in the
            same schema as EvidenceNER.  Confidence values are in [0, 1] and
            are scaled by the ViT document quality score so the DocumentProcessor
            can prefer higher-confidence spans during merging.
        """
        # 1. Document quality score from ViT
        try:
            doc_score = self._score(image)
        except Exception:
            logger.warning("ViT scoring failed β€” using default confidence.", exc_info=True)
            doc_score = 0.6

        # 2. OCR text (use provided or run internally)
        text = ocr_text.strip() if ocr_text.strip() else self._ocr_from_image(image)

        if not text:
            logger.debug("DocumentViT: no OCR text recovered β€” returning empty entity list.")
            return []

        # 3. Regex entity extraction
        raw_entities = _extract_entities_from_text(text, base_confidence=1.0)

        # 4. Scale confidences by ViT document score
        scaled: list[Entity] = [
            Entity(
                text=e.text,
                label=e.label,
                start=e.start,
                end=e.end,
                confidence=round(e.confidence * doc_score, 4),
            )
            for e in raw_entities
        ]

        logger.debug(
            "DocumentViT: doc_score=%.3f, %d entities extracted.",
            doc_score, len(scaled),
        )
        return scaled