kerojohan commited on
Commit
737952c
·
1 Parent(s): 0b4d993

Sync algorithm changes with bat_tracker v1.1.5

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. bat_tracker/pipeline.py +38 -17
  3. tests/test_track_exports.py +123 -1
app.py CHANGED
@@ -13,7 +13,7 @@ import yaml
13
  from bat_tracker.pipeline import run_pipeline
14
 
15
 
16
- APP_VERSION = "v1.1.4"
17
  APP_TITLE = f"Bat Tracker {APP_VERSION}"
18
  APP_DESCRIPTION = (
19
  "Sube un video IR monocromo para ejecutar el pipeline, revisar la region valida "
 
13
  from bat_tracker.pipeline import run_pipeline
14
 
15
 
16
+ APP_VERSION = "v1.1.5"
17
  APP_TITLE = f"Bat Tracker {APP_VERSION}"
18
  APP_DESCRIPTION = (
19
  "Sube un video IR monocromo para ejecutar el pipeline, revisar la region valida "
bat_tracker/pipeline.py CHANGED
@@ -491,7 +491,7 @@ def _filter_track_points(
491
  if not accepted:
492
  reasons_set = set(reject_reasons)
493
  if (
494
- reasons_set.issubset({"min_track_length", "valid_region_gate"})
495
  and len(track_points) >= min_track_length_from_sec
496
  and score >= strong_short_score_min
497
  ):
@@ -538,6 +538,14 @@ def _vector_cosine(v0: tuple[float, float], v1: tuple[float, float]) -> float |
538
  return (v0[0] * v1[0] + v0[1] * v1[1]) / (n0 * n1)
539
 
540
 
 
 
 
 
 
 
 
 
541
  def _auto_merge_track_points(points: List[TrackPoint], tracking_cfg: Dict) -> tuple[List[TrackPoint], List[Dict]]:
542
  if not bool(tracking_cfg.get("auto_merge_suggested", False)):
543
  return points, []
@@ -556,6 +564,7 @@ def _auto_merge_track_points(points: List[TrackPoint], tracking_cfg: Dict) -> tu
556
  min_overlap_common = int(tracking_cfg.get("merge_overlap_min_common_frames", 3))
557
  max_overlap_mean_dist = float(tracking_cfg.get("merge_overlap_max_mean_distance", 60.0))
558
  min_overlap_cos = float(tracking_cfg.get("merge_overlap_min_direction_cosine", 0.8))
 
559
 
560
  parent: Dict[int, int] = {track_id: track_id for track_id in by_track}
561
 
@@ -581,16 +590,14 @@ def _auto_merge_track_points(points: List[TrackPoint], tracking_cfg: Dict) -> tu
581
  a_pts = by_track[track_a_id]
582
  a_start = a_pts[0]
583
  a_end = a_pts[-1]
584
- a_start_vec = (a_pts[min(2, len(a_pts) - 1)].x - a_pts[0].x, a_pts[min(2, len(a_pts) - 1)].y - a_pts[0].y)
585
- a_end_vec = (a_pts[-1].x - a_pts[max(0, len(a_pts) - 3)].x, a_pts[-1].y - a_pts[max(0, len(a_pts) - 3)].y)
586
  a_frames = {p.frame: p for p in a_pts}
587
 
588
  for track_b_id in track_ids[idx + 1 :]:
589
  b_pts = by_track[track_b_id]
590
  b_start = b_pts[0]
591
  b_end = b_pts[-1]
592
- b_start_vec = (b_pts[min(2, len(b_pts) - 1)].x - b_pts[0].x, b_pts[min(2, len(b_pts) - 1)].y - b_pts[0].y)
593
- b_end_vec = (b_pts[-1].x - b_pts[max(0, len(b_pts) - 3)].x, b_pts[-1].y - b_pts[max(0, len(b_pts) - 3)].y)
594
 
595
  reason = None
596
  reason_data: Dict[str, float | int] = {}
@@ -610,28 +617,42 @@ def _auto_merge_track_points(points: List[TrackPoint], tracking_cfg: Dict) -> tu
610
  else:
611
  b_frames = {p.frame: p for p in b_pts}
612
  common_frames = sorted(set(a_frames.keys()).intersection(b_frames.keys()))
613
- if len(common_frames) >= min_overlap_common:
614
  distances = []
615
- cosines = []
616
  for frame in common_frames:
617
  pa = a_frames[frame]
618
  pb = b_frames[frame]
619
  distances.append(hypot(pa.x - pb.x, pa.y - pb.y))
620
 
621
  mean_distance = sum(distances) / len(distances)
622
- c0 = _vector_cosine(a_start_vec, b_start_vec)
623
- c1 = _vector_cosine(a_end_vec, b_end_vec)
624
- if c0 is not None:
625
- cosines.append(c0)
626
- if c1 is not None:
627
- cosines.append(c1)
628
- mean_cos = (sum(cosines) / len(cosines)) if cosines else None
629
- if mean_distance <= max_overlap_mean_dist and (mean_cos is None or mean_cos >= min_overlap_cos):
630
- reason = "overlap"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
  reason_data = {
632
  "common_frames": len(common_frames),
633
  "mean_distance": mean_distance,
634
- "mean_direction_cosine": mean_cos if mean_cos is not None else 1.0,
635
  }
636
 
637
  if reason is None:
 
491
  if not accepted:
492
  reasons_set = set(reject_reasons)
493
  if (
494
+ reasons_set.issubset({"min_track_length"})
495
  and len(track_points) >= min_track_length_from_sec
496
  and score >= strong_short_score_min
497
  ):
 
538
  return (v0[0] * v1[0] + v0[1] * v1[1]) / (n0 * n1)
539
 
540
 
541
+ def _track_edge_vectors(points: List[TrackPoint]) -> tuple[tuple[float, float], tuple[float, float]]:
542
+ start_idx = min(2, len(points) - 1)
543
+ end_idx = max(0, len(points) - 3)
544
+ start_vec = (points[start_idx].x - points[0].x, points[start_idx].y - points[0].y)
545
+ end_vec = (points[-1].x - points[end_idx].x, points[-1].y - points[end_idx].y)
546
+ return start_vec, end_vec
547
+
548
+
549
  def _auto_merge_track_points(points: List[TrackPoint], tracking_cfg: Dict) -> tuple[List[TrackPoint], List[Dict]]:
550
  if not bool(tracking_cfg.get("auto_merge_suggested", False)):
551
  return points, []
 
564
  min_overlap_common = int(tracking_cfg.get("merge_overlap_min_common_frames", 3))
565
  max_overlap_mean_dist = float(tracking_cfg.get("merge_overlap_max_mean_distance", 60.0))
566
  min_overlap_cos = float(tracking_cfg.get("merge_overlap_min_direction_cosine", 0.8))
567
+ local_overlap_min_cos = max(0.65, min_overlap_cos - 0.15)
568
 
569
  parent: Dict[int, int] = {track_id: track_id for track_id in by_track}
570
 
 
590
  a_pts = by_track[track_a_id]
591
  a_start = a_pts[0]
592
  a_end = a_pts[-1]
593
+ a_start_vec, a_end_vec = _track_edge_vectors(a_pts)
 
594
  a_frames = {p.frame: p for p in a_pts}
595
 
596
  for track_b_id in track_ids[idx + 1 :]:
597
  b_pts = by_track[track_b_id]
598
  b_start = b_pts[0]
599
  b_end = b_pts[-1]
600
+ b_start_vec, b_end_vec = _track_edge_vectors(b_pts)
 
601
 
602
  reason = None
603
  reason_data: Dict[str, float | int] = {}
 
617
  else:
618
  b_frames = {p.frame: p for p in b_pts}
619
  common_frames = sorted(set(a_frames.keys()).intersection(b_frames.keys()))
620
+ if len(common_frames) >= 2:
621
  distances = []
 
622
  for frame in common_frames:
623
  pa = a_frames[frame]
624
  pb = b_frames[frame]
625
  distances.append(hypot(pa.x - pb.x, pa.y - pb.y))
626
 
627
  mean_distance = sum(distances) / len(distances)
628
+ start_cos = _vector_cosine(a_start_vec, b_start_vec)
629
+ end_cos = _vector_cosine(a_end_vec, b_end_vec)
630
+ connector_cos = _vector_cosine(a_end_vec, b_start_vec)
631
+ global_cosines = [c for c in (start_cos, end_cos) if c is not None]
632
+ mean_cos = (sum(global_cosines) / len(global_cosines)) if global_cosines else None
633
+
634
+ overlap_reason = None
635
+ if len(common_frames) >= min_overlap_common:
636
+ if mean_distance <= max_overlap_mean_dist and (
637
+ mean_cos is None or mean_cos >= min_overlap_cos or connector_cos is not None and connector_cos >= min_overlap_cos
638
+ ):
639
+ overlap_reason = "overlap"
640
+ elif (
641
+ mean_distance <= max_overlap_mean_dist
642
+ and connector_cos is not None
643
+ and connector_cos >= local_overlap_min_cos
644
+ ):
645
+ overlap_reason = "overlap_local"
646
+
647
+ if overlap_reason is not None:
648
+ direction_score = connector_cos
649
+ if direction_score is None:
650
+ direction_score = mean_cos if mean_cos is not None else 1.0
651
+ reason = overlap_reason
652
  reason_data = {
653
  "common_frames": len(common_frames),
654
  "mean_distance": mean_distance,
655
+ "mean_direction_cosine": direction_score,
656
  }
657
 
658
  if reason is None:
tests/test_track_exports.py CHANGED
@@ -9,8 +9,9 @@ import cv2
9
  import numpy as np
10
  import yaml
11
 
12
- from bat_tracker.pipeline import run_pipeline
13
  from bat_tracker.render import export_tracks_render_json, export_tracks_svg
 
14
 
15
 
16
  SVG_NS = {"svg": "http://www.w3.org/2000/svg"}
@@ -196,3 +197,124 @@ def test_svg_and_render_json_export_empty_tracks_as_valid_empty_documents(tmp_pa
196
  svg_root = ET.parse(svg_path).getroot()
197
  assert svg_root.attrib["viewBox"] == "0 0 64 48"
198
  assert svg_root.findall("svg:g[@class='track']", SVG_NS) == []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
  import yaml
11
 
12
+ from bat_tracker.pipeline import _auto_merge_track_points, run_pipeline
13
  from bat_tracker.render import export_tracks_render_json, export_tracks_svg
14
+ from bat_tracker.tracker import TrackPoint
15
 
16
 
17
  SVG_NS = {"svg": "http://www.w3.org/2000/svg"}
 
197
  svg_root = ET.parse(svg_path).getroot()
198
  assert svg_root.attrib["viewBox"] == "0 0 64 48"
199
  assert svg_root.findall("svg:g[@class='track']", SVG_NS) == []
200
+
201
+
202
+ def _make_track_point(track_id: int, frame: int, x: float, y: float) -> TrackPoint:
203
+ return TrackPoint(
204
+ video_id="video",
205
+ track_id=track_id,
206
+ frame=frame,
207
+ time_sec=frame / 30.0,
208
+ x=x,
209
+ y=y,
210
+ vx=0.0,
211
+ vy=0.0,
212
+ bbox_x1=int(round(x)) - 1,
213
+ bbox_y1=int(round(y)) - 1,
214
+ bbox_x2=int(round(x)) + 1,
215
+ bbox_y2=int(round(y)) + 1,
216
+ area=20.0,
217
+ )
218
+
219
+
220
+ def test_auto_merge_uses_local_overlap_continuity_for_short_shared_window() -> None:
221
+ points = [
222
+ _make_track_point(202, 32671, 715.5, 621.0),
223
+ _make_track_point(202, 32672, 739.5, 631.5),
224
+ _make_track_point(202, 32673, 760.5, 635.5),
225
+ _make_track_point(202, 32674, 785.0, 644.0),
226
+ _make_track_point(202, 32675, 810.5, 641.5),
227
+ _make_track_point(202, 32676, 837.5, 625.5),
228
+ _make_track_point(202, 32677, 860.5, 609.0),
229
+ _make_track_point(202, 32678, 895.0, 570.0),
230
+ _make_track_point(202, 32679, 922.0, 531.0),
231
+ _make_track_point(202, 32680, 917.5, 465.5),
232
+ _make_track_point(203, 32679, 882.0, 494.5),
233
+ _make_track_point(203, 32680, 979.0, 449.0),
234
+ _make_track_point(203, 32681, 1027.0, 404.0),
235
+ _make_track_point(203, 32682, 1035.5, 285.0),
236
+ _make_track_point(203, 32683, 1078.0, 172.0),
237
+ _make_track_point(203, 32684, 1128.5, 54.0),
238
+ ]
239
+ cfg = {
240
+ "auto_merge_suggested": True,
241
+ "merge_max_gap_frames": 12,
242
+ "merge_max_endpoint_distance": 100.0,
243
+ "merge_overlap_min_common_frames": 3,
244
+ "merge_overlap_max_mean_distance": 60.0,
245
+ "merge_overlap_min_direction_cosine": 0.8,
246
+ }
247
+
248
+ merged_points, merges = _auto_merge_track_points(points, cfg)
249
+
250
+ assert any(
251
+ merge["track_a"] == 202 and merge["track_b"] == 203 and merge["reason"] == "overlap_local"
252
+ for merge in merges
253
+ )
254
+ assert {point.track_id for point in merged_points} == {202}
255
+
256
+
257
+ def test_auto_merge_uses_connector_direction_for_overlapping_fragments() -> None:
258
+ points = [
259
+ _make_track_point(180, 13432, 708.0, 667.0),
260
+ _make_track_point(180, 13433, 738.0, 666.5),
261
+ _make_track_point(180, 13434, 769.0, 661.0),
262
+ _make_track_point(180, 13435, 797.5, 650.5),
263
+ _make_track_point(180, 13436, 830.0, 634.5),
264
+ _make_track_point(180, 13437, 859.0, 606.0),
265
+ _make_track_point(180, 13438, 881.5, 588.0),
266
+ _make_track_point(180, 13439, 916.0, 549.5),
267
+ _make_track_point(180, 13440, 956.0, 486.5),
268
+ _make_track_point(180, 13441, 1012.0, 461.0),
269
+ _make_track_point(180, 13442, 1029.0, 363.5),
270
+ _make_track_point(180, 13443, 1081.0, 303.0),
271
+ _make_track_point(181, 13439, 937.5, 550.0),
272
+ _make_track_point(181, 13440, 910.5, 504.5),
273
+ _make_track_point(181, 13441, 962.5, 477.5),
274
+ _make_track_point(181, 13442, 984.5, 393.0),
275
+ _make_track_point(181, 13443, 1002.5, 335.0),
276
+ ]
277
+ cfg = {
278
+ "auto_merge_suggested": True,
279
+ "merge_max_gap_frames": 12,
280
+ "merge_max_endpoint_distance": 100.0,
281
+ "merge_overlap_min_common_frames": 3,
282
+ "merge_overlap_max_mean_distance": 60.0,
283
+ "merge_overlap_min_direction_cosine": 0.8,
284
+ }
285
+
286
+ merged_points, merges = _auto_merge_track_points(points, cfg)
287
+
288
+ assert any(
289
+ merge["track_a"] == 180 and merge["track_b"] == 181 and merge["reason"] == "overlap"
290
+ for merge in merges
291
+ )
292
+ assert {point.track_id for point in merged_points} == {180}
293
+
294
+
295
+ def test_auto_merge_keeps_nearby_tracks_separate_when_connector_direction_breaks() -> None:
296
+ points = [
297
+ _make_track_point(86, 1091, 1703.5, 36.0),
298
+ _make_track_point(86, 1092, 1703.5, 36.0),
299
+ _make_track_point(86, 1093, 1722.0, 47.5),
300
+ _make_track_point(86, 1094, 1693.5, 16.5),
301
+ _make_track_point(86, 1095, 1663.0, 24.5),
302
+ _make_track_point(86, 1096, 1671.5, 2.5),
303
+ _make_track_point(90, 1092, 1640.0, 71.0),
304
+ _make_track_point(90, 1093, 1651.0, 59.5),
305
+ _make_track_point(90, 1094, 1671.5, 46.5),
306
+ _make_track_point(90, 1095, 1707.5, 50.0),
307
+ ]
308
+ cfg = {
309
+ "auto_merge_suggested": True,
310
+ "merge_max_gap_frames": 12,
311
+ "merge_max_endpoint_distance": 100.0,
312
+ "merge_overlap_min_common_frames": 3,
313
+ "merge_overlap_max_mean_distance": 60.0,
314
+ "merge_overlap_min_direction_cosine": 0.8,
315
+ }
316
+
317
+ merged_points, merges = _auto_merge_track_points(points, cfg)
318
+
319
+ assert merges == []
320
+ assert {point.track_id for point in merged_points} == {86, 90}