Pujan-Dev's picture
test
f9003ec
"""
OCR Model Definition and Inference for Number Plate Character Recognition
"""
import torch
import torch.nn as nn
import numpy as np
from torchvision import models
import json
from sklearn.preprocessing import LabelEncoder
from pathlib import Path
import cv2
import sys
sys.path.append(str(Path(__file__).parent.parent))
from config.config import OCR_CONFIG, PREPROCESS_CONFIG, get_device
class OCRModel(nn.Module):
"""
ResNet18-based OCR model for character recognition.
Supports grayscale input images.
"""
def __init__(self, num_classes: int):
super(OCRModel, self).__init__()
# Use ResNet18 as backbone
self.features = models.resnet18(pretrained=OCR_CONFIG.get("pretrained", False))
# Modify first conv layer to accept single channel (grayscale)
self.features.conv1 = nn.Conv2d(
1, 64, kernel_size=7, stride=2, padding=3, bias=False
)
# Remove the original FC layer
self.features.fc = nn.Identity()
# Custom classifier head
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.features(x)
return self.classifier(features)
class CharacterRecognizer:
"""
High-level wrapper for character recognition.
Handles model loading, preprocessing, and inference.
"""
def __init__(self, model_path: str, label_map_path: str, device: torch.device = None):
self.device = device or get_device()
self.model_path = Path(model_path)
self.label_map_path = Path(label_map_path)
# Load label map
self._load_label_map()
# Initialize and load model
self._load_model()
# Setup CLAHE
self.clahe = cv2.createCLAHE(
clipLimit=PREPROCESS_CONFIG["clahe_clip_limit"],
tileGridSize=PREPROCESS_CONFIG["clahe_grid_size"]
)
def _load_label_map(self):
"""Load label map from JSON file."""
with open(self.label_map_path, 'r', encoding='utf-8') as f:
self.label_map = json.load(f)
self.num_classes = len(self.label_map)
# Setup label encoder
self.label_encoder = LabelEncoder()
self.label_encoder.classes_ = np.array([
self.label_map[str(i)] for i in range(self.num_classes)
])
def _load_model(self):
"""Load trained model weights."""
self.model = OCRModel(self.num_classes).to(self.device)
self.model.load_state_dict(
torch.load(self.model_path, map_location=self.device)
)
self.model.eval()
print(f"✓ OCR Model loaded on: {self.device}")
def preprocess(self, img_region: np.ndarray) -> tuple:
"""
Preprocess image region for OCR.
Args:
img_region: Grayscale image region (numpy array)
Returns:
Tuple of (tensor, preprocessed_image)
"""
input_size = OCR_CONFIG["input_size"]
# Resize to model input size
img_resized = cv2.resize(img_region, input_size)
# Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
img_eq = self.clahe.apply(img_resized)
# Apply Gaussian blur to reduce noise
img_blur = cv2.GaussianBlur(
img_eq, PREPROCESS_CONFIG["gaussian_blur_kernel"], 0
)
# Convert to tensor and normalize
img_tensor = torch.from_numpy(img_blur).unsqueeze(0).unsqueeze(0).float() / 255.0
img_tensor = img_tensor.to(self.device)
return img_tensor, img_blur
def predict(self, img_region: np.ndarray) -> tuple:
"""
Perform OCR on a single image region.
Args:
img_region: Grayscale image region
Returns:
Tuple of (predicted_char, confidence, preprocessed_image)
"""
img_tensor, preprocessed_img = self.preprocess(img_region)
with torch.no_grad():
output = self.model(img_tensor)
predicted_index = output.argmax(dim=1).item()
confidence = torch.softmax(output, dim=1).max().item()
predicted_char = self.label_encoder.inverse_transform([predicted_index])[0]
return predicted_char, confidence, preprocessed_img
def predict_batch(self, img_regions: list) -> list:
"""
Perform OCR on multiple image regions.
Args:
img_regions: List of grayscale image regions
Returns:
List of (predicted_char, confidence, preprocessed_image) tuples
"""
if not img_regions:
return []
# Preprocess all images
tensors = []
preprocessed_imgs = []
for img in img_regions:
tensor, preprocessed = self.preprocess(img)
tensors.append(tensor)
preprocessed_imgs.append(preprocessed)
# Stack tensors for batch inference
batch_tensor = torch.cat(tensors, dim=0)
with torch.no_grad():
outputs = self.model(batch_tensor)
predicted_indices = outputs.argmax(dim=1).cpu().numpy()
confidences = torch.softmax(outputs, dim=1).max(dim=1).values.cpu().numpy()
# Decode predictions
predicted_chars = self.label_encoder.inverse_transform(predicted_indices)
return list(zip(predicted_chars, confidences, preprocessed_imgs))
def get_top_k_predictions(self, img_region: np.ndarray, k: int = 5) -> list:
"""
Get top-k predictions with confidence scores.
Args:
img_region: Grayscale image region
k: Number of top predictions to return
Returns:
List of (char, confidence) tuples
"""
img_tensor, _ = self.preprocess(img_region)
with torch.no_grad():
output = self.model(img_tensor)
probs = torch.softmax(output, dim=1)[0]
top_k = torch.topk(probs, k)
results = []
for idx, conf in zip(top_k.indices.cpu().numpy(), top_k.values.cpu().numpy()):
char = self.label_encoder.inverse_transform([idx])[0]
results.append((char, float(conf)))
return results