bat_tracker / tests /test_reusable_context.py
kerojohan
Fix Space after bat_tracker sync
231135a
from __future__ import annotations
import csv
from pathlib import Path
import cv2
import numpy as np
import yaml
from bat_tracker.pipeline import run_pipeline
def _write_video(path: Path, frames: list[np.ndarray], fps: int = 10) -> None:
height, width = frames[0].shape
writer = cv2.VideoWriter(
str(path),
cv2.VideoWriter_fourcc(*"mp4v"),
float(fps),
(width, height),
)
assert writer.isOpened(), f"could not open writer for {path}"
for frame in frames:
writer.write(cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))
writer.release()
def _read_tracks(path: Path) -> list[dict[str, str]]:
with path.open(newline="", encoding="utf-8") as handle:
return list(csv.DictReader(handle))
def _normalize_tracks(rows: list[dict[str, str]]) -> list[tuple[str, ...]]:
keys = [
"track_id",
"frame",
"time_sec",
"x",
"y",
"vx",
"vy",
"bbox_x1",
"bbox_y1",
"bbox_x2",
"bbox_y2",
"area",
]
return [tuple(row[key] for key in keys) for row in rows]
def _filter_rows_to_frame_range(
rows: list[dict[str, str]],
*,
start_frame: int,
end_frame: int,
) -> list[dict[str, str]]:
return [row for row in rows if start_frame <= int(row["frame"]) <= end_frame]
def _base_config() -> dict:
return {
"background": {
"sample_frames": 20,
"uniform_sampling": True,
},
"detection": {
"blur_kernel": 1,
"threshold_mode": "fixed",
"diff_threshold": 10,
"morph_open": 1,
"morph_close": 1,
"min_area": 20,
"max_area": 5000,
"max_global_intensity_shift": -1.0,
"max_foreground_ratio": -1.0,
"max_detections_per_frame": 0,
"temporal_burst_min_detections": 0,
"temporal_burst_window_frames": 0,
"temporal_burst_trigger_frames": 0,
"temporal_burst_cooldown_frames": 0,
},
"tracking": {
"max_distance": 30,
"max_missed": 2,
"min_track_length": 1,
"min_track_displacement": 0.0,
"min_track_path_length": 0.0,
"min_track_straightness": 0.0,
"min_track_duration_sec": 0.0,
"auto_merge_suggested": False,
"require_start_or_end_in_valid_region": False,
"valid_region_gate_dilate_px": 0,
},
"valid_region": {
"enabled": False,
},
"output": {
"progress_enabled": False,
"export_track_clips": False,
},
}
def _make_prefix_clip_case(tmp_path: Path) -> tuple[Path, Path]:
full_frames: list[np.ndarray] = []
clip_frames: list[np.ndarray] = []
for idx in range(40):
frame = np.zeros((48, 64), dtype=np.uint8)
if idx < 12:
x0 = 18 + idx
cv2.rectangle(frame, (x0, 18), (x0 + 18, 32), 220, -1)
full_frames.append(frame)
if idx < 12:
clip_frames.append(frame.copy())
full_path = tmp_path / "full.mp4"
clip_path = tmp_path / "clip.mp4"
_write_video(full_path, full_frames)
_write_video(clip_path, clip_frames)
return full_path, clip_path
def test_clip_can_reuse_full_background_for_prefix_reproducibility(tmp_path: Path) -> None:
full_path, clip_path = _make_prefix_clip_case(tmp_path)
cfg = _base_config()
cfg_path = tmp_path / "cfg.yaml"
cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8")
out_full = tmp_path / "out_full"
out_clip = tmp_path / "out_clip"
out_clip_reused = tmp_path / "out_clip_reused"
run_pipeline(str(full_path), str(out_full), str(cfg_path))
run_pipeline(str(clip_path), str(out_clip), str(cfg_path))
cfg_reused = _base_config()
cfg_reused["background"]["input_image"] = str(out_full / "background.png")
cfg_reused_path = tmp_path / "cfg_reused.yaml"
cfg_reused_path.write_text(yaml.safe_dump(cfg_reused), encoding="utf-8")
run_pipeline(str(clip_path), str(out_clip_reused), str(cfg_reused_path))
full_tracks = _read_tracks(out_full / "tracks.csv")
clip_tracks = _read_tracks(out_clip / "tracks.csv")
reused_tracks = _read_tracks(out_clip_reused / "tracks.csv")
assert _normalize_tracks(clip_tracks) != _normalize_tracks(full_tracks)
assert _normalize_tracks(reused_tracks) == _normalize_tracks(full_tracks)
def test_precomputed_valid_region_mask_is_loaded_verbatim(tmp_path: Path) -> None:
_, clip_path = _make_prefix_clip_case(tmp_path)
mask = np.zeros((48, 64), dtype=np.uint8)
mask[:, 20:44] = 255
mask_path = tmp_path / "mask.png"
cv2.imwrite(str(mask_path), mask)
cfg = _base_config()
cfg["valid_region"] = {
"enabled": True,
"input_mask": str(mask_path),
"apply_to_detection": False,
}
cfg_path = tmp_path / "cfg_mask.yaml"
cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8")
out_dir = tmp_path / "out_mask"
meta = run_pipeline(str(clip_path), str(out_dir), str(cfg_path))
exported_mask = cv2.imread(str(out_dir / "valid_region" / "mask.png"), cv2.IMREAD_GRAYSCALE)
assert exported_mask is not None
assert np.array_equal(exported_mask, mask)
assert meta["valid_region"]["method"] == "input_mask"
assert meta["valid_region"]["input_mask"] == str(mask_path.resolve())
def test_prefix_context_window_makes_full_and_clip_match(tmp_path: Path) -> None:
full_path, clip_path = _make_prefix_clip_case(tmp_path)
cfg = _base_config()
cfg["background"]["context_start_sec"] = 0.0
cfg["background"]["context_duration_sec"] = 1.2
cfg_path = tmp_path / "cfg_prefix.yaml"
cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8")
out_full = tmp_path / "out_full_prefix"
out_clip = tmp_path / "out_clip_prefix"
run_pipeline(str(full_path), str(out_full), str(cfg_path))
run_pipeline(str(clip_path), str(out_clip), str(cfg_path))
full_tracks = _read_tracks(out_full / "tracks.csv")
clip_tracks = _read_tracks(out_clip / "tracks.csv")
full_prefix_tracks = _filter_rows_to_frame_range(full_tracks, start_frame=0, end_frame=11)
assert _normalize_tracks(full_prefix_tracks) == _normalize_tracks(clip_tracks)
def test_valid_region_can_use_dedicated_context_without_changing_detection_background(tmp_path: Path) -> None:
full_path, _ = _make_prefix_clip_case(tmp_path)
cfg = _base_config()
cfg["background"]["context_start_sec"] = 0.0
cfg["background"]["context_duration_sec"] = -1.0
cfg["valid_region"] = {
"enabled": True,
"method": "horizontal_illumination_profile",
"apply_to_detection": False,
"context_start_sec": 0.0,
"context_duration_sec": 1.2,
"blur_kernel_size": 31,
"profile_smooth_window": 9,
"threshold_ratio": 0.4,
"safety_margin": 0,
"min_region_width_ratio": 0.2,
}
cfg_path = tmp_path / "cfg_vr_context.yaml"
cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8")
out_dir = tmp_path / "out_vr_context"
meta = run_pipeline(str(full_path), str(out_dir), str(cfg_path))
assert meta["background"]["context_duration_sec"] == -1.0
assert meta["valid_region"]["enabled"] is True
mask = cv2.imread(str(out_dir / "valid_region" / "mask.png"), cv2.IMREAD_GRAYSCALE)
assert mask is not None
assert int(np.count_nonzero(mask)) > 0
gate_overlay = out_dir / "valid_region" / "gate_overlay.png"
assert gate_overlay.exists()