ocr / inference.py
hanz245's picture
set up
7111e1a
"""
Inference Script for CRNN+CTC Civil Registry OCR
TWO NORMALIZERS:
SimpleNormalizer β€” for PIL-rendered synthetic images (matches training exactly)
AdaptiveNormalizer β€” for physical/scanned images (any zoom, any size)
AUTO-DETECT MODE: automatically decides which pipeline to use based on
text density in the image β€” zoomed-in images get adaptive treatment,
clean synthetic images get simple treatment.
"""
import torch
import cv2
import numpy as np
from pathlib import Path
from typing import Dict, List
from crnn_model import get_crnn_model
from utils import decode_ctc_predictions, extract_form_fields
# ─────────────────────────────────────────────────────────────────────────────
# HELPERS
# ─────────────────────────────────────────────────────────────────────────────
def _to_gray(img: np.ndarray) -> np.ndarray:
if len(img.shape) == 3:
return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img.copy()
def _binarize(gray: np.ndarray) -> np.ndarray:
"""Otsu, falls back to adaptive for uneven backgrounds."""
_, otsu = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
white_ratio = np.mean(otsu == 255)
if white_ratio < 0.30 or white_ratio > 0.97:
return cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY, 11, 2)
return otsu
def _crop_to_text(gray: np.ndarray, pad_ratio=0.15) -> np.ndarray:
"""Crop tightly around dark pixels (the text)."""
inv = cv2.bitwise_not(gray)
_, thresh = cv2.threshold(inv, 20, 255, cv2.THRESH_BINARY)
coords = np.column_stack(np.where(thresh > 0))
if len(coords) == 0:
return gray
y_min, x_min = coords.min(axis=0)
y_max, x_max = coords.max(axis=0)
pad = max(4, int((y_max - y_min) * pad_ratio))
y_min = max(0, y_min - pad)
x_min = max(0, x_min - pad)
y_max = min(gray.shape[0] - 1, y_max + pad)
x_max = min(gray.shape[1] - 1, x_max + pad)
return gray[y_min:y_max+1, x_min:x_max+1]
def _aspect_resize(gray: np.ndarray, H: int, W: int) -> np.ndarray:
"""Resize preserving aspect ratio, pad with white to fill canvas."""
h, w = gray.shape
if h == 0 or w == 0:
return np.ones((H, W), dtype=np.uint8) * 255
scale = H / h
new_w = int(w * scale)
new_h = H
if new_w > W:
scale = W / w
new_h = int(h * scale)
new_w = W
resized = cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
canvas = np.ones((H, W), dtype=np.uint8) * 255
y_off = (H - new_h) // 2
x_off = (W - new_w) // 2
canvas[y_off:y_off+new_h, x_off:x_off+new_w] = resized
return canvas
def _detect_mode(gray: np.ndarray) -> str:
"""
Auto-detect whether image needs adaptive or simple normalization.
Logic:
- If >25% of pixels are dark, text is very large/zoomed β†’ adaptive.
- If image size is far from training size (512x64) β†’ adaptive.
- Otherwise β†’ simple (matches training pipeline).
"""
h, w = gray.shape
_, bw = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
dark_px = np.mean(bw == 0)
# Text fills too much of the image β†’ zoomed in (like shane.jpg)
if dark_px > 0.25:
return 'adaptive'
# Image is far from expected training size (allow 50% tolerance)
if not (256 <= w <= 1024 and 32 <= h <= 128):
return 'adaptive'
return 'simple'
def _to_tensor(img: np.ndarray) -> torch.Tensor:
return torch.FloatTensor(
img.astype(np.float32) / 255.0
).unsqueeze(0).unsqueeze(0)
# ─────────────────────────────────────────────────────────────────────────────
# SIMPLE NORMALIZER ← for PIL-rendered / training-matched images
# ─────────────────────────────────────────────────────────────────────────────
class SimpleNormalizer:
"""
Matches fix_data.py training pipeline exactly:
grayscale β†’ resize β†’ binarize
Best for test images created by create_test_images.py.
"""
def __init__(self, H=64, W=512):
self.H, self.W = H, W
def normalize(self, img: np.ndarray) -> np.ndarray:
gray = _to_gray(img)
resized = cv2.resize(gray, (self.W, self.H), interpolation=cv2.INTER_LANCZOS4)
return _binarize(resized)
def normalize_from_path(self, path: str) -> np.ndarray:
img = cv2.imread(str(path))
if img is None:
raise ValueError(f"Cannot load: {path}")
return self.normalize(img)
# ─────────────────────────────────────────────────────────────────────────────
# ADAPTIVE NORMALIZER ← for real / physical / scanned images
# ─────────────────────────────────────────────────────────────────────────────
class AdaptiveNormalizer:
"""
For physical documents or images with non-standard zoom/size:
grayscale β†’ denoise β†’ crop text β†’ aspect-ratio resize β†’ binarize
Crops to actual text first, so a zoomed-in image like shane.jpg
gets scaled down to training size instead of being squeezed/stretched.
"""
def __init__(self, H=64, W=512):
self.H, self.W = H, W
def normalize(self, img: np.ndarray) -> np.ndarray:
gray = _to_gray(img)
gray = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
gray = _crop_to_text(gray)
canvas = _aspect_resize(gray, self.H, self.W)
return _binarize(canvas)
def normalize_from_path(self, path: str) -> np.ndarray:
img = cv2.imread(str(path))
if img is None:
raise ValueError(f"Cannot load: {path}")
return self.normalize(img)
# ─────────────────────────────────────────────────────────────────────────────
# AUTO NORMALIZER ← detects which pipeline to use per image automatically
# ─────────────────────────────────────────────────────────────────────────────
class AutoNormalizer:
"""
Automatically picks Simple or Adaptive based on image characteristics.
Examples:
demo.jpg (clean 512x64 PIL) β†’ Simple (matches training)
name1.jpg (clean 512x64 PIL) β†’ Simple
shane.jpg (huge zoomed text) β†’ Adaptive (crop then resize)
real scan (any size/zoom) β†’ Adaptive
"""
def __init__(self, H=64, W=512, verbose=False):
self.H, self.W = H, W
self.verbose = verbose
self._simple = SimpleNormalizer(H, W)
self._adaptive = AdaptiveNormalizer(H, W)
def normalize(self, img: np.ndarray) -> np.ndarray:
gray = _to_gray(img)
mode = _detect_mode(gray)
if self.verbose:
print(f" auto β†’ {mode}")
return self._simple.normalize(img) if mode == 'simple' \
else self._adaptive.normalize(img)
def normalize_from_path(self, path: str) -> np.ndarray:
img = cv2.imread(str(path))
if img is None:
raise ValueError(f"Cannot load: {path}")
gray = _to_gray(img)
mode = _detect_mode(gray)
if self.verbose:
print(f" [{Path(path).name}] β†’ {mode}")
return self._simple.normalize(img) if mode == 'simple' \
else self._adaptive.normalize(img)
def to_tensor(self, img: np.ndarray) -> torch.Tensor:
return _to_tensor(img)
# ─────────────────────────────────────────────────────────────────────────────
# MAIN OCR CLASS
# ─────────────────────────────────────────────────────────────────────────────
class CivilRegistryOCR:
def __init__(self, checkpoint_path, device='cuda', mode='auto', verbose=False):
"""
Args:
checkpoint_path : path to best_model.pth
device : 'cuda' or 'cpu'
mode : 'auto' β†’ auto-detect per image (recommended)
'simple' β†’ always use simple pipeline
'adaptive' β†’ always use adaptive pipeline
verbose : print which mode was chosen per image
"""
if device == 'cuda' and not torch.cuda.is_available():
device = 'cpu'
self.device = torch.device(device)
self.verbose = verbose
print(f"Loading model from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=self.device,
weights_only=False)
self.char_to_idx = checkpoint['char_to_idx']
self.idx_to_char = checkpoint['idx_to_char']
self.config = checkpoint.get('config', {})
img_height = self.config.get('img_height', 64)
img_width = self.config.get('img_width', 512)
if mode == 'simple':
self.normalizer = SimpleNormalizer(img_height, img_width)
elif mode == 'adaptive':
self.normalizer = AdaptiveNormalizer(img_height, img_width)
else:
self.normalizer = AutoNormalizer(img_height, img_width, verbose=verbose)
self.model = get_crnn_model(
model_type=self.config.get('model_type', 'standard'),
img_height=img_height,
num_chars=checkpoint['model_state_dict']['fc.weight'].shape[0],
hidden_size=self.config.get('hidden_size', 128),
num_lstm_layers=self.config.get('num_lstm_layers', 1)
)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
print(f"Model loaded successfully")
# Support both key names: val_loss (fine-tuned) and val_cer (synthetic baseline)
# FIXED Bug 5: removed incorrect `val_cer < 10` heuristic that mislabelled
# the metric. The key name alone is the reliable indicator.
val_loss = checkpoint.get('val_loss', None)
val_cer = checkpoint.get('val_cer', None)
if val_loss is not None and val_cer is not None:
print(f" Val Loss : {val_loss:.4f} | Val CER: {val_cer:.2f}%")
elif val_loss is not None:
print(f" Val Loss : {val_loss:.4f} (fine-tuned checkpoint β€” run compare_live_cer.py for true CER)")
elif val_cer is not None:
print(f" Val CER : {val_cer:.2f}%")
else:
print(f" Val CER : N/A (run check_cer.py for true CER)")
print(f" Device : {self.device}")
print(f" Mode : {mode} ({img_height}x{img_width})")
def _preprocess(self, image_path) -> torch.Tensor:
normalized = self.normalizer.normalize_from_path(str(image_path))
return _to_tensor(normalized)
def predict(self, image_path, decode_method='greedy') -> str:
img = self._preprocess(image_path).to(self.device)
with torch.no_grad():
outputs = self.model(img)
decoded = decode_ctc_predictions(
outputs.cpu(), self.idx_to_char, method=decode_method)
return decoded[0]
def predict_batch(self, image_paths, decode_method='greedy') -> List[Dict]:
results = []
for image_path in image_paths:
try:
text = self.predict(image_path, decode_method)
results.append({'image_path': str(image_path),
'text': text, 'success': True})
except Exception as e:
results.append({'image_path': str(image_path),
'error': str(e), 'success': False})
return results
def process_form(self, form_image_path, form_type) -> Dict:
text = self.predict(form_image_path)
fields = extract_form_fields(text, form_type)
fields['raw_text'] = text
return fields
# ─────────────────────────────────────────────────────────────────────────────
# FORM FIELD EXTRACTOR
# ─────────────────────────────────────────────────────────────────────────────
class FormFieldExtractor:
def __init__(self, ocr_model: CivilRegistryOCR):
self.ocr = ocr_model
def extract_form1a_fields(self, path):
text = self.ocr.predict(path)
return {'form_type': 'Form 1A - Birth Certificate', 'raw_text': text}
def extract_form2a_fields(self, path):
text = self.ocr.predict(path)
return {'form_type': 'Form 2A - Death Certificate', 'raw_text': text}
def extract_form3a_fields(self, path):
text = self.ocr.predict(path)
return {'form_type': 'Form 3A - Marriage Certificate', 'raw_text': text}
def extract_form90_fields(self, path):
text = self.ocr.predict(path)
return {'form_type': 'Form 90 - Marriage License Application',
'raw_text': text}
# ─────────────────────────────────────────────────────────────────────────────
# DEMO
# ─────────────────────────────────────────────────────────────────────────────
def demo_inference():
print("=" * 70)
print("Civil Registry OCR (auto-adaptive normalizer)")
print("=" * 70)
ocr = CivilRegistryOCR(
checkpoint_path='checkpoints/best_model.pth',
device='cuda',
mode='auto',
verbose=True # shows which mode each image triggers
)
print("\n1. Single Prediction:")
try:
result = ocr.predict('test_images/date1.jpg')
print(f" Recognized text: {result}")
except Exception as e:
print(f" Error: {e}")
print("\n2. Batch Prediction:")
'''batch_results = ocr.predict_batch([
'test_images/name1.jpg',
'test_images/shane.jpg',
'test_images/date1.jpg',
'test_images/place1.jpg',
])
for r in batch_results:
status = r['text'] if r['success'] else f"ERROR - {r['error']}"
print(f" {r['image_path']}: {status}")'''
print("\n3. Form Processing:")
try:
form_data = ocr.process_form('test_images/form1a_sample.jpg', 'form1a')
print(f" Form Type: Form 1A - Birth Certificate")
print(f" Raw Text: {form_data['raw_text']}")
except Exception as e:
print(f" Error: {e}")
def create_inference_api():
class OCR_API:
def __init__(self, checkpoint_path, mode='auto'):
self.ocr = CivilRegistryOCR(checkpoint_path, mode=mode)
self.extractor = FormFieldExtractor(self.ocr)
def recognize_text(self, p):
return {'text': self.ocr.predict(p), 'success': True}
def process_birth_certificate(self, p):
return self.extractor.extract_form1a_fields(p)
def process_death_certificate(self, p):
return self.extractor.extract_form2a_fields(p)
def process_marriage_certificate(self, p):
return self.extractor.extract_form3a_fields(p)
def process_marriage_license(self, p):
return self.extractor.extract_form90_fields(p)
return OCR_API
if __name__ == "__main__":
demo_inference()