Image-to-Text
Transformers
Safetensors
Khmer
khmer-ocr
feature-extraction
transformer
text-recognition
crnn
khmer-text-recognition
custom_code
File size: 8,448 Bytes
3a57793
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from PIL import Image
from torchvision import transforms
import os
import argparse

class KhmerOCR:
    def __init__(self, model_repo="Darayut/khmer-SeqSE-CRNN-Transformer", device=None):
        """

        Initializes the Khmer OCR model and tokenizer.

        """
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        
        print(f"⏳ Loading model from {model_repo} on {self.device}...")
        
        # Load Model & Tokenizer with trust_remote_code=True
        self.tokenizer = AutoTokenizer.from_pretrained(model_repo, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_repo, trust_remote_code=True).to(self.device)
        self.model.eval()

        # Build Vocab Mappings
        self.vocab = self.tokenizer.get_vocab()
        self.id2char = {v: k for k, v in self.vocab.items()}
        
        # Special Tokens
        self.sos_idx = self.vocab.get("<sos>", 1)
        self.eos_idx = self.vocab.get("<eos>", 2)
        self.pad_idx = self.vocab.get("<pad>", 0)
        self.unk_idx = self.vocab.get("<unk>", 3)

        # Image Transform (Matches Training)
        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(), 
            transforms.Normalize(0.5, 0.5)
        ])

    def _chunk_image(self, img_tensor, chunk_width=100, overlap=16):
        """Internal helper to split image into overlapping chunks."""
        C, H, W = img_tensor.shape
        chunks = []
        start = 0
        while start < W:
            end = min(start + chunk_width, W)
            chunk = img_tensor[:, :, start:end]
            if chunk.shape[2] < chunk_width:
                pad_size = chunk_width - chunk.shape[2]
                chunk = F.pad(chunk, (0, pad_size, 0, 0), value=1.0)
            chunks.append(chunk)
            start += chunk_width - overlap
        return chunks

    def preprocess(self, image_source):
        """

        Preprocesses an image path or PIL Object into a batch of chunks.

        """
        # Load Image
        if isinstance(image_source, str):
            if not os.path.exists(image_source):
                raise FileNotFoundError(f"Image not found at {image_source}")
            image = Image.open(image_source).convert('L')
        elif isinstance(image_source, Image.Image):
            image = image_source.convert('L')
        else:
            raise ValueError("Input must be a file path or PIL Image.")

        # Resize (Fixed Height: 48)
        target_height = 48
        aspect_ratio = image.width / image.height
        new_width = int(target_height * aspect_ratio)
        new_width = max(10, new_width) 
        image = image.resize((new_width, target_height), Image.Resampling.BILINEAR)

        # Transform & Chunk
        img_tensor = self.transform(image) 
        chunks = self._chunk_image(img_tensor, chunk_width=100, overlap=16)
        chunks_tensor = torch.stack(chunks).to(self.device)
        return chunks_tensor

    def predict(self, image_source, method="beam", beam_width=3, max_len=256):
        """

        Main prediction method.

        Args:

            image_source: Path to image or PIL object.

            method: 'greedy' or 'beam'.

            beam_width: Width for beam search (default 3).

            max_len: Max decoded sequence length.

        """
        chunks_tensor = self.preprocess(image_source)
        
        # 1. Encode (CNN + Transformer + BiLSTM Smoothing)
        with torch.no_grad():
            # Wrap in list as model expects batch of images
            memory = self.model([chunks_tensor])

        # 2. Decode
        if method == "greedy" or beam_width <= 1:
            token_ids = self._greedy_decode(memory, max_len)
        else:
            token_ids = self._beam_search(memory, max_len, beam_width)

        # 3. Convert IDs to Text (Manual mapping to avoid spacing issues)
        result_text = ""
        for idx in token_ids:
            if idx in [self.sos_idx, self.eos_idx, self.pad_idx, self.unk_idx]:
                continue
            char = self.id2char.get(idx, "")
            result_text += char

        return result_text

    def _greedy_decode(self, memory, max_len):
        B, T, _ = memory.shape
        memory_mask = torch.zeros((B, T), dtype=torch.bool, device=self.device)
        generated = [self.sos_idx]

        with torch.no_grad():
            for _ in range(max_len):
                tgt = torch.LongTensor([generated]).to(self.device)
                logits = self.model.dec(tgt, memory, memory_mask)
                next_token = torch.argmax(logits[0, -1, :]).item()
                if next_token == self.eos_idx: break
                generated.append(next_token)
        return generated

    def _beam_search(self, memory, max_len, beam_width):
        B, T, D = memory.shape
        memory = memory.expand(beam_width, -1, -1)
        memory_mask = torch.zeros((beam_width, T), dtype=torch.bool, device=self.device)
        
        beams = [(0.0, [self.sos_idx])]
        completed_beams = []

        with torch.no_grad():
            for step in range(max_len):
                k_curr = len(beams)
                current_seqs = [b[1] for b in beams]
                tgt = torch.tensor(current_seqs, dtype=torch.long, device=self.device)
                
                step_logits = self.model.dec(tgt, memory[:k_curr], memory_mask[:k_curr])
                log_probs = F.log_softmax(step_logits[:, -1, :], dim=-1)

                candidates = []
                for i in range(k_curr):
                    score_so_far, seq_so_far = beams[i]
                    topk_probs, topk_idx = log_probs[i].topk(beam_width)
                    for k in range(beam_width):
                        candidates.append((score_so_far + topk_probs[k].item(), seq_so_far + [topk_idx[k].item()]))

                candidates.sort(key=lambda x: x[0], reverse=True)
                next_beams = []
                for score, seq in candidates:
                    if seq[-1] == self.eos_idx:
                        norm_score = score / (len(seq) - 1)
                        completed_beams.append((norm_score, seq))
                    else:
                        next_beams.append((score, seq))
                        if len(next_beams) == beam_width: break
                beams = next_beams
                if not beams: break

        if completed_beams:
            completed_beams.sort(key=lambda x: x[0], reverse=True)
            return completed_beams[0][1]
        elif beams:
            return beams[0][1]
        else:
            return [self.sos_idx]

# ==============================================================================
# CLI USAGE
# ==============================================================================
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Khmer OCR Inference")
    parser.add_argument("--image", type=str, required=True, help="Path to input image")
    parser.add_argument("--method", type=str, default="beam", choices=["greedy", "beam"], help="Decoding method")
    parser.add_argument("--beam_width", type=int, default=3, help="Width for beam search")
    parser.add_argument("--max_len", type=int, default=256, help="Max output length")
    parser.add_argument("--repo", type=str, default="Darayut/khmer-SeqSE-CRNN-Transformer", help="HF Model Repo")
    
    args = parser.parse_args()

    try:
        # Initialize
        ocr = KhmerOCR(model_repo=args.repo)
        
        # Run
        print(f"📷 Processing: {args.image}")
        text = ocr.predict(args.image, method=args.method, beam_width=args.beam_width, max_len=args.max_len)
        
        print("\n" + "="*30)
        print(f"RESULT: {text}")
        print("="*30)
        
        # Auto-Save
        out_path = os.path.splitext(args.image)[0] + ".txt"
        with open(out_path, "w", encoding="utf-8") as f:
            f.write(text)
        print(f"💾 Saved to: {out_path}")

    except Exception as e:
        print(f"❌ Error: {e}")