Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Benchmark different OCR methods for play clock reading. | |
| This script compares: | |
| 1. Tesseract (current method) | |
| 2. EasyOCR (deep learning based) | |
| 3. Template matching (custom digit templates) | |
| Usage: | |
| python scripts/benchmark_ocr.py | |
| """ | |
| import logging | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import List, Tuple, Optional, Dict | |
| import cv2 | |
| import numpy as np | |
| from detection import DetectScoreBug | |
| # Path reference for constants | |
| PROJECT_ROOT = Path(__file__).parent.parent.parent | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| VIDEO_PATH = PROJECT_ROOT / "full_videos" / "OSU vs Tenn 12.21.24.mkv" | |
| TEMPLATE_PATH = PROJECT_ROOT / "data" / "templates" / "scorebug_template_main.png" | |
| CONFIG_PATH = PROJECT_ROOT / "data" / "config" / "play_clock_region.json" | |
| DIGIT_TEMPLATES_DIR = PROJECT_ROOT / "data" / "templates" / "digits" | |
| # Test segment - sample frames with known clock values (30 frames) | |
| TEST_TIMESTAMPS = [2320.0 + i for i in range(30)] | |
| # Expected values based on countdown pattern: 18->17->...->12->40->40->40->39->... | |
| # This is approximate - the real test will use Tesseract as ground truth | |
| def load_play_clock_config() -> Tuple[int, int, int, int]: | |
| """Load play clock region config.""" | |
| import json | |
| with open(CONFIG_PATH, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| return (data["x_offset"], data["y_offset"], data["width"], data["height"]) | |
| def extract_test_frames(video_path: Path, detector: DetectScoreBug, timestamps: List[float]) -> List[Tuple[float, np.ndarray, Tuple[int, int, int, int]]]: | |
| """Extract frames with scorebug for testing.""" | |
| cap = cv2.VideoCapture(str(video_path)) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Could not open video: {video_path}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frames = [] | |
| for ts in timestamps: | |
| frame_number = int(ts * fps) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) | |
| ret, frame = cap.read() | |
| if not ret: | |
| continue | |
| detection = detector.detect(frame) | |
| if detection.detected and detection.bbox: | |
| frames.append((ts, frame, detection.bbox)) | |
| cap.release() | |
| return frames | |
| def extract_play_clock_region(frame: np.ndarray, scorebug_bbox: Tuple[int, int, int, int], config: Tuple[int, int, int, int]) -> np.ndarray: | |
| """Extract play clock region from frame.""" | |
| sb_x, sb_y, _, _ = scorebug_bbox | |
| x_offset, y_offset, width, height = config | |
| pc_x = sb_x + x_offset | |
| pc_y = sb_y + y_offset | |
| return frame[pc_y : pc_y + height, pc_x : pc_x + width].copy() | |
| def preprocess_for_ocr(region: np.ndarray) -> np.ndarray: | |
| """Standard preprocessing for OCR.""" | |
| # Convert to grayscale | |
| gray = cv2.cvtColor(region, cv2.COLOR_BGR2GRAY) | |
| # Scale up | |
| scale_factor = 4 | |
| scaled = cv2.resize(gray, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR) | |
| # Otsu's threshold | |
| _, binary = cv2.threshold(scaled, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| # Invert if needed (dark text on light background) | |
| if np.mean(binary) < 128: | |
| binary = cv2.bitwise_not(binary) | |
| return binary | |
| # ============================================================ | |
| # OCR Method 1: Tesseract (current baseline) | |
| # ============================================================ | |
| def ocr_tesseract(region: np.ndarray) -> Tuple[Optional[int], float]: | |
| """Read digits using Tesseract.""" | |
| import pytesseract | |
| preprocessed = preprocess_for_ocr(region) | |
| # Add padding | |
| padding = 10 | |
| preprocessed = cv2.copyMakeBorder(preprocessed, padding, padding, padding, padding, cv2.BORDER_CONSTANT, value=255) | |
| config = "--psm 7 -c tessedit_char_whitelist=0123456789" | |
| try: | |
| data = pytesseract.image_to_data(preprocessed, config=config, output_type=pytesseract.Output.DICT) | |
| best_text = "" | |
| best_conf = 0.0 | |
| for i, text in enumerate(data["text"]): | |
| conf = float(data["conf"][i]) | |
| if conf > best_conf and text.strip(): | |
| best_text = text.strip() | |
| best_conf = conf | |
| if best_text and best_text.isdigit(): | |
| value = int(best_text) | |
| if 0 <= value <= 40: | |
| return value, best_conf / 100.0 | |
| except Exception as e: | |
| logger.debug(f"Tesseract error: {e}") | |
| return None, 0.0 | |
| # ============================================================ | |
| # OCR Method 2: EasyOCR | |
| # ============================================================ | |
| _easyocr_reader = None | |
| def get_easyocr_reader(): | |
| """Lazy-load EasyOCR reader.""" | |
| global _easyocr_reader | |
| if _easyocr_reader is None: | |
| try: | |
| import easyocr | |
| _easyocr_reader = easyocr.Reader(["en"], gpu=False) # CPU mode for fair comparison | |
| logger.info("EasyOCR reader initialized") | |
| except ImportError: | |
| logger.warning("EasyOCR not installed. Install with: pip install easyocr") | |
| return None | |
| return _easyocr_reader | |
| def ocr_easyocr(region: np.ndarray) -> Tuple[Optional[int], float]: | |
| """Read digits using EasyOCR.""" | |
| reader = get_easyocr_reader() | |
| if reader is None: | |
| return None, 0.0 | |
| preprocessed = preprocess_for_ocr(region) | |
| try: | |
| # EasyOCR expects BGR or grayscale | |
| results = reader.readtext(preprocessed, allowlist="0123456789", detail=1) | |
| if results: | |
| # Get highest confidence result | |
| best_result = max(results, key=lambda x: x[2]) | |
| text = best_result[1].strip() | |
| conf = best_result[2] | |
| if text.isdigit(): | |
| value = int(text) | |
| if 0 <= value <= 40: | |
| return value, conf | |
| except Exception as e: | |
| logger.debug(f"EasyOCR error: {e}") | |
| return None, 0.0 | |
| # ============================================================ | |
| # OCR Method 3: Template Matching for Digits | |
| # ============================================================ | |
| class DigitTemplateMatcher: | |
| """Fast digit recognition using template matching.""" | |
| def __init__(self): | |
| self.digit_templates: Dict[str, np.ndarray] = {} | |
| self._calibrated = False | |
| def calibrate_from_tesseract(self, regions: List[np.ndarray]) -> bool: | |
| """ | |
| Calibrate digit templates using Tesseract as ground truth on first few frames. | |
| This extracts individual digit images from frames where Tesseract successfully reads values. | |
| """ | |
| logger.info("Calibrating digit templates from Tesseract readings...") | |
| for region in regions: | |
| # Get Tesseract reading as ground truth | |
| value, conf = ocr_tesseract(region) | |
| if value is None or conf < 0.7: | |
| continue | |
| # Preprocess and extract digit regions | |
| preprocessed = preprocess_for_ocr(region) | |
| h, w = preprocessed.shape | |
| # Find digit contours | |
| contours, _ = cv2.findContours(cv2.bitwise_not(preprocessed), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| continue | |
| # Get bounding boxes sorted left-to-right | |
| boxes = [cv2.boundingRect(c) for c in contours] | |
| boxes = [(x, y, bw, bh) for x, y, bw, bh in boxes if bh > h * 0.3] # Filter small noise | |
| boxes.sort(key=lambda b: b[0]) # Sort by x position | |
| # Extract digits based on value | |
| value_str = str(value) | |
| if len(boxes) != len(value_str): | |
| continue # Mismatch, skip | |
| for i, (x, y, bw, bh) in enumerate(boxes): | |
| digit = value_str[i] | |
| # Add padding around digit | |
| pad = 4 | |
| x1 = max(0, x - pad) | |
| y1 = max(0, y - pad) | |
| x2 = min(w, x + bw + pad) | |
| y2 = min(h, y + bh + pad) | |
| digit_img = preprocessed[y1:y2, x1:x2] | |
| # Store template (keep best quality one per digit) | |
| if digit not in self.digit_templates or digit_img.shape[0] * digit_img.shape[1] > self.digit_templates[digit].shape[0] * self.digit_templates[digit].shape[1]: | |
| self.digit_templates[digit] = digit_img.copy() | |
| # Check if we have all digits we need (0-4 for tens, 0-9 for ones) | |
| if all(str(d) in self.digit_templates for d in range(10)): | |
| break | |
| logger.info(f" Calibrated templates for digits: {sorted(self.digit_templates.keys())}") | |
| self._calibrated = len(self.digit_templates) >= 5 # At least 0-4 for play clock | |
| return self._calibrated | |
| def read(self, region: np.ndarray) -> Tuple[Optional[int], float]: | |
| """Read digits using template matching.""" | |
| if not self._calibrated: | |
| return None, 0.0 | |
| preprocessed = preprocess_for_ocr(region) | |
| h, w = preprocessed.shape | |
| # Find digit contours | |
| contours, _ = cv2.findContours(cv2.bitwise_not(preprocessed), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None, 0.0 | |
| # Get bounding boxes sorted left-to-right | |
| boxes = [cv2.boundingRect(c) for c in contours] | |
| boxes = [(x, y, bw, bh) for x, y, bw, bh in boxes if bh > h * 0.3] # Filter noise | |
| boxes.sort(key=lambda b: b[0]) | |
| if not boxes: | |
| return None, 0.0 | |
| # Match each digit region to templates | |
| digits = [] | |
| total_conf = 0.0 | |
| for x, y, bw, bh in boxes: | |
| # Extract digit with padding | |
| pad = 4 | |
| x1 = max(0, x - pad) | |
| y1 = max(0, y - pad) | |
| x2 = min(w, x + bw + pad) | |
| y2 = min(h, y + bh + pad) | |
| digit_img = preprocessed[y1:y2, x1:x2] | |
| # Match against all templates | |
| best_digit = None | |
| best_conf = 0.0 | |
| for digit, template in self.digit_templates.items(): | |
| # Resize template to match digit height | |
| if template.shape[0] == 0 or digit_img.shape[0] == 0: | |
| continue | |
| scale = digit_img.shape[0] / template.shape[0] | |
| new_w = max(1, int(template.shape[1] * scale)) | |
| resized = cv2.resize(template, (new_w, digit_img.shape[0]), interpolation=cv2.INTER_LINEAR) | |
| # Pad smaller image to match sizes for comparison | |
| digit_img_padded = digit_img | |
| if resized.shape[1] < digit_img.shape[1]: | |
| diff = digit_img.shape[1] - resized.shape[1] | |
| resized = cv2.copyMakeBorder(resized, 0, 0, diff // 2, diff - diff // 2, cv2.BORDER_CONSTANT, value=255) | |
| elif digit_img.shape[1] < resized.shape[1]: | |
| diff = resized.shape[1] - digit_img.shape[1] | |
| digit_img_padded = cv2.copyMakeBorder(digit_img, 0, 0, diff // 2, diff - diff // 2, cv2.BORDER_CONSTANT, value=255) | |
| # Ensure same size | |
| min_h = min(resized.shape[0], digit_img_padded.shape[0]) | |
| min_w = min(resized.shape[1], digit_img_padded.shape[1]) | |
| resized = resized[:min_h, :min_w] | |
| digit_img_padded = digit_img_padded[:min_h, :min_w] | |
| # Calculate normalized cross-correlation | |
| if resized.size == 0 or digit_img_padded.size == 0: | |
| continue | |
| # Simple pixel difference score | |
| diff = np.abs(resized.astype(float) - digit_img_padded.astype(float)) | |
| score = 1.0 - (np.mean(diff) / 255.0) | |
| if score > best_conf: | |
| best_conf = score | |
| best_digit = digit | |
| if best_digit is not None and best_conf > 0.5: | |
| digits.append(best_digit) | |
| total_conf += best_conf | |
| if not digits: | |
| return None, 0.0 | |
| # Combine digits into number | |
| try: | |
| value = int("".join(digits)) | |
| avg_conf = total_conf / len(digits) | |
| if 0 <= value <= 40: | |
| return value, avg_conf | |
| except ValueError: | |
| pass | |
| return None, 0.0 | |
| _digit_matcher = None | |
| def get_digit_matcher() -> DigitTemplateMatcher: | |
| """Get or create digit template matcher.""" | |
| global _digit_matcher | |
| if _digit_matcher is None: | |
| _digit_matcher = DigitTemplateMatcher() | |
| return _digit_matcher | |
| def ocr_template_matching(region: np.ndarray) -> Tuple[Optional[int], float]: | |
| """Read digits using template matching.""" | |
| matcher = get_digit_matcher() | |
| return matcher.read(region) | |
| # ============================================================ | |
| # Benchmark Runner | |
| # ============================================================ | |
| def run_benchmark(frames: List[Tuple[float, np.ndarray, Tuple[int, int, int, int]]], config: Tuple[int, int, int, int]) -> None: | |
| """Run benchmark comparing OCR methods.""" | |
| logger.info("=" * 60) | |
| logger.info("OCR BENCHMARK") | |
| logger.info("=" * 60) | |
| logger.info(f"Testing {len(frames)} frames") | |
| # Extract play clock regions | |
| regions = [] | |
| for ts, frame, scorebug_bbox in frames: | |
| region = extract_play_clock_region(frame, scorebug_bbox, config) | |
| regions.append((ts, region)) | |
| # Method 1: Tesseract (baseline - also used for ground truth) | |
| logger.info("") | |
| logger.info("-" * 60) | |
| logger.info("Method 1: Tesseract (baseline)") | |
| logger.info("-" * 60) | |
| tesseract_results = [] | |
| t_start = time.perf_counter() | |
| for ts, region in regions: | |
| value, conf = ocr_tesseract(region) | |
| tesseract_results.append((ts, value, conf)) | |
| tesseract_time = time.perf_counter() - t_start | |
| tesseract_success = sum(1 for _, v, _ in tesseract_results if v is not None) | |
| logger.info(f" Success rate: {tesseract_success}/{len(regions)} ({100*tesseract_success/len(regions):.1f}%)") | |
| logger.info(f" Total time: {tesseract_time:.3f}s") | |
| logger.info(f" Per-frame: {1000*tesseract_time/len(regions):.1f}ms") | |
| logger.info(f" Values: {[v for _, v, _ in tesseract_results]}") | |
| # Use Tesseract results as ground truth for accuracy comparison | |
| ground_truth = {ts: v for ts, v, _ in tesseract_results if v is not None} | |
| # Method 2: EasyOCR | |
| logger.info("") | |
| logger.info("-" * 60) | |
| logger.info("Method 2: EasyOCR") | |
| logger.info("-" * 60) | |
| reader = get_easyocr_reader() | |
| easyocr_time = 0 | |
| easyocr_success = 0 | |
| easyocr_accuracy = 0 | |
| if reader: | |
| easyocr_results = [] | |
| t_start = time.perf_counter() | |
| for ts, region in regions: | |
| value, conf = ocr_easyocr(region) | |
| easyocr_results.append((ts, value, conf)) | |
| easyocr_time = time.perf_counter() - t_start | |
| easyocr_success = sum(1 for _, v, _ in easyocr_results if v is not None) | |
| # Calculate accuracy vs ground truth | |
| easyocr_correct = sum(1 for ts, v, _ in easyocr_results if ts in ground_truth and v == ground_truth[ts]) | |
| easyocr_accuracy = easyocr_correct / len(ground_truth) * 100 if ground_truth else 0 | |
| logger.info(f" Success rate: {easyocr_success}/{len(regions)} ({100*easyocr_success/len(regions):.1f}%)") | |
| logger.info(f" Accuracy vs Tesseract: {easyocr_correct}/{len(ground_truth)} ({easyocr_accuracy:.1f}%)") | |
| logger.info(f" Total time: {easyocr_time:.3f}s") | |
| logger.info(f" Per-frame: {1000*easyocr_time/len(regions):.1f}ms") | |
| logger.info(f" Speedup vs Tesseract: {tesseract_time/easyocr_time:.2f}x") | |
| logger.info(f" Values: {[v for _, v, _ in easyocr_results]}") | |
| else: | |
| logger.info(" SKIPPED (EasyOCR not installed)") | |
| # Method 3: Template Matching | |
| logger.info("") | |
| logger.info("-" * 60) | |
| logger.info("Method 3: Template Matching") | |
| logger.info("-" * 60) | |
| matcher = get_digit_matcher() | |
| # Calibrate using first 10 regions (not counted in benchmark time) | |
| calibration_regions = [r for _, r in regions[:10]] | |
| if matcher.calibrate_from_tesseract(calibration_regions): | |
| template_results = [] | |
| t_start = time.perf_counter() | |
| for ts, region in regions: | |
| value, conf = ocr_template_matching(region) | |
| template_results.append((ts, value, conf)) | |
| template_time = time.perf_counter() - t_start | |
| template_success = sum(1 for _, v, _ in template_results if v is not None) | |
| template_correct = sum(1 for ts, v, _ in template_results if ts in ground_truth and v == ground_truth[ts]) | |
| template_accuracy = template_correct / len(ground_truth) * 100 if ground_truth else 0 | |
| logger.info(f" Success rate: {template_success}/{len(regions)} ({100*template_success/len(regions):.1f}%)") | |
| logger.info(f" Accuracy vs Tesseract: {template_correct}/{len(ground_truth)} ({template_accuracy:.1f}%)") | |
| logger.info(f" Total time: {template_time:.3f}s") | |
| logger.info(f" Per-frame: {1000*template_time/len(regions):.1f}ms") | |
| logger.info(f" Speedup vs Tesseract: {tesseract_time/template_time:.2f}x") | |
| logger.info(f" Values: {[v for _, v, _ in template_results]}") | |
| else: | |
| logger.info(" SKIPPED (calibration failed)") | |
| template_time = 0 | |
| template_success = 0 | |
| template_accuracy = 0 | |
| # Summary | |
| logger.info("") | |
| logger.info("=" * 60) | |
| logger.info("SUMMARY") | |
| logger.info("=" * 60) | |
| logger.info(f"{'Method':<20} {'Time/frame':<12} {'Success':<12} {'Accuracy':<12} {'Speedup':<10}") | |
| logger.info("-" * 66) | |
| logger.info(f"{'Tesseract':<20} {f'{1000*tesseract_time/len(regions):.1f}ms':<12} {f'{tesseract_success}/{len(regions)}':<12} {'(baseline)':<12} {'1.00x':<10}") | |
| if reader and easyocr_time > 0: | |
| logger.info( | |
| f"{'EasyOCR':<20} {f'{1000*easyocr_time/len(regions):.1f}ms':<12} {f'{easyocr_success}/{len(regions)}':<12} {f'{easyocr_accuracy:.1f}%':<12} {f'{tesseract_time/easyocr_time:.2f}x':<10}" | |
| ) | |
| if template_time > 0: | |
| logger.info( | |
| f"{'Template Matching':<20} {f'{1000*template_time/len(regions):.1f}ms':<12} {f'{template_success}/{len(regions)}':<12} {f'{template_accuracy:.1f}%':<12} {f'{tesseract_time/template_time:.2f}x':<10}" | |
| ) | |
| def main(): | |
| """Main entry point.""" | |
| logger.info("OCR Benchmark Tool") | |
| logger.info("=" * 60) | |
| # Verify paths | |
| if not VIDEO_PATH.exists(): | |
| logger.error(f"Video not found: {VIDEO_PATH}") | |
| return 1 | |
| if not TEMPLATE_PATH.exists(): | |
| logger.error(f"Template not found: {TEMPLATE_PATH}") | |
| return 1 | |
| if not CONFIG_PATH.exists(): | |
| logger.error(f"Config not found: {CONFIG_PATH}") | |
| return 1 | |
| # Load config | |
| config = load_play_clock_config() | |
| logger.info(f"Play clock config: {config}") | |
| # Initialize scorebug detector | |
| detector = DetectScoreBug(template_path=str(TEMPLATE_PATH)) | |
| # Extract test frames | |
| logger.info(f"Extracting {len(TEST_TIMESTAMPS)} test frames...") | |
| frames = extract_test_frames(VIDEO_PATH, detector, TEST_TIMESTAMPS) | |
| logger.info(f"Extracted {len(frames)} frames with scorebug") | |
| if not frames: | |
| logger.error("No frames with scorebug found!") | |
| return 1 | |
| # Run benchmark | |
| run_benchmark(frames, config) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |