cfb40 / scripts /test_digit_extraction.py
andytaylor-smg's picture
moving stuff all around
6c65498
#!/usr/bin/env python3
"""
Test digit extraction from play clock regions.
This test verifies that:
1. Individual digits are correctly isolated from play clock regions
2. Color normalization correctly handles both red and white digits
3. Digit templates can be built from collected samples
Usage:
cd /Users/andytaylor/Documents/Personal/cfb40
source .venv/bin/activate
python tests/test_digit_templates/test_digit_extraction.py
"""
import logging
import sys
from pathlib import Path
import cv2
from detection import DetectScoreBug
from readers import normalize_to_grayscale, detect_red_digits
from setup import DigitTemplateBuilder, PlayClockRegionExtractor
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Test configuration
VIDEO_PATH = "full_videos/OSU vs Tenn 12.21.24.mkv"
TEMPLATE_PATH = "output/OSU_vs_Tenn_12_21_24_template.png"
PLAYCLOCK_CONFIG_PATH = "output/OSU_vs_Tenn_12_21_24_playclock_config.json"
# Test segment: 38:40-41:40 (3-min quick test with 5 plays)
START_TIME = 38 * 60 + 40 # 2320 seconds
END_TIME = 41 * 60 + 40 # 2500 seconds
SAMPLE_INTERVAL = 0.5 # Sample every 0.5 seconds
def extract_samples_with_ocr(video_path: str, start_time: float, end_time: float, sample_interval: float):
"""
Extract play clock samples using OCR for ground truth labeling.
Returns list of (timestamp, clock_value, region_image, confidence)
"""
# Initialize components
scorebug_detector = DetectScoreBug(template_path=TEMPLATE_PATH)
clock_reader = PlayClockRegionExtractor(region_config_path=PLAYCLOCK_CONFIG_PATH)
# Open video
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Could not open video: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
logger.info("Video FPS: %.2f", fps)
# Calculate frame positions
start_frame = int(start_time * fps)
end_frame = int(end_time * fps)
frame_skip = int(sample_interval * fps)
logger.info("Processing frames %d to %d (skip=%d)", start_frame, end_frame, frame_skip)
# Lock scorebug region on first detection
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
ret, frame = cap.read()
if ret:
scorebug_detector.discover_and_lock_region(frame)
logger.info("Scorebug region locked: %s", scorebug_detector.fixed_region)
# Collect samples
samples = []
current_frame = start_frame
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
while current_frame < end_frame:
ret, frame = cap.read()
if not ret:
break
current_time = current_frame / fps
# Detect scorebug
detection = scorebug_detector.detect(frame)
if detection.detected and detection.bbox:
# Extract play clock region
sb_x, sb_y, _, _ = detection.bbox
pc_config = clock_reader.config
pc_x = sb_x + pc_config.x_offset
pc_y = sb_y + pc_config.y_offset
pc_w = pc_config.width
pc_h = pc_config.height
# Validate bounds
frame_h, frame_w = frame.shape[:2]
if 0 <= pc_x and 0 <= pc_y and pc_x + pc_w <= frame_w and pc_y + pc_h <= frame_h:
region = frame[pc_y : pc_y + pc_h, pc_x : pc_x + pc_w].copy()
# OCR the region for ground truth
reading = clock_reader.read(frame, detection.bbox)
if reading.detected and reading.value is not None:
samples.append((current_time, reading.value, region, reading.confidence))
# Skip frames
for _ in range(frame_skip - 1):
cap.grab()
current_frame += frame_skip
cap.release()
logger.info("Collected %d samples with OCR ground truth", len(samples))
return samples
def test_digit_extraction():
"""Test that digits are correctly extracted with color normalization."""
logger.info("=" * 60)
logger.info("TEST: Digit Extraction with Color Normalization")
logger.info("=" * 60)
# Check video exists
if not Path(VIDEO_PATH).exists():
logger.error("Video not found: %s", VIDEO_PATH)
return False
if not Path(TEMPLATE_PATH).exists():
logger.error("Template not found: %s", TEMPLATE_PATH)
return False
if not Path(PLAYCLOCK_CONFIG_PATH).exists():
logger.error("Playclock config not found: %s", PLAYCLOCK_CONFIG_PATH)
return False
# Extract samples
logger.info("\n[Step 1] Extracting samples with OCR...")
samples = extract_samples_with_ocr(VIDEO_PATH, START_TIME, END_TIME, SAMPLE_INTERVAL)
if len(samples) < 50:
logger.warning("Only %d samples collected, expected at least 50", len(samples))
# Build templates with color normalization
logger.info("\n[Step 2] Building digit templates (with color normalization)...")
builder = DigitTemplateBuilder()
for timestamp, clock_value, region, confidence in samples:
builder.add_sample(region, clock_value, timestamp, confidence)
# Check coverage - now position-aware (center for single-digit, right for double-digit)
coverage = builder.get_coverage_status()
logger.info("\nSample Coverage (position-aware):")
logger.info(" Ones (center, from 0-9): %s (missing: %s)", coverage["ones_center"], coverage["ones_center_missing"])
logger.info(" Ones (right, from 10-40): %s (missing: %s)", coverage["ones_right"], coverage["ones_right_missing"])
logger.info(" Tens (left, from 10-40): %s (missing: %s)", coverage["tens"], coverage["tens_missing"])
logger.info(" Blank (left, from 0-9): %s", "YES" if coverage["has_blank"] else "NO")
# Test color normalization on different clock values
logger.info("\n[Step 3] Testing color normalization...")
red_count = 0
white_count = 0
for timestamp, clock_value, region, _ in samples:
# Check if normalization detects red using the shared function
if detect_red_digits(region):
red_count += 1
else:
white_count += 1
# Verify normalization produces valid output
gray = normalize_to_grayscale(region)
if gray is None or gray.size == 0:
logger.error("Color normalization failed for clock=%d at t=%.1f", clock_value, timestamp)
return False
logger.info("Samples with red digits detected: %d", red_count)
logger.info("Samples with white digits detected: %d", white_count)
logger.info("Color normalization: PASSED (all samples normalized successfully)")
# Build templates
logger.info("\n[Step 4] Building templates...")
library = builder.build_templates(min_samples=2) # Lower threshold for testing
status = library.get_coverage_status()
logger.info("\nTemplate Library Coverage: %d/%d", status["total_have"], status["total_needed"])
logger.info(" Ones (center): %s (missing: %s)", status["ones_center_have"], status["ones_center_missing"])
logger.info(" Ones (right): %s (missing: %s)", status["ones_right_have"], status["ones_right_missing"])
logger.info(" Tens (left): %s (missing: %s)", status["tens_have"], status["tens_missing"])
logger.info(" Blank: %s", "YES" if status["has_blank"] else "NO")
# Save templates for inspection
output_dir = Path("output/debug/digit_templates")
output_dir.mkdir(parents=True, exist_ok=True)
library.save(str(output_dir))
logger.info("\nTemplates saved to: %s", output_dir)
# Save some sample digit images for visual inspection
sample_dir = output_dir / "samples"
sample_dir.mkdir(exist_ok=True)
logger.info("\nSaving sample digit images for visual inspection...")
sample_count = 0
for (is_tens, digit, position), digit_samples in builder.samples.items():
if len(digit_samples) > 0:
sample = digit_samples[0] # First sample
type_str = "tens" if is_tens else "ones"
digit_str = "blank" if digit == -1 else str(digit)
filename = f"{type_str}_{digit_str}_{position}_sample.png"
cv2.imwrite(str(sample_dir / filename), sample.image)
sample_count += 1
logger.info("Saved %d sample images to: %s", sample_count, sample_dir)
# Summary
logger.info("\n" + "=" * 60)
logger.info("TEST SUMMARY")
logger.info("=" * 60)
logger.info("Samples collected: %d", len(samples))
logger.info("Templates built: %d/%d (25 total needed for dual-mode)", status["total_have"], status["total_needed"])
logger.info("Color normalization: PASSED")
# Check if we have enough coverage
min_coverage = 0.5 # At least 50% coverage for basic test
actual_coverage = status["total_have"] / status["total_needed"]
if actual_coverage >= min_coverage:
logger.info("Coverage: PASSED (%.1f%% >= %.1f%%)", actual_coverage * 100, min_coverage * 100)
return True
else:
logger.warning("Coverage: FAILED (%.1f%% < %.1f%%)", actual_coverage * 100, min_coverage * 100)
return False
if __name__ == "__main__":
success = test_digit_extraction()
sys.exit(0 if success else 1)