mini-kh-OCR / mini_kh_ocr.py
phonsobon's picture
Upload mini_kh_ocr.py with huggingface_hub
f12abb2 verified
"""
mini-kh-OCR Pipeline
--------------------
Combines:
- phonsobon/mini-text-detection (YOLO11n — detects subject / reference / content)
- phonsobon/mini-ocr (CRNN + CTC — recognises Khmer & English text)
Usage:
from mini_kh_ocr import MiniKhOCR
ocr = MiniKhOCR()
result = ocr("your_image.jpg")
print(result)
"""
import os
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
# ══════════════════════════════════════════════════════════════════════════════
# 1. CONSTANTS
# ══════════════════════════════════════════════════════════════════════════════
CLASS_NAMES = {0: "subject", 1: "reference", 2: "content"}
TOKENS = (
"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"0123456789"
"កខគឃងចឆជឈញដឋឌឍណតថទធនបផពភមយរលវឝឞសហឡអឣឤឥឦឧឩឪឫឬឭឮឯឰឱឲឳ"
"ាិីឹឺុូួើឿៀេែៃោៅំះៈ៉៊់៌៍៎៏័៑្។៕៖ៗ៘៛៝"
"០១២៣៤៥៦៧៨៩៳"
"!@#$%^&*()-_=+[]{};:'\",.<>?/|\\ "
)
NUM_CHARS = len(TOKENS)
IDX2CHAR = {i + 1: c for i, c in enumerate(TOKENS)}
# ══════════════════════════════════════════════════════════════════════════════
# 2. OCR MODEL DEFINITION (KhmerOCR_DTWG)
# ══════════════════════════════════════════════════════════════════════════════
class KhmerOCR_DTWG(nn.Module):
def __init__(self, num_chars=NUM_CHARS, hidden_size=256):
super().__init__()
self.cnn = nn.Sequential(
self._conv(1, 32), nn.MaxPool2d(2, 2),
self._conv(32, 64), nn.MaxPool2d(2, 2),
self._conv(64, 128),
self._conv(128, 128),
nn.MaxPool2d((2, 1), (2, 1)),
self._conv(128, 256),
self._conv(256, 256),
nn.MaxPool2d((4, 1), (4, 1)),
)
self.lstm1 = nn.LSTM(256, hidden_size, bidirectional=True, batch_first=True)
self.fc1 = nn.Linear(hidden_size * 2, hidden_size)
self.lstm2 = nn.LSTM(hidden_size, hidden_size, bidirectional=True, batch_first=True)
self.fc = nn.Linear(hidden_size * 2, num_chars + 1)
def _conv(self, i, o):
return nn.Sequential(
nn.Conv2d(i, o, 3, 1, 1, bias=False),
nn.BatchNorm2d(o),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.cnn(x)
x = x.squeeze(2).permute(0, 2, 1)
x, _ = self.lstm1(x)
x = torch.relu(self.fc1(x))
x, _ = self.lstm2(x)
x = self.fc(x)
return x.permute(1, 0, 2)
# ══════════════════════════════════════════════════════════════════════════════
# 3. HELPERS
# ══════════════════════════════════════════════════════════════════════════════
def _load_crop_for_ocr(pil_img: Image.Image) -> torch.Tensor:
"""Resize a PIL crop to height=32, normalise, return (1,1,32,W) tensor."""
img = pil_img.convert("L")
w, h = img.size
if h == 0:
h = 1
new_w = max(1, int(w / h * 32))
img = img.resize((new_w, 32))
arr = np.array(img, dtype=np.float32) / 255.0
return torch.tensor(arr).unsqueeze(0).unsqueeze(0) # (1,1,32,W)
def _ctc_decode(logits: torch.Tensor) -> str:
"""Greedy CTC decode — logits shape: (T, 1, C)."""
preds = torch.argmax(logits, dim=2)[:, 0].cpu().numpy()
prev, text = -1, []
for p in preds:
if p != prev and p != 0:
text.append(IDX2CHAR.get(int(p), ""))
prev = p
return "".join(text)
def _sort_boxes_top_to_bottom(boxes, cls_ids, confs):
"""Sort detections by vertical position (top → bottom)."""
order = sorted(range(len(boxes)), key=lambda i: boxes[i][1])
return [boxes[i] for i in order], [cls_ids[i] for i in order], [confs[i] for i in order]
# ══════════════════════════════════════════════════════════════════════════════
# 4. MAIN PIPELINE CLASS
# ══════════════════════════════════════════════════════════════════════════════
class MiniKhOCR:
"""
End-to-end Khmer OCR pipeline.
Parameters
----------
det_conf : float — detection confidence threshold (default 0.25)
det_iou : float — NMS IoU threshold (default 0.45)
det_imgsz : int — detection image size (default 640)
device : str — 'cuda' | 'cpu' | 'auto'
"""
def __init__(
self,
det_conf: float = 0.25,
det_iou: float = 0.45,
det_imgsz: int = 640,
device: str = "auto",
):
if device == "auto":
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"[mini-kh-OCR] Device: {self.device}")
# ── detection model ────────────────────────────────────────────────
print("[mini-kh-OCR] Loading detection model ...")
det_path = hf_hub_download(
repo_id="phonsobon/mini-text-detection",
filename="khmer-text-detection-mini.pt",
)
self.detector = YOLO(det_path)
self.det_conf = det_conf
self.det_iou = det_iou
self.det_imgsz = det_imgsz
# ── recognition model ──────────────────────────────────────────────
print("[mini-kh-OCR] Loading recognition model ...")
ocr_path = hf_hub_download(
repo_id="phonsobon/mini-ocr",
filename="model.pt",
)
self.recogniser = KhmerOCR_DTWG(NUM_CHARS).to(self.device)
self.recogniser.load_state_dict(
torch.load(ocr_path, map_location=self.device)
)
self.recogniser.eval()
print("[mini-kh-OCR] Ready ✅")
# ──────────────────────────────────────────────────────────────────────────
def _recognise(self, crop: Image.Image) -> str:
"""Run OCR on a single PIL crop."""
tensor = _load_crop_for_ocr(crop).to(self.device)
with torch.no_grad():
logits = self.recogniser(tensor)
return _ctc_decode(logits)
# ──────────────────────────────────────────────────────────────────────────
def __call__(
self,
image,
return_crops: bool = False,
verbose: bool = False,
) -> dict:
"""
Run detection + recognition on an image.
Parameters
----------
image : str | PIL.Image — file path or PIL image
return_crops : bool — include cropped PIL images in output
verbose : bool — print each detected region
Returns
-------
dict with keys:
"subject" : list of str
"reference" : list of str
"content" : list of str
"regions" : list of dicts with box, class, conf, text (and crop if requested)
"""
if isinstance(image, str):
pil_img = Image.open(image).convert("RGB")
else:
pil_img = image.convert("RGB")
# ── Step 1: detect ────────────────────────────────────────────────
det_results = self.detector.predict(
source=pil_img,
conf=self.det_conf,
iou=self.det_iou,
imgsz=self.det_imgsz,
verbose=False,
)
raw_boxes = det_results[0].boxes.xyxy.cpu().numpy().astype(int).tolist()
raw_cls = [int(c) for c in det_results[0].boxes.cls.cpu().numpy()]
raw_conf = [float(c) for c in det_results[0].boxes.conf.cpu().numpy()]
# ── Step 2: sort top → bottom ─────────────────────────────────────
boxes, cls_ids, confs = _sort_boxes_top_to_bottom(raw_boxes, raw_cls, raw_conf)
# ── Step 3: recognise each crop ───────────────────────────────────
result = {"subject": [], "reference": [], "content": [], "regions": []}
for box, cls_id, conf in zip(boxes, cls_ids, confs):
x1, y1, x2, y2 = box
label = CLASS_NAMES.get(cls_id, "unknown")
crop = pil_img.crop((x1, y1, x2, y2))
text = self._recognise(crop)
if label in result:
result[label].append(text)
region = {
"class": label,
"conf": round(conf, 3),
"box": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
"text": text,
}
if return_crops:
region["crop"] = crop
result["regions"].append(region)
if verbose:
print(f" [{label}] ({x1},{y1})→({x2},{y2}) conf={conf:.2f}{text!r}")
return result
# ──────────────────────────────────────────────────────────────────────────
def to_document(self, result: dict) -> str:
"""
Format result as a structured text document.
Example output:
[SUBJECT]
ភ្នំពេញ ក្រុង
[REFERENCE]
លេខ ០០១
[CONTENT]
អត្ថបទដំបូង
អត្ថបទទីពីរ
"""
lines = []
for cls in ("subject", "reference", "content"):
texts = result.get(cls, [])
if texts:
lines.append(f"[{cls.upper()}]")
lines.extend(texts)
lines.append("")
return "\n".join(lines).strip()