Pujan-Dev's picture
test
f9003ec
#!/usr/bin/env python3
"""
============== COMPLETE OCR PIPELINE (Multi-Line Support) ==============
This pipeline combines:
1. YOLO-based number plate detection
2. Character segmentation using contour detection
3. OCR using a ResNet18-based model
4. Multi-line plate support (for Nepali plates)
Usage:
python main.py <image_path>
python main.py <image_path> --no-yolo # Skip YOLO detection
python main.py <image_path> --save # Save results
"""
import cv2
import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import json
# Local imports
from config.config import (
CONTOUR_CONFIG, INFERENCE_CONFIG, VIZ_CONFIG,
OCR_MODEL_PATH, LABEL_MAP_PATH, YOLO_MODEL_PATH,
setup_directories, get_device, RESULTS_DIR, CONTOURS_BW_DIR
)
from model.ocr import CharacterRecognizer
from model.plate_detector import get_detector
from utils.helper import (
detect_contours, filter_contours_by_size, extract_roi,
convert_to_binary, remove_overlapping_centers,
group_contours_by_line, format_plate_number,
draw_detections, calculate_confidence_stats, save_contour_images
)
class NumberPlateOCR:
"""
Complete Number Plate OCR Pipeline.
Supports:
- YOLO-based plate detection (optional)
- Multi-line plate recognition
- Nepali and English characters
- Embossed number plates
"""
def __init__(self, use_yolo: bool = True, verbose: bool = True):
"""
Initialize the OCR pipeline.
Args:
use_yolo: Whether to use YOLO for plate detection
verbose: Print progress messages
"""
self.verbose = verbose
self.device = get_device()
# Setup directories
setup_directories()
# Initialize OCR model
self._log("Loading OCR model...")
self.ocr = CharacterRecognizer(
model_path=str(OCR_MODEL_PATH),
label_map_path=str(LABEL_MAP_PATH),
device=self.device
)
# Initialize plate detector
self.use_yolo = use_yolo
if use_yolo:
self._log("Loading YOLO plate detector...")
self.detector = get_detector(use_yolo=True, model_path=str(YOLO_MODEL_PATH))
else:
self.detector = None
self._log("✓ Pipeline initialized successfully!")
@staticmethod
def _is_nepali_token(token: str) -> bool:
"""Check if token is Nepali (Devanagari) or Nepali-specific label."""
if not token:
return False
if token == "Nepali Flag":
return True
return any('\u0900' <= ch <= '\u097F' for ch in token)
@staticmethod
def _is_english_token(token: str) -> bool:
"""Check if token is plain English alphanumeric."""
if not token:
return False
return all(('0' <= ch <= '9') or ('A' <= ch <= 'Z') or ('a' <= ch <= 'z') for ch in token)
@staticmethod
def _english_digit_to_nepali(token: str) -> str:
"""Convert English digits to Nepali digits (keeps non-digits unchanged)."""
digit_map = str.maketrans("0123456789", "०१२३४५६७८९")
return token.translate(digit_map)
def _apply_nepali_dominant_correction(self, line_results: List[Dict]):
"""
If a line is predominantly Nepali, replace English predictions using
next Nepali top-k prediction from OCR model.
"""
if not line_results:
return
nepali_count = sum(1 for r in line_results if self._is_nepali_token(r['char']))
english_count = sum(1 for r in line_results if self._is_english_token(r['char']))
if nepali_count <= english_count:
return
for r in line_results:
curr_char = r['char']
if not self._is_english_token(curr_char):
continue
replacement_char = None
replacement_conf = None
top_k = self.ocr.get_top_k_predictions(r['_roi_bw'], k=5)
for candidate_char, candidate_conf in top_k[1:]:
if self._is_nepali_token(candidate_char):
replacement_char = candidate_char
replacement_conf = candidate_conf
break
if replacement_char is None and any(ch.isdigit() for ch in curr_char):
replacement_char = self._english_digit_to_nepali(curr_char)
replacement_conf = r['conf']
if replacement_char is not None:
r['char'] = replacement_char
r['conf'] = float(replacement_conf)
def _log(self, message: str):
"""Print log message if verbose."""
if self.verbose:
print(message)
def process_image(self, image_path: str,
save_contours: bool = False,
show_visualization: bool = True) -> Dict:
"""
Process an image and extract plate number.
Args:
image_path: Path to input image
save_contours: Whether to save extracted character images
show_visualization: Whether to display matplotlib visualizations
Returns:
Dict with recognition results
"""
# Load image
self._log(f"\n{'='*60}")
self._log(f"Processing: {image_path}")
self._log(f"{'='*60}")
orig_image = cv2.imread(image_path)
gray_image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if orig_image is None:
raise ValueError(f"Could not load image: {image_path}")
# Step 1: Detect plates (optional YOLO step)
if self.use_yolo and self.detector:
self._log("\n📍 Step 1: Detecting number plates with YOLO...")
plates = self._detect_plates(orig_image)
if not plates:
self._log("⚠ No plates detected by YOLO, processing full image...")
plates = [{'plate_image': orig_image, 'bbox': None, 'confidence': 1.0}]
else:
self._log("\n📍 Step 1: Using full image (YOLO disabled)...")
plates = [{'plate_image': orig_image, 'bbox': None, 'confidence': 1.0}]
# Process each detected plate
all_results = []
for plate_idx, plate_data in enumerate(plates):
self._log(f"\n📋 Processing Plate {plate_idx + 1}/{len(plates)}")
plate_img = plate_data['plate_image']
plate_gray = cv2.cvtColor(plate_img, cv2.COLOR_BGR2GRAY) if len(plate_img.shape) == 3 else plate_img
# Step 2: Extract character contours
self._log("📍 Step 2: Detecting character contours...")
contours = self._extract_contours(plate_gray, plate_img)
if not contours:
self._log("⚠ No characters detected in plate")
continue
# Save contours if requested
if save_contours:
self._log(f" Saving contour images to {CONTOURS_BW_DIR}")
save_contour_images(contours, plate_img, str(CONTOURS_BW_DIR))
# Step 3: Group by lines
self._log("📍 Step 3: Grouping characters by lines...")
lines = group_contours_by_line(contours)
self._log(f" Detected {len(lines)} line(s)")
for i, line in enumerate(lines):
self._log(f" Line {i+1}: {len(line)} characters")
# Step 4: Run OCR
self._log("📍 Step 4: Running OCR on characters...")
ocr_results = self._run_ocr(lines, plate_img)
# Step 5: Format results
formatted = format_plate_number(lines, ocr_results)
confidence_stats = calculate_confidence_stats(ocr_results)
result = {
'plate_index': plate_idx,
'plate_bbox': plate_data['bbox'],
'plate_confidence': plate_data.get('confidence', 1.0),
'plate_image': plate_img,
'lines': formatted['lines'],
'multiline_text': formatted['multiline'],
'singleline_text': formatted['singleline'],
'num_lines': formatted['num_lines'],
'total_chars': formatted['total_chars'],
'details': formatted['details'],
'confidence_stats': confidence_stats,
'raw_ocr_results': ocr_results
}
all_results.append(result)
# Visualize
if show_visualization:
self._visualize_plate(plate_img, lines, ocr_results, plate_idx)
# Print final summary
self._print_results(all_results)
return {
'image_path': image_path,
'num_plates': len(all_results),
'plates': all_results
}
def _detect_plates(self, image: np.ndarray) -> List[Dict]:
"""Detect plates using YOLO."""
detections = self.detector.detect(image)
self._log(f" Found {len(detections)} plate(s)")
for i, det in enumerate(detections):
self._log(f" Plate {i+1}: confidence={det['confidence']:.2%}")
return detections
def _extract_contours(self, gray_image: np.ndarray,
color_image: np.ndarray) -> List[Dict]:
"""Extract and filter character contours."""
# Detect contours
contours, hierarchy, thresh = detect_contours(gray_image)
self._log(f" Total contours found: {len(contours)}")
# Filter by size
filtered = filter_contours_by_size(contours, gray_image.shape)
self._log(f" After size filter: {len(filtered)}")
# Sort by x position
sorted_contours = sorted(filtered, key=lambda c: (c['x'], c['y']))
# Remove only true edge artifacts (do not blindly drop first contours)
remove_edge_artifacts = CONTOUR_CONFIG.get("remove_edge_artifacts", True)
edge_margin = CONTOUR_CONFIG.get("edge_margin", 2)
if remove_edge_artifacts and len(sorted_contours) > 4:
image_h, image_w = gray_image.shape[:2]
non_edge_contours = [
c for c in sorted_contours
if (
c['x'] > edge_margin and
c['y'] > edge_margin and
(c['x'] + c['w']) < (image_w - edge_margin) and
(c['y'] + c['h']) < (image_h - edge_margin)
)
]
# Keep edge filtering only if it does not remove too many candidates
if len(non_edge_contours) >= max(3, int(0.6 * len(sorted_contours))):
sorted_contours = non_edge_contours
self._log(f" After edge-artifact filter: {len(sorted_contours)}")
# Extract ROI for each contour
for c in sorted_contours:
roi = extract_roi(color_image, c)
c['roi_bw'] = convert_to_binary(roi)
# Remove overlapping centers (like inner hole of '0')
final_contours = remove_overlapping_centers(sorted_contours, verbose=self.verbose)
removed = len(sorted_contours) - len(final_contours)
if removed > 0:
self._log(f" Removed {removed} overlapping contours")
return final_contours
def _run_ocr(self, lines: List[List[Dict]],
plate_image: np.ndarray) -> List[List[Dict]]:
"""Run OCR on grouped character lines."""
min_confidence = INFERENCE_CONFIG["min_confidence"]
results_by_line = []
for line_idx, line in enumerate(lines):
line_results = []
for c in line:
char, conf, processed_img = self.ocr.predict(c['roi_bw'])
if conf > min_confidence:
line_results.append({
'char': char,
'conf': conf,
'x': c['x'],
'y': c['y'],
'w': c['w'],
'h': c['h'],
'processed_img': processed_img,
'_roi_bw': c['roi_bw']
})
self._apply_nepali_dominant_correction(line_results)
for r in line_results:
r.pop('_roi_bw', None)
results_by_line.append(line_results)
total_chars = sum(len(line) for line in results_by_line)
self._log(f" Characters with confidence > {min_confidence*100:.0f}%: {total_chars}")
return results_by_line
def _visualize_plate(self, plate_image: np.ndarray,
lines: List[List[Dict]],
ocr_results: List[List[Dict]],
plate_idx: int):
"""Visualize OCR results."""
if not VIZ_CONFIG["show_plots"]:
return
# Show original plate
plt.figure(figsize=VIZ_CONFIG["figure_size"])
plt.imshow(cv2.cvtColor(plate_image, cv2.COLOR_BGR2RGB))
plt.title(f'Plate {plate_idx + 1} - {len(lines)} Line(s) Detected')
plt.axis('off')
plt.show()
# Show OCR results for each line
for line_idx, line_results in enumerate(ocr_results):
n = len(line_results)
if n > 0:
cols = min(VIZ_CONFIG["max_cols"], n)
rows = (n + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols*1.5, rows*2))
axes = np.array(axes).reshape(-1) if n > 1 else [axes]
for i, r in enumerate(line_results):
axes[i].imshow(r['processed_img'], cmap='gray')
axes[i].set_title(f'"{r["char"]}" ({r["conf"]:.0%})',
fontsize=VIZ_CONFIG["font_size"])
axes[i].axis('off')
# Hide empty subplots
for i in range(n, len(axes)):
axes[i].axis('off')
line_text = "".join([r['char'] for r in line_results])
plt.suptitle(f'Line {line_idx+1}: "{line_text}"', fontsize=12)
plt.tight_layout()
plt.show()
def _print_results(self, results: List[Dict]):
"""Print formatted results."""
print("\n" + "="*60)
print("📋 PLATE NUMBER RECOGNITION RESULTS")
print("="*60)
for result in results:
plate_idx = result['plate_index'] + 1
print(f"\n🏷️ PLATE {plate_idx}:")
print("-"*40)
for line_detail in result['details']:
print(f"\n 📌 Line {line_detail['line_num']}:")
for i, char_info in enumerate(line_detail['characters']):
print(f" {i+1}. '{char_info['char']}' ({char_info['conf']:.1%})")
print(f" → Result: {line_detail['text']}")
# Final result
print("\n" + "-"*40)
if result['num_lines'] > 1:
print(" Multi-line format:")
for i, line in enumerate(result['lines']):
print(f" Line {i+1}: {line}")
print(f"\n Single-line: {result['singleline_text']}")
else:
text = result['lines'][0] if result['lines'] else 'No characters detected'
print(f" Result: {text}")
# Confidence stats
stats = result['confidence_stats']
print(f"\n Confidence: avg={stats['mean']:.1%}, min={stats['min']:.1%}, max={stats['max']:.1%}")
print("\n" + "="*60)
def process_from_plate_image(self, plate_image: np.ndarray,
show_visualization: bool = True) -> Dict:
"""
Process a pre-cropped plate image (skip YOLO detection).
Args:
plate_image: Cropped plate image (BGR)
show_visualization: Whether to show plots
Returns:
Recognition result dict
"""
plate_gray = cv2.cvtColor(plate_image, cv2.COLOR_BGR2GRAY) if len(plate_image.shape) == 3 else plate_image
# Extract contours
contours = self._extract_contours(plate_gray, plate_image)
if not contours:
return {'lines': [], 'singleline_text': '', 'total_chars': 0}
# Group by lines
lines = group_contours_by_line(contours)
# Run OCR
ocr_results = self._run_ocr(lines, plate_image)
# Format results
formatted = format_plate_number(lines, ocr_results)
if show_visualization:
self._visualize_plate(plate_image, lines, ocr_results, 0)
return {
'lines': formatted['lines'],
'multiline_text': formatted['multiline'],
'singleline_text': formatted['singleline'],
'num_lines': formatted['num_lines'],
'total_chars': formatted['total_chars'],
'details': formatted['details'],
'confidence_stats': calculate_confidence_stats(ocr_results)
}
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Number Plate OCR Pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python main.py image.jpg
python main.py image.jpg --no-yolo
python main.py image.jpg --save --no-viz
python main.py image.jpg --output results.json
"""
)
parser.add_argument('image', type=str, help='Path to input image')
parser.add_argument('--no-yolo', action='store_true',
help='Skip YOLO plate detection')
parser.add_argument('--save', action='store_true',
help='Save extracted character images')
parser.add_argument('--no-viz', action='store_true',
help='Disable visualization')
parser.add_argument('--output', '-o', type=str,
help='Save results to JSON file')
parser.add_argument('--quiet', '-q', action='store_true',
help='Suppress progress messages')
args = parser.parse_args()
# Validate input
if not os.path.exists(args.image):
print(f"Error: Image not found: {args.image}")
return 1
# Initialize pipeline
pipeline = NumberPlateOCR(
use_yolo=not args.no_yolo,
verbose=not args.quiet
)
# Process image
results = pipeline.process_image(
args.image,
save_contours=args.save,
show_visualization=not args.no_viz
)
# Save results if requested
if args.output:
# Remove non-serializable items
save_results = {
'image_path': results['image_path'],
'num_plates': results['num_plates'],
'plates': []
}
for plate in results['plates']:
save_plate = {
'plate_index': plate['plate_index'],
'plate_bbox': plate['plate_bbox'],
'lines': plate['lines'],
'multiline_text': plate['multiline_text'],
'singleline_text': plate['singleline_text'],
'num_lines': plate['num_lines'],
'total_chars': plate['total_chars'],
'confidence_stats': plate['confidence_stats']
}
save_results['plates'].append(save_plate)
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(save_results, f, indent=2, ensure_ascii=False)
print(f"\n✓ Results saved to: {args.output}")
return 0
if __name__ == "__main__":
exit(main())