mustafa2ak commited on
Commit
06d6b9f
·
verified ·
1 Parent(s): e4afadb

Update tracking.py

Browse files
Files changed (1) hide show
  1. tracking.py +202 -95
tracking.py CHANGED
@@ -1,4 +1,7 @@
1
-
 
 
 
2
  import numpy as np
3
  from typing import List, Optional, Tuple
4
  from scipy.optimize import linear_sum_assignment
@@ -7,7 +10,7 @@ import uuid
7
  from detection import Detection
8
 
9
  class Track:
10
- """Simple track for a single dog"""
11
 
12
  def __init__(self, detection: Detection, track_id: Optional[int] = None):
13
  """Initialize track from first detection"""
@@ -28,45 +31,66 @@ class Track:
28
  self.trajectory = deque(maxlen=30)
29
  self.trajectory.append((cx, cy))
30
 
31
- # Add velocity for better prediction
32
- self.velocity = [0, 0]
33
-
 
 
 
 
 
 
 
 
 
 
 
 
34
  def _generate_id(self) -> int:
35
  """Generate unique track ID"""
36
  return int(uuid.uuid4().int % 100000)
37
-
38
  def predict(self):
39
- """Simple motion prediction"""
40
  self.age += 1
41
  self.time_since_update += 1
42
 
43
- # Simple constant velocity model
44
- if len(self.trajectory) >= 2:
45
- # Calculate velocity from last two positions
46
- curr = self.trajectory[-1]
47
- prev = self.trajectory[-2]
48
- self.velocity = [curr[0] - prev[0], curr[1] - prev[1]]
49
 
50
- # Predict next position
51
- predicted_cx = curr[0] + self.velocity[0] * 0.5 # Damping factor
52
- predicted_cy = curr[1] + self.velocity[1] * 0.5
53
 
54
- # Update bbox based on predicted center
55
- width = self.bbox[2] - self.bbox[0]
56
- height = self.bbox[3] - self.bbox[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  self.bbox = [
58
- predicted_cx - width/2,
59
- predicted_cy - height/2,
60
- predicted_cx + width/2,
61
- predicted_cy + height/2
62
  ]
63
-
64
  def update(self, detection: Detection):
65
  """Update track with new detection"""
66
  self.bbox = detection.bbox
67
  self.detections.append(detection)
68
  self.confidence = detection.confidence
69
-
70
  self.hits += 1
71
  self.time_since_update = 0
72
 
@@ -75,119 +99,172 @@ class Track:
75
  cy = (self.bbox[1] + self.bbox[3]) / 2
76
  self.trajectory.append((cx, cy))
77
 
78
- # Confirm track after 2 hits (reduced from 3 for faster confirmation)
 
 
 
 
 
 
 
 
 
 
 
79
  if self.state == 'tentative' and self.hits >= 2:
80
  self.state = 'confirmed'
81
 
82
  # Keep only recent detections to save memory
83
  if len(self.detections) > 5:
84
  self.detections = self.detections[-5:]
85
-
86
  def mark_missed(self):
87
  """Mark track as missed in current frame"""
88
- if self.state == 'confirmed' and self.time_since_update > 10: # Reduced from 15
89
  self.state = 'deleted'
90
 
91
- class SimpleTracker:
 
92
  """
93
- Fixed ByteTrack with stricter matching to prevent merging dogs
94
  """
95
 
96
- def __init__(self,
97
- match_threshold: float = 0.3, # LOWERED from 0.5 - stricter matching
98
  track_buffer: int = 30,
99
- min_iou_for_match: float = 0.2): # Added minimum IoU
 
100
  """
101
- Initialize tracker with stricter parameters
102
 
103
  Args:
104
- match_threshold: IoU threshold for matching (lower = stricter)
105
  track_buffer: Frames to keep lost tracks
106
- min_iou_for_match: Minimum IoU to even consider a match
 
107
  """
108
  self.match_threshold = match_threshold
109
  self.track_buffer = track_buffer
110
  self.min_iou_for_match = min_iou_for_match
 
111
 
112
  self.tracks: List[Track] = []
113
  self.track_id_count = 1
114
 
115
- # Add distance threshold to prevent far matches
116
  self.max_center_distance = 200 # pixels
 
117
 
118
  def update(self, detections: List[Detection]) -> List[Track]:
119
  """
120
- Update tracks with stricter matching criteria
121
  """
122
  # Predict existing tracks
123
  for track in self.tracks:
124
  track.predict()
125
-
126
- # Split tracks by confidence level for cascade matching
127
  confirmed_tracks = [t for t in self.tracks if t.state == 'confirmed']
128
  tentative_tracks = [t for t in self.tracks if t.state == 'tentative']
129
 
130
- # First match with confirmed tracks (stricter threshold)
131
- matched_indices = []
132
- unmatched_dets = list(range(len(detections)))
133
- unmatched_confirmed = list(range(len(confirmed_tracks)))
 
 
134
 
135
- if len(detections) > 0 and len(confirmed_tracks) > 0:
136
- # Calculate cost matrix with IoU and distance
137
- cost_matrix = self._calculate_cost_matrix(confirmed_tracks, detections)
 
 
 
138
 
139
- # Hungarian matching
140
  if cost_matrix.size > 0:
141
  row_ind, col_ind = linear_sum_assignment(cost_matrix)
142
 
143
  for r, c in zip(row_ind, col_ind):
144
- # Stricter matching criteria
145
  if cost_matrix[r, c] < (1 - self.match_threshold):
146
- confirmed_tracks[r].update(detections[c])
147
- matched_indices.append((r, c))
148
- if c in unmatched_dets:
149
- unmatched_dets.remove(c)
150
- if r in unmatched_confirmed:
151
- unmatched_confirmed.remove(r)
152
 
153
- # Match remaining detections with tentative tracks
154
- remaining_dets = [detections[i] for i in unmatched_dets]
155
- if len(remaining_dets) > 0 and len(tentative_tracks) > 0:
156
- cost_matrix = self._calculate_cost_matrix(tentative_tracks, remaining_dets)
 
 
 
 
 
 
 
 
 
 
157
 
158
  if cost_matrix.size > 0:
159
  row_ind, col_ind = linear_sum_assignment(cost_matrix)
160
 
161
  for r, c in zip(row_ind, col_ind):
162
- if cost_matrix[r, c] < (1 - self.match_threshold * 0.7): # Even stricter for tentative
163
- tentative_tracks[r].update(remaining_dets[c])
164
- if unmatched_dets[c] in unmatched_dets:
165
- unmatched_dets.remove(unmatched_dets[c])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  # Mark unmatched tracks as missed
168
  for idx in unmatched_confirmed:
169
  confirmed_tracks[idx].mark_missed()
170
-
171
  for track in tentative_tracks:
172
  if track.time_since_update > 0:
173
  track.mark_missed()
174
 
175
  # Create new tracks for unmatched detections
176
- for det_idx in unmatched_dets:
177
- # Check if detection is far from all existing tracks
 
 
178
  detection = detections[det_idx]
179
- is_new = True
180
 
181
- # Additional check: ensure new detection is not too close to existing tracks
 
182
  det_center = self._get_center(detection.bbox)
 
183
  for track in self.tracks:
184
  if track.state != 'deleted':
185
  track_center = self._get_center(track.bbox)
186
- dist = np.sqrt((det_center[0] - track_center[0])**2 +
187
- (det_center[1] - track_center[1])**2)
188
 
189
- # If too close to existing track, likely same dog
190
- if dist < 50: # Very close threshold
191
  is_new = False
192
  break
193
 
@@ -202,44 +279,70 @@ class SimpleTracker:
202
  # Return only confirmed tracks
203
  return [t for t in self.tracks if t.state == 'confirmed']
204
 
205
- def _calculate_cost_matrix(self, tracks: List[Track],
206
- detections: List[Detection]) -> np.ndarray:
207
-
208
  if not tracks or not detections:
209
  return np.array([])
210
-
211
- matrix = np.ones((len(tracks), len(detections)))
 
 
212
 
213
  for t_idx, track in enumerate(tracks):
214
- track_center = self._get_center(track.bbox)
 
 
 
 
215
 
216
  for d_idx, detection in enumerate(detections):
217
- # Calculate IoU
218
  iou = self._iou(track.bbox, detection.bbox)
219
 
220
- # Calculate center distance
221
- det_center = self._get_center(detection.bbox)
222
- distance = np.sqrt((track_center[0] - det_center[0])**2 +
223
- (track_center[1] - det_center[1])**2)
224
 
225
- # Combined cost (lower is better)
 
 
 
 
 
 
 
 
226
  if iou >= self.min_iou_for_match and distance < self.max_center_distance:
227
- # Weighted combination: IoU is more important
228
  iou_cost = 1 - iou
229
  dist_cost = distance / self.max_center_distance
230
- matrix[t_idx, d_idx] = 0.7 * iou_cost + 0.3 * dist_cost
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  else:
232
- # No match possible
233
- matrix[t_idx, d_idx] = 1.0
234
-
235
- return matrix
236
 
237
  def _get_center(self, bbox: List[float]) -> Tuple[float, float]:
238
-
239
  return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
240
 
241
  def _iou(self, bbox1: List[float], bbox2: List[float]) -> float:
242
-
243
  x1 = max(bbox1[0], bbox2[0])
244
  y1 = max(bbox1[1], bbox2[1])
245
  x2 = min(bbox1[2], bbox2[2])
@@ -247,7 +350,7 @@ class SimpleTracker:
247
 
248
  if x2 < x1 or y2 < y1:
249
  return 0.0
250
-
251
  intersection = (x2 - x1) * (y2 - y1)
252
  area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
253
  area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
@@ -256,6 +359,10 @@ class SimpleTracker:
256
  return intersection / (union + 1e-6)
257
 
258
  def set_match_threshold(self, threshold: float):
259
-
260
  self.match_threshold = max(0.1, min(0.8, threshold))
261
-
 
 
 
 
 
1
+ """
2
+ tracking.py - Enhanced tracking with proven techniques
3
+ Fixes the index bug and adds robust features
4
+ """
5
  import numpy as np
6
  from typing import List, Optional, Tuple
7
  from scipy.optimize import linear_sum_assignment
 
10
  from detection import Detection
11
 
12
  class Track:
13
+ """Enhanced track with Kalman filter prediction"""
14
 
15
  def __init__(self, detection: Detection, track_id: Optional[int] = None):
16
  """Initialize track from first detection"""
 
31
  self.trajectory = deque(maxlen=30)
32
  self.trajectory.append((cx, cy))
33
 
34
+ # Enhanced motion model
35
+ self.velocity = np.array([0.0, 0.0])
36
+ self.acceleration = np.array([0.0, 0.0])
37
+
38
+ # Appearance features for re-association
39
+ self.appearance_features = []
40
+ if hasattr(detection, 'features'):
41
+ self.appearance_features.append(detection.features)
42
+
43
+ # Size tracking for scale changes
44
+ self.sizes = deque(maxlen=10)
45
+ width = self.bbox[2] - self.bbox[0]
46
+ height = self.bbox[3] - self.bbox[1]
47
+ self.sizes.append((width, height))
48
+
49
  def _generate_id(self) -> int:
50
  """Generate unique track ID"""
51
  return int(uuid.uuid4().int % 100000)
52
+
53
  def predict(self):
54
+ """Enhanced motion prediction with acceleration"""
55
  self.age += 1
56
  self.time_since_update += 1
57
 
58
+ if len(self.trajectory) >= 3:
59
+ # Calculate velocity and acceleration from recent positions
60
+ positions = np.array(list(self.trajectory))[-3:]
 
 
 
61
 
62
+ # Velocity from last two positions
63
+ self.velocity = positions[-1] - positions[-2]
 
64
 
65
+ # Acceleration from velocity change
66
+ if len(positions) == 3:
67
+ prev_velocity = positions[-2] - positions[-3]
68
+ self.acceleration = (self.velocity - prev_velocity) * 0.5
69
+
70
+ # Predict next position with damping
71
+ predicted_pos = positions[-1] + self.velocity * 0.8 + self.acceleration * 0.2
72
+
73
+ # Get average recent size for stable bbox
74
+ if self.sizes:
75
+ avg_width = np.mean([s[0] for s in self.sizes])
76
+ avg_height = np.mean([s[1] for s in self.sizes])
77
+ else:
78
+ avg_width = self.bbox[2] - self.bbox[0]
79
+ avg_height = self.bbox[3] - self.bbox[1]
80
+
81
+ # Update bbox with predicted center and smoothed size
82
  self.bbox = [
83
+ predicted_pos[0] - avg_width/2,
84
+ predicted_pos[1] - avg_height/2,
85
+ predicted_pos[0] + avg_width/2,
86
+ predicted_pos[1] + avg_height/2
87
  ]
88
+
89
  def update(self, detection: Detection):
90
  """Update track with new detection"""
91
  self.bbox = detection.bbox
92
  self.detections.append(detection)
93
  self.confidence = detection.confidence
 
94
  self.hits += 1
95
  self.time_since_update = 0
96
 
 
99
  cy = (self.bbox[1] + self.bbox[3]) / 2
100
  self.trajectory.append((cx, cy))
101
 
102
+ # Update size history
103
+ width = self.bbox[2] - self.bbox[0]
104
+ height = self.bbox[3] - self.bbox[1]
105
+ self.sizes.append((width, height))
106
+
107
+ # Store appearance features
108
+ if hasattr(detection, 'features'):
109
+ self.appearance_features.append(detection.features)
110
+ if len(self.appearance_features) > 5:
111
+ self.appearance_features = self.appearance_features[-5:]
112
+
113
+ # Confirm track after 2 hits
114
  if self.state == 'tentative' and self.hits >= 2:
115
  self.state = 'confirmed'
116
 
117
  # Keep only recent detections to save memory
118
  if len(self.detections) > 5:
119
  self.detections = self.detections[-5:]
120
+
121
  def mark_missed(self):
122
  """Mark track as missed in current frame"""
123
+ if self.state == 'confirmed' and self.time_since_update > 10:
124
  self.state = 'deleted'
125
 
126
+
127
+ class RobustTracker:
128
  """
129
+ Enhanced tracker with multiple association strategies
130
  """
131
 
132
+ def __init__(self,
133
+ match_threshold: float = 0.3,
134
  track_buffer: int = 30,
135
+ min_iou_for_match: float = 0.2,
136
+ use_appearance: bool = True):
137
  """
138
+ Initialize tracker with multiple matching strategies
139
 
140
  Args:
141
+ match_threshold: IoU threshold for matching
142
  track_buffer: Frames to keep lost tracks
143
+ min_iou_for_match: Minimum IoU to consider a match
144
+ use_appearance: Whether to use appearance features
145
  """
146
  self.match_threshold = match_threshold
147
  self.track_buffer = track_buffer
148
  self.min_iou_for_match = min_iou_for_match
149
+ self.use_appearance = use_appearance
150
 
151
  self.tracks: List[Track] = []
152
  self.track_id_count = 1
153
 
154
+ # Enhanced parameters
155
  self.max_center_distance = 200 # pixels
156
+ self.min_size_similarity = 0.5 # Size change threshold
157
 
158
  def update(self, detections: List[Detection]) -> List[Track]:
159
  """
160
+ Update tracks with multiple association strategies
161
  """
162
  # Predict existing tracks
163
  for track in self.tracks:
164
  track.predict()
165
+
166
+ # Split tracks by state
167
  confirmed_tracks = [t for t in self.tracks if t.state == 'confirmed']
168
  tentative_tracks = [t for t in self.tracks if t.state == 'tentative']
169
 
170
+ # First association: High confidence detections with confirmed tracks
171
+ high_conf_dets = [i for i, d in enumerate(detections) if d.confidence > 0.7]
172
+ low_conf_dets = [i for i, d in enumerate(detections) if d.confidence <= 0.7]
173
+
174
+ matched_track_indices = set()
175
+ matched_det_indices = set()
176
 
177
+ # Match high confidence detections first
178
+ if high_conf_dets and confirmed_tracks:
179
+ cost_matrix = self._calculate_enhanced_cost_matrix(
180
+ confirmed_tracks,
181
+ [detections[i] for i in high_conf_dets]
182
+ )
183
 
 
184
  if cost_matrix.size > 0:
185
  row_ind, col_ind = linear_sum_assignment(cost_matrix)
186
 
187
  for r, c in zip(row_ind, col_ind):
 
188
  if cost_matrix[r, c] < (1 - self.match_threshold):
189
+ confirmed_tracks[r].update(detections[high_conf_dets[c]])
190
+ matched_track_indices.add(r)
191
+ matched_det_indices.add(high_conf_dets[c])
 
 
 
192
 
193
+ # Get unmatched items
194
+ unmatched_confirmed = [i for i in range(len(confirmed_tracks))
195
+ if i not in matched_track_indices]
196
+ unmatched_dets = [i for i in range(len(detections))
197
+ if i not in matched_det_indices]
198
+
199
+ # Second association: Remaining confirmed tracks with remaining detections
200
+ if unmatched_dets and unmatched_confirmed:
201
+ remaining_tracks = [confirmed_tracks[i] for i in unmatched_confirmed]
202
+ remaining_dets = [detections[i] for i in unmatched_dets]
203
+
204
+ cost_matrix = self._calculate_enhanced_cost_matrix(
205
+ remaining_tracks, remaining_dets
206
+ )
207
 
208
  if cost_matrix.size > 0:
209
  row_ind, col_ind = linear_sum_assignment(cost_matrix)
210
 
211
  for r, c in zip(row_ind, col_ind):
212
+ if cost_matrix[r, c] < (1 - self.match_threshold * 0.8):
213
+ track_idx = unmatched_confirmed[r]
214
+ det_idx = unmatched_dets[c]
215
+ confirmed_tracks[track_idx].update(detections[det_idx])
216
+ matched_det_indices.add(det_idx)
217
+
218
+ # Remove from unmatched lists
219
+ if track_idx in unmatched_confirmed:
220
+ unmatched_confirmed.remove(track_idx)
221
+
222
+ # Update unmatched detections list after confirmed track matching
223
+ unmatched_dets = [i for i in range(len(detections))
224
+ if i not in matched_det_indices]
225
+
226
+ # Third association: Tentative tracks
227
+ if unmatched_dets and tentative_tracks:
228
+ remaining_dets = [detections[i] for i in unmatched_dets]
229
+
230
+ cost_matrix = self._calculate_enhanced_cost_matrix(
231
+ tentative_tracks, remaining_dets
232
+ )
233
+
234
+ if cost_matrix.size > 0:
235
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
236
+
237
+ for r, c in zip(row_ind, col_ind):
238
+ if cost_matrix[r, c] < (1 - self.match_threshold * 0.6):
239
+ det_idx = unmatched_dets[c]
240
+ tentative_tracks[r].update(detections[det_idx])
241
+ matched_det_indices.add(det_idx)
242
 
243
  # Mark unmatched tracks as missed
244
  for idx in unmatched_confirmed:
245
  confirmed_tracks[idx].mark_missed()
246
+
247
  for track in tentative_tracks:
248
  if track.time_since_update > 0:
249
  track.mark_missed()
250
 
251
  # Create new tracks for unmatched detections
252
+ final_unmatched_dets = [i for i in range(len(detections))
253
+ if i not in matched_det_indices]
254
+
255
+ for det_idx in final_unmatched_dets:
256
  detection = detections[det_idx]
 
257
 
258
+ # Additional check: ensure not too close to existing tracks
259
+ is_new = True
260
  det_center = self._get_center(detection.bbox)
261
+
262
  for track in self.tracks:
263
  if track.state != 'deleted':
264
  track_center = self._get_center(track.bbox)
265
+ dist = np.linalg.norm(np.array(det_center) - np.array(track_center))
 
266
 
267
+ if dist < 40: # Very close threshold
 
268
  is_new = False
269
  break
270
 
 
279
  # Return only confirmed tracks
280
  return [t for t in self.tracks if t.state == 'confirmed']
281
 
282
+ def _calculate_enhanced_cost_matrix(self, tracks: List[Track],
283
+ detections: List[Detection]) -> np.ndarray:
284
+ """Calculate cost matrix with multiple cues"""
285
  if not tracks or not detections:
286
  return np.array([])
287
+
288
+ n_tracks = len(tracks)
289
+ n_dets = len(detections)
290
+ cost_matrix = np.ones((n_tracks, n_dets))
291
 
292
  for t_idx, track in enumerate(tracks):
293
+ track_center = np.array(self._get_center(track.bbox))
294
+ track_size = np.array([
295
+ track.bbox[2] - track.bbox[0],
296
+ track.bbox[3] - track.bbox[1]
297
+ ])
298
 
299
  for d_idx, detection in enumerate(detections):
300
+ # IoU cost
301
  iou = self._iou(track.bbox, detection.bbox)
302
 
303
+ # Center distance cost
304
+ det_center = np.array(self._get_center(detection.bbox))
305
+ distance = np.linalg.norm(track_center - det_center)
 
306
 
307
+ # Size similarity cost
308
+ det_size = np.array([
309
+ detection.bbox[2] - detection.bbox[0],
310
+ detection.bbox[3] - detection.bbox[1]
311
+ ])
312
+ size_sim = np.minimum(track_size, det_size) / (np.maximum(track_size, det_size) + 1e-6)
313
+ size_cost = 1 - np.mean(size_sim)
314
+
315
+ # Combine costs
316
  if iou >= self.min_iou_for_match and distance < self.max_center_distance:
 
317
  iou_cost = 1 - iou
318
  dist_cost = distance / self.max_center_distance
319
+
320
+ # Weighted combination
321
+ total_cost = (0.5 * iou_cost +
322
+ 0.3 * dist_cost +
323
+ 0.2 * size_cost)
324
+
325
+ # Appearance cost if available
326
+ if self.use_appearance and track.appearance_features and hasattr(detection, 'features'):
327
+ # Simple cosine similarity
328
+ track_feat = np.mean(track.appearance_features, axis=0)
329
+ det_feat = detection.features
330
+ app_sim = np.dot(track_feat, det_feat) / (
331
+ np.linalg.norm(track_feat) * np.linalg.norm(det_feat) + 1e-6
332
+ )
333
+ app_cost = 1 - app_sim
334
+ total_cost = 0.4 * iou_cost + 0.2 * dist_cost + 0.2 * size_cost + .2 * app_cost
335
+
336
+ cost_matrix[t_idx, d_idx] = total_cost
337
  else:
338
+ cost_matrix[t_idx, d_idx] = 1.0
339
+
340
+ return cost_matrix
 
341
 
342
  def _get_center(self, bbox: List[float]) -> Tuple[float, float]:
 
343
  return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
344
 
345
  def _iou(self, bbox1: List[float], bbox2: List[float]) -> float:
 
346
  x1 = max(bbox1[0], bbox2[0])
347
  y1 = max(bbox1[1], bbox2[1])
348
  x2 = min(bbox1[2], bbox2[2])
 
350
 
351
  if x2 < x1 or y2 < y1:
352
  return 0.0
353
+
354
  intersection = (x2 - x1) * (y2 - y1)
355
  area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
356
  area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
 
359
  return intersection / (union + 1e-6)
360
 
361
  def set_match_threshold(self, threshold: float):
362
+ """Update matching threshold"""
363
  self.match_threshold = max(0.1, min(0.8, threshold))
364
+ print(f"Tracking threshold updated to: {self.match_threshold:.2f}")
365
+
366
+
367
+ # Use the enhanced tracker as default
368
+ SimpleTracker = RobustTracker