Didier commited on
Commit
58ba391
·
verified ·
1 Parent(s): d069605

Upload vlm_ocr.py

Browse files
Files changed (1) hide show
  1. vlm_ocr.py +162 -0
vlm_ocr.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: vlm_ocr.py
3
+
4
+ This module provides a VLM OCR model for Docling.
5
+
6
+ :author: Didier Guillevic
7
+ :email: didier.guillevic@gmail.com
8
+ :date: 2026-02-27
9
+ :license: Apache License 2.0
10
+ """
11
+ import base64
12
+ import io
13
+ import logging
14
+ import requests
15
+ import itertools
16
+ from collections.abc import Iterable
17
+ from pathlib import Path
18
+ from typing import Any, ClassVar, List, Literal, Optional, Type
19
+
20
+ from docling.datamodel.accelerator_options import AcceleratorOptions
21
+ from docling.datamodel.base_models import Page
22
+ from docling.datamodel.document import ConversionResult
23
+ from docling.datamodel.pipeline_options import OcrOptions
24
+ from docling.models.base_ocr_model import BaseOcrModel
25
+ from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
26
+ from docling_core.types.doc.page import BoundingRectangle, TextCell
27
+ from PIL import Image
28
+
29
+ _log = logging.getLogger(__name__)
30
+ _cancel_requested = False
31
+
32
+ def request_cancel():
33
+ global _cancel_requested
34
+ _cancel_requested = True
35
+
36
+ def reset_cancel():
37
+ global _cancel_requested
38
+ _cancel_requested = False
39
+
40
+ class VlmOcrOptions(OcrOptions):
41
+ kind: ClassVar[Literal["vlm_ocr"]] = "vlm_ocr"
42
+ lang: List[str] = ["en"]
43
+ model: str = "Ministral-3-14B-Instruct-2512"
44
+ openai_base_url: str = "http://localhost:8080/v1"
45
+ openai_api_key: str = "Keep learning"
46
+ prompt: str = "Transcribe the text in this image. Return only the transcription. Use standard Markdown table syntax for any tables found. Be extremely accurate."
47
+ timeout: float = 300.0
48
+
49
+ class VlmOcrModel(BaseOcrModel):
50
+ def __init__(
51
+ self,
52
+ enabled: bool,
53
+ artifacts_path: Optional[Path],
54
+ options: VlmOcrOptions,
55
+ accelerator_options: AcceleratorOptions,
56
+ ):
57
+ super().__init__(
58
+ enabled=enabled,
59
+ artifacts_path=artifacts_path,
60
+ options=options,
61
+ accelerator_options=accelerator_options,
62
+ )
63
+ self.options: VlmOcrOptions = options
64
+
65
+ def __call__(
66
+ self, conv_res: ConversionResult, page_batch: Iterable[Page]
67
+ ) -> Iterable[Page]:
68
+ if not self.enabled:
69
+ yield from page_batch
70
+ return
71
+
72
+ for page in page_batch:
73
+ if _cancel_requested:
74
+ _log.info("OCR execution cancelled.")
75
+ yield page
76
+ continue
77
+
78
+ if page._backend is None or not page._backend.is_valid():
79
+ yield page
80
+ continue
81
+
82
+ # Identify OCR regions
83
+ ocr_rects = self.get_ocr_rects(page)
84
+ all_ocr_cells = []
85
+
86
+ for i, ocr_rect in enumerate(ocr_rects):
87
+ if ocr_rect.area() == 0:
88
+ continue
89
+
90
+ # Get the image for the region
91
+ high_res_image = page._backend.get_page_image(
92
+ scale=3.0, cropbox=ocr_rect
93
+ )
94
+
95
+ # Convert PIL Image to Base64
96
+ buffered = io.BytesIO()
97
+ high_res_image.save(buffered, format="PNG")
98
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
99
+
100
+ # Call OpenAI-compatible API
101
+ payload = {
102
+ "model": self.options.model,
103
+ "messages": [
104
+ {
105
+ "role": "user",
106
+ "content": [
107
+ {"type": "text", "text": self.options.prompt},
108
+ {
109
+ "type": "image_url",
110
+ "image_url": {"url": f"data:image/png;base64,{img_str}"},
111
+ },
112
+ ],
113
+ }
114
+ ],
115
+ "temperature": 0.0,
116
+ }
117
+ headers = {"Authorization": f"Bearer {self.options.openai_api_key}"}
118
+ endpoint = f"{self.options.openai_base_url.rstrip('/')}/chat/completions"
119
+
120
+ try:
121
+ _log.info(f"Sending VLM OCR request for page {page.page_no}, region {i}")
122
+ response = requests.post(
123
+ endpoint,
124
+ json=payload,
125
+ headers=headers,
126
+ timeout=self.options.timeout,
127
+ )
128
+ response.raise_for_status()
129
+ result = response.json()
130
+ transcription = result["choices"][0]["message"]["content"]
131
+
132
+ cell = TextCell(
133
+ index=len(all_ocr_cells),
134
+ text=transcription,
135
+ orig=transcription,
136
+ from_ocr=True,
137
+ confidence=1.0,
138
+ rect=BoundingRectangle.from_bounding_box(ocr_rect),
139
+ )
140
+ all_ocr_cells.append(cell)
141
+
142
+ except Exception as e:
143
+ _log.error(f"VLM OCR failed for page {page.page_no}: {e}")
144
+
145
+ # Post-process the cells
146
+ self.post_process_cells(all_ocr_cells, page)
147
+ yield page
148
+
149
+ @classmethod
150
+ def get_options_type(cls) -> Type[OcrOptions]:
151
+ return VlmOcrOptions
152
+
153
+ class LocalVlmPdfPipeline(StandardPdfPipeline):
154
+ def _make_ocr_model(self, art_path: Path | None) -> Any:
155
+ if isinstance(self.pipeline_options.ocr_options, VlmOcrOptions):
156
+ return VlmOcrModel(
157
+ enabled=self.pipeline_options.do_ocr,
158
+ artifacts_path=art_path,
159
+ options=self.pipeline_options.ocr_options,
160
+ accelerator_options=self.pipeline_options.accelerator_options,
161
+ )
162
+ return super()._make_ocr_model(art_path)