"""OCR vision stack only (no LLM). Used by Celery OCR worker and composed by ``agents.ocr_agent``.""" from __future__ import annotations import logging import os import uuid from typing import Any, Dict, List, Optional, Tuple import cv2 import numpy as np from .compat import allow_ultralytics_weights logger = logging.getLogger(__name__) _OCR_MAX_EDGE = 2000 _CROP_PAD = 4 class OcrVisionPipeline: """ Hybrid pipeline: 1. YOLO for layout analysis (weights preloaded; layout path reserved). 2. PaddleOCR for Vietnamese text extraction. 3. Pix2Tex for LaTeX formula extraction. """ def __init__(self) -> None: logger.info("[OcrVisionPipeline] Initializing engines...") try: from ultralytics import YOLO allow_ultralytics_weights() logger.info("[OcrVisionPipeline] Loading YOLO...") self.layout_model = YOLO("yolov8n.pt") logger.info("[OcrVisionPipeline] YOLO initialized.") except Exception as e: logger.error("[OcrVisionPipeline] YOLO init failed: %s", e) self.layout_model = None try: from paddleocr import PaddleOCR logger.info("[OcrVisionPipeline] Loading PaddleOCR...") self.text_model = PaddleOCR(use_angle_cls=True, lang="vi") logger.info("[OcrVisionPipeline] PaddleOCR (vi) initialized.") except Exception as e: logger.error("[OcrVisionPipeline] PaddleOCR init failed: %s", e) self.text_model = None try: from pix2tex.cli import LatexOCR logger.info("[OcrVisionPipeline] Loading Pix2Tex...") self.math_model = LatexOCR() logger.info("[OcrVisionPipeline] Pix2Tex initialized.") except Exception as e: logger.error("[OcrVisionPipeline] Pix2Tex init failed: %s", e) self.math_model = None def _preprocess_image_for_ocr(self, src_path: str) -> Tuple[str, bool]: """Resize large images, CLAHE on luminance; returns path (may be new temp file).""" img = cv2.imread(src_path, cv2.IMREAD_COLOR) if img is None: g = cv2.imread(src_path, cv2.IMREAD_GRAYSCALE) if g is None: logger.warning("[OcrVisionPipeline] OpenCV could not read %s; using original.", src_path) return src_path, False img = cv2.cvtColor(g, cv2.COLOR_GRAY2BGR) h, w = img.shape[:2] max_dim = max(h, w) if max_dim > _OCR_MAX_EDGE: scale = _OCR_MAX_EDGE / max_dim img = cv2.resize( img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA ) logger.info("[OcrVisionPipeline] Resized for OCR to max edge %s", _OCR_MAX_EDGE) gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) gray = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray) den = cv2.fastNlMeansDenoising(gray, None, 8, 7, 21) out = f"temp_ocr_prep_{uuid.uuid4().hex}.png" cv2.imwrite(out, den) return out, True def _load_bgr_for_crops(self, path: str) -> Optional[np.ndarray]: im = cv2.imread(path, cv2.IMREAD_COLOR) if im is None: g = cv2.imread(path, cv2.IMREAD_GRAYSCALE) if g is None: return None im = cv2.cvtColor(g, cv2.COLOR_GRAY2BGR) return im def _crop_from_quad(self, img_bgr: np.ndarray, bbox) -> Optional[np.ndarray]: try: pts = np.array(bbox, dtype=np.float32) xs = pts[:, 0] ys = pts[:, 1] H, W = img_bgr.shape[:2] x1 = max(0, int(xs.min()) - _CROP_PAD) y1 = max(0, int(ys.min()) - _CROP_PAD) x2 = min(W, int(xs.max()) + _CROP_PAD) y2 = min(H, int(ys.max()) + _CROP_PAD) if x2 <= x1 or y2 <= y1: return None return img_bgr[y1:y2, x1:x2].copy() except Exception as e: logger.debug("[OcrVisionPipeline] crop failed: %s", e) return None def _latex_from_crop_bgr(self, crop_bgr: np.ndarray) -> Optional[str]: if self.math_model is None or crop_bgr is None or crop_bgr.size == 0: return None ch, cw = crop_bgr.shape[:2] if ch < 10 or cw < 10: return None try: from PIL import Image rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB) pil = Image.fromarray(rgb) out = self.math_model(pil) if isinstance(out, str) and out.strip(): return out.strip() except Exception as e: logger.debug("[OcrVisionPipeline] Pix2Tex on crop failed: %s", e) return None def _maybe_math_from_crop(self, img_bgr: Optional[np.ndarray], bbox, text: str) -> str: if img_bgr is None or not self.math_model: return text is_math_hint = any( c in text for c in ["\\", "^", "_", "{", "}", "=", "+", "-", "*", "/"] ) if not is_math_hint: return text crop = self._crop_from_quad(img_bgr, bbox) latex = self._latex_from_crop_bgr(crop) if crop is not None else None if latex: logger.info("[OcrVisionPipeline] Pix2Tex replaced line fragment (len=%s)", len(latex)) return f"${latex}$" return text async def process_image(self, image_path: str) -> str: """Return assembled raw OCR text (no LLM).""" logger.info("==[OcrVisionPipeline] Processing: %s==", image_path) if not os.path.exists(image_path): return f"Error: File {image_path} not found." prep_path, prep_cleanup = self._preprocess_image_for_ocr(image_path) paddle_path = prep_path if prep_cleanup else image_path img_bgr = self._load_bgr_for_crops(prep_path if prep_cleanup else image_path) raw_fragments: List[Dict[str, Any]] = [] try: if self.text_model: logger.info("[OcrVisionPipeline] Running PaddleOCR on %s...", paddle_path) result = self.text_model.ocr(paddle_path) logger.info("[OcrVisionPipeline] PaddleOCR raw result: %s", result) if not result: logger.warning("[OcrVisionPipeline] PaddleOCR returned no results.") return "" if isinstance(result[0], dict): res_dict = result[0] rec_texts = res_dict.get("rec_texts", []) rec_scores = res_dict.get("rec_scores", []) rec_polys = res_dict.get("rec_polys", []) for i in range(len(rec_texts)): text = rec_texts[i] bbox = rec_polys[i] score = rec_scores[i] if i < len(rec_scores) else None if score is not None and float(score) < 0.45: logger.debug( "[OcrVisionPipeline] Low-confidence line (score=%s): %s", score, text[:80], ) y_top = int(min(p[1] for p in bbox)) if hasattr(bbox, "__iter__") else 0 content = self._maybe_math_from_crop(img_bgr, bbox, text) raw_fragments.append({"y": y_top, "content": content, "type": "text"}) elif isinstance(result[0], list): for line in result[0]: bbox = line[0] text = line[1][0] score = line[1][1] if len(line[1]) > 1 else None if score is not None and float(score) < 0.45: logger.debug( "[OcrVisionPipeline] Low-confidence line (score=%s): %s", score, text[:80], ) y_top = int(bbox[0][1]) content = self._maybe_math_from_crop(img_bgr, bbox, text) raw_fragments.append({"y": y_top, "content": content, "type": "text"}) finally: if prep_cleanup and os.path.exists(prep_path): try: os.remove(prep_path) except OSError: pass raw_fragments.sort(key=lambda x: x["y"]) combined_text = "\n".join([f["content"] for f in raw_fragments]) logger.info( "[OcrVisionPipeline] Raw OCR output assembled:\n---\n%s\n---", combined_text ) if not combined_text.strip(): logger.warning("[OcrVisionPipeline] No text detected.") return "" return combined_text async def process_url(self, url: str) -> str: """Download image and run ``process_image`` (raw only).""" import httpx from app.url_utils import sanitize_url url = sanitize_url(url) if not url: return "Error: Empty image URL after cleanup." async with httpx.AsyncClient() as client: resp = await client.get(url) if resp.status_code == 200: temp_path = "temp_url_image.png" with open(temp_path, "wb") as f: f.write(resp.content) try: return await self.process_image(temp_path) finally: if os.path.exists(temp_path): os.remove(temp_path) return f"Error: Failed to fetch image from URL {url}"