Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test template matching accuracy against OCR ground truth. | |
| This test: | |
| 1. Collects samples from a longer segment to get full digit coverage | |
| 2. Splits samples into training (build templates) and test (evaluate accuracy) sets | |
| 3. Compares template matching results against OCR ground truth | |
| 4. Measures timing improvement | |
| 5. Saves debug images for wrong/undetected cases (if <= 10 total errors) | |
| Uses dual-mode matching to handle both single-digit (centered) and double-digit | |
| (left/right) layouts. Templates needed: 25 total (10 center + 10 right + 4 tens + 1 blank). | |
| Usage: | |
| cd /Users/andytaylor/Documents/Personal/cfb40 | |
| source .venv/bin/activate | |
| python tests/test_digit_templates/test_template_accuracy.py | |
| """ | |
| import logging | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import List | |
| import cv2 | |
| import numpy as np | |
| from detection import DetectScoreBug | |
| from readers import ReadPlayClock | |
| from setup import DigitTemplateBuilder, PlayClockRegionExtractor | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Test configuration | |
| VIDEO_PATH = "full_videos/OSU vs Tenn 12.21.24.mkv" | |
| TEMPLATE_PATH = "output/OSU_vs_Tenn_12_21_24_template.png" | |
| PLAYCLOCK_CONFIG_PATH = "output/OSU_vs_Tenn_12_21_24_playclock_config.json" | |
| # Use longer segment to get more digit coverage | |
| # 38:40 to 48:40 = 10-minute segment with ~13 plays per v3 baseline | |
| START_TIME = 38 * 60 + 40 # 2320 seconds | |
| END_TIME = 48 * 60 + 40 # 2920 seconds | |
| SAMPLE_INTERVAL = 0.5 | |
| # Debug output directory | |
| DEBUG_DIR = Path("output/debug/digit_templates/errors") | |
| def collect_all_samples(video_path: str, start_time: float, end_time: float, sample_interval: float): | |
| """ | |
| Collect play clock samples with OCR ground truth. | |
| Returns list of (timestamp, clock_value, region_image, confidence) | |
| """ | |
| scorebug_detector = DetectScoreBug(template_path=TEMPLATE_PATH) | |
| clock_reader = PlayClockRegionExtractor(region_config_path=PLAYCLOCK_CONFIG_PATH) | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise RuntimeError(f"Could not open video: {video_path}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| start_frame = int(start_time * fps) | |
| end_frame = int(end_time * fps) | |
| frame_skip = int(sample_interval * fps) | |
| logger.info("Collecting samples from %.1fs to %.1fs", start_time, end_time) | |
| # Lock scorebug region | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| ret, frame = cap.read() | |
| if ret: | |
| scorebug_detector.discover_and_lock_region(frame) | |
| logger.info("Scorebug region: %s", scorebug_detector.fixed_region) | |
| samples = [] | |
| current_frame = start_frame | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| while current_frame < end_frame: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| current_time = current_frame / fps | |
| detection = scorebug_detector.detect(frame) | |
| if detection.detected and detection.bbox: | |
| sb_x, sb_y, _, _ = detection.bbox | |
| pc_config = clock_reader.config | |
| pc_x = sb_x + pc_config.x_offset | |
| pc_y = sb_y + pc_config.y_offset | |
| pc_w = pc_config.width | |
| pc_h = pc_config.height | |
| frame_h, frame_w = frame.shape[:2] | |
| if 0 <= pc_x and 0 <= pc_y and pc_x + pc_w <= frame_w and pc_y + pc_h <= frame_h: | |
| region = frame[pc_y : pc_y + pc_h, pc_x : pc_x + pc_w].copy() | |
| reading = clock_reader.read(frame, detection.bbox) | |
| if reading.detected and reading.value is not None: | |
| samples.append((current_time, reading.value, region, reading.confidence)) | |
| for _ in range(frame_skip - 1): | |
| cap.grab() | |
| current_frame += frame_skip | |
| cap.release() | |
| return samples | |
| def split_samples(samples: List, train_ratio: float = 0.7): | |
| """Split samples into training and test sets.""" | |
| # Sort by timestamp to ensure temporal split | |
| sorted_samples = sorted(samples, key=lambda x: x[0]) | |
| split_idx = int(len(sorted_samples) * train_ratio) | |
| train_samples = sorted_samples[:split_idx] | |
| test_samples = sorted_samples[split_idx:] | |
| return train_samples, test_samples | |
| def save_debug_images(error_results: List[dict], output_dir: Path): | |
| """ | |
| Save debug images for error cases. | |
| Each image shows: | |
| - Original region (scaled up) | |
| - Preprocessed binary image | |
| - Annotation with OCR value, template value, confidence | |
| """ | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Clear previous debug images | |
| for f in output_dir.glob("*.png"): | |
| f.unlink() | |
| for i, result in enumerate(error_results): | |
| timestamp = result["timestamp"] | |
| ocr_value = result["ocr_value"] | |
| template_value = result["template_value"] | |
| confidence = result["confidence"] | |
| status = result["status"] | |
| region = result["region"] | |
| # Scale up the region for visibility (4x) | |
| scale = 4 | |
| scaled_region = cv2.resize(region, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST) | |
| # Create a larger canvas with annotation space | |
| canvas_h = scaled_region.shape[0] + 60 | |
| canvas_w = max(scaled_region.shape[1], 300) | |
| canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8) | |
| canvas[:] = (40, 40, 40) # Dark gray background | |
| # Place scaled region at top | |
| x_offset = (canvas_w - scaled_region.shape[1]) // 2 | |
| canvas[0 : scaled_region.shape[0], x_offset : x_offset + scaled_region.shape[1]] = scaled_region | |
| # Add annotations | |
| y_text = scaled_region.shape[0] + 20 | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.5 | |
| color = (0, 0, 255) if status == "WRONG" else (0, 165, 255) # Red for wrong, orange for undetected | |
| # Status and timestamp | |
| cv2.putText(canvas, f"{status} @ {timestamp:.1f}s", (10, y_text), font, font_scale, color, 1) | |
| # OCR vs Template | |
| y_text += 18 | |
| template_str = str(template_value) if template_value is not None else "None" | |
| cv2.putText(canvas, f"OCR: {ocr_value} Template: {template_str} Conf: {confidence:.2f}", (10, y_text), font, font_scale, (200, 200, 200), 1) | |
| # Save with descriptive filename | |
| if status == "WRONG": | |
| filename = f"wrong_{i:02d}_t{timestamp:.0f}s_ocr{ocr_value}_tmpl{template_value}.png" | |
| else: | |
| filename = f"missed_{i:02d}_t{timestamp:.0f}s_ocr{ocr_value}.png" | |
| cv2.imwrite(str(output_dir / filename), canvas) | |
| logger.info("Saved %d debug images to: %s", len(error_results), output_dir) | |
| def test_template_accuracy(): | |
| """Test template matching accuracy against OCR ground truth.""" | |
| logger.info("=" * 60) | |
| logger.info("TEST: Template Matching Accuracy vs OCR") | |
| logger.info("=" * 60) | |
| # Check files exist | |
| for path, name in [(VIDEO_PATH, "Video"), (TEMPLATE_PATH, "Template"), (PLAYCLOCK_CONFIG_PATH, "Config")]: | |
| if not Path(path).exists(): | |
| logger.error("%s not found: %s", name, path) | |
| return False | |
| # Collect all samples | |
| logger.info("\n[Step 1] Collecting samples with OCR ground truth...") | |
| all_samples = collect_all_samples(VIDEO_PATH, START_TIME, END_TIME, SAMPLE_INTERVAL) | |
| logger.info("Total samples: %d", len(all_samples)) | |
| # Split into train/test | |
| logger.info("\n[Step 2] Splitting samples (70% train, 30% test)...") | |
| train_samples, test_samples = split_samples(all_samples, train_ratio=0.7) | |
| logger.info("Training samples: %d", len(train_samples)) | |
| logger.info("Test samples: %d", len(test_samples)) | |
| # Build templates from training set | |
| logger.info("\n[Step 3] Building templates from training samples...") | |
| builder = DigitTemplateBuilder() | |
| for timestamp, clock_value, region, confidence in train_samples: | |
| builder.add_sample(region, clock_value, timestamp, confidence) | |
| # Coverage with dual-mode templates (center + right positions) | |
| coverage = builder.get_coverage_status() | |
| logger.info("Training coverage (dual-mode):") | |
| logger.info(" Ones (center): %s (missing: %s)", coverage["ones_center"], coverage["ones_center_missing"]) | |
| logger.info(" Ones (right): %s (missing: %s)", coverage["ones_right"], coverage["ones_right_missing"]) | |
| logger.info(" Tens (left): %s (missing: %s)", coverage["tens"], coverage["tens_missing"]) | |
| logger.info(" Blank: %s", "YES" if coverage["has_blank"] else "NO") | |
| library = builder.build_templates(min_samples=2) | |
| lib_status = library.get_coverage_status() | |
| logger.info("Templates built: %d/%d", lib_status["total_have"], lib_status["total_needed"]) | |
| # Test template matching on test set | |
| logger.info("\n[Step 4] Testing template matching accuracy...") | |
| template_reader = ReadPlayClock(library) | |
| correct = 0 | |
| wrong = 0 | |
| undetected = 0 | |
| error_results = [] # Store errors with region images for debug | |
| # Also measure timing | |
| template_times = [] | |
| for timestamp, ocr_value, region, ocr_confidence in test_samples: | |
| # Template matching | |
| t_start = time.perf_counter() | |
| template_result = template_reader.read(region) | |
| t_template = time.perf_counter() - t_start | |
| template_times.append(t_template) | |
| if template_result.detected and template_result.value is not None: | |
| if template_result.value == ocr_value: | |
| correct += 1 | |
| else: | |
| wrong += 1 | |
| error_results.append( | |
| { | |
| "timestamp": timestamp, | |
| "ocr_value": ocr_value, | |
| "template_value": template_result.value, | |
| "confidence": template_result.confidence, | |
| "status": "WRONG", | |
| "region": region, # Store region for debug image | |
| } | |
| ) | |
| else: | |
| undetected += 1 | |
| error_results.append( | |
| { | |
| "timestamp": timestamp, | |
| "ocr_value": ocr_value, | |
| "template_value": None, | |
| "confidence": template_result.confidence, | |
| "status": "UNDETECTED", | |
| "region": region, # Store region for debug image | |
| } | |
| ) | |
| total = correct + wrong + undetected | |
| accuracy = correct / total if total > 0 else 0 | |
| detection_rate = (correct + wrong) / total if total > 0 else 0 | |
| logger.info("\nAccuracy Results:") | |
| logger.info(" Correct: %d (%.1f%%)", correct, 100 * correct / total if total > 0 else 0) | |
| logger.info(" Wrong: %d (%.1f%%)", wrong, 100 * wrong / total if total > 0 else 0) | |
| logger.info(" Undetected: %d (%.1f%%)", undetected, 100 * undetected / total if total > 0 else 0) | |
| logger.info(" Accuracy (correct/total): %.1f%%", accuracy * 100) | |
| logger.info(" Detection rate: %.1f%%", detection_rate * 100) | |
| # Show error details | |
| if error_results: | |
| logger.info("\nError details:") | |
| for r in error_results[:10]: | |
| if r["status"] == "WRONG": | |
| logger.info(" WRONG @ t=%.1fs: OCR=%d, Template=%d, conf=%.2f", r["timestamp"], r["ocr_value"], r["template_value"], r["confidence"]) | |
| else: | |
| logger.info(" UNDETECTED @ t=%.1fs: OCR=%d, conf=%.2f", r["timestamp"], r["ocr_value"], r["confidence"]) | |
| # Save debug images if <= 10 total errors | |
| if len(error_results) > 0 and len(error_results) <= 10: | |
| logger.info("\n[Step 4.5] Saving debug images for %d errors...", len(error_results)) | |
| save_debug_images(error_results, DEBUG_DIR) | |
| elif len(error_results) > 10: | |
| logger.info("\nSkipping debug images: %d errors > 10 threshold", len(error_results)) | |
| # Timing comparison | |
| logger.info("\n[Step 5] Timing comparison...") | |
| avg_template_time = sum(template_times) / len(template_times) if template_times else 0 | |
| logger.info(" Template matching: %.3fms/frame", avg_template_time * 1000) | |
| logger.info(" EasyOCR (benchmark): ~48.9ms/frame") | |
| logger.info(" Speedup: ~%.0fx", 48.9 / (avg_template_time * 1000) if avg_template_time > 0 else 0) | |
| # Summary | |
| logger.info("\n" + "=" * 60) | |
| logger.info("TEST SUMMARY") | |
| logger.info("=" * 60) | |
| logger.info("Templates built: %d/%d (%.1f%%)", lib_status["total_have"], lib_status["total_needed"], 100 * lib_status["total_have"] / lib_status["total_needed"]) | |
| logger.info("Accuracy: %.1f%% (%d/%d correct)", accuracy * 100, correct, total) | |
| logger.info("Detection rate: %.1f%%", detection_rate * 100) | |
| logger.info("Speedup: ~%.0fx faster than OCR", 48.9 / (avg_template_time * 1000) if avg_template_time > 0 else 0) | |
| # Pass criteria: >= 95% accuracy | |
| passed = accuracy >= 0.95 or (accuracy >= 0.90 and lib_status["total_have"] < lib_status["total_needed"]) | |
| if passed: | |
| logger.info("\nTEST: PASSED") | |
| else: | |
| logger.info("\nTEST: FAILED (accuracy %.1f%% < 95%%)", accuracy * 100) | |
| # Save library for use in integration tests | |
| output_dir = Path("output/debug/digit_templates") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| library.save(str(output_dir)) | |
| logger.info("\nTemplates saved to: %s", output_dir) | |
| if len(error_results) > 0 and len(error_results) <= 10: | |
| logger.info("Debug images saved to: %s", DEBUG_DIR) | |
| return passed | |
| if __name__ == "__main__": | |
| success = test_template_accuracy() | |
| sys.exit(0 if success else 1) | |