Spaces:
Running
Running
| from __future__ import annotations | |
| import csv | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import yaml | |
| from bat_tracker.pipeline import run_pipeline | |
| def _write_video(path: Path, frames: list[np.ndarray], fps: int = 10) -> None: | |
| height, width = frames[0].shape | |
| writer = cv2.VideoWriter( | |
| str(path), | |
| cv2.VideoWriter_fourcc(*"mp4v"), | |
| float(fps), | |
| (width, height), | |
| ) | |
| assert writer.isOpened(), f"could not open writer for {path}" | |
| for frame in frames: | |
| writer.write(cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)) | |
| writer.release() | |
| def _read_tracks(path: Path) -> list[dict[str, str]]: | |
| with path.open(newline="", encoding="utf-8") as handle: | |
| return list(csv.DictReader(handle)) | |
| def _normalize_tracks(rows: list[dict[str, str]]) -> list[tuple[str, ...]]: | |
| keys = [ | |
| "track_id", | |
| "frame", | |
| "time_sec", | |
| "x", | |
| "y", | |
| "vx", | |
| "vy", | |
| "bbox_x1", | |
| "bbox_y1", | |
| "bbox_x2", | |
| "bbox_y2", | |
| "area", | |
| ] | |
| return [tuple(row[key] for key in keys) for row in rows] | |
| def _filter_rows_to_frame_range( | |
| rows: list[dict[str, str]], | |
| *, | |
| start_frame: int, | |
| end_frame: int, | |
| ) -> list[dict[str, str]]: | |
| return [row for row in rows if start_frame <= int(row["frame"]) <= end_frame] | |
| def _base_config() -> dict: | |
| return { | |
| "background": { | |
| "sample_frames": 20, | |
| "uniform_sampling": True, | |
| }, | |
| "detection": { | |
| "blur_kernel": 1, | |
| "threshold_mode": "fixed", | |
| "diff_threshold": 10, | |
| "morph_open": 1, | |
| "morph_close": 1, | |
| "min_area": 20, | |
| "max_area": 5000, | |
| "max_global_intensity_shift": -1.0, | |
| "max_foreground_ratio": -1.0, | |
| "max_detections_per_frame": 0, | |
| "temporal_burst_min_detections": 0, | |
| "temporal_burst_window_frames": 0, | |
| "temporal_burst_trigger_frames": 0, | |
| "temporal_burst_cooldown_frames": 0, | |
| }, | |
| "tracking": { | |
| "max_distance": 30, | |
| "max_missed": 2, | |
| "min_track_length": 1, | |
| "min_track_displacement": 0.0, | |
| "min_track_path_length": 0.0, | |
| "min_track_straightness": 0.0, | |
| "min_track_duration_sec": 0.0, | |
| "auto_merge_suggested": False, | |
| "require_start_or_end_in_valid_region": False, | |
| "valid_region_gate_dilate_px": 0, | |
| }, | |
| "valid_region": { | |
| "enabled": False, | |
| }, | |
| "output": { | |
| "progress_enabled": False, | |
| "export_track_clips": False, | |
| }, | |
| } | |
| def _make_prefix_clip_case(tmp_path: Path) -> tuple[Path, Path]: | |
| full_frames: list[np.ndarray] = [] | |
| clip_frames: list[np.ndarray] = [] | |
| for idx in range(40): | |
| frame = np.zeros((48, 64), dtype=np.uint8) | |
| if idx < 12: | |
| x0 = 18 + idx | |
| cv2.rectangle(frame, (x0, 18), (x0 + 18, 32), 220, -1) | |
| full_frames.append(frame) | |
| if idx < 12: | |
| clip_frames.append(frame.copy()) | |
| full_path = tmp_path / "full.mp4" | |
| clip_path = tmp_path / "clip.mp4" | |
| _write_video(full_path, full_frames) | |
| _write_video(clip_path, clip_frames) | |
| return full_path, clip_path | |
| def test_clip_can_reuse_full_background_for_prefix_reproducibility(tmp_path: Path) -> None: | |
| full_path, clip_path = _make_prefix_clip_case(tmp_path) | |
| cfg = _base_config() | |
| cfg_path = tmp_path / "cfg.yaml" | |
| cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") | |
| out_full = tmp_path / "out_full" | |
| out_clip = tmp_path / "out_clip" | |
| out_clip_reused = tmp_path / "out_clip_reused" | |
| run_pipeline(str(full_path), str(out_full), str(cfg_path)) | |
| run_pipeline(str(clip_path), str(out_clip), str(cfg_path)) | |
| cfg_reused = _base_config() | |
| cfg_reused["background"]["input_image"] = str(out_full / "background.png") | |
| cfg_reused_path = tmp_path / "cfg_reused.yaml" | |
| cfg_reused_path.write_text(yaml.safe_dump(cfg_reused), encoding="utf-8") | |
| run_pipeline(str(clip_path), str(out_clip_reused), str(cfg_reused_path)) | |
| full_tracks = _read_tracks(out_full / "tracks.csv") | |
| clip_tracks = _read_tracks(out_clip / "tracks.csv") | |
| reused_tracks = _read_tracks(out_clip_reused / "tracks.csv") | |
| assert _normalize_tracks(clip_tracks) != _normalize_tracks(full_tracks) | |
| assert _normalize_tracks(reused_tracks) == _normalize_tracks(full_tracks) | |
| def test_precomputed_valid_region_mask_is_loaded_verbatim(tmp_path: Path) -> None: | |
| _, clip_path = _make_prefix_clip_case(tmp_path) | |
| mask = np.zeros((48, 64), dtype=np.uint8) | |
| mask[:, 20:44] = 255 | |
| mask_path = tmp_path / "mask.png" | |
| cv2.imwrite(str(mask_path), mask) | |
| cfg = _base_config() | |
| cfg["valid_region"] = { | |
| "enabled": True, | |
| "input_mask": str(mask_path), | |
| "apply_to_detection": False, | |
| } | |
| cfg_path = tmp_path / "cfg_mask.yaml" | |
| cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") | |
| out_dir = tmp_path / "out_mask" | |
| meta = run_pipeline(str(clip_path), str(out_dir), str(cfg_path)) | |
| exported_mask = cv2.imread(str(out_dir / "valid_region" / "mask.png"), cv2.IMREAD_GRAYSCALE) | |
| assert exported_mask is not None | |
| assert np.array_equal(exported_mask, mask) | |
| assert meta["valid_region"]["method"] == "input_mask" | |
| assert meta["valid_region"]["input_mask"] == str(mask_path.resolve()) | |
| def test_prefix_context_window_makes_full_and_clip_match(tmp_path: Path) -> None: | |
| full_path, clip_path = _make_prefix_clip_case(tmp_path) | |
| cfg = _base_config() | |
| cfg["background"]["context_start_sec"] = 0.0 | |
| cfg["background"]["context_duration_sec"] = 1.2 | |
| cfg_path = tmp_path / "cfg_prefix.yaml" | |
| cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") | |
| out_full = tmp_path / "out_full_prefix" | |
| out_clip = tmp_path / "out_clip_prefix" | |
| run_pipeline(str(full_path), str(out_full), str(cfg_path)) | |
| run_pipeline(str(clip_path), str(out_clip), str(cfg_path)) | |
| full_tracks = _read_tracks(out_full / "tracks.csv") | |
| clip_tracks = _read_tracks(out_clip / "tracks.csv") | |
| full_prefix_tracks = _filter_rows_to_frame_range(full_tracks, start_frame=0, end_frame=11) | |
| assert _normalize_tracks(full_prefix_tracks) == _normalize_tracks(clip_tracks) | |
| def test_valid_region_can_use_dedicated_context_without_changing_detection_background(tmp_path: Path) -> None: | |
| full_path, _ = _make_prefix_clip_case(tmp_path) | |
| cfg = _base_config() | |
| cfg["background"]["context_start_sec"] = 0.0 | |
| cfg["background"]["context_duration_sec"] = -1.0 | |
| cfg["valid_region"] = { | |
| "enabled": True, | |
| "method": "horizontal_illumination_profile", | |
| "apply_to_detection": False, | |
| "context_start_sec": 0.0, | |
| "context_duration_sec": 1.2, | |
| "blur_kernel_size": 31, | |
| "profile_smooth_window": 9, | |
| "threshold_ratio": 0.4, | |
| "safety_margin": 0, | |
| "min_region_width_ratio": 0.2, | |
| } | |
| cfg_path = tmp_path / "cfg_vr_context.yaml" | |
| cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") | |
| out_dir = tmp_path / "out_vr_context" | |
| meta = run_pipeline(str(full_path), str(out_dir), str(cfg_path)) | |
| assert meta["background"]["context_duration_sec"] == -1.0 | |
| assert meta["valid_region"]["enabled"] is True | |
| mask = cv2.imread(str(out_dir / "valid_region" / "mask.png"), cv2.IMREAD_GRAYSCALE) | |
| assert mask is not None | |
| assert int(np.count_nonzero(mask)) > 0 | |
| gate_overlay = out_dir / "valid_region" / "gate_overlay.png" | |
| assert gate_overlay.exists() | |