Spaces:
Sleeping
Sleeping
| """ | |
| DocumentViT β Vision Transformer for structured evidence extraction. | |
| Architecture: google/vit-base-patch16-224 (encoder, no classification head). | |
| The [CLS] token embedding and patch embedding statistics are used | |
| as a "document quality score" that weights entity confidences. | |
| Actual entity spans are extracted from OCR text via regex and | |
| then confidence-adjusted by the ViT's image assessment. | |
| Why ViT alongside OCR? | |
| Tesseract excels at clean printed text but degrades on blurry screenshots | |
| and poorly-scanned documents. The ViT path applies OCR independently | |
| with its own quality signal, and the DocumentProcessor takes the | |
| higher-confidence span when both paths find the same field. | |
| Input: PIL Image (any size; auto-resized to 224Γ224 by the ViT processor). | |
| Output: List of Entity spans β same schema as EvidenceNER (src.ner.model.Entity). | |
| Fine-tuning note: | |
| The current implementation uses the pre-trained ViT encoder as a feature | |
| extractor for confidence scoring. Swap in a fine-tuned checkpoint by | |
| passing the local model directory to __init__(model_name=...). | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import re | |
| import pytesseract | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, ViTModel | |
| from src.document_processor.ocr import ( | |
| _clean_text, | |
| _extract_entities_from_text, | |
| _preprocess_pil, | |
| ) | |
| from src.ner.model import Entity | |
| logger = logging.getLogger(__name__) | |
| _DEFAULT_MODEL = "google/vit-base-patch16-224" | |
| # --------------------------------------------------------------------------- | |
| # Document quality scoring from ViT patch embeddings | |
| # --------------------------------------------------------------------------- | |
| def _document_score(last_hidden_state: torch.Tensor) -> float: | |
| """ | |
| Estimate how "document-like" an image is from its ViT patch embeddings. | |
| Structured documents (receipts, bills, screenshots) have high spatial | |
| variation across their 196 patches (text, logos, lines, whitespace). | |
| Natural photos tend to have smoother patch distributions. | |
| Returns a score in [0, 1] used to scale entity confidences. | |
| """ | |
| # Patch embeddings: shape (1, 196, 768) | |
| patches = last_hidden_state[:, 1:, :] # exclude [CLS] | |
| # Mean variance across the embedding dimension, averaged over patches | |
| # Higher β more information diversity β more document-like | |
| patch_var = patches.var(dim=-1).mean().item() # scalar | |
| # Empirically, document images yield patch_var ~ 1.0β3.0 | |
| # Natural photos tend to be lower or at different scales | |
| score = min(1.0, max(0.3, patch_var / 2.5)) | |
| return round(score, 3) | |
| # --------------------------------------------------------------------------- | |
| # DocumentViT | |
| # --------------------------------------------------------------------------- | |
| class DocumentViT: | |
| """ | |
| ViT-based document evidence extractor. | |
| Combines: | |
| 1. ViT image quality scoring (using patch embedding variance) | |
| 2. Pytesseract OCR for text recovery from the image | |
| 3. Regex entity extraction on the OCR text | |
| 4. Confidence values scaled by the ViT document score | |
| This design means the ViT pass adds value even on blurry images where | |
| OCR quality is low: the document score automatically lowers confidence, | |
| signalling to the DocumentProcessor that these entities need human review. | |
| """ | |
| def __init__(self, model_name: str = _DEFAULT_MODEL) -> None: | |
| """Load the ViT encoder and image processor from *model_name*.""" | |
| logger.info("Loading DocumentViT from %s β¦", model_name) | |
| self._processor = AutoImageProcessor.from_pretrained(model_name) | |
| self._model = ViTModel.from_pretrained(model_name) | |
| self._model.eval() | |
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self._model.to(self._device) | |
| logger.info("DocumentViT ready on %s.", self._device) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _score(self, image: Image.Image) -> float: | |
| """Run the ViT encoder on *image* and return a document quality score.""" | |
| inputs = self._processor(images=image.convert("RGB"), return_tensors="pt") | |
| inputs = {k: v.to(self._device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self._model(**inputs) | |
| return _document_score(outputs.last_hidden_state.cpu()) | |
| def _ocr_from_image(image: Image.Image) -> str: | |
| """ | |
| Pre-process *image* and run Tesseract. | |
| Uses the same preprocessing pipeline as ocr.py so results are | |
| consistent and can be merged without offset conflicts. | |
| """ | |
| preprocessed = _preprocess_pil(image.convert("RGB")) | |
| raw = pytesseract.image_to_string(preprocessed, lang="eng") | |
| return _clean_text(raw) | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def extract( | |
| self, | |
| image: Image.Image, | |
| ocr_text: str = "", | |
| ) -> list[Entity]: | |
| """ | |
| Extract evidence entities from *image*. | |
| Args: | |
| image: Source document image (PIL Image, any mode/size). | |
| ocr_text: Pre-computed OCR text (e.g. from ocr.py). If empty, | |
| OCR is run internally so the ViT path is self-contained. | |
| Returns: | |
| List of Entity objects (text/label/start/end/confidence) in the | |
| same schema as EvidenceNER. Confidence values are in [0, 1] and | |
| are scaled by the ViT document quality score so the DocumentProcessor | |
| can prefer higher-confidence spans during merging. | |
| """ | |
| # 1. Document quality score from ViT | |
| try: | |
| doc_score = self._score(image) | |
| except Exception: | |
| logger.warning("ViT scoring failed β using default confidence.", exc_info=True) | |
| doc_score = 0.6 | |
| # 2. OCR text (use provided or run internally) | |
| text = ocr_text.strip() if ocr_text.strip() else self._ocr_from_image(image) | |
| if not text: | |
| logger.debug("DocumentViT: no OCR text recovered β returning empty entity list.") | |
| return [] | |
| # 3. Regex entity extraction | |
| raw_entities = _extract_entities_from_text(text, base_confidence=1.0) | |
| # 4. Scale confidences by ViT document score | |
| scaled: list[Entity] = [ | |
| Entity( | |
| text=e.text, | |
| label=e.label, | |
| start=e.start, | |
| end=e.end, | |
| confidence=round(e.confidence * doc_score, 4), | |
| ) | |
| for e in raw_entities | |
| ] | |
| logger.debug( | |
| "DocumentViT: doc_score=%.3f, %d entities extracted.", | |
| doc_score, len(scaled), | |
| ) | |
| return scaled | |