| """ |
| compare_live_cer.py |
| =================== |
| Runs live CER on all three checkpoints to find the best one. |
| Usage: python compare_live_cer.py |
| """ |
|
|
| import os |
| import sys |
| import json |
| import random |
| import cv2 |
| import numpy as np |
| import editdistance |
| import torch |
| import torch.nn.functional as F |
| sys.path.append('.') |
| from crnn_model import get_crnn_model |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| VAL_ANN = 'data/val_annotations.json' |
| VAL_DIR = 'data/val' |
| MAX_SAMPLES = 200 |
|
|
| CHECKPOINTS = { |
| 'Synthetic' : 'checkpoints/best_model.pth', |
| 'EMNIST' : 'checkpoints/best_model_emnist.pth', |
| 'IAM' : 'checkpoints/best_model_iam.pth', |
| } |
|
|
|
|
| def normalize(img, H=64, W=512): |
| if len(img.shape) == 3: |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
| else: |
| gray = img.copy() |
| gray = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21) |
| inv = cv2.bitwise_not(gray) |
| _, thresh = cv2.threshold(inv, 20, 255, cv2.THRESH_BINARY) |
| coords = np.column_stack(np.where(thresh > 0)) |
| if len(coords) > 0: |
| y_min, x_min = coords.min(axis=0) |
| y_max, x_max = coords.max(axis=0) |
| pad = max(4, int((y_max - y_min) * 0.15)) |
| y_min = max(0, y_min - pad) |
| x_min = max(0, x_min - pad) |
| y_max = min(gray.shape[0]-1, y_max + pad) |
| x_max = min(gray.shape[1]-1, x_max + pad) |
| gray = gray[y_min:y_max+1, x_min:x_max+1] |
| h, w = gray.shape |
| if h == 0 or w == 0: |
| return np.ones((H, W), dtype=np.uint8) * 255 |
| scale = H / h |
| new_w = int(w * scale) |
| if new_w > W: |
| scale = W / w |
| new_w = W |
| new_h = int(h * scale) |
| else: |
| new_h = H |
| resized = cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) |
| canvas = np.ones((H, W), dtype=np.uint8) * 255 |
| canvas[(H-new_h)//2:(H-new_h)//2+new_h, |
| (W-new_w)//2:(W-new_w)//2+new_w] = resized |
| _, otsu = cv2.threshold(canvas, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| return otsu |
|
|
|
|
| def greedy_decode(outputs, idx_to_char): |
| pred_indices = torch.argmax(outputs, dim=2).permute(1, 0) |
| results = [] |
| for seq in pred_indices: |
| chars, prev = [], -1 |
| for idx in seq: |
| idx = idx.item() |
| if idx != 0 and idx != prev and idx in idx_to_char: |
| chars.append(idx_to_char[idx]) |
| prev = idx |
| results.append(''.join(chars)) |
| return results |
|
|
|
|
| def evaluate(checkpoint_path, label): |
| if not os.path.exists(checkpoint_path): |
| print(f" {label:<12}: FILE NOT FOUND — skipping") |
| return |
|
|
| c = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| config = c.get('config', {}) |
|
|
| |
| idx_to_char = c.get('idx_to_char', None) |
| if idx_to_char is None: |
| from dataset import build_char_maps |
| _, idx_to_char, _ = build_char_maps() |
|
|
| model = get_crnn_model( |
| model_type = config.get('model_type', 'standard'), |
| img_height = config.get('img_height', 64), |
| num_chars = c['model_state_dict']['fc.weight'].shape[0], |
| hidden_size = config.get('hidden_size', 128), |
| num_lstm_layers = config.get('num_lstm_layers', 1), |
| ).to(device) |
| model.load_state_dict(c['model_state_dict'], strict=False) |
| model.eval() |
|
|
| with open(VAL_ANN, 'r', encoding='utf-8') as f: |
| anns = json.load(f) |
| random.seed(42) |
| if len(anns) > MAX_SAMPLES: |
| anns = random.sample(anns, MAX_SAMPLES) |
|
|
| total_cd, total_c = 0, 0 |
| exact, n = 0, 0 |
| worst = [] |
|
|
| with torch.no_grad(): |
| for ann in anns: |
| img_path = os.path.join(VAL_DIR, ann['image_path']) |
| gt = ann['text'] |
| if not os.path.exists(img_path): |
| continue |
| raw = cv2.imread(img_path) |
| if raw is None: |
| continue |
| norm = normalize(raw) |
| tensor = torch.FloatTensor( |
| norm.astype(np.float32) / 255.0 |
| ).unsqueeze(0).unsqueeze(0).to(device) |
| out = model(tensor) |
| pred = greedy_decode(out.cpu(), idx_to_char)[0] |
| cd = editdistance.eval(pred, gt) |
| total_cd += cd |
| total_c += len(gt) |
| if pred == gt: |
| exact += 1 |
| if cd > 0: |
| worst.append((gt, pred, cd)) |
| n += 1 |
|
|
| cer = (total_cd / total_c * 100) if total_c > 0 else 0 |
| acc = (exact / n * 100) if n > 0 else 0 |
| print(f" {label:<12}: CER={cer:.2f}% ExactMatch={acc:.1f}% (n={n})") |
|
|
| if worst: |
| worst = sorted(worst, key=lambda x: x[2], reverse=True)[:2] |
| for gt, pred, d in worst: |
| print(f" [{d}] '{gt}' -> '{pred}'") |
|
|
|
|
| print("=" * 60) |
| print(" LIVE CER COMPARISON — all checkpoints") |
| print("=" * 60) |
| for label, path in CHECKPOINTS.items(): |
| evaluate(path, label) |
| print("=" * 60) |
| print("Use the checkpoint with the lowest CER for IAM/physical fine-tuning.") |