Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test digit extraction from play clock regions. | |
| This test verifies that: | |
| 1. Individual digits are correctly isolated from play clock regions | |
| 2. Color normalization correctly handles both red and white digits | |
| 3. Digit templates can be built from collected samples | |
| Usage: | |
| cd /Users/andytaylor/Documents/Personal/cfb40 | |
| source .venv/bin/activate | |
| python tests/test_digit_templates/test_digit_extraction.py | |
| """ | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import cv2 | |
| from detection import DetectScoreBug | |
| from readers import normalize_to_grayscale, detect_red_digits | |
| from setup import DigitTemplateBuilder, PlayClockRegionExtractor | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Test configuration | |
| VIDEO_PATH = "full_videos/OSU vs Tenn 12.21.24.mkv" | |
| TEMPLATE_PATH = "output/OSU_vs_Tenn_12_21_24_template.png" | |
| PLAYCLOCK_CONFIG_PATH = "output/OSU_vs_Tenn_12_21_24_playclock_config.json" | |
| # Test segment: 38:40-41:40 (3-min quick test with 5 plays) | |
| START_TIME = 38 * 60 + 40 # 2320 seconds | |
| END_TIME = 41 * 60 + 40 # 2500 seconds | |
| SAMPLE_INTERVAL = 0.5 # Sample every 0.5 seconds | |
| def extract_samples_with_ocr(video_path: str, start_time: float, end_time: float, sample_interval: float): | |
| """ | |
| Extract play clock samples using OCR for ground truth labeling. | |
| Returns list of (timestamp, clock_value, region_image, confidence) | |
| """ | |
| # Initialize components | |
| scorebug_detector = DetectScoreBug(template_path=TEMPLATE_PATH) | |
| clock_reader = PlayClockRegionExtractor(region_config_path=PLAYCLOCK_CONFIG_PATH) | |
| # Open video | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise RuntimeError(f"Could not open video: {video_path}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| logger.info("Video FPS: %.2f", fps) | |
| # Calculate frame positions | |
| start_frame = int(start_time * fps) | |
| end_frame = int(end_time * fps) | |
| frame_skip = int(sample_interval * fps) | |
| logger.info("Processing frames %d to %d (skip=%d)", start_frame, end_frame, frame_skip) | |
| # Lock scorebug region on first detection | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| ret, frame = cap.read() | |
| if ret: | |
| scorebug_detector.discover_and_lock_region(frame) | |
| logger.info("Scorebug region locked: %s", scorebug_detector.fixed_region) | |
| # Collect samples | |
| samples = [] | |
| current_frame = start_frame | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| while current_frame < end_frame: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| current_time = current_frame / fps | |
| # Detect scorebug | |
| detection = scorebug_detector.detect(frame) | |
| if detection.detected and detection.bbox: | |
| # Extract play clock region | |
| sb_x, sb_y, _, _ = detection.bbox | |
| pc_config = clock_reader.config | |
| pc_x = sb_x + pc_config.x_offset | |
| pc_y = sb_y + pc_config.y_offset | |
| pc_w = pc_config.width | |
| pc_h = pc_config.height | |
| # Validate bounds | |
| frame_h, frame_w = frame.shape[:2] | |
| if 0 <= pc_x and 0 <= pc_y and pc_x + pc_w <= frame_w and pc_y + pc_h <= frame_h: | |
| region = frame[pc_y : pc_y + pc_h, pc_x : pc_x + pc_w].copy() | |
| # OCR the region for ground truth | |
| reading = clock_reader.read(frame, detection.bbox) | |
| if reading.detected and reading.value is not None: | |
| samples.append((current_time, reading.value, region, reading.confidence)) | |
| # Skip frames | |
| for _ in range(frame_skip - 1): | |
| cap.grab() | |
| current_frame += frame_skip | |
| cap.release() | |
| logger.info("Collected %d samples with OCR ground truth", len(samples)) | |
| return samples | |
| def test_digit_extraction(): | |
| """Test that digits are correctly extracted with color normalization.""" | |
| logger.info("=" * 60) | |
| logger.info("TEST: Digit Extraction with Color Normalization") | |
| logger.info("=" * 60) | |
| # Check video exists | |
| if not Path(VIDEO_PATH).exists(): | |
| logger.error("Video not found: %s", VIDEO_PATH) | |
| return False | |
| if not Path(TEMPLATE_PATH).exists(): | |
| logger.error("Template not found: %s", TEMPLATE_PATH) | |
| return False | |
| if not Path(PLAYCLOCK_CONFIG_PATH).exists(): | |
| logger.error("Playclock config not found: %s", PLAYCLOCK_CONFIG_PATH) | |
| return False | |
| # Extract samples | |
| logger.info("\n[Step 1] Extracting samples with OCR...") | |
| samples = extract_samples_with_ocr(VIDEO_PATH, START_TIME, END_TIME, SAMPLE_INTERVAL) | |
| if len(samples) < 50: | |
| logger.warning("Only %d samples collected, expected at least 50", len(samples)) | |
| # Build templates with color normalization | |
| logger.info("\n[Step 2] Building digit templates (with color normalization)...") | |
| builder = DigitTemplateBuilder() | |
| for timestamp, clock_value, region, confidence in samples: | |
| builder.add_sample(region, clock_value, timestamp, confidence) | |
| # Check coverage - now position-aware (center for single-digit, right for double-digit) | |
| coverage = builder.get_coverage_status() | |
| logger.info("\nSample Coverage (position-aware):") | |
| logger.info(" Ones (center, from 0-9): %s (missing: %s)", coverage["ones_center"], coverage["ones_center_missing"]) | |
| logger.info(" Ones (right, from 10-40): %s (missing: %s)", coverage["ones_right"], coverage["ones_right_missing"]) | |
| logger.info(" Tens (left, from 10-40): %s (missing: %s)", coverage["tens"], coverage["tens_missing"]) | |
| logger.info(" Blank (left, from 0-9): %s", "YES" if coverage["has_blank"] else "NO") | |
| # Test color normalization on different clock values | |
| logger.info("\n[Step 3] Testing color normalization...") | |
| red_count = 0 | |
| white_count = 0 | |
| for timestamp, clock_value, region, _ in samples: | |
| # Check if normalization detects red using the shared function | |
| if detect_red_digits(region): | |
| red_count += 1 | |
| else: | |
| white_count += 1 | |
| # Verify normalization produces valid output | |
| gray = normalize_to_grayscale(region) | |
| if gray is None or gray.size == 0: | |
| logger.error("Color normalization failed for clock=%d at t=%.1f", clock_value, timestamp) | |
| return False | |
| logger.info("Samples with red digits detected: %d", red_count) | |
| logger.info("Samples with white digits detected: %d", white_count) | |
| logger.info("Color normalization: PASSED (all samples normalized successfully)") | |
| # Build templates | |
| logger.info("\n[Step 4] Building templates...") | |
| library = builder.build_templates(min_samples=2) # Lower threshold for testing | |
| status = library.get_coverage_status() | |
| logger.info("\nTemplate Library Coverage: %d/%d", status["total_have"], status["total_needed"]) | |
| logger.info(" Ones (center): %s (missing: %s)", status["ones_center_have"], status["ones_center_missing"]) | |
| logger.info(" Ones (right): %s (missing: %s)", status["ones_right_have"], status["ones_right_missing"]) | |
| logger.info(" Tens (left): %s (missing: %s)", status["tens_have"], status["tens_missing"]) | |
| logger.info(" Blank: %s", "YES" if status["has_blank"] else "NO") | |
| # Save templates for inspection | |
| output_dir = Path("output/debug/digit_templates") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| library.save(str(output_dir)) | |
| logger.info("\nTemplates saved to: %s", output_dir) | |
| # Save some sample digit images for visual inspection | |
| sample_dir = output_dir / "samples" | |
| sample_dir.mkdir(exist_ok=True) | |
| logger.info("\nSaving sample digit images for visual inspection...") | |
| sample_count = 0 | |
| for (is_tens, digit, position), digit_samples in builder.samples.items(): | |
| if len(digit_samples) > 0: | |
| sample = digit_samples[0] # First sample | |
| type_str = "tens" if is_tens else "ones" | |
| digit_str = "blank" if digit == -1 else str(digit) | |
| filename = f"{type_str}_{digit_str}_{position}_sample.png" | |
| cv2.imwrite(str(sample_dir / filename), sample.image) | |
| sample_count += 1 | |
| logger.info("Saved %d sample images to: %s", sample_count, sample_dir) | |
| # Summary | |
| logger.info("\n" + "=" * 60) | |
| logger.info("TEST SUMMARY") | |
| logger.info("=" * 60) | |
| logger.info("Samples collected: %d", len(samples)) | |
| logger.info("Templates built: %d/%d (25 total needed for dual-mode)", status["total_have"], status["total_needed"]) | |
| logger.info("Color normalization: PASSED") | |
| # Check if we have enough coverage | |
| min_coverage = 0.5 # At least 50% coverage for basic test | |
| actual_coverage = status["total_have"] / status["total_needed"] | |
| if actual_coverage >= min_coverage: | |
| logger.info("Coverage: PASSED (%.1f%% >= %.1f%%)", actual_coverage * 100, min_coverage * 100) | |
| return True | |
| else: | |
| logger.warning("Coverage: FAILED (%.1f%% < %.1f%%)", actual_coverage * 100, min_coverage * 100) | |
| return False | |
| if __name__ == "__main__": | |
| success = test_digit_extraction() | |
| sys.exit(0 if success else 1) | |