mustafa2ak commited on
Commit
7704bc6
·
verified ·
1 Parent(s): 3c82458

Create tracking.py

Browse files
Files changed (1) hide show
  1. tracking.py +197 -0
tracking.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Optional, Tuple
3
+ from scipy.optimize import linear_sum_assignment
4
+ from collections import deque
5
+ import uuid
6
+
7
+ class Track:
8
+ """Simple track for a single dog"""
9
+
10
+ def __init__(self, detection: Detection, track_id: Optional[int] = None):
11
+ """Initialize track from first detection"""
12
+ self.track_id = track_id if track_id else self._generate_id()
13
+ self.bbox = detection.bbox
14
+ self.detections = [detection]
15
+ self.confidence = detection.confidence
16
+
17
+ # Track state
18
+ self.age = 1
19
+ self.time_since_update = 0
20
+ self.state = 'tentative' # tentative -> confirmed -> deleted
21
+ self.hits = 1
22
+
23
+ # Store center points for trajectory
24
+ cx = (self.bbox[0] + self.bbox[2]) / 2
25
+ cy = (self.bbox[1] + self.bbox[3]) / 2
26
+ self.trajectory = deque(maxlen=30)
27
+ self.trajectory.append((cx, cy))
28
+
29
+ def _generate_id(self) -> int:
30
+ """Generate unique track ID"""
31
+ return int(uuid.uuid4().int % 100000)
32
+
33
+ def predict(self):
34
+ """Simple prediction - just use last position"""
35
+ self.age += 1
36
+ self.time_since_update += 1
37
+
38
+ def update(self, detection: Detection):
39
+ """Update track with new detection"""
40
+ self.bbox = detection.bbox
41
+ self.detections.append(detection)
42
+ self.confidence = detection.confidence
43
+
44
+ self.hits += 1
45
+ self.time_since_update = 0
46
+
47
+ # Update trajectory
48
+ cx = (self.bbox[0] + self.bbox[2]) / 2
49
+ cy = (self.bbox[1] + self.bbox[3]) / 2
50
+ self.trajectory.append((cx, cy))
51
+
52
+ # Confirm track after 3 hits
53
+ if self.state == 'tentative' and self.hits >= 3:
54
+ self.state = 'confirmed'
55
+
56
+ # Keep only recent detections to save memory
57
+ if len(self.detections) > 10:
58
+ self.detections = self.detections[-10:]
59
+
60
+ def mark_missed(self):
61
+ """Mark track as missed in current frame"""
62
+ if self.state == 'confirmed' and self.time_since_update > 15:
63
+ self.state = 'deleted'
64
+
65
+ class SimpleTracker:
66
+ """
67
+ Simplified ByteTrack - IoU-based tracking
68
+ Robust and proven approach without complexity
69
+ """
70
+
71
+ def __init__(self,
72
+ match_threshold: float = 0.5,
73
+ track_buffer: int = 30):
74
+ """
75
+ Initialize tracker
76
+
77
+ Args:
78
+ match_threshold: IoU threshold for matching (0.5 works well)
79
+ track_buffer: Frames to keep lost tracks
80
+ """
81
+ self.match_threshold = match_threshold
82
+ self.track_buffer = track_buffer
83
+
84
+ self.tracks: List[Track] = []
85
+ self.track_id_count = 1
86
+
87
+ def update(self, detections: List[Detection]) -> List[Track]:
88
+ """
89
+ Update tracks with new detections
90
+
91
+ Args:
92
+ detections: List of detections from current frame
93
+
94
+ Returns:
95
+ List of active tracks
96
+ """
97
+ # Predict existing tracks
98
+ for track in self.tracks:
99
+ track.predict()
100
+
101
+ # Get active tracks
102
+ active_tracks = [t for t in self.tracks if t.state != 'deleted']
103
+
104
+ if len(detections) > 0 and len(active_tracks) > 0:
105
+ # Calculate IoU matrix
106
+ iou_matrix = self._calculate_iou_matrix(active_tracks, detections)
107
+
108
+ # Hungarian matching
109
+ matched, unmatched_tracks, unmatched_dets = self._associate(
110
+ iou_matrix, self.match_threshold
111
+ )
112
+
113
+ # Update matched tracks
114
+ for t_idx, d_idx in matched:
115
+ active_tracks[t_idx].update(detections[d_idx])
116
+
117
+ # Mark unmatched tracks as missed
118
+ for t_idx in unmatched_tracks:
119
+ active_tracks[t_idx].mark_missed()
120
+
121
+ # Create new tracks for unmatched detections
122
+ for d_idx in unmatched_dets:
123
+ new_track = Track(detections[d_idx], self.track_id_count)
124
+ self.track_id_count += 1
125
+ self.tracks.append(new_track)
126
+
127
+ elif len(detections) > 0:
128
+ # No existing tracks - create new ones
129
+ for detection in detections:
130
+ new_track = Track(detection, self.track_id_count)
131
+ self.track_id_count += 1
132
+ self.tracks.append(new_track)
133
+ else:
134
+ # No detections - mark all as missed
135
+ for track in active_tracks:
136
+ track.mark_missed()
137
+
138
+ # Remove deleted tracks
139
+ self.tracks = [t for t in self.tracks if t.state != 'deleted']
140
+
141
+ # Return confirmed tracks
142
+ return [t for t in self.tracks if t.state == 'confirmed']
143
+
144
+ def _calculate_iou_matrix(self, tracks: List[Track],
145
+ detections: List[Detection]) -> np.ndarray:
146
+ """Calculate IoU between all tracks and detections"""
147
+ matrix = np.zeros((len(tracks), len(detections)))
148
+
149
+ for t_idx, track in enumerate(tracks):
150
+ for d_idx, detection in enumerate(detections):
151
+ matrix[t_idx, d_idx] = self._iou(track.bbox, detection.bbox)
152
+
153
+ return matrix
154
+
155
+ def _iou(self, bbox1: List[float], bbox2: List[float]) -> float:
156
+ """Calculate Intersection over Union"""
157
+ x1 = max(bbox1[0], bbox2[0])
158
+ y1 = max(bbox1[1], bbox2[1])
159
+ x2 = min(bbox1[2], bbox2[2])
160
+ y2 = min(bbox1[3], bbox2[3])
161
+
162
+ if x2 < x1 or y2 < y1:
163
+ return 0.0
164
+
165
+ intersection = (x2 - x1) * (y2 - y1)
166
+ area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
167
+ area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
168
+ union = area1 + area2 - intersection
169
+
170
+ return intersection / union if union > 0 else 0.0
171
+
172
+ def _associate(self, iou_matrix: np.ndarray,
173
+ threshold: float) -> Tuple[List, List, List]:
174
+ """Hungarian algorithm for optimal assignment"""
175
+ matched_indices = []
176
+
177
+ if iou_matrix.max() >= threshold:
178
+ # Convert to cost matrix
179
+ cost_matrix = 1 - iou_matrix
180
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
181
+
182
+ for r, c in zip(row_ind, col_ind):
183
+ if iou_matrix[r, c] >= threshold:
184
+ matched_indices.append([r, c])
185
+
186
+ unmatched_tracks = []
187
+ unmatched_detections = []
188
+
189
+ for t in range(iou_matrix.shape[0]):
190
+ if t not in [m[0] for m in matched_indices]:
191
+ unmatched_tracks.append(t)
192
+
193
+ for d in range(iou_matrix.shape[1]):
194
+ if d not in [m[1] for m in matched_indices]:
195
+ unmatched_detections.append(d)
196
+
197
+ return matched_indices, unmatched_tracks, unmatched_detections