""" mini-kh-OCR Pipeline -------------------- Combines: - phonsobon/mini-text-detection (YOLO11n — detects subject / reference / content) - phonsobon/mini-ocr (CRNN + CTC — recognises Khmer & English text) Usage: from mini_kh_ocr import MiniKhOCR ocr = MiniKhOCR() result = ocr("your_image.jpg") print(result) """ import os import torch import torch.nn as nn import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from ultralytics import YOLO # ══════════════════════════════════════════════════════════════════════════════ # 1. CONSTANTS # ══════════════════════════════════════════════════════════════════════════════ CLASS_NAMES = {0: "subject", 1: "reference", 2: "content"} TOKENS = ( "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "0123456789" "កខគឃងចឆជឈញដឋឌឍណតថទធនបផពភមយរលវឝឞសហឡអឣឤឥឦឧឩឪឫឬឭឮឯឰឱឲឳ" "ាិីឹឺុូួើឿៀេែៃោៅំះៈ៉៊់៌៍៎៏័៑្។៕៖ៗ៘៛៝" "០១២៣៤៥៦៧៨៩៳" "!@#$%^&*()-_=+[]{};:'\",.<>?/|\\ " ) NUM_CHARS = len(TOKENS) IDX2CHAR = {i + 1: c for i, c in enumerate(TOKENS)} # ══════════════════════════════════════════════════════════════════════════════ # 2. OCR MODEL DEFINITION (KhmerOCR_DTWG) # ══════════════════════════════════════════════════════════════════════════════ class KhmerOCR_DTWG(nn.Module): def __init__(self, num_chars=NUM_CHARS, hidden_size=256): super().__init__() self.cnn = nn.Sequential( self._conv(1, 32), nn.MaxPool2d(2, 2), self._conv(32, 64), nn.MaxPool2d(2, 2), self._conv(64, 128), self._conv(128, 128), nn.MaxPool2d((2, 1), (2, 1)), self._conv(128, 256), self._conv(256, 256), nn.MaxPool2d((4, 1), (4, 1)), ) self.lstm1 = nn.LSTM(256, hidden_size, bidirectional=True, batch_first=True) self.fc1 = nn.Linear(hidden_size * 2, hidden_size) self.lstm2 = nn.LSTM(hidden_size, hidden_size, bidirectional=True, batch_first=True) self.fc = nn.Linear(hidden_size * 2, num_chars + 1) def _conv(self, i, o): return nn.Sequential( nn.Conv2d(i, o, 3, 1, 1, bias=False), nn.BatchNorm2d(o), nn.ReLU(inplace=True), ) def forward(self, x): x = self.cnn(x) x = x.squeeze(2).permute(0, 2, 1) x, _ = self.lstm1(x) x = torch.relu(self.fc1(x)) x, _ = self.lstm2(x) x = self.fc(x) return x.permute(1, 0, 2) # ══════════════════════════════════════════════════════════════════════════════ # 3. HELPERS # ══════════════════════════════════════════════════════════════════════════════ def _load_crop_for_ocr(pil_img: Image.Image) -> torch.Tensor: """Resize a PIL crop to height=32, normalise, return (1,1,32,W) tensor.""" img = pil_img.convert("L") w, h = img.size if h == 0: h = 1 new_w = max(1, int(w / h * 32)) img = img.resize((new_w, 32)) arr = np.array(img, dtype=np.float32) / 255.0 return torch.tensor(arr).unsqueeze(0).unsqueeze(0) # (1,1,32,W) def _ctc_decode(logits: torch.Tensor) -> str: """Greedy CTC decode — logits shape: (T, 1, C).""" preds = torch.argmax(logits, dim=2)[:, 0].cpu().numpy() prev, text = -1, [] for p in preds: if p != prev and p != 0: text.append(IDX2CHAR.get(int(p), "")) prev = p return "".join(text) def _sort_boxes_top_to_bottom(boxes, cls_ids, confs): """Sort detections by vertical position (top → bottom).""" order = sorted(range(len(boxes)), key=lambda i: boxes[i][1]) return [boxes[i] for i in order], [cls_ids[i] for i in order], [confs[i] for i in order] # ══════════════════════════════════════════════════════════════════════════════ # 4. MAIN PIPELINE CLASS # ══════════════════════════════════════════════════════════════════════════════ class MiniKhOCR: """ End-to-end Khmer OCR pipeline. Parameters ---------- det_conf : float — detection confidence threshold (default 0.25) det_iou : float — NMS IoU threshold (default 0.45) det_imgsz : int — detection image size (default 640) device : str — 'cuda' | 'cpu' | 'auto' """ def __init__( self, det_conf: float = 0.25, det_iou: float = 0.45, det_imgsz: int = 640, device: str = "auto", ): if device == "auto": self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) print(f"[mini-kh-OCR] Device: {self.device}") # ── detection model ──────────────────────────────────────────────── print("[mini-kh-OCR] Loading detection model ...") det_path = hf_hub_download( repo_id="phonsobon/mini-text-detection", filename="khmer-text-detection-mini.pt", ) self.detector = YOLO(det_path) self.det_conf = det_conf self.det_iou = det_iou self.det_imgsz = det_imgsz # ── recognition model ────────────────────────────────────────────── print("[mini-kh-OCR] Loading recognition model ...") ocr_path = hf_hub_download( repo_id="phonsobon/mini-ocr", filename="model.pt", ) self.recogniser = KhmerOCR_DTWG(NUM_CHARS).to(self.device) self.recogniser.load_state_dict( torch.load(ocr_path, map_location=self.device) ) self.recogniser.eval() print("[mini-kh-OCR] Ready ✅") # ────────────────────────────────────────────────────────────────────────── def _recognise(self, crop: Image.Image) -> str: """Run OCR on a single PIL crop.""" tensor = _load_crop_for_ocr(crop).to(self.device) with torch.no_grad(): logits = self.recogniser(tensor) return _ctc_decode(logits) # ────────────────────────────────────────────────────────────────────────── def __call__( self, image, return_crops: bool = False, verbose: bool = False, ) -> dict: """ Run detection + recognition on an image. Parameters ---------- image : str | PIL.Image — file path or PIL image return_crops : bool — include cropped PIL images in output verbose : bool — print each detected region Returns ------- dict with keys: "subject" : list of str "reference" : list of str "content" : list of str "regions" : list of dicts with box, class, conf, text (and crop if requested) """ if isinstance(image, str): pil_img = Image.open(image).convert("RGB") else: pil_img = image.convert("RGB") # ── Step 1: detect ──────────────────────────────────────────────── det_results = self.detector.predict( source=pil_img, conf=self.det_conf, iou=self.det_iou, imgsz=self.det_imgsz, verbose=False, ) raw_boxes = det_results[0].boxes.xyxy.cpu().numpy().astype(int).tolist() raw_cls = [int(c) for c in det_results[0].boxes.cls.cpu().numpy()] raw_conf = [float(c) for c in det_results[0].boxes.conf.cpu().numpy()] # ── Step 2: sort top → bottom ───────────────────────────────────── boxes, cls_ids, confs = _sort_boxes_top_to_bottom(raw_boxes, raw_cls, raw_conf) # ── Step 3: recognise each crop ─────────────────────────────────── result = {"subject": [], "reference": [], "content": [], "regions": []} for box, cls_id, conf in zip(boxes, cls_ids, confs): x1, y1, x2, y2 = box label = CLASS_NAMES.get(cls_id, "unknown") crop = pil_img.crop((x1, y1, x2, y2)) text = self._recognise(crop) if label in result: result[label].append(text) region = { "class": label, "conf": round(conf, 3), "box": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}, "text": text, } if return_crops: region["crop"] = crop result["regions"].append(region) if verbose: print(f" [{label}] ({x1},{y1})→({x2},{y2}) conf={conf:.2f} → {text!r}") return result # ────────────────────────────────────────────────────────────────────────── def to_document(self, result: dict) -> str: """ Format result as a structured text document. Example output: [SUBJECT] ភ្នំពេញ ក្រុង [REFERENCE] លេខ ០០១ [CONTENT] អត្ថបទដំបូង អត្ថបទទីពីរ """ lines = [] for cls in ("subject", "reference", "content"): texts = result.get(cls, []) if texts: lines.append(f"[{cls.upper()}]") lines.extend(texts) lines.append("") return "\n".join(lines).strip()