custom_padleocr / handler.py
Gurdaan Walia
Optimize: use predict() for native GPU batching, improve result parsing
246af2f
Raw
History Blame Contribute Delete
6.42 kB
"""
HuggingFace Inference Endpoint custom handler for PaddleOCR.
Conforms to the HF Inference Toolkit EndpointHandler contract:
- __init__(path="") called once at startup
- __call__(data) called per request; data always contains "inputs" key
Supports two call modes determined by the shape of `data["inputs"]`:
Single image:
{ "inputs": "<base64-string>" }
Returns: { "results": [["text", confidence], ...] }
Batch images (send ALL tiles for a page in one call for maximum GPU throughput):
{ "inputs": [{"id": "<any>", "image_base64": "<base64-string>"}, ...] }
Returns: { "results": {"<id>": [["text", confidence], ...], ...} }
Performance note:
PaddleOCR 3.x predict() accepts a list of numpy arrays and processes them
as a single GPU batch — dramatically faster than calling it per-image.
Always prefer one batch call per page over multiple single calls.
"""
from __future__ import annotations
import base64
import io
import logging
from typing import Any, Dict, List, Tuple
# Install langchain shim BEFORE paddleocr / paddlex ever gets imported.
import _shim # noqa: F401
import numpy as np
from PIL import Image
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("paddleocr-handler")
def _decode_image(image_base64: str) -> np.ndarray:
"""Decode a base64-encoded PNG/JPEG string into an RGB numpy array."""
if "," in image_base64 and image_base64.strip().startswith("data:"):
image_base64 = image_base64.split(",", 1)[1]
raw = base64.b64decode(image_base64)
img = Image.open(io.BytesIO(raw)).convert("RGB")
return np.array(img)
def _parse_single_result(page_result) -> List[Tuple[str, float]]:
"""
Normalize one page's PaddleOCR output to a flat list of (text, confidence).
Handles both v2.x nested-list format and v3.x dict format.
"""
results: List[Tuple[str, float]] = []
if not page_result:
return results
# PaddleOCR 3.x predict() per-image result: {'rec_texts': [...], 'rec_scores': [...]}
if isinstance(page_result, dict):
texts = page_result.get("rec_texts") or []
scores = page_result.get("rec_scores") or []
for t, s in zip(texts, scores):
results.append((str(t), float(s)))
return results
# Legacy: page_result is a list [[box, (text, conf)], ...]
# When called via ocr(), raw = [[...]] so caller passes raw[0]
for line in page_result or []:
try:
text_part = line[1]
results.append((str(text_part[0]), float(text_part[1])))
except (IndexError, TypeError, ValueError):
continue
return results
class EndpointHandler:
def __init__(self, path: str = ""):
"""Called once when the endpoint starts. Loads the PaddleOCR engine."""
from paddleocr import PaddleOCR
logger.info("Initializing PaddleOCR engine (GPU)...")
# use_gpu is picked up automatically when a CUDA device is present;
# setting it explicitly ensures it is not silently skipped.
self._ocr = PaddleOCR(lang="en", use_textline_orientation=True, use_gpu=True)
logger.info("PaddleOCR engine ready.")
def _run_batch(self, images: List[np.ndarray]) -> List[List[Tuple[str, float]]]:
"""
Run OCR on a list of images in one GPU call via predict().
Falls back to per-image ocr() if predict() is unavailable.
Returns a list of results, one per input image.
"""
ocr = self._ocr
# --- Fast path: predict() accepts a list and batches on GPU ---
if hasattr(ocr, "predict"):
try:
raw_batch = ocr.predict(images) # returns list, one result per image
return [_parse_single_result(r) for r in raw_batch]
except Exception as exc:
logger.warning(f"predict() batch failed, falling back to per-image ocr(): {exc}")
# --- Fallback: ocr() called per image ---
all_results = []
for img in images:
raw = None
try:
raw = ocr.ocr(img, cls=True)
except TypeError:
raw = ocr.ocr(img)
except Exception:
raw = None
page = raw[0] if raw else None
all_results.append(_parse_single_result(page))
return all_results
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Route to single or batch processing based on the shape of `inputs`.
Single: data = {"inputs": "<base64>"}
Batch: data = {"inputs": [{"id": "...", "image_base64": "..."}, ...]}
"""
inputs = data.get("inputs", data)
# ---- Batch mode: decode all images then run as one GPU batch ----
if isinstance(inputs, list):
ids: List[str] = []
arrays: List[np.ndarray] = []
decode_errors: Dict[str, str] = {}
for i, item in enumerate(inputs):
item_id = item.get("id", str(i))
ids.append(item_id)
try:
arrays.append(_decode_image(item["image_base64"]))
except Exception as exc:
logger.warning(f"Decode error for id={item_id}: {exc}")
decode_errors[item_id] = str(exc)
arrays.append(None) # placeholder to keep index alignment
# Filter out failed decodes, run batch, then re-align results
valid_indices = [i for i, a in enumerate(arrays) if a is not None]
valid_arrays = [arrays[i] for i in valid_indices]
batch_results = self._run_batch(valid_arrays) if valid_arrays else []
out: Dict[str, Any] = {}
result_iter = iter(batch_results)
for i, item_id in enumerate(ids):
if arrays[i] is None:
out[item_id] = []
else:
out[item_id] = next(result_iter, [])
return {"results": out}
# ---- Single mode ----
try:
img_array = _decode_image(str(inputs))
results = self._run_batch([img_array])
return {"results": results[0] if results else []}
except Exception as exc:
logger.exception(f"OCR error: {exc}")
return {"error": str(exc), "results": []}