cfb40 / src /setup /template_builder.py
andytaylor-smg's picture
perfect mypy
719b8f7
"""
Digit template builder for creating play clock digit templates from OCR samples.
This module provides the DigitTemplateBuilder class which collects samples from
the play clock region, extracts individual digits, and builds averaged templates
for each unique digit value.
Region extraction and preprocessing utilities are shared from utils to eliminate code duplication.
"""
import logging
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
from utils import (
extract_center_region,
extract_far_left_region,
extract_left_region,
extract_right_region,
preprocess_playclock_region,
)
from .coverage import ONES_DIGITS, categorize_template_keys
from .models import DigitSample, DigitTemplate
from .template_library import DigitTemplateLibrary
logger = logging.getLogger(__name__)
class DigitTemplateBuilder:
"""
Builds digit templates from OCR-labeled play clock samples.
Collects samples from the play clock region, extracts individual digits,
and builds averaged templates for each unique digit value.
Uses color normalization so red and white digits produce the same template.
"""
# Play clock region dimensions (from config)
DEFAULT_REGION_WIDTH = 50
DEFAULT_REGION_HEIGHT = 28
def __init__(self, region_width: int = DEFAULT_REGION_WIDTH, region_height: int = DEFAULT_REGION_HEIGHT):
"""
Initialize the template builder.
Args:
region_width: Width of play clock region in pixels
region_height: Height of play clock region in pixels
"""
self.region_width = region_width
self.region_height = region_height
# Collected samples: {(is_tens, digit_value, position): [DigitSample, ...]}
self.samples: Dict[Tuple[bool, int, str], List[DigitSample]] = {}
# Track raw clock region images for potential reprocessing
self.raw_regions: List[Tuple[float, int, np.ndarray[Any, Any]]] = [] # (timestamp, clock_value, region)
logger.info("DigitTemplateBuilder initialized (region: %dx%d)", region_width, region_height)
def preprocess_region(self, region: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
"""
Preprocess play clock region for template extraction.
Delegates to shared utility function in utils.regions.
Args:
region: Play clock region (BGR format)
Returns:
Preprocessed binary image (white digits on black background)
"""
return preprocess_playclock_region(region, scale_factor=4)
def extract_digits(
self, preprocessed: np.ndarray[Any, Any], clock_value: int
) -> Tuple[Optional[np.ndarray[Any, Any]], Optional[np.ndarray[Any, Any]], Optional[np.ndarray[Any, Any]], Optional[np.ndarray[Any, Any]]]:
"""
Extract individual digit images from preprocessed play clock region.
For double-digit values (10-40): extracts left (tens) and right (ones)
For single-digit values (0-9): extracts far-left (blank) and center (ones)
Args:
preprocessed: Preprocessed play clock image (scaled 4x)
clock_value: The known clock value (0-40)
Returns:
Tuple of (tens_digit_image, ones_right_image, ones_center_image, blank_image)
- For double-digit: tens=left, ones_right=right, ones_center=None, blank=None
- For single-digit: tens=None, ones_right=None, ones_center=center, blank=far_left
"""
if clock_value >= 10:
# Double-digit: standard left/right split
return extract_left_region(preprocessed), extract_right_region(preprocessed), None, None
# Single-digit: far-left is blank (truly empty), ones is centered
return None, None, extract_center_region(preprocessed), extract_far_left_region(preprocessed)
def add_sample(self, region: np.ndarray[Any, Any], clock_value: int, timestamp: float, confidence: float = 1.0) -> None:
"""
Add a play clock sample for template building.
Routes samples based on display layout:
- Single-digit (0-9): Digit is CENTER-aligned, tens position is blank
- Double-digit (10-40): Tens on LEFT, ones on RIGHT
Args:
region: Play clock region (BGR format, original size)
clock_value: OCR-determined clock value (0-40)
timestamp: Video timestamp
confidence: OCR confidence score
"""
if clock_value < 0 or clock_value > 40:
logger.warning("Invalid clock value %d, skipping sample", clock_value)
return
# Store raw region for potential reprocessing
self.raw_regions.append((timestamp, clock_value, region.copy()))
# Preprocess (handles red-to-white conversion automatically)
preprocessed = self.preprocess_region(region)
# Extract digits based on single vs double digit display
tens_img, ones_right_img, ones_center_img, blank_img = self.extract_digits(preprocessed, clock_value)
# Determine digit values
ones_digit = clock_value % 10
tens_digit = clock_value // 10 if clock_value >= 10 else -1 # -1 = blank
if clock_value >= 10:
# Double-digit display (10-40): tens on left, ones on right
assert tens_img is not None # Asserts: validated by extract_digits
# Store tens sample (left position)
tens_sample = DigitSample(
digit_value=tens_digit,
is_tens_digit=True,
position="left",
image=tens_img,
source_clock_value=clock_value,
timestamp=timestamp,
confidence=confidence,
)
tens_key = (True, tens_digit, "left")
if tens_key not in self.samples:
self.samples[tens_key] = []
self.samples[tens_key].append(tens_sample)
# Store ones sample (right position)
assert ones_right_img is not None # Asserts: validated by extract_digits
ones_sample = DigitSample(
digit_value=ones_digit,
is_tens_digit=False,
position="right",
image=ones_right_img,
source_clock_value=clock_value,
timestamp=timestamp,
confidence=confidence,
)
ones_key = (False, ones_digit, "right")
if ones_key not in self.samples:
self.samples[ones_key] = []
self.samples[ones_key].append(ones_sample)
logger.debug(
"Added double-digit sample: clock=%d, tens=%d (left), ones=%d (right), t=%.1f",
clock_value,
tens_digit,
ones_digit,
timestamp,
)
else:
# Single-digit display (0-9): digit is centered, tens position is blank
# Store blank sample (far-left position - should be truly empty)
assert blank_img is not None # Asserts: validated by extract_digits
blank_sample = DigitSample(
digit_value=-1, # blank
is_tens_digit=True,
position="left", # Still use "left" as the position key for compatibility
image=blank_img, # Now using far-left region that's truly empty
source_clock_value=clock_value,
timestamp=timestamp,
confidence=confidence,
)
blank_key = (True, -1, "left")
if blank_key not in self.samples:
self.samples[blank_key] = []
self.samples[blank_key].append(blank_sample)
# Store ones sample (center position)
assert ones_center_img is not None # Asserts: validated by extract_digits
ones_sample = DigitSample(
digit_value=ones_digit,
is_tens_digit=False,
position="center",
image=ones_center_img,
source_clock_value=clock_value,
timestamp=timestamp,
confidence=confidence,
)
ones_key = (False, ones_digit, "center")
if ones_key not in self.samples:
self.samples[ones_key] = []
self.samples[ones_key].append(ones_sample)
logger.debug(
"Added single-digit sample: clock=%d, ones=%d (center), blank (far-left), t=%.1f",
clock_value,
ones_digit,
timestamp,
)
def get_sample_count(self) -> Dict[str, int]:
"""Get count of samples collected for each digit and position."""
counts = {}
for (is_tens, digit, position), samples in self.samples.items():
type_str = "tens" if is_tens else "ones"
digit_str = "blank" if digit == -1 else str(digit)
key = f"{type_str}_{digit_str}_{position}"
counts[key] = len(samples)
return counts
def build_templates(self, min_samples: int = 3) -> DigitTemplateLibrary:
"""
Build templates from collected samples.
For each digit/position combination, averages multiple samples
to create a robust template.
Args:
min_samples: Minimum samples required to build a template (default: 3)
Returns:
DigitTemplateLibrary with built templates
"""
library = DigitTemplateLibrary()
for (is_tens, digit, position), samples in self.samples.items():
if len(samples) < min_samples:
digit_display = "blank" if digit == -1 else str(digit)
logger.warning(
"Insufficient samples for %s digit %s (%s): %d < %d",
"tens" if is_tens else "ones",
digit_display,
position,
len(samples),
min_samples,
)
continue
# Resize all samples to match dimensions of first sample
target_shape = samples[0].image.shape
# Average the samples (with resizing if needed)
sum_image = np.zeros(target_shape, dtype=np.float32)
valid_count = 0
total_confidence = 0.0
for sample in samples:
img = sample.image
if img.shape != target_shape:
img = cv2.resize(img, (target_shape[1], target_shape[0]))
sum_image += img.astype(np.float32)
valid_count += 1
total_confidence += sample.confidence
if valid_count > 0:
avg_image = (sum_image / valid_count).astype(np.uint8)
# Threshold the averaged image to clean it up
_, template_img = cv2.threshold(avg_image, 127, 255, cv2.THRESH_BINARY)
template = DigitTemplate(
digit_value=digit,
is_tens_digit=is_tens,
position=position,
template=template_img,
sample_count=valid_count,
avg_confidence=total_confidence / valid_count,
)
library.add_template(template)
digit_display = "blank" if digit == -1 else str(digit)
logger.info(
"Built template: %s digit %s (%s) from %d samples",
"tens" if is_tens else "ones",
digit_display,
position,
valid_count,
)
# Log coverage status
coverage = library.get_coverage_status()
logger.info(
"Template coverage: %d/%d (%.1f%%)",
coverage["total_have"],
coverage["total_needed"],
100 * coverage["total_have"] / coverage["total_needed"],
)
return library
def get_coverage_status(self) -> Dict[str, Any]:
"""Get current sample coverage status."""
# Get keys for samples that have at least one entry
keys_with_samples = [key for key, samples in self.samples.items() if len(samples) >= 1]
# Use shared utility to categorize
ones_center_have, ones_right_have, tens_have, has_blank = categorize_template_keys(keys_with_samples)
return {
"ones_center": sorted(ones_center_have),
"ones_right": sorted(ones_right_have),
"tens": sorted(tens_have),
"has_blank": has_blank,
"ones_center_missing": sorted(ONES_DIGITS - ones_center_have),
"ones_right_missing": sorted(ONES_DIGITS - ones_right_have),
"tens_missing": sorted({1, 2, 3, 4} - tens_have),
}
def get_coverage_estimate(self) -> float:
"""
Get a simple coverage estimate as a float (0.0-1.0).
Returns:
Coverage estimate where 1.0 = all templates have samples
"""
status = self.get_coverage_status()
# Count what we have (with at least 1 sample each)
total_have = len(status["ones_center"]) + len(status["ones_right"]) + len(status["tens"])
if status["has_blank"]:
total_have += 1
# Total needed: 10 ones_center + 10 ones_right + 4 tens + 1 blank = 25
total_needed = 25
return total_have / total_needed