Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| ============== COMPLETE OCR PIPELINE (Multi-Line Support) ============== | |
| This pipeline combines: | |
| 1. YOLO-based number plate detection | |
| 2. Character segmentation using contour detection | |
| 3. OCR using a ResNet18-based model | |
| 4. Multi-line plate support (for Nepali plates) | |
| Usage: | |
| python main.py <image_path> | |
| python main.py <image_path> --no-yolo # Skip YOLO detection | |
| python main.py <image_path> --save # Save results | |
| """ | |
| import cv2 | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import argparse | |
| import os | |
| from pathlib import Path | |
| from typing import List, Dict, Optional, Tuple | |
| import json | |
| # Local imports | |
| from config.config import ( | |
| CONTOUR_CONFIG, INFERENCE_CONFIG, VIZ_CONFIG, | |
| OCR_MODEL_PATH, LABEL_MAP_PATH, YOLO_MODEL_PATH, | |
| setup_directories, get_device, RESULTS_DIR, CONTOURS_BW_DIR | |
| ) | |
| from model.ocr import CharacterRecognizer | |
| from model.plate_detector import get_detector | |
| from utils.helper import ( | |
| detect_contours, filter_contours_by_size, extract_roi, | |
| convert_to_binary, remove_overlapping_centers, | |
| group_contours_by_line, format_plate_number, | |
| draw_detections, calculate_confidence_stats, save_contour_images | |
| ) | |
| class NumberPlateOCR: | |
| """ | |
| Complete Number Plate OCR Pipeline. | |
| Supports: | |
| - YOLO-based plate detection (optional) | |
| - Multi-line plate recognition | |
| - Nepali and English characters | |
| - Embossed number plates | |
| """ | |
| def __init__(self, use_yolo: bool = True, verbose: bool = True): | |
| """ | |
| Initialize the OCR pipeline. | |
| Args: | |
| use_yolo: Whether to use YOLO for plate detection | |
| verbose: Print progress messages | |
| """ | |
| self.verbose = verbose | |
| self.device = get_device() | |
| # Setup directories | |
| setup_directories() | |
| # Initialize OCR model | |
| self._log("Loading OCR model...") | |
| self.ocr = CharacterRecognizer( | |
| model_path=str(OCR_MODEL_PATH), | |
| label_map_path=str(LABEL_MAP_PATH), | |
| device=self.device | |
| ) | |
| # Initialize plate detector | |
| self.use_yolo = use_yolo | |
| if use_yolo: | |
| self._log("Loading YOLO plate detector...") | |
| self.detector = get_detector(use_yolo=True, model_path=str(YOLO_MODEL_PATH)) | |
| else: | |
| self.detector = None | |
| self._log("✓ Pipeline initialized successfully!") | |
| def _is_nepali_token(token: str) -> bool: | |
| """Check if token is Nepali (Devanagari) or Nepali-specific label.""" | |
| if not token: | |
| return False | |
| if token == "Nepali Flag": | |
| return True | |
| return any('\u0900' <= ch <= '\u097F' for ch in token) | |
| def _is_english_token(token: str) -> bool: | |
| """Check if token is plain English alphanumeric.""" | |
| if not token: | |
| return False | |
| return all(('0' <= ch <= '9') or ('A' <= ch <= 'Z') or ('a' <= ch <= 'z') for ch in token) | |
| def _english_digit_to_nepali(token: str) -> str: | |
| """Convert English digits to Nepali digits (keeps non-digits unchanged).""" | |
| digit_map = str.maketrans("0123456789", "०१२३४५६७८९") | |
| return token.translate(digit_map) | |
| def _apply_nepali_dominant_correction(self, line_results: List[Dict]): | |
| """ | |
| If a line is predominantly Nepali, replace English predictions using | |
| next Nepali top-k prediction from OCR model. | |
| """ | |
| if not line_results: | |
| return | |
| nepali_count = sum(1 for r in line_results if self._is_nepali_token(r['char'])) | |
| english_count = sum(1 for r in line_results if self._is_english_token(r['char'])) | |
| if nepali_count <= english_count: | |
| return | |
| for r in line_results: | |
| curr_char = r['char'] | |
| if not self._is_english_token(curr_char): | |
| continue | |
| replacement_char = None | |
| replacement_conf = None | |
| top_k = self.ocr.get_top_k_predictions(r['_roi_bw'], k=5) | |
| for candidate_char, candidate_conf in top_k[1:]: | |
| if self._is_nepali_token(candidate_char): | |
| replacement_char = candidate_char | |
| replacement_conf = candidate_conf | |
| break | |
| if replacement_char is None and any(ch.isdigit() for ch in curr_char): | |
| replacement_char = self._english_digit_to_nepali(curr_char) | |
| replacement_conf = r['conf'] | |
| if replacement_char is not None: | |
| r['char'] = replacement_char | |
| r['conf'] = float(replacement_conf) | |
| def _log(self, message: str): | |
| """Print log message if verbose.""" | |
| if self.verbose: | |
| print(message) | |
| def process_image(self, image_path: str, | |
| save_contours: bool = False, | |
| show_visualization: bool = True) -> Dict: | |
| """ | |
| Process an image and extract plate number. | |
| Args: | |
| image_path: Path to input image | |
| save_contours: Whether to save extracted character images | |
| show_visualization: Whether to display matplotlib visualizations | |
| Returns: | |
| Dict with recognition results | |
| """ | |
| # Load image | |
| self._log(f"\n{'='*60}") | |
| self._log(f"Processing: {image_path}") | |
| self._log(f"{'='*60}") | |
| orig_image = cv2.imread(image_path) | |
| gray_image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) | |
| if orig_image is None: | |
| raise ValueError(f"Could not load image: {image_path}") | |
| # Step 1: Detect plates (optional YOLO step) | |
| if self.use_yolo and self.detector: | |
| self._log("\n📍 Step 1: Detecting number plates with YOLO...") | |
| plates = self._detect_plates(orig_image) | |
| if not plates: | |
| self._log("⚠ No plates detected by YOLO, processing full image...") | |
| plates = [{'plate_image': orig_image, 'bbox': None, 'confidence': 1.0}] | |
| else: | |
| self._log("\n📍 Step 1: Using full image (YOLO disabled)...") | |
| plates = [{'plate_image': orig_image, 'bbox': None, 'confidence': 1.0}] | |
| # Process each detected plate | |
| all_results = [] | |
| for plate_idx, plate_data in enumerate(plates): | |
| self._log(f"\n📋 Processing Plate {plate_idx + 1}/{len(plates)}") | |
| plate_img = plate_data['plate_image'] | |
| plate_gray = cv2.cvtColor(plate_img, cv2.COLOR_BGR2GRAY) if len(plate_img.shape) == 3 else plate_img | |
| # Step 2: Extract character contours | |
| self._log("📍 Step 2: Detecting character contours...") | |
| contours = self._extract_contours(plate_gray, plate_img) | |
| if not contours: | |
| self._log("⚠ No characters detected in plate") | |
| continue | |
| # Save contours if requested | |
| if save_contours: | |
| self._log(f" Saving contour images to {CONTOURS_BW_DIR}") | |
| save_contour_images(contours, plate_img, str(CONTOURS_BW_DIR)) | |
| # Step 3: Group by lines | |
| self._log("📍 Step 3: Grouping characters by lines...") | |
| lines = group_contours_by_line(contours) | |
| self._log(f" Detected {len(lines)} line(s)") | |
| for i, line in enumerate(lines): | |
| self._log(f" Line {i+1}: {len(line)} characters") | |
| # Step 4: Run OCR | |
| self._log("📍 Step 4: Running OCR on characters...") | |
| ocr_results = self._run_ocr(lines, plate_img) | |
| # Step 5: Format results | |
| formatted = format_plate_number(lines, ocr_results) | |
| confidence_stats = calculate_confidence_stats(ocr_results) | |
| result = { | |
| 'plate_index': plate_idx, | |
| 'plate_bbox': plate_data['bbox'], | |
| 'plate_confidence': plate_data.get('confidence', 1.0), | |
| 'plate_image': plate_img, | |
| 'lines': formatted['lines'], | |
| 'multiline_text': formatted['multiline'], | |
| 'singleline_text': formatted['singleline'], | |
| 'num_lines': formatted['num_lines'], | |
| 'total_chars': formatted['total_chars'], | |
| 'details': formatted['details'], | |
| 'confidence_stats': confidence_stats, | |
| 'raw_ocr_results': ocr_results | |
| } | |
| all_results.append(result) | |
| # Visualize | |
| if show_visualization: | |
| self._visualize_plate(plate_img, lines, ocr_results, plate_idx) | |
| # Print final summary | |
| self._print_results(all_results) | |
| return { | |
| 'image_path': image_path, | |
| 'num_plates': len(all_results), | |
| 'plates': all_results | |
| } | |
| def _detect_plates(self, image: np.ndarray) -> List[Dict]: | |
| """Detect plates using YOLO.""" | |
| detections = self.detector.detect(image) | |
| self._log(f" Found {len(detections)} plate(s)") | |
| for i, det in enumerate(detections): | |
| self._log(f" Plate {i+1}: confidence={det['confidence']:.2%}") | |
| return detections | |
| def _extract_contours(self, gray_image: np.ndarray, | |
| color_image: np.ndarray) -> List[Dict]: | |
| """Extract and filter character contours.""" | |
| # Detect contours | |
| contours, hierarchy, thresh = detect_contours(gray_image) | |
| self._log(f" Total contours found: {len(contours)}") | |
| # Filter by size | |
| filtered = filter_contours_by_size(contours, gray_image.shape) | |
| self._log(f" After size filter: {len(filtered)}") | |
| # Sort by x position | |
| sorted_contours = sorted(filtered, key=lambda c: (c['x'], c['y'])) | |
| # Remove only true edge artifacts (do not blindly drop first contours) | |
| remove_edge_artifacts = CONTOUR_CONFIG.get("remove_edge_artifacts", True) | |
| edge_margin = CONTOUR_CONFIG.get("edge_margin", 2) | |
| if remove_edge_artifacts and len(sorted_contours) > 4: | |
| image_h, image_w = gray_image.shape[:2] | |
| non_edge_contours = [ | |
| c for c in sorted_contours | |
| if ( | |
| c['x'] > edge_margin and | |
| c['y'] > edge_margin and | |
| (c['x'] + c['w']) < (image_w - edge_margin) and | |
| (c['y'] + c['h']) < (image_h - edge_margin) | |
| ) | |
| ] | |
| # Keep edge filtering only if it does not remove too many candidates | |
| if len(non_edge_contours) >= max(3, int(0.6 * len(sorted_contours))): | |
| sorted_contours = non_edge_contours | |
| self._log(f" After edge-artifact filter: {len(sorted_contours)}") | |
| # Extract ROI for each contour | |
| for c in sorted_contours: | |
| roi = extract_roi(color_image, c) | |
| c['roi_bw'] = convert_to_binary(roi) | |
| # Remove overlapping centers (like inner hole of '0') | |
| final_contours = remove_overlapping_centers(sorted_contours, verbose=self.verbose) | |
| removed = len(sorted_contours) - len(final_contours) | |
| if removed > 0: | |
| self._log(f" Removed {removed} overlapping contours") | |
| return final_contours | |
| def _run_ocr(self, lines: List[List[Dict]], | |
| plate_image: np.ndarray) -> List[List[Dict]]: | |
| """Run OCR on grouped character lines.""" | |
| min_confidence = INFERENCE_CONFIG["min_confidence"] | |
| results_by_line = [] | |
| for line_idx, line in enumerate(lines): | |
| line_results = [] | |
| for c in line: | |
| char, conf, processed_img = self.ocr.predict(c['roi_bw']) | |
| if conf > min_confidence: | |
| line_results.append({ | |
| 'char': char, | |
| 'conf': conf, | |
| 'x': c['x'], | |
| 'y': c['y'], | |
| 'w': c['w'], | |
| 'h': c['h'], | |
| 'processed_img': processed_img, | |
| '_roi_bw': c['roi_bw'] | |
| }) | |
| self._apply_nepali_dominant_correction(line_results) | |
| for r in line_results: | |
| r.pop('_roi_bw', None) | |
| results_by_line.append(line_results) | |
| total_chars = sum(len(line) for line in results_by_line) | |
| self._log(f" Characters with confidence > {min_confidence*100:.0f}%: {total_chars}") | |
| return results_by_line | |
| def _visualize_plate(self, plate_image: np.ndarray, | |
| lines: List[List[Dict]], | |
| ocr_results: List[List[Dict]], | |
| plate_idx: int): | |
| """Visualize OCR results.""" | |
| if not VIZ_CONFIG["show_plots"]: | |
| return | |
| # Show original plate | |
| plt.figure(figsize=VIZ_CONFIG["figure_size"]) | |
| plt.imshow(cv2.cvtColor(plate_image, cv2.COLOR_BGR2RGB)) | |
| plt.title(f'Plate {plate_idx + 1} - {len(lines)} Line(s) Detected') | |
| plt.axis('off') | |
| plt.show() | |
| # Show OCR results for each line | |
| for line_idx, line_results in enumerate(ocr_results): | |
| n = len(line_results) | |
| if n > 0: | |
| cols = min(VIZ_CONFIG["max_cols"], n) | |
| rows = (n + cols - 1) // cols | |
| fig, axes = plt.subplots(rows, cols, figsize=(cols*1.5, rows*2)) | |
| axes = np.array(axes).reshape(-1) if n > 1 else [axes] | |
| for i, r in enumerate(line_results): | |
| axes[i].imshow(r['processed_img'], cmap='gray') | |
| axes[i].set_title(f'"{r["char"]}" ({r["conf"]:.0%})', | |
| fontsize=VIZ_CONFIG["font_size"]) | |
| axes[i].axis('off') | |
| # Hide empty subplots | |
| for i in range(n, len(axes)): | |
| axes[i].axis('off') | |
| line_text = "".join([r['char'] for r in line_results]) | |
| plt.suptitle(f'Line {line_idx+1}: "{line_text}"', fontsize=12) | |
| plt.tight_layout() | |
| plt.show() | |
| def _print_results(self, results: List[Dict]): | |
| """Print formatted results.""" | |
| print("\n" + "="*60) | |
| print("📋 PLATE NUMBER RECOGNITION RESULTS") | |
| print("="*60) | |
| for result in results: | |
| plate_idx = result['plate_index'] + 1 | |
| print(f"\n🏷️ PLATE {plate_idx}:") | |
| print("-"*40) | |
| for line_detail in result['details']: | |
| print(f"\n 📌 Line {line_detail['line_num']}:") | |
| for i, char_info in enumerate(line_detail['characters']): | |
| print(f" {i+1}. '{char_info['char']}' ({char_info['conf']:.1%})") | |
| print(f" → Result: {line_detail['text']}") | |
| # Final result | |
| print("\n" + "-"*40) | |
| if result['num_lines'] > 1: | |
| print(" Multi-line format:") | |
| for i, line in enumerate(result['lines']): | |
| print(f" Line {i+1}: {line}") | |
| print(f"\n Single-line: {result['singleline_text']}") | |
| else: | |
| text = result['lines'][0] if result['lines'] else 'No characters detected' | |
| print(f" Result: {text}") | |
| # Confidence stats | |
| stats = result['confidence_stats'] | |
| print(f"\n Confidence: avg={stats['mean']:.1%}, min={stats['min']:.1%}, max={stats['max']:.1%}") | |
| print("\n" + "="*60) | |
| def process_from_plate_image(self, plate_image: np.ndarray, | |
| show_visualization: bool = True) -> Dict: | |
| """ | |
| Process a pre-cropped plate image (skip YOLO detection). | |
| Args: | |
| plate_image: Cropped plate image (BGR) | |
| show_visualization: Whether to show plots | |
| Returns: | |
| Recognition result dict | |
| """ | |
| plate_gray = cv2.cvtColor(plate_image, cv2.COLOR_BGR2GRAY) if len(plate_image.shape) == 3 else plate_image | |
| # Extract contours | |
| contours = self._extract_contours(plate_gray, plate_image) | |
| if not contours: | |
| return {'lines': [], 'singleline_text': '', 'total_chars': 0} | |
| # Group by lines | |
| lines = group_contours_by_line(contours) | |
| # Run OCR | |
| ocr_results = self._run_ocr(lines, plate_image) | |
| # Format results | |
| formatted = format_plate_number(lines, ocr_results) | |
| if show_visualization: | |
| self._visualize_plate(plate_image, lines, ocr_results, 0) | |
| return { | |
| 'lines': formatted['lines'], | |
| 'multiline_text': formatted['multiline'], | |
| 'singleline_text': formatted['singleline'], | |
| 'num_lines': formatted['num_lines'], | |
| 'total_chars': formatted['total_chars'], | |
| 'details': formatted['details'], | |
| 'confidence_stats': calculate_confidence_stats(ocr_results) | |
| } | |
| def main(): | |
| """Main entry point.""" | |
| parser = argparse.ArgumentParser( | |
| description="Number Plate OCR Pipeline", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python main.py image.jpg | |
| python main.py image.jpg --no-yolo | |
| python main.py image.jpg --save --no-viz | |
| python main.py image.jpg --output results.json | |
| """ | |
| ) | |
| parser.add_argument('image', type=str, help='Path to input image') | |
| parser.add_argument('--no-yolo', action='store_true', | |
| help='Skip YOLO plate detection') | |
| parser.add_argument('--save', action='store_true', | |
| help='Save extracted character images') | |
| parser.add_argument('--no-viz', action='store_true', | |
| help='Disable visualization') | |
| parser.add_argument('--output', '-o', type=str, | |
| help='Save results to JSON file') | |
| parser.add_argument('--quiet', '-q', action='store_true', | |
| help='Suppress progress messages') | |
| args = parser.parse_args() | |
| # Validate input | |
| if not os.path.exists(args.image): | |
| print(f"Error: Image not found: {args.image}") | |
| return 1 | |
| # Initialize pipeline | |
| pipeline = NumberPlateOCR( | |
| use_yolo=not args.no_yolo, | |
| verbose=not args.quiet | |
| ) | |
| # Process image | |
| results = pipeline.process_image( | |
| args.image, | |
| save_contours=args.save, | |
| show_visualization=not args.no_viz | |
| ) | |
| # Save results if requested | |
| if args.output: | |
| # Remove non-serializable items | |
| save_results = { | |
| 'image_path': results['image_path'], | |
| 'num_plates': results['num_plates'], | |
| 'plates': [] | |
| } | |
| for plate in results['plates']: | |
| save_plate = { | |
| 'plate_index': plate['plate_index'], | |
| 'plate_bbox': plate['plate_bbox'], | |
| 'lines': plate['lines'], | |
| 'multiline_text': plate['multiline_text'], | |
| 'singleline_text': plate['singleline_text'], | |
| 'num_lines': plate['num_lines'], | |
| 'total_chars': plate['total_chars'], | |
| 'confidence_stats': plate['confidence_stats'] | |
| } | |
| save_results['plates'].append(save_plate) | |
| with open(args.output, 'w', encoding='utf-8') as f: | |
| json.dump(save_results, f, indent=2, ensure_ascii=False) | |
| print(f"\n✓ Results saved to: {args.output}") | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |