Spaces:
Running
Running
| from __future__ import annotations | |
| import io | |
| import threading | |
| from typing import Union | |
| from urllib.parse import urlparse | |
| import numpy as np | |
| from PIL import Image | |
| from app.config import get_settings | |
| from app.core.logger import get_logger | |
| _logger = get_logger(__name__) | |
| _settings = get_settings() | |
| _lock = threading.Lock() | |
| _engine = None | |
| def _get_engine(): | |
| global _engine | |
| if _engine is None: | |
| with _lock: | |
| if _engine is None: | |
| from rapidocr_onnxruntime import RapidOCR | |
| _engine = RapidOCR( | |
| Det={"use_cuda": _settings.ocr_det_cuda, "use_dml": _settings.ocr_det_dml}, | |
| Cls={"use_cuda": _settings.ocr_cls_cuda, "use_dml": _settings.ocr_cls_dml}, | |
| Rec={"use_cuda": _settings.ocr_rec_cuda, "use_dml": _settings.ocr_rec_dml}, | |
| print_verbose=False, | |
| ) | |
| return _engine | |
| def _to_numpy(source) -> Union[np.ndarray, str]: | |
| if isinstance(source, Image.Image): | |
| img = source | |
| if img.mode not in ("RGB", "L", "RGBA"): | |
| img = img.convert("RGB") | |
| return np.array(img) | |
| if isinstance(source, (bytes, bytearray)): | |
| img = Image.open(io.BytesIO(source)) | |
| if img.mode not in ("RGB", "L", "RGBA"): | |
| img = img.convert("RGB") | |
| return np.array(img) | |
| if isinstance(source, str): | |
| parsed = urlparse(source) | |
| if parsed.scheme in {"http", "https"}: | |
| import httpx | |
| resp = httpx.get(source, follow_redirects=True, timeout=30) | |
| resp.raise_for_status() | |
| img = Image.open(io.BytesIO(resp.content)) | |
| if img.mode not in ("RGB", "L", "RGBA"): | |
| img = img.convert("RGB") | |
| return np.array(img) | |
| return source | |
| if isinstance(source, np.ndarray): | |
| return source | |
| raise TypeError( | |
| f"ocr_image expects bytes, str, numpy.ndarray or PIL.Image; got {type(source).__name__}" | |
| ) | |
| def ocr_image( | |
| source, | |
| *, | |
| use_det: bool = True, | |
| use_cls: bool = True, | |
| use_rec: bool = True, | |
| text_score: float = 0.5, | |
| ) -> str: | |
| engine = _get_engine() | |
| img = _to_numpy(source) | |
| result, _ = engine( | |
| img, | |
| use_det=use_det, | |
| use_cls=use_cls, | |
| use_rec=use_rec, | |
| text_score=text_score, | |
| ) | |
| if not result: | |
| return "" | |
| lines = [item[1] for item in result if len(item) > 1 and item[1]] | |
| return "\n".join(lines) | |
| def ocr_pdf(source: Union[str, bytes], *, dpi: int = 150) -> str: | |
| try: | |
| import pypdfium2 as pdfium | |
| except ImportError: | |
| _logger.error("pypdfium2 not installed") | |
| return "" | |
| try: | |
| pdf = pdfium.PdfDocument(source) | |
| scale = dpi / 72.0 | |
| page_texts: list[str] = [] | |
| for page_index in range(len(pdf)): | |
| page = pdf[page_index] | |
| bitmap = page.render(scale=scale, rotation=0) | |
| pil_image = bitmap.to_pil() | |
| page_text = ocr_image(pil_image) | |
| if page_text: | |
| page_texts.append(page_text) | |
| pdf.close() | |
| return "\n\n".join(page_texts) | |
| except Exception as exc: | |
| _logger.error("Failed to OCR PDF: %s", exc) | |
| return "" | |
| class OCRService: | |
| def __init__(self) -> None: | |
| self._engine = _get_engine() | |
| def image_to_text(self, source, text_score: float = 0.5) -> str: | |
| return ocr_image(source, text_score=text_score) | |
| def pdf_to_text(self, source: Union[str, bytes], dpi: int = 150) -> str: | |
| return ocr_pdf(source, dpi=dpi) | |