""" DocumentViT — Vision Transformer for structured evidence extraction. Architecture: google/vit-base-patch16-224 (encoder, no classification head). The [CLS] token embedding and patch embedding statistics are used as a "document quality score" that weights entity confidences. Actual entity spans are extracted from OCR text via regex and then confidence-adjusted by the ViT's image assessment. Why ViT alongside OCR? Tesseract excels at clean printed text but degrades on blurry screenshots and poorly-scanned documents. The ViT path applies OCR independently with its own quality signal, and the DocumentProcessor takes the higher-confidence span when both paths find the same field. Input: PIL Image (any size; auto-resized to 224×224 by the ViT processor). Output: List of Entity spans — same schema as EvidenceNER (src.ner.model.Entity). Fine-tuning note: The current implementation uses the pre-trained ViT encoder as a feature extractor for confidence scoring. Swap in a fine-tuned checkpoint by passing the local model directory to __init__(model_name=...). """ from __future__ import annotations import logging import re import pytesseract import torch from PIL import Image from transformers import AutoImageProcessor, ViTModel from src.document_processor.ocr import ( _clean_text, _extract_entities_from_text, _preprocess_pil, ) from src.ner.model import Entity logger = logging.getLogger(__name__) _DEFAULT_MODEL = "google/vit-base-patch16-224" # --------------------------------------------------------------------------- # Document quality scoring from ViT patch embeddings # --------------------------------------------------------------------------- def _document_score(last_hidden_state: torch.Tensor) -> float: """ Estimate how "document-like" an image is from its ViT patch embeddings. Structured documents (receipts, bills, screenshots) have high spatial variation across their 196 patches (text, logos, lines, whitespace). Natural photos tend to have smoother patch distributions. Returns a score in [0, 1] used to scale entity confidences. """ # Patch embeddings: shape (1, 196, 768) patches = last_hidden_state[:, 1:, :] # exclude [CLS] # Mean variance across the embedding dimension, averaged over patches # Higher → more information diversity → more document-like patch_var = patches.var(dim=-1).mean().item() # scalar # Empirically, document images yield patch_var ~ 1.0–3.0 # Natural photos tend to be lower or at different scales score = min(1.0, max(0.3, patch_var / 2.5)) return round(score, 3) # --------------------------------------------------------------------------- # DocumentViT # --------------------------------------------------------------------------- class DocumentViT: """ ViT-based document evidence extractor. Combines: 1. ViT image quality scoring (using patch embedding variance) 2. Pytesseract OCR for text recovery from the image 3. Regex entity extraction on the OCR text 4. Confidence values scaled by the ViT document score This design means the ViT pass adds value even on blurry images where OCR quality is low: the document score automatically lowers confidence, signalling to the DocumentProcessor that these entities need human review. """ def __init__(self, model_name: str = _DEFAULT_MODEL) -> None: """Load the ViT encoder and image processor from *model_name*.""" logger.info("Loading DocumentViT from %s …", model_name) self._processor = AutoImageProcessor.from_pretrained(model_name) self._model = ViTModel.from_pretrained(model_name) self._model.eval() self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._model.to(self._device) logger.info("DocumentViT ready on %s.", self._device) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _score(self, image: Image.Image) -> float: """Run the ViT encoder on *image* and return a document quality score.""" inputs = self._processor(images=image.convert("RGB"), return_tensors="pt") inputs = {k: v.to(self._device) for k, v in inputs.items()} with torch.no_grad(): outputs = self._model(**inputs) return _document_score(outputs.last_hidden_state.cpu()) @staticmethod def _ocr_from_image(image: Image.Image) -> str: """ Pre-process *image* and run Tesseract. Uses the same preprocessing pipeline as ocr.py so results are consistent and can be merged without offset conflicts. """ preprocessed = _preprocess_pil(image.convert("RGB")) raw = pytesseract.image_to_string(preprocessed, lang="eng") return _clean_text(raw) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def extract( self, image: Image.Image, ocr_text: str = "", ) -> list[Entity]: """ Extract evidence entities from *image*. Args: image: Source document image (PIL Image, any mode/size). ocr_text: Pre-computed OCR text (e.g. from ocr.py). If empty, OCR is run internally so the ViT path is self-contained. Returns: List of Entity objects (text/label/start/end/confidence) in the same schema as EvidenceNER. Confidence values are in [0, 1] and are scaled by the ViT document quality score so the DocumentProcessor can prefer higher-confidence spans during merging. """ # 1. Document quality score from ViT try: doc_score = self._score(image) except Exception: logger.warning("ViT scoring failed — using default confidence.", exc_info=True) doc_score = 0.6 # 2. OCR text (use provided or run internally) text = ocr_text.strip() if ocr_text.strip() else self._ocr_from_image(image) if not text: logger.debug("DocumentViT: no OCR text recovered — returning empty entity list.") return [] # 3. Regex entity extraction raw_entities = _extract_entities_from_text(text, base_confidence=1.0) # 4. Scale confidences by ViT document score scaled: list[Entity] = [ Entity( text=e.text, label=e.label, start=e.start, end=e.end, confidence=round(e.confidence * doc_score, 4), ) for e in raw_entities ] logger.debug( "DocumentViT: doc_score=%.3f, %d entities extracted.", doc_score, len(scaled), ) return scaled