File size: 12,173 Bytes
091afb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
"""
check_cer.py
============
Measures TRUE CER by actually running the model on images.

Usage:
    python check_cer.py                        # live CER on val set
    python check_cer.py --saved                # old behavior (fast, unreliable)
    python check_cer.py --images test_images/  # run on any image folder
"""

import os
import sys
import json
import random
import cv2
import numpy as np
import editdistance
from pathlib import Path

try:
    import torch
except ImportError:
    print("ERROR: torch not installed. Run: pip install torch")
    exit(1)

USE_SAVED = '--saved' in sys.argv
IMAGE_DIR = None
for i, arg in enumerate(sys.argv[1:], 1):
    if arg == '--images' and i < len(sys.argv) - 1:
        IMAGE_DIR = sys.argv[i + 1]
    elif arg.startswith('--images='):
        IMAGE_DIR = arg.split('=', 1)[1]

CHECKPOINTS = [
    'checkpoint_epoch_50.pth',
    'checkpoint_epoch_60.pth',
    'checkpoint_epoch_70.pth',
    'checkpoint_epoch_80.pth',
    'checkpoint_epoch_90.pth',
    'checkpoint_epoch_100.pth',
]
CHECKPOINT_DIR = 'checkpoints'
VAL_DATA_DIR   = 'data/val'
VAL_ANN_FILE   = 'data/val_annotations.json'


class AdaptiveImageNormalizer:
    def __init__(self, target_height=64, target_width=512):
        self.H = target_height
        self.W = target_width

    def _crop_to_text(self, gray):
        inv = cv2.bitwise_not(gray)
        _, thresh = cv2.threshold(inv, 20, 255, cv2.THRESH_BINARY)
        coords = np.column_stack(np.where(thresh > 0))
        if len(coords) == 0:
            return gray
        y_min, x_min = coords.min(axis=0)
        y_max, x_max = coords.max(axis=0)
        pad   = max(4, int((y_max - y_min) * 0.15))
        y_min = max(0, y_min - pad)
        x_min = max(0, x_min - pad)
        y_max = min(gray.shape[0] - 1, y_max + pad)
        x_max = min(gray.shape[1] - 1, x_max + pad)
        return gray[y_min:y_max + 1, x_min:x_max + 1]

    def _smart_resize_gray(self, gray):
        h, w = gray.shape
        if h == 0 or w == 0:
            return np.ones((self.H, self.W), dtype=np.uint8) * 255
        scale = self.H / h
        new_w = int(w * scale)
        new_h = self.H
        if new_w > self.W:
            scale = self.W / w
            new_h = int(h * scale)
            new_w = self.W
        resized = cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
        canvas  = np.ones((self.H, self.W), dtype=np.uint8) * 255
        y_off   = (self.H - new_h) // 2
        x_off   = (self.W - new_w) // 2
        canvas[y_off:y_off + new_h, x_off:x_off + new_w] = resized
        return canvas

    def _binarize(self, img):
        _, otsu = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        white_ratio = np.mean(otsu == 255)
        if white_ratio < 0.30 or white_ratio > 0.97:
            return cv2.adaptiveThreshold(
                img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                cv2.THRESH_BINARY, 11, 2)
        return otsu

    def normalize(self, img):
        if len(img.shape) == 3:
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        else:
            gray = img.copy()
        gray = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
        gray = self._crop_to_text(gray)
        gray = self._smart_resize_gray(gray)
        return self._binarize(gray)

    def to_tensor(self, img):
        return torch.FloatTensor(
            img.astype(np.float32) / 255.0
        ).unsqueeze(0).unsqueeze(0)


def greedy_decode(outputs, idx_to_char):
    pred_indices = torch.argmax(outputs, dim=2).permute(1, 0)
    results = []
    for seq in pred_indices:
        chars, prev = [], -1
        for idx in seq:
            idx = idx.item()
            if idx != 0 and idx != prev and idx in idx_to_char:
                chars.append(idx_to_char[idx])
            prev = idx
        results.append(''.join(chars))
    return results


def measure_live_cer(model, idx_to_char, img_h, img_w,
                     ann_file, data_dir, device, max_samples=200):
    if not os.path.exists(ann_file):
        return None, 0, f"Annotation file not found: {ann_file}"

    with open(ann_file, 'r', encoding='utf-8') as f:
        annotations = json.load(f)

    if len(annotations) > max_samples:
        random.seed(42)
        annotations = random.sample(annotations, max_samples)

    normalizer = AdaptiveImageNormalizer(img_h, img_w)
    model.eval()

    total_char_dist = 0
    total_chars     = 0
    total_word_dist = 0
    total_words     = 0
    n_exact         = 0
    n_evaluated     = 0
    worst_errors    = []

    with torch.no_grad():
        for ann in annotations:
            img_path = os.path.join(data_dir, ann['image_path'])
            gt       = ann['text']
            if not os.path.exists(img_path):
                continue
            try:
                raw = cv2.imread(img_path)
                if raw is None:
                    continue
                norm   = normalizer.normalize(raw)
                tensor = normalizer.to_tensor(norm).to(device)
                out    = model(tensor)
                pred   = greedy_decode(out.cpu(), idx_to_char)[0]

                cd = editdistance.eval(pred, gt)
                wd = editdistance.eval(pred.split(), gt.split())

                total_char_dist += cd
                total_chars     += len(gt)
                total_word_dist += wd
                total_words     += len(gt.split())
                if pred == gt:
                    n_exact += 1
                if cd > 0:
                    worst_errors.append((gt, pred, cd))
                n_evaluated += 1
            except Exception:
                continue

    if n_evaluated == 0:
        return None, 0, "No images could be evaluated"

    cer = (total_char_dist / total_chars * 100) if total_chars > 0 else 0
    wer = (total_word_dist / total_words * 100) if total_words > 0 else 0
    acc = (n_exact / n_evaluated * 100)

    return {
        'cer': cer, 'wer': wer, 'exact_match': acc,
        'n_evaluated': n_evaluated,
        'errors': sorted(worst_errors, key=lambda x: x[2], reverse=True)[:5]
    }, n_evaluated, None


def run_on_folder(model, idx_to_char, img_h, img_w, folder, device):
    normalizer = AdaptiveImageNormalizer(img_h, img_w)
    model.eval()
    exts  = {'.jpg', '.jpeg', '.png', '.bmp'}
    paths = sorted(p for p in Path(folder).rglob('*') if p.suffix.lower() in exts)
    results = []
    with torch.no_grad():
        for p in paths:
            try:
                raw    = cv2.imread(str(p))
                norm   = normalizer.normalize(raw)
                tensor = normalizer.to_tensor(norm).to(device)
                pred   = greedy_decode(model(tensor).cpu(), idx_to_char)[0]
                results.append((p.name, pred))
            except Exception as e:
                results.append((p.name, f'ERROR: {e}'))
    return results


# ─────────────────────────────────────────────────────────────────────────────
# MAIN
# ─────────────────────────────────────────────────────────────────────────────

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if USE_SAVED:
    print("=" * 65)
    print("  SAVED CER  (training-time value β€” may not reflect real accuracy)")
    print("  Run without --saved for true live CER.")
    print("=" * 65)
    print("{:<8} {:<12} {:<12} {}".format("Epoch", "CER(%)", "WER(%)", "File"))
    print("-" * 65)
    best_cer, best_cp = float('inf'), None
    for cp in CHECKPOINTS:
        path = os.path.join(CHECKPOINT_DIR, cp)
        if not os.path.exists(path):
            continue
        try:
            c        = torch.load(path, weights_only=False)
            cer      = c.get('val_cer', c.get('val_loss', 0))
            epoch    = c['epoch']
            history  = c.get('history', {})
            wer_list = history.get('val_wer', [])
            wer      = wer_list[epoch - 1] if wer_list and epoch <= len(wer_list) else None
            wer_s    = f"{wer:.4f}%" if wer else 'N/A'
            marker   = ' <-- BEST' if cer < best_cer else ''
            print("{:<8} {:<12} {:<12} {}{}".format(
                epoch, f"{cer:.4f}%", wer_s, cp, marker))
            if cer < best_cer:
                best_cer, best_cp = cer, cp
        except Exception as e:
            print(f"  Could not load {cp}: {e}")
    print("=" * 65)
    print(f"\nBEST: {best_cp}  CER={best_cer:.4f}%")

else:
    print("=" * 78)
    print("  LIVE CER  β€”  model actually runs on images  (true accuracy)")
    print("=" * 78)
    print("{:<8} {:<10} {:<10} {:<12} {:<8} {}".format(
        "Epoch", "CER(%)", "WER(%)", "ExactMatch", "N", "File"))
    print("-" * 78)

    best_cer, best_cp, best_metrics = float('inf'), None, None

    for cp in CHECKPOINTS:
        cp_path = os.path.join(CHECKPOINT_DIR, cp)
        if not os.path.exists(cp_path):
            print(f"  (skipping {cp} β€” not found)")
            continue
        try:
            from crnn_model import get_crnn_model
            c           = torch.load(cp_path, map_location=device, weights_only=False)
            epoch       = c['epoch']
            idx_to_char = c['idx_to_char']
            config      = c.get('config', {})
            img_h       = config.get('img_height', 64)
            img_w       = config.get('img_width', 512)
            saved_cer   = c.get('val_cer', c.get('val_loss', None))

            model = get_crnn_model(
                model_type=config.get('model_type', 'standard'),
                img_height=img_h,
                num_chars=c['model_state_dict']['fc.weight'].shape[0],
                hidden_size=config.get('hidden_size', 128),       # FIXED: was 256
                num_lstm_layers=config.get('num_lstm_layers', 1)  # FIXED: was 2
            ).to(device)
            model.load_state_dict(c['model_state_dict'])

            if IMAGE_DIR:
                print(f"\nPredictions from {cp}:")
                for fname, pred in run_on_folder(
                        model, idx_to_char, img_h, img_w, IMAGE_DIR, device):
                    print(f"  {fname:<35} ->  {pred}")
                continue

            metrics, n, err = measure_live_cer(
                model, idx_to_char, img_h, img_w,
                VAL_ANN_FILE, VAL_DATA_DIR, device)

            if metrics is None:
                print(f"  Epoch {epoch}  SKIP: {err}")
                continue

            cer    = metrics['cer']
            marker = ' <-- BEST' if cer < best_cer else ''
            print("{:<8} {:<10} {:<10} {:<12} {:<8} {}{}".format(
                epoch,
                f"{cer:.2f}%",
                f"{metrics['wer']:.2f}%",
                f"{metrics['exact_match']:.1f}%",
                n, cp, marker))

            if saved_cer and abs(cer - saved_cer) > 2.0:
                print(f"          ^ MISMATCH: saved={saved_cer:.2f}%  live={cer:.2f}%"
                      f"  diff={abs(cer - saved_cer):.2f}%")
                print(f"            Cause: model trained on clean synthetic only.")
                print(f"            Fix:   regenerate data with fix_data.py + retrain.")

            if cer < best_cer:
                best_cer, best_cp, best_metrics = cer, cp, metrics

        except Exception as e:
            print(f"  Could not evaluate {cp}: {e}")

    if not IMAGE_DIR:
        print("=" * 78)
        print(f"\nBEST CHECKPOINT : {best_cp}")
        print(f"BEST LIVE CER   : {best_cer:.4f}%")

        if best_metrics and best_metrics['errors']:
            print(f"\nWorst predictions (GT -> Predicted):")
            for gt, pred, dist in best_metrics['errors']:
                print(f"  [{dist:2d}]  '{gt}'")
                print(f"        '{pred}'")

        print(f"\nTo use best model:")
        print(f"  import shutil")
        print(f"  shutil.copy('checkpoints/{best_cp}', 'checkpoints/best_model.pth')")
        print("=" * 78)