| """ |
| mini-kh-OCR Pipeline |
| -------------------- |
| Combines: |
| - phonsobon/mini-text-detection (YOLO11n — detects subject / reference / content) |
| - phonsobon/mini-ocr (CRNN + CTC — recognises Khmer & English text) |
| |
| Usage: |
| from mini_kh_ocr import MiniKhOCR |
| ocr = MiniKhOCR() |
| result = ocr("your_image.jpg") |
| print(result) |
| """ |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
| from ultralytics import YOLO |
|
|
|
|
| |
| |
| |
|
|
| CLASS_NAMES = {0: "subject", 1: "reference", 2: "content"} |
|
|
| TOKENS = ( |
| "abcdefghijklmnopqrstuvwxyz" |
| "ABCDEFGHIJKLMNOPQRSTUVWXYZ" |
| "0123456789" |
| "កខគឃងចឆជឈញដឋឌឍណតថទធនបផពភមយរលវឝឞសហឡអឣឤឥឦឧឩឪឫឬឭឮឯឰឱឲឳ" |
| "ាិីឹឺុូួើឿៀេែៃោៅំះៈ៉៊់៌៍៎៏័៑្។៕៖ៗ៘៛៝" |
| "០១២៣៤៥៦៧៨៩៳" |
| "!@#$%^&*()-_=+[]{};:'\",.<>?/|\\ " |
| ) |
| NUM_CHARS = len(TOKENS) |
| IDX2CHAR = {i + 1: c for i, c in enumerate(TOKENS)} |
|
|
|
|
| |
| |
| |
|
|
| class KhmerOCR_DTWG(nn.Module): |
| def __init__(self, num_chars=NUM_CHARS, hidden_size=256): |
| super().__init__() |
| self.cnn = nn.Sequential( |
| self._conv(1, 32), nn.MaxPool2d(2, 2), |
| self._conv(32, 64), nn.MaxPool2d(2, 2), |
| self._conv(64, 128), |
| self._conv(128, 128), |
| nn.MaxPool2d((2, 1), (2, 1)), |
| self._conv(128, 256), |
| self._conv(256, 256), |
| nn.MaxPool2d((4, 1), (4, 1)), |
| ) |
| self.lstm1 = nn.LSTM(256, hidden_size, bidirectional=True, batch_first=True) |
| self.fc1 = nn.Linear(hidden_size * 2, hidden_size) |
| self.lstm2 = nn.LSTM(hidden_size, hidden_size, bidirectional=True, batch_first=True) |
| self.fc = nn.Linear(hidden_size * 2, num_chars + 1) |
|
|
| def _conv(self, i, o): |
| return nn.Sequential( |
| nn.Conv2d(i, o, 3, 1, 1, bias=False), |
| nn.BatchNorm2d(o), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| def forward(self, x): |
| x = self.cnn(x) |
| x = x.squeeze(2).permute(0, 2, 1) |
| x, _ = self.lstm1(x) |
| x = torch.relu(self.fc1(x)) |
| x, _ = self.lstm2(x) |
| x = self.fc(x) |
| return x.permute(1, 0, 2) |
|
|
|
|
| |
| |
| |
|
|
| def _load_crop_for_ocr(pil_img: Image.Image) -> torch.Tensor: |
| """Resize a PIL crop to height=32, normalise, return (1,1,32,W) tensor.""" |
| img = pil_img.convert("L") |
| w, h = img.size |
| if h == 0: |
| h = 1 |
| new_w = max(1, int(w / h * 32)) |
| img = img.resize((new_w, 32)) |
| arr = np.array(img, dtype=np.float32) / 255.0 |
| return torch.tensor(arr).unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def _ctc_decode(logits: torch.Tensor) -> str: |
| """Greedy CTC decode — logits shape: (T, 1, C).""" |
| preds = torch.argmax(logits, dim=2)[:, 0].cpu().numpy() |
| prev, text = -1, [] |
| for p in preds: |
| if p != prev and p != 0: |
| text.append(IDX2CHAR.get(int(p), "")) |
| prev = p |
| return "".join(text) |
|
|
|
|
| def _sort_boxes_top_to_bottom(boxes, cls_ids, confs): |
| """Sort detections by vertical position (top → bottom).""" |
| order = sorted(range(len(boxes)), key=lambda i: boxes[i][1]) |
| return [boxes[i] for i in order], [cls_ids[i] for i in order], [confs[i] for i in order] |
|
|
|
|
| |
| |
| |
|
|
| class MiniKhOCR: |
| """ |
| End-to-end Khmer OCR pipeline. |
| |
| Parameters |
| ---------- |
| det_conf : float — detection confidence threshold (default 0.25) |
| det_iou : float — NMS IoU threshold (default 0.45) |
| det_imgsz : int — detection image size (default 640) |
| device : str — 'cuda' | 'cpu' | 'auto' |
| """ |
|
|
| def __init__( |
| self, |
| det_conf: float = 0.25, |
| det_iou: float = 0.45, |
| det_imgsz: int = 640, |
| device: str = "auto", |
| ): |
| if device == "auto": |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| self.device = torch.device(device) |
|
|
| print(f"[mini-kh-OCR] Device: {self.device}") |
|
|
| |
| print("[mini-kh-OCR] Loading detection model ...") |
| det_path = hf_hub_download( |
| repo_id="phonsobon/mini-text-detection", |
| filename="khmer-text-detection-mini.pt", |
| ) |
| self.detector = YOLO(det_path) |
| self.det_conf = det_conf |
| self.det_iou = det_iou |
| self.det_imgsz = det_imgsz |
|
|
| |
| print("[mini-kh-OCR] Loading recognition model ...") |
| ocr_path = hf_hub_download( |
| repo_id="phonsobon/mini-ocr", |
| filename="model.pt", |
| ) |
| self.recogniser = KhmerOCR_DTWG(NUM_CHARS).to(self.device) |
| self.recogniser.load_state_dict( |
| torch.load(ocr_path, map_location=self.device) |
| ) |
| self.recogniser.eval() |
|
|
| print("[mini-kh-OCR] Ready ✅") |
|
|
| |
| def _recognise(self, crop: Image.Image) -> str: |
| """Run OCR on a single PIL crop.""" |
| tensor = _load_crop_for_ocr(crop).to(self.device) |
| with torch.no_grad(): |
| logits = self.recogniser(tensor) |
| return _ctc_decode(logits) |
|
|
| |
| def __call__( |
| self, |
| image, |
| return_crops: bool = False, |
| verbose: bool = False, |
| ) -> dict: |
| """ |
| Run detection + recognition on an image. |
| |
| Parameters |
| ---------- |
| image : str | PIL.Image — file path or PIL image |
| return_crops : bool — include cropped PIL images in output |
| verbose : bool — print each detected region |
| |
| Returns |
| ------- |
| dict with keys: |
| "subject" : list of str |
| "reference" : list of str |
| "content" : list of str |
| "regions" : list of dicts with box, class, conf, text (and crop if requested) |
| """ |
| if isinstance(image, str): |
| pil_img = Image.open(image).convert("RGB") |
| else: |
| pil_img = image.convert("RGB") |
|
|
| |
| det_results = self.detector.predict( |
| source=pil_img, |
| conf=self.det_conf, |
| iou=self.det_iou, |
| imgsz=self.det_imgsz, |
| verbose=False, |
| ) |
|
|
| raw_boxes = det_results[0].boxes.xyxy.cpu().numpy().astype(int).tolist() |
| raw_cls = [int(c) for c in det_results[0].boxes.cls.cpu().numpy()] |
| raw_conf = [float(c) for c in det_results[0].boxes.conf.cpu().numpy()] |
|
|
| |
| boxes, cls_ids, confs = _sort_boxes_top_to_bottom(raw_boxes, raw_cls, raw_conf) |
|
|
| |
| result = {"subject": [], "reference": [], "content": [], "regions": []} |
|
|
| for box, cls_id, conf in zip(boxes, cls_ids, confs): |
| x1, y1, x2, y2 = box |
| label = CLASS_NAMES.get(cls_id, "unknown") |
|
|
| crop = pil_img.crop((x1, y1, x2, y2)) |
| text = self._recognise(crop) |
|
|
| if label in result: |
| result[label].append(text) |
|
|
| region = { |
| "class": label, |
| "conf": round(conf, 3), |
| "box": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}, |
| "text": text, |
| } |
| if return_crops: |
| region["crop"] = crop |
|
|
| result["regions"].append(region) |
|
|
| if verbose: |
| print(f" [{label}] ({x1},{y1})→({x2},{y2}) conf={conf:.2f} → {text!r}") |
|
|
| return result |
|
|
| |
| def to_document(self, result: dict) -> str: |
| """ |
| Format result as a structured text document. |
| |
| Example output: |
| [SUBJECT] |
| ភ្នំពេញ ក្រុង |
| |
| [REFERENCE] |
| លេខ ០០១ |
| |
| [CONTENT] |
| អត្ថបទដំបូង |
| អត្ថបទទីពីរ |
| """ |
| lines = [] |
| for cls in ("subject", "reference", "content"): |
| texts = result.get(cls, []) |
| if texts: |
| lines.append(f"[{cls.upper()}]") |
| lines.extend(texts) |
| lines.append("") |
| return "\n".join(lines).strip() |
|
|