guide / src /document_processor /vit_model.py
sangram kumar yerra
phase 4 - NextAction MLP & Document Processor
b861cd9
Raw
History Blame Contribute Delete
7.09 kB
"""
DocumentViT β€” Vision Transformer for structured evidence extraction.
Architecture: google/vit-base-patch16-224 (encoder, no classification head).
The [CLS] token embedding and patch embedding statistics are used
as a "document quality score" that weights entity confidences.
Actual entity spans are extracted from OCR text via regex and
then confidence-adjusted by the ViT's image assessment.
Why ViT alongside OCR?
Tesseract excels at clean printed text but degrades on blurry screenshots
and poorly-scanned documents. The ViT path applies OCR independently
with its own quality signal, and the DocumentProcessor takes the
higher-confidence span when both paths find the same field.
Input: PIL Image (any size; auto-resized to 224Γ—224 by the ViT processor).
Output: List of Entity spans β€” same schema as EvidenceNER (src.ner.model.Entity).
Fine-tuning note:
The current implementation uses the pre-trained ViT encoder as a feature
extractor for confidence scoring. Swap in a fine-tuned checkpoint by
passing the local model directory to __init__(model_name=...).
"""
from __future__ import annotations
import logging
import re
import pytesseract
import torch
from PIL import Image
from transformers import AutoImageProcessor, ViTModel
from src.document_processor.ocr import (
_clean_text,
_extract_entities_from_text,
_preprocess_pil,
)
from src.ner.model import Entity
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "google/vit-base-patch16-224"
# ---------------------------------------------------------------------------
# Document quality scoring from ViT patch embeddings
# ---------------------------------------------------------------------------
def _document_score(last_hidden_state: torch.Tensor) -> float:
"""
Estimate how "document-like" an image is from its ViT patch embeddings.
Structured documents (receipts, bills, screenshots) have high spatial
variation across their 196 patches (text, logos, lines, whitespace).
Natural photos tend to have smoother patch distributions.
Returns a score in [0, 1] used to scale entity confidences.
"""
# Patch embeddings: shape (1, 196, 768)
patches = last_hidden_state[:, 1:, :] # exclude [CLS]
# Mean variance across the embedding dimension, averaged over patches
# Higher β†’ more information diversity β†’ more document-like
patch_var = patches.var(dim=-1).mean().item() # scalar
# Empirically, document images yield patch_var ~ 1.0–3.0
# Natural photos tend to be lower or at different scales
score = min(1.0, max(0.3, patch_var / 2.5))
return round(score, 3)
# ---------------------------------------------------------------------------
# DocumentViT
# ---------------------------------------------------------------------------
class DocumentViT:
"""
ViT-based document evidence extractor.
Combines:
1. ViT image quality scoring (using patch embedding variance)
2. Pytesseract OCR for text recovery from the image
3. Regex entity extraction on the OCR text
4. Confidence values scaled by the ViT document score
This design means the ViT pass adds value even on blurry images where
OCR quality is low: the document score automatically lowers confidence,
signalling to the DocumentProcessor that these entities need human review.
"""
def __init__(self, model_name: str = _DEFAULT_MODEL) -> None:
"""Load the ViT encoder and image processor from *model_name*."""
logger.info("Loading DocumentViT from %s …", model_name)
self._processor = AutoImageProcessor.from_pretrained(model_name)
self._model = ViTModel.from_pretrained(model_name)
self._model.eval()
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._model.to(self._device)
logger.info("DocumentViT ready on %s.", self._device)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _score(self, image: Image.Image) -> float:
"""Run the ViT encoder on *image* and return a document quality score."""
inputs = self._processor(images=image.convert("RGB"), return_tensors="pt")
inputs = {k: v.to(self._device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self._model(**inputs)
return _document_score(outputs.last_hidden_state.cpu())
@staticmethod
def _ocr_from_image(image: Image.Image) -> str:
"""
Pre-process *image* and run Tesseract.
Uses the same preprocessing pipeline as ocr.py so results are
consistent and can be merged without offset conflicts.
"""
preprocessed = _preprocess_pil(image.convert("RGB"))
raw = pytesseract.image_to_string(preprocessed, lang="eng")
return _clean_text(raw)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def extract(
self,
image: Image.Image,
ocr_text: str = "",
) -> list[Entity]:
"""
Extract evidence entities from *image*.
Args:
image: Source document image (PIL Image, any mode/size).
ocr_text: Pre-computed OCR text (e.g. from ocr.py). If empty,
OCR is run internally so the ViT path is self-contained.
Returns:
List of Entity objects (text/label/start/end/confidence) in the
same schema as EvidenceNER. Confidence values are in [0, 1] and
are scaled by the ViT document quality score so the DocumentProcessor
can prefer higher-confidence spans during merging.
"""
# 1. Document quality score from ViT
try:
doc_score = self._score(image)
except Exception:
logger.warning("ViT scoring failed β€” using default confidence.", exc_info=True)
doc_score = 0.6
# 2. OCR text (use provided or run internally)
text = ocr_text.strip() if ocr_text.strip() else self._ocr_from_image(image)
if not text:
logger.debug("DocumentViT: no OCR text recovered β€” returning empty entity list.")
return []
# 3. Regex entity extraction
raw_entities = _extract_entities_from_text(text, base_confidence=1.0)
# 4. Scale confidences by ViT document score
scaled: list[Entity] = [
Entity(
text=e.text,
label=e.label,
start=e.start,
end=e.end,
confidence=round(e.confidence * doc_score, 4),
)
for e in raw_entities
]
logger.debug(
"DocumentViT: doc_score=%.3f, %d entities extracted.",
doc_score, len(scaled),
)
return scaled