"""OCR engine — ctypes wrapper for Windows 11 SnippingTool's oneocr.dll. Provides offline OCR capability using Microsoft's AI model from Snipping Tool. Requires oneocr.dll, oneocr.onemodel, and onnxruntime.dll in ocr_data/ directory. Usage: from src.services.ocr.engine import OcrEngine engine = OcrEngine() result = engine.recognize_pil(pil_image) print(result.text) """ from __future__ import annotations import ctypes import os from contextlib import contextmanager from ctypes import ( POINTER, Structure, byref, c_char_p, c_float, c_int32, c_int64, c_ubyte, ) from pathlib import Path from typing import TYPE_CHECKING from ocr.models import BoundingRect, OcrLine, OcrResult, OcrWord # Constants (previously imported from src.config.constants) OCR_DLL_NAME = "oneocr.dll" OCR_MODEL_NAME = "oneocr.onemodel" OCR_MODEL_KEY = b'kj)TGtrK>f]b[Piow.gU+nC@s""""""4' OCR_MAX_LINES = 200 if TYPE_CHECKING: from PIL import Image c_int64_p = POINTER(c_int64) c_float_p = POINTER(c_float) c_ubyte_p = POINTER(c_ubyte) class _ImageStructure(Structure): """Image data structure for oneocr.dll (CV_8UC4 format).""" _fields_ = [ ("type", c_int32), ("width", c_int32), ("height", c_int32), ("_reserved", c_int32), ("step_size", c_int64), ("data_ptr", c_ubyte_p), ] class _BoundingBox(Structure): """Bounding box coordinates from DLL.""" _fields_ = [ ("x1", c_float), ("y1", c_float), ("x2", c_float), ("y2", c_float), ("x3", c_float), ("y3", c_float), ("x4", c_float), ("y4", c_float), ] _BoundingBox_p = POINTER(_BoundingBox) # DLL function signatures: (name, argtypes, restype) _DLL_FUNCTIONS: list[tuple[str, list[type], type | None]] = [ ("CreateOcrInitOptions", [c_int64_p], c_int64), ("OcrInitOptionsSetUseModelDelayLoad", [c_int64, ctypes.c_char], c_int64), ("CreateOcrPipeline", [c_char_p, c_char_p, c_int64, c_int64_p], c_int64), ("CreateOcrProcessOptions", [c_int64_p], c_int64), ("OcrProcessOptionsSetMaxRecognitionLineCount", [c_int64, c_int64], c_int64), ("RunOcrPipeline", [c_int64, POINTER(_ImageStructure), c_int64, c_int64_p], c_int64), ("GetImageAngle", [c_int64, c_float_p], c_int64), ("GetOcrLineCount", [c_int64, c_int64_p], c_int64), ("GetOcrLine", [c_int64, c_int64, c_int64_p], c_int64), ("GetOcrLineContent", [c_int64, POINTER(c_char_p)], c_int64), ("GetOcrLineBoundingBox", [c_int64, POINTER(_BoundingBox_p)], c_int64), ("GetOcrLineWordCount", [c_int64, c_int64_p], c_int64), ("GetOcrWord", [c_int64, c_int64, c_int64_p], c_int64), ("GetOcrWordContent", [c_int64, POINTER(c_char_p)], c_int64), ("GetOcrWordBoundingBox", [c_int64, POINTER(_BoundingBox_p)], c_int64), ("GetOcrWordConfidence", [c_int64, c_float_p], c_int64), ("ReleaseOcrResult", [c_int64], None), ("ReleaseOcrInitOptions", [c_int64], None), ("ReleaseOcrPipeline", [c_int64], None), ("ReleaseOcrProcessOptions", [c_int64], None), ] @contextmanager def _suppress_output(): """Suppress stdout/stderr during DLL initialization (it prints to console).""" devnull = os.open(os.devnull, os.O_WRONLY) original_stdout = os.dup(1) original_stderr = os.dup(2) os.dup2(devnull, 1) os.dup2(devnull, 2) try: yield finally: os.dup2(original_stdout, 1) os.dup2(original_stderr, 2) os.close(original_stdout) os.close(original_stderr) os.close(devnull) class OcrEngine: """Offline OCR engine using Windows 11 SnippingTool's oneocr.dll. Args: ocr_data_dir: Path to directory containing oneocr.dll, oneocr.onemodel, onnxruntime.dll. Defaults to PROJECT_ROOT/ocr_data/. """ def __init__(self, ocr_data_dir: str | Path | None = None) -> None: if ocr_data_dir is None: ocr_data_dir = Path(__file__).resolve().parent.parent / "ocr_data" self._data_dir = str(Path(ocr_data_dir).resolve()) self._dll: ctypes.WinDLL | None = None # type: ignore[name-defined] self._init_options = c_int64() self._pipeline = c_int64() self._process_options = c_int64() self._load_dll() self._initialize_pipeline() def __del__(self) -> None: if self._dll: try: self._dll.ReleaseOcrProcessOptions(self._process_options) self._dll.ReleaseOcrPipeline(self._pipeline) self._dll.ReleaseOcrInitOptions(self._init_options) except Exception: pass # --- Public API --- def recognize_pil(self, image: Image.Image) -> OcrResult: """Run OCR on a PIL Image. Args: image: PIL Image object (any mode — will be converted to RGBA/BGRA). Returns: OcrResult with recognized text, lines, words, and confidence values. """ if any(x < 50 or x > 10000 for x in image.size): return OcrResult(error="Unsupported image size (must be 50-10000px)") if image.mode != "RGBA": image = image.convert("RGBA") # Convert RGB(A) → BGRA (DLL expects BGRA channel order) r, g, b, a = image.split() from PIL import Image as PILImage bgra_image = PILImage.merge("RGBA", (b, g, r, a)) return self._process_image( width=bgra_image.width, height=bgra_image.height, step=bgra_image.width * 4, data=bgra_image.tobytes(), ) def recognize_bytes(self, image_bytes: bytes) -> OcrResult: """Run OCR on raw image bytes (PNG/JPEG/etc). Args: image_bytes: Raw image file bytes. Returns: OcrResult. """ from io import BytesIO from PIL import Image img = Image.open(BytesIO(image_bytes)) return self.recognize_pil(img) # --- Internal --- def _load_dll(self) -> None: """Load oneocr.dll and bind function signatures.""" try: kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] if hasattr(kernel32, "SetDllDirectoryW"): kernel32.SetDllDirectoryW(self._data_dir) dll_path = os.path.join(self._data_dir, OCR_DLL_NAME) if not os.path.exists(dll_path): raise FileNotFoundError(f"OCR DLL not found: {dll_path}") self._dll = ctypes.WinDLL(dll_path) # type: ignore[attr-defined] for name, argtypes, restype in _DLL_FUNCTIONS: func = getattr(self._dll, name) func.argtypes = argtypes func.restype = restype except (OSError, RuntimeError) as e: raise RuntimeError(f"Failed to load OCR DLL from {self._data_dir}: {e}") from e def _initialize_pipeline(self) -> None: """Create OCR init options, pipeline, and process options.""" assert self._dll is not None # Init options self._check( self._dll.CreateOcrInitOptions(byref(self._init_options)), "CreateOcrInitOptions failed", ) self._check( self._dll.OcrInitOptionsSetUseModelDelayLoad(self._init_options, 0), "OcrInitOptionsSetUseModelDelayLoad failed", ) # Pipeline (loads AI model — suppress DLL stdout noise) model_path = os.path.join(self._data_dir, OCR_MODEL_NAME) if not os.path.exists(model_path): raise FileNotFoundError(f"OCR model not found: {model_path}") model_buf = ctypes.create_string_buffer(model_path.encode()) key_buf = ctypes.create_string_buffer(OCR_MODEL_KEY) with _suppress_output(): result = self._dll.CreateOcrPipeline( model_buf, key_buf, self._init_options, byref(self._pipeline) ) self._check(result, "CreateOcrPipeline failed (wrong key or corrupted model?)") # Process options self._check( self._dll.CreateOcrProcessOptions(byref(self._process_options)), "CreateOcrProcessOptions failed", ) self._check( self._dll.OcrProcessOptionsSetMaxRecognitionLineCount( self._process_options, OCR_MAX_LINES ), "OcrProcessOptionsSetMaxRecognitionLineCount failed", ) def _process_image(self, width: int, height: int, step: int, data: bytes) -> OcrResult: """Create image structure and run OCR pipeline.""" assert self._dll is not None data_ptr = (c_ubyte * len(data)).from_buffer_copy(data) img_struct = _ImageStructure( type=3, width=width, height=height, _reserved=0, step_size=step, data_ptr=data_ptr, ) ocr_result = c_int64() if self._dll.RunOcrPipeline( self._pipeline, byref(img_struct), self._process_options, byref(ocr_result) ) != 0: return OcrResult(error="RunOcrPipeline returned non-zero") parsed = self._parse_results(ocr_result) self._dll.ReleaseOcrResult(ocr_result) return parsed def _parse_results(self, ocr_result: c_int64) -> OcrResult: """Extract text, lines, words from DLL result handle.""" assert self._dll is not None line_count = c_int64() if self._dll.GetOcrLineCount(ocr_result, byref(line_count)) != 0: return OcrResult(error="GetOcrLineCount failed") lines: list[OcrLine] = [] for idx in range(line_count.value): line = self._parse_line(ocr_result, idx) if line: lines.append(line) # Text angle text_angle_val = c_float() text_angle: float | None = None if self._dll.GetImageAngle(ocr_result, byref(text_angle_val)) == 0: text_angle = text_angle_val.value full_text = "\n".join(line.text for line in lines if line.text) return OcrResult(text=full_text, text_angle=text_angle, lines=lines) def _parse_line(self, ocr_result: c_int64, line_index: int) -> OcrLine | None: """Parse a single line from OCR result.""" assert self._dll is not None line_handle = c_int64() if self._dll.GetOcrLine(ocr_result, line_index, byref(line_handle)) != 0: return None if not line_handle.value: return None # Line text content = c_char_p() line_text = "" if self._dll.GetOcrLineContent(line_handle, byref(content)) == 0 and content.value: line_text = content.value.decode("utf-8", errors="ignore") # Line bounding box line_bbox = self._get_bbox(line_handle, self._dll.GetOcrLineBoundingBox) # Words word_count = c_int64() words: list[OcrWord] = [] if self._dll.GetOcrLineWordCount(line_handle, byref(word_count)) == 0: for wi in range(word_count.value): word = self._parse_word(line_handle, wi) if word: words.append(word) return OcrLine(text=line_text, bounding_rect=line_bbox, words=words) def _parse_word(self, line_handle: c_int64, word_index: int) -> OcrWord | None: """Parse a single word.""" assert self._dll is not None word_handle = c_int64() if self._dll.GetOcrWord(line_handle, word_index, byref(word_handle)) != 0: return None # Word text content = c_char_p() word_text = "" if self._dll.GetOcrWordContent(word_handle, byref(content)) == 0 and content.value: word_text = content.value.decode("utf-8", errors="ignore") # Word bounding box word_bbox = self._get_bbox(word_handle, self._dll.GetOcrWordBoundingBox) # Word confidence confidence_val = c_float() confidence = 0.0 if self._dll.GetOcrWordConfidence(word_handle, byref(confidence_val)) == 0: confidence = confidence_val.value return OcrWord(text=word_text, bounding_rect=word_bbox, confidence=confidence) @staticmethod def _get_bbox(handle: c_int64, bbox_fn: object) -> BoundingRect | None: """Extract bounding box from a handle.""" bbox_ptr = _BoundingBox_p() if bbox_fn(handle, byref(bbox_ptr)) == 0 and bbox_ptr: # type: ignore[operator] bb = bbox_ptr.contents return BoundingRect( x1=bb.x1, y1=bb.y1, x2=bb.x2, y2=bb.y2, x3=bb.x3, y3=bb.y3, x4=bb.x4, y4=bb.y4, ) return None @staticmethod def _check(result_code: int, msg: str) -> None: if result_code != 0: raise RuntimeError(f"{msg} (code: {result_code})")