import torch import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer from PIL import Image from torchvision import transforms import os import argparse class KhmerOCR: def __init__(self, model_repo="Darayut/khmer-SeqSE-CRNN-Transformer", device=None): """ Initializes the Khmer OCR model and tokenizer. """ if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = device print(f"⏳ Loading model from {model_repo} on {self.device}...") # Load Model & Tokenizer with trust_remote_code=True self.tokenizer = AutoTokenizer.from_pretrained(model_repo, trust_remote_code=True) self.model = AutoModel.from_pretrained(model_repo, trust_remote_code=True).to(self.device) self.model.eval() # Build Vocab Mappings self.vocab = self.tokenizer.get_vocab() self.id2char = {v: k for k, v in self.vocab.items()} # Special Tokens self.sos_idx = self.vocab.get("", 1) self.eos_idx = self.vocab.get("", 2) self.pad_idx = self.vocab.get("", 0) self.unk_idx = self.vocab.get("", 3) # Image Transform (Matches Training) self.transform = transforms.Compose([ transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize(0.5, 0.5) ]) def _chunk_image(self, img_tensor, chunk_width=100, overlap=16): """Internal helper to split image into overlapping chunks.""" C, H, W = img_tensor.shape chunks = [] start = 0 while start < W: end = min(start + chunk_width, W) chunk = img_tensor[:, :, start:end] if chunk.shape[2] < chunk_width: pad_size = chunk_width - chunk.shape[2] chunk = F.pad(chunk, (0, pad_size, 0, 0), value=1.0) chunks.append(chunk) start += chunk_width - overlap return chunks def preprocess(self, image_source): """ Preprocesses an image path or PIL Object into a batch of chunks. """ # Load Image if isinstance(image_source, str): if not os.path.exists(image_source): raise FileNotFoundError(f"Image not found at {image_source}") image = Image.open(image_source).convert('L') elif isinstance(image_source, Image.Image): image = image_source.convert('L') else: raise ValueError("Input must be a file path or PIL Image.") # Resize (Fixed Height: 48) target_height = 48 aspect_ratio = image.width / image.height new_width = int(target_height * aspect_ratio) new_width = max(10, new_width) image = image.resize((new_width, target_height), Image.Resampling.BILINEAR) # Transform & Chunk img_tensor = self.transform(image) chunks = self._chunk_image(img_tensor, chunk_width=100, overlap=16) chunks_tensor = torch.stack(chunks).to(self.device) return chunks_tensor def predict(self, image_source, method="beam", beam_width=3, max_len=256): """ Main prediction method. Args: image_source: Path to image or PIL object. method: 'greedy' or 'beam'. beam_width: Width for beam search (default 3). max_len: Max decoded sequence length. """ chunks_tensor = self.preprocess(image_source) # 1. Encode (CNN + Transformer + BiLSTM Smoothing) with torch.no_grad(): # Wrap in list as model expects batch of images memory = self.model([chunks_tensor]) # 2. Decode if method == "greedy" or beam_width <= 1: token_ids = self._greedy_decode(memory, max_len) else: token_ids = self._beam_search(memory, max_len, beam_width) # 3. Convert IDs to Text (Manual mapping to avoid spacing issues) result_text = "" for idx in token_ids: if idx in [self.sos_idx, self.eos_idx, self.pad_idx, self.unk_idx]: continue char = self.id2char.get(idx, "") result_text += char return result_text def _greedy_decode(self, memory, max_len): B, T, _ = memory.shape memory_mask = torch.zeros((B, T), dtype=torch.bool, device=self.device) generated = [self.sos_idx] with torch.no_grad(): for _ in range(max_len): tgt = torch.LongTensor([generated]).to(self.device) logits = self.model.dec(tgt, memory, memory_mask) next_token = torch.argmax(logits[0, -1, :]).item() if next_token == self.eos_idx: break generated.append(next_token) return generated def _beam_search(self, memory, max_len, beam_width): B, T, D = memory.shape memory = memory.expand(beam_width, -1, -1) memory_mask = torch.zeros((beam_width, T), dtype=torch.bool, device=self.device) beams = [(0.0, [self.sos_idx])] completed_beams = [] with torch.no_grad(): for step in range(max_len): k_curr = len(beams) current_seqs = [b[1] for b in beams] tgt = torch.tensor(current_seqs, dtype=torch.long, device=self.device) step_logits = self.model.dec(tgt, memory[:k_curr], memory_mask[:k_curr]) log_probs = F.log_softmax(step_logits[:, -1, :], dim=-1) candidates = [] for i in range(k_curr): score_so_far, seq_so_far = beams[i] topk_probs, topk_idx = log_probs[i].topk(beam_width) for k in range(beam_width): candidates.append((score_so_far + topk_probs[k].item(), seq_so_far + [topk_idx[k].item()])) candidates.sort(key=lambda x: x[0], reverse=True) next_beams = [] for score, seq in candidates: if seq[-1] == self.eos_idx: norm_score = score / (len(seq) - 1) completed_beams.append((norm_score, seq)) else: next_beams.append((score, seq)) if len(next_beams) == beam_width: break beams = next_beams if not beams: break if completed_beams: completed_beams.sort(key=lambda x: x[0], reverse=True) return completed_beams[0][1] elif beams: return beams[0][1] else: return [self.sos_idx] # ============================================================================== # CLI USAGE # ============================================================================== if __name__ == "__main__": parser = argparse.ArgumentParser(description="Khmer OCR Inference") parser.add_argument("--image", type=str, required=True, help="Path to input image") parser.add_argument("--method", type=str, default="beam", choices=["greedy", "beam"], help="Decoding method") parser.add_argument("--beam_width", type=int, default=3, help="Width for beam search") parser.add_argument("--max_len", type=int, default=256, help="Max output length") parser.add_argument("--repo", type=str, default="Darayut/khmer-SeqSE-CRNN-Transformer", help="HF Model Repo") args = parser.parse_args() try: # Initialize ocr = KhmerOCR(model_repo=args.repo) # Run print(f"📷 Processing: {args.image}") text = ocr.predict(args.image, method=args.method, beam_width=args.beam_width, max_len=args.max_len) print("\n" + "="*30) print(f"RESULT: {text}") print("="*30) # Auto-Save out_path = os.path.splitext(args.image)[0] + ".txt" with open(out_path, "w", encoding="utf-8") as f: f.write(text) print(f"💾 Saved to: {out_path}") except Exception as e: print(f"❌ Error: {e}")