Khmer OCR β€” ResNet + BiLSTM + CTC

This repository contains a deep learning model designed for Khmer and English Optical Character Recognition (OCR). It utilizes a ResNet backbone for spatial feature extraction, a bidirectional LSTM for sequence modeling, and Connectionist Temporal Classification (CTC) loss for alignment-free text recognition.

Model Details

Model Description

  • Model type: khm_ocr_general_document
  • Language(s): Khmer (km), English (en). Note: The model exhibits higher accuracy and optimization patterns for Khmer character compositions.
  • Training State: Trained from scratch on a comprehensive dataset of printed text lines.
  • Architecture:
    • CNN: ResNet blocks processing grayscale document line images downscaled to a fixed height.
    • RNN: 2-layer Bidirectional LSTM tracking character transitions.
    • Classifier: Linear layer mapping features to vocabulary indices decoded via greedy CTC.
  • Charactor Error Rate: 0.005589
  • Word Error Rate: 0.045868
  • Training:
    • batch_size: 16
    • device: "cuda"
    • epochs: 30
    • learning_rate: 0.001
    • num_workers: 4
    • optimizer: "adam"
    • scheduler_factor: 0.5
    • scheduler_patience: 5

Out-of-Scope Use

While this model handles varied font distributions well, it is strictly an line-level text recognizer.

  • Optimal on: Printed text documents featuring standard degradation, light blur, or low levels of scanner noise.
  • Fails on: Heavily underlined text paths, highly blurred captures, handwritten manuscripts, or unwarped, rotated document images. Ensure input crops are straight and baseline-aligned before inference.

Recommendations

This architecture is under active development. While it successfully handles clean, isolated line crops, performance drops on extreme layouts. Preprocessing elements (like layout analysis and text-line deskewing) should be executed prior to feeding imagery to this network.


How to Get Started with the Model

File Layout Requirements

Before running predictions, ensure your repository or local folder contains your model files and characters map structured as follows:

.
β”œβ”€β”€ predict.py
└── khmer_ocr_model_CRNN/
    β”œβ”€β”€ best_model.pth
    └── char.json

predict.py

import argparse
import json
from pathlib import Path

import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as T

HERE = Path(__file__).parent
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}


# ── Model Architecture ────────────────────────────────────────────────────────

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
        )
        self.downsample = None
        if stride != 1 or in_ch != out_ch:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch),
            )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.conv(x)
        if self.downsample is not None:
            identity = self.downsample(x)
        return self.relu(out + identity)


class ResNetBiLSTMCTC(nn.Module):
    def __init__(self, num_classes, hidden_size=256, num_layers=2, dropout=0.3):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1, bias=False),       
            nn.BatchNorm2d(64),                                 
            nn.ReLU(inplace=True),                              
            nn.MaxPool2d(2, 2),                                 
            nn.Sequential(ResBlock(64, 64, stride=1)),          
            nn.Sequential(ResBlock(64, 128, stride=2)),         
            nn.Sequential(ResBlock(128, 256, stride=1)),        
        )
        self.rnn = nn.LSTM(
            input_size=256 * 16,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0.0,
            batch_first=False,
        )
        self.classifier = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        feat = self.cnn(x)                       # [B, 256, H', W']
        b, c, h, w = feat.shape
        seq = feat.permute(3, 0, 1, 2)          # [W', B, C, H']
        seq = seq.reshape(w, b, c * h)           # [W', B, C*H']
        out, _ = self.rnn(seq)                   # [W', B, 2*hidden]
        return self.classifier(out).log_softmax(2)  # [W', B, num_classes]


# ── Charset Parsing ───────────────────────────────────────────────────────────

def load_charset(path):
    """
    Builds an index-to-character mapping lookup from char.json files.
    Accommodates categories grouped by blocks, lists, or custom string token pairs.
    """
    with open(path, encoding="utf-8") as f:
        data = json.load(f)

    if isinstance(data, list):
        return {i: ch for i, ch in enumerate(data)}

    if isinstance(data, dict):
        known_categories = {"khmer", "latin", "digits", "special"}
        if known_categories & set(data.keys()):
            order = ["khmer", "latin", "digits", "special"]
            all_chars = "".join(data.get(cat, "") for cat in order)
            return {0: "<blank>", **{i + 1: ch for i, ch in enumerate(all_chars)}}

        try:
            return {int(k): v for k, v in data.items()}
        except ValueError:
            pass

        return {v: k for k, v in data.items()}

    raise ValueError(f"Unrecognized token map schema inside character file: {path}")


# ── Processing & Greedy Decoding ─────────────────────────────────────────────

def ctc_decode(log_probs, idx2char, blank=0):
    """Collapses consecutive duplicate indexes and strips blank tokens."""
    indices = log_probs.argmax(dim=1).tolist()
    chars, prev = [], None
    for idx in indices:
        if idx != prev and idx != blank:
            chars.append(idx2char.get(idx, "?"))
        prev = idx
    return "".join(chars)


_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5]),
])

def preprocess(image_path, img_h=64, img_w=512):
    """Converts target image asset to grayscale, scales to expected shape, and adds batch dims."""
    img = Image.open(image_path).convert("L")
    img = img.resize((img_w, img_h), Image.BICUBIC)
    return _transform(img).unsqueeze(0)


def predict_one(model, image_path, idx2char, device, blank=0):
    tensor = preprocess(image_path).to(device)
    with torch.no_grad():
        log_probs = model(tensor)[:, 0, :]   # Extracted time steps: [Time, Classes]
    return ctc_decode(log_probs.cpu(), idx2char, blank=blank)


# ── Core Runtime Entrypoint ───────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(description="Khmer OCR Inference System Stack")
    parser.add_argument(
        "input",
        help="Target filepath to standalone line crop OR parent path containing image lists.",
    )
    parser.add_argument(
        "--model",
        default=str(HERE / "khmer_ocr_model_CRNN" / "best_model.pth"),
        help="Checkpoint parameter file destination location path.",
    )
    parser.add_argument(
        "--charset",
        default=str(HERE / "khmer_ocr_model_CRNN" / "char.json"),
        help="JSON configuration text format vocabulary parsing schema.",
    )
    parser.add_argument(
        "--output",
        default=str(HERE / "predictions.txt"),
        help="Target text document destination to record string output arrays.",
    )
    parser.add_argument(
        "--device",
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Hardware execution pipeline override runtime flag.",
    )
    args = parser.parse_args()

    device = torch.device(args.device)

    # Initialize vocabulary bounds
    idx2char = load_charset(args.charset)
    num_classes = max(idx2char.keys()) + 1

    # Instantiate weights mapping sequence layout
    model = ResNetBiLSTMCTC(num_classes=num_classes)
    state_dict = torch.load(args.model, map_location="cpu")
    model.load_state_dict(state_dict)
    model.to(device).eval()
    
    print(f"Model File   : {args.model}")
    print(f"Charset File : {args.charset} ({num_classes} distribution classes)")
    print(f"Target Device: {device}\n")

    input_path = Path(args.input)
    out_path = Path(args.output)

    # ── Path Evaluator Logic: Independent Image Evaluation ──────────────────
    if input_path.is_file():
        text = predict_one(model, input_path, idx2char, device)
        print(f"Image File : {input_path.name}")
        print(f"Prediction : {text}")
        with open(out_path, "w", encoding="utf-8") as f:
            f.write(f"{input_path.name}\t{text}\n")
        print(f"Saved logs : {out_path}")
        return

    # ── Path Evaluator Logic: Directory Iteration Loop ───────────────────────
    if input_path.is_dir():
        images = sorted(
            p for p in input_path.iterdir()
            if p.suffix.lower() in IMAGE_EXTS
        )
        if not images:
            print(f"Termination: No valid image file variations found inside directory context '{input_path}'")
            return

        print(f"Queued Processing Run: Found {len(images)} sequence elements inside directory structure.\n")
        with open(out_path, "w", encoding="utf-8") as f:
            for i, img_path in enumerate(images, 1):
                try:
                    text = predict_one(model, img_path, idx2char, device)
                except Exception as error_exception:
                    text = f"RUNTIME ERROR METRIC EXCEPTION: {error_exception}"
                print(f"[{i}/{len(images)}] {img_path.name} -> {text}")
                f.write(f"{img_path.name}\t{text}\n")

        print(f"\nExecution Pipeline Complete: Stream saved down cleanly into '{out_path}'")
        return

    print(f"Invalid Operation Error: Resource tracking index target reference '{input_path}' does not point to structural nodes.")


if __name__ == "__main__":
    main()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support