File size: 6,424 Bytes
293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee 246af2f 293f9ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """
HuggingFace Inference Endpoint custom handler for PaddleOCR.
Conforms to the HF Inference Toolkit EndpointHandler contract:
- __init__(path="") called once at startup
- __call__(data) called per request; data always contains "inputs" key
Supports two call modes determined by the shape of `data["inputs"]`:
Single image:
{ "inputs": "<base64-string>" }
Returns: { "results": [["text", confidence], ...] }
Batch images (send ALL tiles for a page in one call for maximum GPU throughput):
{ "inputs": [{"id": "<any>", "image_base64": "<base64-string>"}, ...] }
Returns: { "results": {"<id>": [["text", confidence], ...], ...} }
Performance note:
PaddleOCR 3.x predict() accepts a list of numpy arrays and processes them
as a single GPU batch — dramatically faster than calling it per-image.
Always prefer one batch call per page over multiple single calls.
"""
from __future__ import annotations
import base64
import io
import logging
from typing import Any, Dict, List, Tuple
# Install langchain shim BEFORE paddleocr / paddlex ever gets imported.
import _shim # noqa: F401
import numpy as np
from PIL import Image
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("paddleocr-handler")
def _decode_image(image_base64: str) -> np.ndarray:
"""Decode a base64-encoded PNG/JPEG string into an RGB numpy array."""
if "," in image_base64 and image_base64.strip().startswith("data:"):
image_base64 = image_base64.split(",", 1)[1]
raw = base64.b64decode(image_base64)
img = Image.open(io.BytesIO(raw)).convert("RGB")
return np.array(img)
def _parse_single_result(page_result) -> List[Tuple[str, float]]:
"""
Normalize one page's PaddleOCR output to a flat list of (text, confidence).
Handles both v2.x nested-list format and v3.x dict format.
"""
results: List[Tuple[str, float]] = []
if not page_result:
return results
# PaddleOCR 3.x predict() per-image result: {'rec_texts': [...], 'rec_scores': [...]}
if isinstance(page_result, dict):
texts = page_result.get("rec_texts") or []
scores = page_result.get("rec_scores") or []
for t, s in zip(texts, scores):
results.append((str(t), float(s)))
return results
# Legacy: page_result is a list [[box, (text, conf)], ...]
# When called via ocr(), raw = [[...]] so caller passes raw[0]
for line in page_result or []:
try:
text_part = line[1]
results.append((str(text_part[0]), float(text_part[1])))
except (IndexError, TypeError, ValueError):
continue
return results
class EndpointHandler:
def __init__(self, path: str = ""):
"""Called once when the endpoint starts. Loads the PaddleOCR engine."""
from paddleocr import PaddleOCR
logger.info("Initializing PaddleOCR engine (GPU)...")
# use_gpu is picked up automatically when a CUDA device is present;
# setting it explicitly ensures it is not silently skipped.
self._ocr = PaddleOCR(lang="en", use_textline_orientation=True, use_gpu=True)
logger.info("PaddleOCR engine ready.")
def _run_batch(self, images: List[np.ndarray]) -> List[List[Tuple[str, float]]]:
"""
Run OCR on a list of images in one GPU call via predict().
Falls back to per-image ocr() if predict() is unavailable.
Returns a list of results, one per input image.
"""
ocr = self._ocr
# --- Fast path: predict() accepts a list and batches on GPU ---
if hasattr(ocr, "predict"):
try:
raw_batch = ocr.predict(images) # returns list, one result per image
return [_parse_single_result(r) for r in raw_batch]
except Exception as exc:
logger.warning(f"predict() batch failed, falling back to per-image ocr(): {exc}")
# --- Fallback: ocr() called per image ---
all_results = []
for img in images:
raw = None
try:
raw = ocr.ocr(img, cls=True)
except TypeError:
raw = ocr.ocr(img)
except Exception:
raw = None
page = raw[0] if raw else None
all_results.append(_parse_single_result(page))
return all_results
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Route to single or batch processing based on the shape of `inputs`.
Single: data = {"inputs": "<base64>"}
Batch: data = {"inputs": [{"id": "...", "image_base64": "..."}, ...]}
"""
inputs = data.get("inputs", data)
# ---- Batch mode: decode all images then run as one GPU batch ----
if isinstance(inputs, list):
ids: List[str] = []
arrays: List[np.ndarray] = []
decode_errors: Dict[str, str] = {}
for i, item in enumerate(inputs):
item_id = item.get("id", str(i))
ids.append(item_id)
try:
arrays.append(_decode_image(item["image_base64"]))
except Exception as exc:
logger.warning(f"Decode error for id={item_id}: {exc}")
decode_errors[item_id] = str(exc)
arrays.append(None) # placeholder to keep index alignment
# Filter out failed decodes, run batch, then re-align results
valid_indices = [i for i, a in enumerate(arrays) if a is not None]
valid_arrays = [arrays[i] for i in valid_indices]
batch_results = self._run_batch(valid_arrays) if valid_arrays else []
out: Dict[str, Any] = {}
result_iter = iter(batch_results)
for i, item_id in enumerate(ids):
if arrays[i] is None:
out[item_id] = []
else:
out[item_id] = next(result_iter, [])
return {"results": out}
# ---- Single mode ----
try:
img_array = _decode_image(str(inputs))
results = self._run_batch([img_array])
return {"results": results[0] if results else []}
except Exception as exc:
logger.exception(f"OCR error: {exc}")
return {"error": str(exc), "results": []}
|