oneocr / ocr /engine.py
OneOCR Dev
OneOCR - reverse engineering complete, ONNX pipeline 53% match rate
ce847d4
"""OCR engine — ctypes wrapper for Windows 11 SnippingTool's oneocr.dll.
Provides offline OCR capability using Microsoft's AI model from Snipping Tool.
Requires oneocr.dll, oneocr.onemodel, and onnxruntime.dll in ocr_data/ directory.
Usage:
from src.services.ocr.engine import OcrEngine
engine = OcrEngine()
result = engine.recognize_pil(pil_image)
print(result.text)
"""
from __future__ import annotations
import ctypes
import os
from contextlib import contextmanager
from ctypes import (
POINTER,
Structure,
byref,
c_char_p,
c_float,
c_int32,
c_int64,
c_ubyte,
)
from pathlib import Path
from typing import TYPE_CHECKING
from ocr.models import BoundingRect, OcrLine, OcrResult, OcrWord
# Constants (previously imported from src.config.constants)
OCR_DLL_NAME = "oneocr.dll"
OCR_MODEL_NAME = "oneocr.onemodel"
OCR_MODEL_KEY = b'kj)TGtrK>f]b[Piow.gU+nC@s""""""4'
OCR_MAX_LINES = 200
if TYPE_CHECKING:
from PIL import Image
c_int64_p = POINTER(c_int64)
c_float_p = POINTER(c_float)
c_ubyte_p = POINTER(c_ubyte)
class _ImageStructure(Structure):
"""Image data structure for oneocr.dll (CV_8UC4 format)."""
_fields_ = [
("type", c_int32),
("width", c_int32),
("height", c_int32),
("_reserved", c_int32),
("step_size", c_int64),
("data_ptr", c_ubyte_p),
]
class _BoundingBox(Structure):
"""Bounding box coordinates from DLL."""
_fields_ = [
("x1", c_float), ("y1", c_float),
("x2", c_float), ("y2", c_float),
("x3", c_float), ("y3", c_float),
("x4", c_float), ("y4", c_float),
]
_BoundingBox_p = POINTER(_BoundingBox)
# DLL function signatures: (name, argtypes, restype)
_DLL_FUNCTIONS: list[tuple[str, list[type], type | None]] = [
("CreateOcrInitOptions", [c_int64_p], c_int64),
("OcrInitOptionsSetUseModelDelayLoad", [c_int64, ctypes.c_char], c_int64),
("CreateOcrPipeline", [c_char_p, c_char_p, c_int64, c_int64_p], c_int64),
("CreateOcrProcessOptions", [c_int64_p], c_int64),
("OcrProcessOptionsSetMaxRecognitionLineCount", [c_int64, c_int64], c_int64),
("RunOcrPipeline", [c_int64, POINTER(_ImageStructure), c_int64, c_int64_p], c_int64),
("GetImageAngle", [c_int64, c_float_p], c_int64),
("GetOcrLineCount", [c_int64, c_int64_p], c_int64),
("GetOcrLine", [c_int64, c_int64, c_int64_p], c_int64),
("GetOcrLineContent", [c_int64, POINTER(c_char_p)], c_int64),
("GetOcrLineBoundingBox", [c_int64, POINTER(_BoundingBox_p)], c_int64),
("GetOcrLineWordCount", [c_int64, c_int64_p], c_int64),
("GetOcrWord", [c_int64, c_int64, c_int64_p], c_int64),
("GetOcrWordContent", [c_int64, POINTER(c_char_p)], c_int64),
("GetOcrWordBoundingBox", [c_int64, POINTER(_BoundingBox_p)], c_int64),
("GetOcrWordConfidence", [c_int64, c_float_p], c_int64),
("ReleaseOcrResult", [c_int64], None),
("ReleaseOcrInitOptions", [c_int64], None),
("ReleaseOcrPipeline", [c_int64], None),
("ReleaseOcrProcessOptions", [c_int64], None),
]
@contextmanager
def _suppress_output():
"""Suppress stdout/stderr during DLL initialization (it prints to console)."""
devnull = os.open(os.devnull, os.O_WRONLY)
original_stdout = os.dup(1)
original_stderr = os.dup(2)
os.dup2(devnull, 1)
os.dup2(devnull, 2)
try:
yield
finally:
os.dup2(original_stdout, 1)
os.dup2(original_stderr, 2)
os.close(original_stdout)
os.close(original_stderr)
os.close(devnull)
class OcrEngine:
"""Offline OCR engine using Windows 11 SnippingTool's oneocr.dll.
Args:
ocr_data_dir: Path to directory containing oneocr.dll, oneocr.onemodel, onnxruntime.dll.
Defaults to PROJECT_ROOT/ocr_data/.
"""
def __init__(self, ocr_data_dir: str | Path | None = None) -> None:
if ocr_data_dir is None:
ocr_data_dir = Path(__file__).resolve().parent.parent / "ocr_data"
self._data_dir = str(Path(ocr_data_dir).resolve())
self._dll: ctypes.WinDLL | None = None # type: ignore[name-defined]
self._init_options = c_int64()
self._pipeline = c_int64()
self._process_options = c_int64()
self._load_dll()
self._initialize_pipeline()
def __del__(self) -> None:
if self._dll:
try:
self._dll.ReleaseOcrProcessOptions(self._process_options)
self._dll.ReleaseOcrPipeline(self._pipeline)
self._dll.ReleaseOcrInitOptions(self._init_options)
except Exception:
pass
# --- Public API ---
def recognize_pil(self, image: Image.Image) -> OcrResult:
"""Run OCR on a PIL Image.
Args:
image: PIL Image object (any mode — will be converted to RGBA/BGRA).
Returns:
OcrResult with recognized text, lines, words, and confidence values.
"""
if any(x < 50 or x > 10000 for x in image.size):
return OcrResult(error="Unsupported image size (must be 50-10000px)")
if image.mode != "RGBA":
image = image.convert("RGBA")
# Convert RGB(A) → BGRA (DLL expects BGRA channel order)
r, g, b, a = image.split()
from PIL import Image as PILImage
bgra_image = PILImage.merge("RGBA", (b, g, r, a))
return self._process_image(
width=bgra_image.width,
height=bgra_image.height,
step=bgra_image.width * 4,
data=bgra_image.tobytes(),
)
def recognize_bytes(self, image_bytes: bytes) -> OcrResult:
"""Run OCR on raw image bytes (PNG/JPEG/etc).
Args:
image_bytes: Raw image file bytes.
Returns:
OcrResult.
"""
from io import BytesIO
from PIL import Image
img = Image.open(BytesIO(image_bytes))
return self.recognize_pil(img)
# --- Internal ---
def _load_dll(self) -> None:
"""Load oneocr.dll and bind function signatures."""
try:
kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
if hasattr(kernel32, "SetDllDirectoryW"):
kernel32.SetDllDirectoryW(self._data_dir)
dll_path = os.path.join(self._data_dir, OCR_DLL_NAME)
if not os.path.exists(dll_path):
raise FileNotFoundError(f"OCR DLL not found: {dll_path}")
self._dll = ctypes.WinDLL(dll_path) # type: ignore[attr-defined]
for name, argtypes, restype in _DLL_FUNCTIONS:
func = getattr(self._dll, name)
func.argtypes = argtypes
func.restype = restype
except (OSError, RuntimeError) as e:
raise RuntimeError(f"Failed to load OCR DLL from {self._data_dir}: {e}") from e
def _initialize_pipeline(self) -> None:
"""Create OCR init options, pipeline, and process options."""
assert self._dll is not None
# Init options
self._check(
self._dll.CreateOcrInitOptions(byref(self._init_options)),
"CreateOcrInitOptions failed",
)
self._check(
self._dll.OcrInitOptionsSetUseModelDelayLoad(self._init_options, 0),
"OcrInitOptionsSetUseModelDelayLoad failed",
)
# Pipeline (loads AI model — suppress DLL stdout noise)
model_path = os.path.join(self._data_dir, OCR_MODEL_NAME)
if not os.path.exists(model_path):
raise FileNotFoundError(f"OCR model not found: {model_path}")
model_buf = ctypes.create_string_buffer(model_path.encode())
key_buf = ctypes.create_string_buffer(OCR_MODEL_KEY)
with _suppress_output():
result = self._dll.CreateOcrPipeline(
model_buf, key_buf, self._init_options, byref(self._pipeline)
)
self._check(result, "CreateOcrPipeline failed (wrong key or corrupted model?)")
# Process options
self._check(
self._dll.CreateOcrProcessOptions(byref(self._process_options)),
"CreateOcrProcessOptions failed",
)
self._check(
self._dll.OcrProcessOptionsSetMaxRecognitionLineCount(
self._process_options, OCR_MAX_LINES
),
"OcrProcessOptionsSetMaxRecognitionLineCount failed",
)
def _process_image(self, width: int, height: int, step: int, data: bytes) -> OcrResult:
"""Create image structure and run OCR pipeline."""
assert self._dll is not None
data_ptr = (c_ubyte * len(data)).from_buffer_copy(data)
img_struct = _ImageStructure(
type=3,
width=width,
height=height,
_reserved=0,
step_size=step,
data_ptr=data_ptr,
)
ocr_result = c_int64()
if self._dll.RunOcrPipeline(
self._pipeline, byref(img_struct), self._process_options, byref(ocr_result)
) != 0:
return OcrResult(error="RunOcrPipeline returned non-zero")
parsed = self._parse_results(ocr_result)
self._dll.ReleaseOcrResult(ocr_result)
return parsed
def _parse_results(self, ocr_result: c_int64) -> OcrResult:
"""Extract text, lines, words from DLL result handle."""
assert self._dll is not None
line_count = c_int64()
if self._dll.GetOcrLineCount(ocr_result, byref(line_count)) != 0:
return OcrResult(error="GetOcrLineCount failed")
lines: list[OcrLine] = []
for idx in range(line_count.value):
line = self._parse_line(ocr_result, idx)
if line:
lines.append(line)
# Text angle
text_angle_val = c_float()
text_angle: float | None = None
if self._dll.GetImageAngle(ocr_result, byref(text_angle_val)) == 0:
text_angle = text_angle_val.value
full_text = "\n".join(line.text for line in lines if line.text)
return OcrResult(text=full_text, text_angle=text_angle, lines=lines)
def _parse_line(self, ocr_result: c_int64, line_index: int) -> OcrLine | None:
"""Parse a single line from OCR result."""
assert self._dll is not None
line_handle = c_int64()
if self._dll.GetOcrLine(ocr_result, line_index, byref(line_handle)) != 0:
return None
if not line_handle.value:
return None
# Line text
content = c_char_p()
line_text = ""
if self._dll.GetOcrLineContent(line_handle, byref(content)) == 0 and content.value:
line_text = content.value.decode("utf-8", errors="ignore")
# Line bounding box
line_bbox = self._get_bbox(line_handle, self._dll.GetOcrLineBoundingBox)
# Words
word_count = c_int64()
words: list[OcrWord] = []
if self._dll.GetOcrLineWordCount(line_handle, byref(word_count)) == 0:
for wi in range(word_count.value):
word = self._parse_word(line_handle, wi)
if word:
words.append(word)
return OcrLine(text=line_text, bounding_rect=line_bbox, words=words)
def _parse_word(self, line_handle: c_int64, word_index: int) -> OcrWord | None:
"""Parse a single word."""
assert self._dll is not None
word_handle = c_int64()
if self._dll.GetOcrWord(line_handle, word_index, byref(word_handle)) != 0:
return None
# Word text
content = c_char_p()
word_text = ""
if self._dll.GetOcrWordContent(word_handle, byref(content)) == 0 and content.value:
word_text = content.value.decode("utf-8", errors="ignore")
# Word bounding box
word_bbox = self._get_bbox(word_handle, self._dll.GetOcrWordBoundingBox)
# Word confidence
confidence_val = c_float()
confidence = 0.0
if self._dll.GetOcrWordConfidence(word_handle, byref(confidence_val)) == 0:
confidence = confidence_val.value
return OcrWord(text=word_text, bounding_rect=word_bbox, confidence=confidence)
@staticmethod
def _get_bbox(handle: c_int64, bbox_fn: object) -> BoundingRect | None:
"""Extract bounding box from a handle."""
bbox_ptr = _BoundingBox_p()
if bbox_fn(handle, byref(bbox_ptr)) == 0 and bbox_ptr: # type: ignore[operator]
bb = bbox_ptr.contents
return BoundingRect(
x1=bb.x1, y1=bb.y1, x2=bb.x2, y2=bb.y2,
x3=bb.x3, y3=bb.y3, x4=bb.x4, y4=bb.y4,
)
return None
@staticmethod
def _check(result_code: int, msg: str) -> None:
if result_code != 0:
raise RuntimeError(f"{msg} (code: {result_code})")