cfb40 / scripts /test_flag_overlap_fix.py
andytaylor-smg's picture
documented issues with oregon game, moving to fix now
6b03145
#!/usr/bin/env python3
"""
Test script to verify FLAG overlap fix in PlayMerger.
Tests that FLAG plays overlapping with the previous play are:
1. Trimmed to start after the previous play's end time
2. Removed if resulting duration would be too short
"""
import sys
from pathlib import Path
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from tracking.models import PlayEvent
from tracking.play_merger import PlayMerger
def create_play(play_number: int, start: float, end: float, play_type: str = "normal") -> PlayEvent:
"""Helper to create a PlayEvent for testing."""
return PlayEvent(
play_number=play_number,
start_time=start,
end_time=end,
confidence=0.9,
start_method="test",
end_method="test",
play_type=play_type,
)
def test_flag_overlap_adjustment():
"""Test that FLAG plays overlapping with previous play are adjusted."""
print("=" * 60)
print("Test: FLAG overlap adjustment")
print("=" * 60)
# Scenario: Normal play ends at 100s, FLAG play starts at 95s (5s overlap)
normal_play = create_play(1, start=90.0, end=100.0, play_type="normal")
flag_play = create_play(2, start=95.0, end=115.0, play_type="flag")
print(f"\nBefore merge:")
print(f" Normal play: {normal_play.start_time:.1f}s - {normal_play.end_time:.1f}s")
print(f" FLAG play: {flag_play.start_time:.1f}s - {flag_play.end_time:.1f}s")
print(f" Overlap: {normal_play.end_time - flag_play.start_time:.1f}s")
merger = PlayMerger()
merged = merger.merge([normal_play], [flag_play])
print(f"\nAfter merge ({len(merged)} plays):")
for play in merged:
print(f" Play #{play.play_number} ({play.play_type}): {play.start_time:.1f}s - {play.end_time:.1f}s")
# Verify FLAG play was adjusted
assert len(merged) == 2, f"Expected 2 plays, got {len(merged)}"
assert merged[1].play_type == "flag", "Second play should be FLAG"
assert merged[1].start_time == 100.0, f"FLAG start should be 100.0, got {merged[1].start_time}"
print("\n✓ PASS: FLAG play adjusted to start after previous play")
def test_flag_overlap_removal():
"""Test that FLAG plays that would be too short after adjustment are removed."""
print("\n" + "=" * 60)
print("Test: FLAG overlap removal (too short after trim)")
print("=" * 60)
# Scenario: Normal play ends at 100s, FLAG play starts at 98s and ends at 101s
# After trimming 2s overlap, FLAG would only be 1s (below min_flag_duration of 2s)
normal_play = create_play(1, start=90.0, end=100.0, play_type="normal")
flag_play = create_play(2, start=98.0, end=101.0, play_type="flag")
print(f"\nBefore merge:")
print(f" Normal play: {normal_play.start_time:.1f}s - {normal_play.end_time:.1f}s")
print(f" FLAG play: {flag_play.start_time:.1f}s - {flag_play.end_time:.1f}s (duration: {flag_play.end_time - flag_play.start_time:.1f}s)")
print(f" Overlap: {normal_play.end_time - flag_play.start_time:.1f}s")
print(f" After trim: {flag_play.end_time - normal_play.end_time:.1f}s (would be too short)")
merger = PlayMerger()
merged = merger.merge([normal_play], [flag_play])
print(f"\nAfter merge ({len(merged)} plays):")
for play in merged:
print(f" Play #{play.play_number} ({play.play_type}): {play.start_time:.1f}s - {play.end_time:.1f}s")
# Verify FLAG play was removed
assert len(merged) == 1, f"Expected 1 play (FLAG removed), got {len(merged)}"
assert merged[0].play_type == "normal", "Only play should be normal"
print("\n✓ PASS: FLAG play removed (too short after trim)")
def test_no_overlap():
"""Test that non-overlapping FLAG plays are not modified."""
print("\n" + "=" * 60)
print("Test: No overlap (FLAG starts after previous play)")
print("=" * 60)
# Scenario: Normal play ends at 100s, FLAG play starts at 105s (no overlap)
normal_play = create_play(1, start=90.0, end=100.0, play_type="normal")
flag_play = create_play(2, start=105.0, end=120.0, play_type="flag")
print(f"\nBefore merge:")
print(f" Normal play: {normal_play.start_time:.1f}s - {normal_play.end_time:.1f}s")
print(f" FLAG play: {flag_play.start_time:.1f}s - {flag_play.end_time:.1f}s")
print(f" Gap: {flag_play.start_time - normal_play.end_time:.1f}s (no overlap)")
merger = PlayMerger()
merged = merger.merge([normal_play], [flag_play])
print(f"\nAfter merge ({len(merged)} plays):")
for play in merged:
print(f" Play #{play.play_number} ({play.play_type}): {play.start_time:.1f}s - {play.end_time:.1f}s")
# Verify FLAG play was not modified
assert len(merged) == 2, f"Expected 2 plays, got {len(merged)}"
assert merged[1].start_time == 105.0, f"FLAG start should be unchanged at 105.0, got {merged[1].start_time}"
print("\n✓ PASS: Non-overlapping FLAG play unchanged")
def test_multiple_flags():
"""Test handling of multiple FLAG plays with overlaps."""
print("\n" + "=" * 60)
print("Test: Multiple FLAG plays with various overlaps")
print("=" * 60)
# Scenario: Three plays, two FLAG plays with different overlap situations
plays = [
create_play(1, start=90.0, end=100.0, play_type="normal"),
create_play(2, start=95.0, end=110.0, play_type="flag"), # 5s overlap, should be trimmed
create_play(3, start=120.0, end=135.0, play_type="flag"), # no overlap, should be unchanged
]
print(f"\nBefore merge:")
for p in plays:
print(f" Play #{p.play_number} ({p.play_type}): {p.start_time:.1f}s - {p.end_time:.1f}s")
merger = PlayMerger()
merged = merger.merge(plays)
print(f"\nAfter merge ({len(merged)} plays):")
for play in merged:
print(f" Play #{play.play_number} ({play.play_type}): {play.start_time:.1f}s - {play.end_time:.1f}s")
# Verify results
assert len(merged) == 3, f"Expected 3 plays, got {len(merged)}"
assert merged[1].start_time == 100.0, f"First FLAG should be trimmed to 100.0, got {merged[1].start_time}"
assert merged[2].start_time == 120.0, f"Second FLAG should be unchanged at 120.0, got {merged[2].start_time}"
print("\n✓ PASS: Multiple FLAG plays handled correctly")
if __name__ == "__main__":
print("\n" + "=" * 60)
print(" FLAG OVERLAP FIX TEST SUITE")
print("=" * 60)
test_flag_overlap_adjustment()
test_flag_overlap_removal()
test_no_overlap()
test_multiple_flags()
print("\n" + "=" * 60)
print(" ALL TESTS PASSED")
print("=" * 60 + "\n")