Khmer OCR β ResNet + BiLSTM + CTC
This repository contains a deep learning model designed for Khmer and English Optical Character Recognition (OCR). It utilizes a ResNet backbone for spatial feature extraction, a bidirectional LSTM for sequence modeling, and Connectionist Temporal Classification (CTC) loss for alignment-free text recognition.
Model Details
Model Description
- Model type:
khm_ocr_general_document - Language(s): Khmer (
km), English (en). Note: The model exhibits higher accuracy and optimization patterns for Khmer character compositions. - Training State: Trained from scratch on a comprehensive dataset of printed text lines.
- Architecture:
- CNN: ResNet blocks processing grayscale document line images downscaled to a fixed height.
- RNN: 2-layer Bidirectional LSTM tracking character transitions.
- Classifier: Linear layer mapping features to vocabulary indices decoded via greedy CTC.
- Charactor Error Rate: 0.005589
- Word Error Rate: 0.045868
- Training:
- batch_size: 16
- device: "cuda"
- epochs: 30
- learning_rate: 0.001
- num_workers: 4
- optimizer: "adam"
- scheduler_factor: 0.5
- scheduler_patience: 5
Out-of-Scope Use
While this model handles varied font distributions well, it is strictly an line-level text recognizer.
- Optimal on: Printed text documents featuring standard degradation, light blur, or low levels of scanner noise.
- Fails on: Heavily underlined text paths, highly blurred captures, handwritten manuscripts, or unwarped, rotated document images. Ensure input crops are straight and baseline-aligned before inference.
Recommendations
This architecture is under active development. While it successfully handles clean, isolated line crops, performance drops on extreme layouts. Preprocessing elements (like layout analysis and text-line deskewing) should be executed prior to feeding imagery to this network.
How to Get Started with the Model
File Layout Requirements
Before running predictions, ensure your repository or local folder contains your model files and characters map structured as follows:
.
βββ predict.py
βββ khmer_ocr_model_CRNN/
βββ best_model.pth
βββ char.json
predict.py
import argparse
import json
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as T
HERE = Path(__file__).parent
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}
# ββ Model Architecture ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride=1):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
)
self.downsample = None
if stride != 1 or in_ch != out_ch:
self.downsample = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
nn.BatchNorm2d(out_ch),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
identity = x
out = self.conv(x)
if self.downsample is not None:
identity = self.downsample(x)
return self.relu(out + identity)
class ResNetBiLSTMCTC(nn.Module):
def __init__(self, num_classes, hidden_size=256, num_layers=2, dropout=0.3):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Sequential(ResBlock(64, 64, stride=1)),
nn.Sequential(ResBlock(64, 128, stride=2)),
nn.Sequential(ResBlock(128, 256, stride=1)),
)
self.rnn = nn.LSTM(
input_size=256 * 16,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=True,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=False,
)
self.classifier = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
feat = self.cnn(x) # [B, 256, H', W']
b, c, h, w = feat.shape
seq = feat.permute(3, 0, 1, 2) # [W', B, C, H']
seq = seq.reshape(w, b, c * h) # [W', B, C*H']
out, _ = self.rnn(seq) # [W', B, 2*hidden]
return self.classifier(out).log_softmax(2) # [W', B, num_classes]
# ββ Charset Parsing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_charset(path):
"""
Builds an index-to-character mapping lookup from char.json files.
Accommodates categories grouped by blocks, lists, or custom string token pairs.
"""
with open(path, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
return {i: ch for i, ch in enumerate(data)}
if isinstance(data, dict):
known_categories = {"khmer", "latin", "digits", "special"}
if known_categories & set(data.keys()):
order = ["khmer", "latin", "digits", "special"]
all_chars = "".join(data.get(cat, "") for cat in order)
return {0: "<blank>", **{i + 1: ch for i, ch in enumerate(all_chars)}}
try:
return {int(k): v for k, v in data.items()}
except ValueError:
pass
return {v: k for k, v in data.items()}
raise ValueError(f"Unrecognized token map schema inside character file: {path}")
# ββ Processing & Greedy Decoding βββββββββββββββββββββββββββββββββββββββββββββ
def ctc_decode(log_probs, idx2char, blank=0):
"""Collapses consecutive duplicate indexes and strips blank tokens."""
indices = log_probs.argmax(dim=1).tolist()
chars, prev = [], None
for idx in indices:
if idx != prev and idx != blank:
chars.append(idx2char.get(idx, "?"))
prev = idx
return "".join(chars)
_transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.5], std=[0.5]),
])
def preprocess(image_path, img_h=64, img_w=512):
"""Converts target image asset to grayscale, scales to expected shape, and adds batch dims."""
img = Image.open(image_path).convert("L")
img = img.resize((img_w, img_h), Image.BICUBIC)
return _transform(img).unsqueeze(0)
def predict_one(model, image_path, idx2char, device, blank=0):
tensor = preprocess(image_path).to(device)
with torch.no_grad():
log_probs = model(tensor)[:, 0, :] # Extracted time steps: [Time, Classes]
return ctc_decode(log_probs.cpu(), idx2char, blank=blank)
# ββ Core Runtime Entrypoint βββββββββββββββββββββββββββββββββββββββββββββββββββ
def main():
parser = argparse.ArgumentParser(description="Khmer OCR Inference System Stack")
parser.add_argument(
"input",
help="Target filepath to standalone line crop OR parent path containing image lists.",
)
parser.add_argument(
"--model",
default=str(HERE / "khmer_ocr_model_CRNN" / "best_model.pth"),
help="Checkpoint parameter file destination location path.",
)
parser.add_argument(
"--charset",
default=str(HERE / "khmer_ocr_model_CRNN" / "char.json"),
help="JSON configuration text format vocabulary parsing schema.",
)
parser.add_argument(
"--output",
default=str(HERE / "predictions.txt"),
help="Target text document destination to record string output arrays.",
)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Hardware execution pipeline override runtime flag.",
)
args = parser.parse_args()
device = torch.device(args.device)
# Initialize vocabulary bounds
idx2char = load_charset(args.charset)
num_classes = max(idx2char.keys()) + 1
# Instantiate weights mapping sequence layout
model = ResNetBiLSTMCTC(num_classes=num_classes)
state_dict = torch.load(args.model, map_location="cpu")
model.load_state_dict(state_dict)
model.to(device).eval()
print(f"Model File : {args.model}")
print(f"Charset File : {args.charset} ({num_classes} distribution classes)")
print(f"Target Device: {device}\n")
input_path = Path(args.input)
out_path = Path(args.output)
# ββ Path Evaluator Logic: Independent Image Evaluation ββββββββββββββββββ
if input_path.is_file():
text = predict_one(model, input_path, idx2char, device)
print(f"Image File : {input_path.name}")
print(f"Prediction : {text}")
with open(out_path, "w", encoding="utf-8") as f:
f.write(f"{input_path.name}\t{text}\n")
print(f"Saved logs : {out_path}")
return
# ββ Path Evaluator Logic: Directory Iteration Loop βββββββββββββββββββββββ
if input_path.is_dir():
images = sorted(
p for p in input_path.iterdir()
if p.suffix.lower() in IMAGE_EXTS
)
if not images:
print(f"Termination: No valid image file variations found inside directory context '{input_path}'")
return
print(f"Queued Processing Run: Found {len(images)} sequence elements inside directory structure.\n")
with open(out_path, "w", encoding="utf-8") as f:
for i, img_path in enumerate(images, 1):
try:
text = predict_one(model, img_path, idx2char, device)
except Exception as error_exception:
text = f"RUNTIME ERROR METRIC EXCEPTION: {error_exception}"
print(f"[{i}/{len(images)}] {img_path.name} -> {text}")
f.write(f"{img_path.name}\t{text}\n")
print(f"\nExecution Pipeline Complete: Stream saved down cleanly into '{out_path}'")
return
print(f"Invalid Operation Error: Resource tracking index target reference '{input_path}' does not point to structural nodes.")
if __name__ == "__main__":
main()