File size: 7,638 Bytes
231135a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from __future__ import annotations

import csv
from pathlib import Path

import cv2
import numpy as np
import yaml

from bat_tracker.pipeline import run_pipeline


def _write_video(path: Path, frames: list[np.ndarray], fps: int = 10) -> None:
    height, width = frames[0].shape
    writer = cv2.VideoWriter(
        str(path),
        cv2.VideoWriter_fourcc(*"mp4v"),
        float(fps),
        (width, height),
    )
    assert writer.isOpened(), f"could not open writer for {path}"
    for frame in frames:
        writer.write(cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))
    writer.release()


def _read_tracks(path: Path) -> list[dict[str, str]]:
    with path.open(newline="", encoding="utf-8") as handle:
        return list(csv.DictReader(handle))


def _normalize_tracks(rows: list[dict[str, str]]) -> list[tuple[str, ...]]:
    keys = [
        "track_id",
        "frame",
        "time_sec",
        "x",
        "y",
        "vx",
        "vy",
        "bbox_x1",
        "bbox_y1",
        "bbox_x2",
        "bbox_y2",
        "area",
    ]
    return [tuple(row[key] for key in keys) for row in rows]


def _filter_rows_to_frame_range(
    rows: list[dict[str, str]],
    *,
    start_frame: int,
    end_frame: int,
) -> list[dict[str, str]]:
    return [row for row in rows if start_frame <= int(row["frame"]) <= end_frame]


def _base_config() -> dict:
    return {
        "background": {
            "sample_frames": 20,
            "uniform_sampling": True,
        },
        "detection": {
            "blur_kernel": 1,
            "threshold_mode": "fixed",
            "diff_threshold": 10,
            "morph_open": 1,
            "morph_close": 1,
            "min_area": 20,
            "max_area": 5000,
            "max_global_intensity_shift": -1.0,
            "max_foreground_ratio": -1.0,
            "max_detections_per_frame": 0,
            "temporal_burst_min_detections": 0,
            "temporal_burst_window_frames": 0,
            "temporal_burst_trigger_frames": 0,
            "temporal_burst_cooldown_frames": 0,
        },
        "tracking": {
            "max_distance": 30,
            "max_missed": 2,
            "min_track_length": 1,
            "min_track_displacement": 0.0,
            "min_track_path_length": 0.0,
            "min_track_straightness": 0.0,
            "min_track_duration_sec": 0.0,
            "auto_merge_suggested": False,
            "require_start_or_end_in_valid_region": False,
            "valid_region_gate_dilate_px": 0,
        },
        "valid_region": {
            "enabled": False,
        },
        "output": {
            "progress_enabled": False,
            "export_track_clips": False,
        },
    }


def _make_prefix_clip_case(tmp_path: Path) -> tuple[Path, Path]:
    full_frames: list[np.ndarray] = []
    clip_frames: list[np.ndarray] = []
    for idx in range(40):
        frame = np.zeros((48, 64), dtype=np.uint8)
        if idx < 12:
            x0 = 18 + idx
            cv2.rectangle(frame, (x0, 18), (x0 + 18, 32), 220, -1)
        full_frames.append(frame)
        if idx < 12:
            clip_frames.append(frame.copy())

    full_path = tmp_path / "full.mp4"
    clip_path = tmp_path / "clip.mp4"
    _write_video(full_path, full_frames)
    _write_video(clip_path, clip_frames)
    return full_path, clip_path


def test_clip_can_reuse_full_background_for_prefix_reproducibility(tmp_path: Path) -> None:
    full_path, clip_path = _make_prefix_clip_case(tmp_path)

    cfg = _base_config()
    cfg_path = tmp_path / "cfg.yaml"
    cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8")

    out_full = tmp_path / "out_full"
    out_clip = tmp_path / "out_clip"
    out_clip_reused = tmp_path / "out_clip_reused"

    run_pipeline(str(full_path), str(out_full), str(cfg_path))
    run_pipeline(str(clip_path), str(out_clip), str(cfg_path))

    cfg_reused = _base_config()
    cfg_reused["background"]["input_image"] = str(out_full / "background.png")
    cfg_reused_path = tmp_path / "cfg_reused.yaml"
    cfg_reused_path.write_text(yaml.safe_dump(cfg_reused), encoding="utf-8")
    run_pipeline(str(clip_path), str(out_clip_reused), str(cfg_reused_path))

    full_tracks = _read_tracks(out_full / "tracks.csv")
    clip_tracks = _read_tracks(out_clip / "tracks.csv")
    reused_tracks = _read_tracks(out_clip_reused / "tracks.csv")

    assert _normalize_tracks(clip_tracks) != _normalize_tracks(full_tracks)
    assert _normalize_tracks(reused_tracks) == _normalize_tracks(full_tracks)


def test_precomputed_valid_region_mask_is_loaded_verbatim(tmp_path: Path) -> None:
    _, clip_path = _make_prefix_clip_case(tmp_path)

    mask = np.zeros((48, 64), dtype=np.uint8)
    mask[:, 20:44] = 255
    mask_path = tmp_path / "mask.png"
    cv2.imwrite(str(mask_path), mask)

    cfg = _base_config()
    cfg["valid_region"] = {
        "enabled": True,
        "input_mask": str(mask_path),
        "apply_to_detection": False,
    }
    cfg_path = tmp_path / "cfg_mask.yaml"
    cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8")

    out_dir = tmp_path / "out_mask"
    meta = run_pipeline(str(clip_path), str(out_dir), str(cfg_path))

    exported_mask = cv2.imread(str(out_dir / "valid_region" / "mask.png"), cv2.IMREAD_GRAYSCALE)
    assert exported_mask is not None
    assert np.array_equal(exported_mask, mask)
    assert meta["valid_region"]["method"] == "input_mask"
    assert meta["valid_region"]["input_mask"] == str(mask_path.resolve())


def test_prefix_context_window_makes_full_and_clip_match(tmp_path: Path) -> None:
    full_path, clip_path = _make_prefix_clip_case(tmp_path)

    cfg = _base_config()
    cfg["background"]["context_start_sec"] = 0.0
    cfg["background"]["context_duration_sec"] = 1.2
    cfg_path = tmp_path / "cfg_prefix.yaml"
    cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8")

    out_full = tmp_path / "out_full_prefix"
    out_clip = tmp_path / "out_clip_prefix"
    run_pipeline(str(full_path), str(out_full), str(cfg_path))
    run_pipeline(str(clip_path), str(out_clip), str(cfg_path))

    full_tracks = _read_tracks(out_full / "tracks.csv")
    clip_tracks = _read_tracks(out_clip / "tracks.csv")
    full_prefix_tracks = _filter_rows_to_frame_range(full_tracks, start_frame=0, end_frame=11)
    assert _normalize_tracks(full_prefix_tracks) == _normalize_tracks(clip_tracks)


def test_valid_region_can_use_dedicated_context_without_changing_detection_background(tmp_path: Path) -> None:
    full_path, _ = _make_prefix_clip_case(tmp_path)

    cfg = _base_config()
    cfg["background"]["context_start_sec"] = 0.0
    cfg["background"]["context_duration_sec"] = -1.0
    cfg["valid_region"] = {
        "enabled": True,
        "method": "horizontal_illumination_profile",
        "apply_to_detection": False,
        "context_start_sec": 0.0,
        "context_duration_sec": 1.2,
        "blur_kernel_size": 31,
        "profile_smooth_window": 9,
        "threshold_ratio": 0.4,
        "safety_margin": 0,
        "min_region_width_ratio": 0.2,
    }
    cfg_path = tmp_path / "cfg_vr_context.yaml"
    cfg_path.write_text(yaml.safe_dump(cfg), encoding="utf-8")

    out_dir = tmp_path / "out_vr_context"
    meta = run_pipeline(str(full_path), str(out_dir), str(cfg_path))

    assert meta["background"]["context_duration_sec"] == -1.0
    assert meta["valid_region"]["enabled"] is True
    mask = cv2.imread(str(out_dir / "valid_region" / "mask.png"), cv2.IMREAD_GRAYSCALE)
    assert mask is not None
    assert int(np.count_nonzero(mask)) > 0
    gate_overlay = out_dir / "valid_region" / "gate_overlay.png"
    assert gate_overlay.exists()