File size: 8,448 Bytes
3a57793 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from PIL import Image
from torchvision import transforms
import os
import argparse
class KhmerOCR:
def __init__(self, model_repo="Darayut/khmer-SeqSE-CRNN-Transformer", device=None):
"""
Initializes the Khmer OCR model and tokenizer.
"""
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
print(f"⏳ Loading model from {model_repo} on {self.device}...")
# Load Model & Tokenizer with trust_remote_code=True
self.tokenizer = AutoTokenizer.from_pretrained(model_repo, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_repo, trust_remote_code=True).to(self.device)
self.model.eval()
# Build Vocab Mappings
self.vocab = self.tokenizer.get_vocab()
self.id2char = {v: k for k, v in self.vocab.items()}
# Special Tokens
self.sos_idx = self.vocab.get("<sos>", 1)
self.eos_idx = self.vocab.get("<eos>", 2)
self.pad_idx = self.vocab.get("<pad>", 0)
self.unk_idx = self.vocab.get("<unk>", 3)
# Image Transform (Matches Training)
self.transform = transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
])
def _chunk_image(self, img_tensor, chunk_width=100, overlap=16):
"""Internal helper to split image into overlapping chunks."""
C, H, W = img_tensor.shape
chunks = []
start = 0
while start < W:
end = min(start + chunk_width, W)
chunk = img_tensor[:, :, start:end]
if chunk.shape[2] < chunk_width:
pad_size = chunk_width - chunk.shape[2]
chunk = F.pad(chunk, (0, pad_size, 0, 0), value=1.0)
chunks.append(chunk)
start += chunk_width - overlap
return chunks
def preprocess(self, image_source):
"""
Preprocesses an image path or PIL Object into a batch of chunks.
"""
# Load Image
if isinstance(image_source, str):
if not os.path.exists(image_source):
raise FileNotFoundError(f"Image not found at {image_source}")
image = Image.open(image_source).convert('L')
elif isinstance(image_source, Image.Image):
image = image_source.convert('L')
else:
raise ValueError("Input must be a file path or PIL Image.")
# Resize (Fixed Height: 48)
target_height = 48
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
new_width = max(10, new_width)
image = image.resize((new_width, target_height), Image.Resampling.BILINEAR)
# Transform & Chunk
img_tensor = self.transform(image)
chunks = self._chunk_image(img_tensor, chunk_width=100, overlap=16)
chunks_tensor = torch.stack(chunks).to(self.device)
return chunks_tensor
def predict(self, image_source, method="beam", beam_width=3, max_len=256):
"""
Main prediction method.
Args:
image_source: Path to image or PIL object.
method: 'greedy' or 'beam'.
beam_width: Width for beam search (default 3).
max_len: Max decoded sequence length.
"""
chunks_tensor = self.preprocess(image_source)
# 1. Encode (CNN + Transformer + BiLSTM Smoothing)
with torch.no_grad():
# Wrap in list as model expects batch of images
memory = self.model([chunks_tensor])
# 2. Decode
if method == "greedy" or beam_width <= 1:
token_ids = self._greedy_decode(memory, max_len)
else:
token_ids = self._beam_search(memory, max_len, beam_width)
# 3. Convert IDs to Text (Manual mapping to avoid spacing issues)
result_text = ""
for idx in token_ids:
if idx in [self.sos_idx, self.eos_idx, self.pad_idx, self.unk_idx]:
continue
char = self.id2char.get(idx, "")
result_text += char
return result_text
def _greedy_decode(self, memory, max_len):
B, T, _ = memory.shape
memory_mask = torch.zeros((B, T), dtype=torch.bool, device=self.device)
generated = [self.sos_idx]
with torch.no_grad():
for _ in range(max_len):
tgt = torch.LongTensor([generated]).to(self.device)
logits = self.model.dec(tgt, memory, memory_mask)
next_token = torch.argmax(logits[0, -1, :]).item()
if next_token == self.eos_idx: break
generated.append(next_token)
return generated
def _beam_search(self, memory, max_len, beam_width):
B, T, D = memory.shape
memory = memory.expand(beam_width, -1, -1)
memory_mask = torch.zeros((beam_width, T), dtype=torch.bool, device=self.device)
beams = [(0.0, [self.sos_idx])]
completed_beams = []
with torch.no_grad():
for step in range(max_len):
k_curr = len(beams)
current_seqs = [b[1] for b in beams]
tgt = torch.tensor(current_seqs, dtype=torch.long, device=self.device)
step_logits = self.model.dec(tgt, memory[:k_curr], memory_mask[:k_curr])
log_probs = F.log_softmax(step_logits[:, -1, :], dim=-1)
candidates = []
for i in range(k_curr):
score_so_far, seq_so_far = beams[i]
topk_probs, topk_idx = log_probs[i].topk(beam_width)
for k in range(beam_width):
candidates.append((score_so_far + topk_probs[k].item(), seq_so_far + [topk_idx[k].item()]))
candidates.sort(key=lambda x: x[0], reverse=True)
next_beams = []
for score, seq in candidates:
if seq[-1] == self.eos_idx:
norm_score = score / (len(seq) - 1)
completed_beams.append((norm_score, seq))
else:
next_beams.append((score, seq))
if len(next_beams) == beam_width: break
beams = next_beams
if not beams: break
if completed_beams:
completed_beams.sort(key=lambda x: x[0], reverse=True)
return completed_beams[0][1]
elif beams:
return beams[0][1]
else:
return [self.sos_idx]
# ==============================================================================
# CLI USAGE
# ==============================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Khmer OCR Inference")
parser.add_argument("--image", type=str, required=True, help="Path to input image")
parser.add_argument("--method", type=str, default="beam", choices=["greedy", "beam"], help="Decoding method")
parser.add_argument("--beam_width", type=int, default=3, help="Width for beam search")
parser.add_argument("--max_len", type=int, default=256, help="Max output length")
parser.add_argument("--repo", type=str, default="Darayut/khmer-SeqSE-CRNN-Transformer", help="HF Model Repo")
args = parser.parse_args()
try:
# Initialize
ocr = KhmerOCR(model_repo=args.repo)
# Run
print(f"📷 Processing: {args.image}")
text = ocr.predict(args.image, method=args.method, beam_width=args.beam_width, max_len=args.max_len)
print("\n" + "="*30)
print(f"RESULT: {text}")
print("="*30)
# Auto-Save
out_path = os.path.splitext(args.image)[0] + ".txt"
with open(out_path, "w", encoding="utf-8") as f:
f.write(text)
print(f"💾 Saved to: {out_path}")
except Exception as e:
print(f"❌ Error: {e}") |