mustafa2ak commited on
Commit
ec87247
·
verified ·
1 Parent(s): a26442f

Update tracking.py

Browse files
Files changed (1) hide show
  1. tracking.py +380 -248
tracking.py CHANGED
@@ -1,29 +1,33 @@
1
  """
2
- tracking.py - Enhanced tracking with proven techniques
3
- Fixes the index bug and adds robust features
4
  """
5
  import numpy as np
6
- from typing import List, Optional, Tuple
7
  from scipy.optimize import linear_sum_assignment
8
  from collections import deque
9
  import uuid
10
  from detection import Detection
 
 
 
11
 
12
  class Track:
13
- """Enhanced track with Kalman filter prediction"""
14
 
15
  def __init__(self, detection: Detection, track_id: Optional[int] = None):
16
  """Initialize track from first detection"""
17
  self.track_id = track_id if track_id else self._generate_id()
18
- self.bbox = detection.bbox
19
  self.detections = [detection]
20
- self.confidence = detection.confidence
21
 
22
  # Track state
23
  self.age = 1
24
  self.time_since_update = 0
25
  self.state = 'tentative'
26
  self.hits = 1
 
27
 
28
  # Store center points for trajectory
29
  cx = (self.bbox[0] + self.bbox[2]) / 2
@@ -42,106 +46,147 @@ class Track:
42
 
43
  # Size tracking for scale changes
44
  self.sizes = deque(maxlen=10)
45
- width = self.bbox[2] - self.bbox[0]
46
- height = self.bbox[3] - self.bbox[1]
47
  self.sizes.append((width, height))
48
 
 
 
 
 
49
  def _generate_id(self) -> int:
50
  """Generate unique track ID"""
51
  return int(uuid.uuid4().int % 100000)
52
 
53
  def predict(self):
54
- """Enhanced motion prediction with acceleration"""
55
  self.age += 1
56
  self.time_since_update += 1
 
57
 
58
- if len(self.trajectory) >= 3:
59
- # Calculate velocity and acceleration from recent positions
60
- positions = np.array(list(self.trajectory))[-3:]
61
-
62
- # Velocity from last two positions
63
- self.velocity = positions[-1] - positions[-2]
64
-
65
- # Acceleration from velocity change
66
- if len(positions) == 3:
67
- prev_velocity = positions[-2] - positions[-3]
68
- self.acceleration = (self.velocity - prev_velocity) * 0.5
69
-
70
- # Predict next position with damping
71
- predicted_pos = positions[-1] + self.velocity * 0.8 + self.acceleration * 0.2
72
-
73
- # Get average recent size for stable bbox
74
- if self.sizes:
75
- avg_width = np.mean([s[0] for s in self.sizes])
76
- avg_height = np.mean([s[1] for s in self.sizes])
77
- else:
78
- avg_width = self.bbox[2] - self.bbox[0]
79
- avg_height = self.bbox[3] - self.bbox[1]
80
-
81
- # Update bbox with predicted center and smoothed size
82
- self.bbox = [
83
- predicted_pos[0] - avg_width/2,
84
- predicted_pos[1] - avg_height/2,
85
- predicted_pos[0] + avg_width/2,
86
- predicted_pos[1] + avg_height/2
87
- ]
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def update(self, detection: Detection):
90
  """Update track with new detection"""
91
- self.bbox = detection.bbox
92
- self.detections.append(detection)
93
- self.confidence = detection.confidence
94
- self.hits += 1
95
- self.time_since_update = 0
96
-
97
- # Update trajectory
98
- cx = (self.bbox[0] + self.bbox[2]) / 2
99
- cy = (self.bbox[1] + self.bbox[3]) / 2
100
- self.trajectory.append((cx, cy))
101
-
102
- # Update size history
103
- width = self.bbox[2] - self.bbox[0]
104
- height = self.bbox[3] - self.bbox[1]
105
- self.sizes.append((width, height))
106
-
107
- # Store appearance features
108
- if hasattr(detection, 'features'):
109
- self.appearance_features.append(detection.features)
110
- if len(self.appearance_features) > 5:
111
- self.appearance_features = self.appearance_features[-5:]
112
-
113
- # Confirm track after 2 hits
114
- if self.state == 'tentative' and self.hits >= 2:
115
- self.state = 'confirmed'
116
 
117
- # Keep only recent detections to save memory
118
- if len(self.detections) > 5:
119
- self.detections = self.detections[-5:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  def mark_missed(self):
122
  """Mark track as missed in current frame"""
123
- if self.state == 'confirmed' and self.time_since_update > 10:
124
- self.state = 'deleted'
 
 
 
 
 
 
 
125
 
126
 
127
  class RobustTracker:
128
  """
129
- Enhanced tracker with multiple association strategies
130
  """
131
 
132
  def __init__(self,
133
- match_threshold: float = 0.3,
134
  track_buffer: int = 30,
135
- min_iou_for_match: float = 0.2,
136
- use_appearance: bool = True):
137
  """
138
- Initialize tracker with multiple matching strategies
139
 
140
  Args:
141
- match_threshold: IoU threshold for matching
142
  track_buffer: Frames to keep lost tracks
143
  min_iou_for_match: Minimum IoU to consider a match
144
- use_appearance: Whether to use appearance features
145
  """
146
  self.match_threshold = match_threshold
147
  self.track_buffer = track_buffer
@@ -152,217 +197,304 @@ class RobustTracker:
152
  self.track_id_count = 1
153
 
154
  # Enhanced parameters
155
- self.max_center_distance = 200 # pixels
156
- self.min_size_similarity = 0.5 # Size change threshold
 
 
 
157
 
158
  def update(self, detections: List[Detection]) -> List[Track]:
159
  """
160
- Update tracks with multiple association strategies
161
  """
162
- # Predict existing tracks
163
- for track in self.tracks:
164
- track.predict()
 
 
165
 
166
- # Split tracks by state
167
- confirmed_tracks = [t for t in self.tracks if t.state == 'confirmed']
168
- tentative_tracks = [t for t in self.tracks if t.state == 'tentative']
169
-
170
- # First association: High confidence detections with confirmed tracks
171
- high_conf_dets = [i for i, d in enumerate(detections) if d.confidence > 0.7]
172
- low_conf_dets = [i for i, d in enumerate(detections) if d.confidence <= 0.7]
173
-
174
- matched_track_indices = set()
175
- matched_det_indices = set()
176
 
177
- # Match high confidence detections first
178
- if high_conf_dets and confirmed_tracks:
179
- cost_matrix = self._calculate_enhanced_cost_matrix(
180
- confirmed_tracks,
181
- [detections[i] for i in high_conf_dets]
182
- )
183
-
184
- if cost_matrix.size > 0:
185
- row_ind, col_ind = linear_sum_assignment(cost_matrix)
186
 
187
- for r, c in zip(row_ind, col_ind):
188
- if cost_matrix[r, c] < (1 - self.match_threshold):
189
- confirmed_tracks[r].update(detections[high_conf_dets[c]])
190
- matched_track_indices.add(r)
191
- matched_det_indices.add(high_conf_dets[c])
192
-
193
- # Get unmatched items
194
- unmatched_confirmed = [i for i in range(len(confirmed_tracks))
195
- if i not in matched_track_indices]
196
- unmatched_dets = [i for i in range(len(detections))
197
- if i not in matched_det_indices]
198
-
199
- # Second association: Remaining confirmed tracks with remaining detections
200
- if unmatched_dets and unmatched_confirmed:
201
- remaining_tracks = [confirmed_tracks[i] for i in unmatched_confirmed]
202
- remaining_dets = [detections[i] for i in unmatched_dets]
203
 
204
- cost_matrix = self._calculate_enhanced_cost_matrix(
205
- remaining_tracks, remaining_dets
206
- )
207
 
208
- if cost_matrix.size > 0:
209
- row_ind, col_ind = linear_sum_assignment(cost_matrix)
 
 
 
 
 
 
 
 
 
 
210
 
211
- for r, c in zip(row_ind, col_ind):
212
- if cost_matrix[r, c] < (1 - self.match_threshold * 0.8):
213
- track_idx = unmatched_confirmed[r]
214
- det_idx = unmatched_dets[c]
215
- confirmed_tracks[track_idx].update(detections[det_idx])
216
- matched_det_indices.add(det_idx)
217
-
218
- # Remove from unmatched lists
219
- if track_idx in unmatched_confirmed:
220
- unmatched_confirmed.remove(track_idx)
221
-
222
- # Update unmatched detections list after confirmed track matching
223
- unmatched_dets = [i for i in range(len(detections))
224
- if i not in matched_det_indices]
225
-
226
- # Third association: Tentative tracks
227
- if unmatched_dets and tentative_tracks:
228
- remaining_dets = [detections[i] for i in unmatched_dets]
229
 
230
- cost_matrix = self._calculate_enhanced_cost_matrix(
231
- tentative_tracks, remaining_dets
232
- )
 
 
 
 
 
233
 
234
- if cost_matrix.size > 0:
235
- row_ind, col_ind = linear_sum_assignment(cost_matrix)
236
-
237
- for r, c in zip(row_ind, col_ind):
238
- if cost_matrix[r, c] < (1 - self.match_threshold * 0.6):
239
- det_idx = unmatched_dets[c]
240
- tentative_tracks[r].update(detections[det_idx])
241
- matched_det_indices.add(det_idx)
242
-
243
- # Mark unmatched tracks as missed
244
- for idx in unmatched_confirmed:
245
- confirmed_tracks[idx].mark_missed()
246
 
247
- for track in tentative_tracks:
248
- if track.time_since_update > 0:
249
- track.mark_missed()
250
-
251
- # Create new tracks for unmatched detections
252
- final_unmatched_dets = [i for i in range(len(detections))
253
- if i not in matched_det_indices]
 
 
 
 
 
 
 
 
 
254
 
255
- for det_idx in final_unmatched_dets:
256
- detection = detections[det_idx]
 
 
 
 
 
 
 
 
 
 
257
 
258
- # Additional check: ensure not too close to existing tracks
259
- is_new = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  det_center = self._get_center(detection.bbox)
261
 
262
  for track in self.tracks:
263
- if track.state != 'deleted':
264
- track_center = self._get_center(track.bbox)
265
- dist = np.linalg.norm(np.array(det_center) - np.array(track_center))
266
 
267
- if dist < 40: # Very close threshold
268
- is_new = False
269
- break
270
-
271
- if is_new:
272
- new_track = Track(detection, self.track_id_count)
273
- self.track_id_count += 1
274
- self.tracks.append(new_track)
275
-
276
- # Remove deleted tracks
277
- self.tracks = [t for t in self.tracks if t.state != 'deleted']
278
-
279
- # Return only confirmed tracks
280
- return [t for t in self.tracks if t.state == 'confirmed']
281
 
282
  def _calculate_enhanced_cost_matrix(self, tracks: List[Track],
283
  detections: List[Detection]) -> np.ndarray:
284
- """Calculate cost matrix with multiple cues"""
285
- if not tracks or not detections:
286
- return np.array([])
287
-
288
- n_tracks = len(tracks)
289
- n_dets = len(detections)
290
- cost_matrix = np.ones((n_tracks, n_dets))
291
-
292
- for t_idx, track in enumerate(tracks):
293
- track_center = np.array(self._get_center(track.bbox))
294
- track_size = np.array([
295
- track.bbox[2] - track.bbox[0],
296
- track.bbox[3] - track.bbox[1]
297
- ])
298
-
299
- for d_idx, detection in enumerate(detections):
300
- # IoU cost
301
- iou = self._iou(track.bbox, detection.bbox)
302
 
303
- # Center distance cost
304
- det_center = np.array(self._get_center(detection.bbox))
305
- distance = np.linalg.norm(track_center - det_center)
306
-
307
- # Size similarity cost
308
- det_size = np.array([
309
- detection.bbox[2] - detection.bbox[0],
310
- detection.bbox[3] - detection.bbox[1]
 
 
 
 
311
  ])
312
- size_sim = np.minimum(track_size, det_size) / (np.maximum(track_size, det_size) + 1e-6)
313
- size_cost = 1 - np.mean(size_sim)
314
 
315
- # Combine costs
316
- if iou >= self.min_iou_for_match and distance < self.max_center_distance:
317
- iou_cost = 1 - iou
318
- dist_cost = distance / self.max_center_distance
 
 
319
 
320
- # Weighted combination
321
- total_cost = (0.5 * iou_cost +
322
- 0.3 * dist_cost +
323
- 0.2 * size_cost)
324
 
325
- # Appearance cost if available
326
- if self.use_appearance and track.appearance_features and hasattr(detection, 'features'):
327
- # Simple cosine similarity
328
- track_feat = np.mean(track.appearance_features, axis=0)
329
- det_feat = detection.features
330
- app_sim = np.dot(track_feat, det_feat) / (
331
- np.linalg.norm(track_feat) * np.linalg.norm(det_feat) + 1e-6
332
- )
333
- app_cost = 1 - app_sim
334
- total_cost = 0.4 * iou_cost + 0.2 * dist_cost + 0.2 * size_cost + .2 * app_cost
335
 
336
- cost_matrix[t_idx, d_idx] = total_cost
337
- else:
338
- cost_matrix[t_idx, d_idx] = 1.0
339
 
340
- return cost_matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
  def _get_center(self, bbox: List[float]) -> Tuple[float, float]:
343
- return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
 
 
 
 
 
 
344
 
345
  def _iou(self, bbox1: List[float], bbox2: List[float]) -> float:
346
- x1 = max(bbox1[0], bbox2[0])
347
- y1 = max(bbox1[1], bbox2[1])
348
- x2 = min(bbox1[2], bbox2[2])
349
- y2 = min(bbox1[3], bbox2[3])
350
-
351
- if x2 < x1 or y2 < y1:
352
- return 0.0
 
 
353
 
354
- intersection = (x2 - x1) * (y2 - y1)
355
- area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
356
- area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
357
- union = area1 + area2 - intersection
358
-
359
- return intersection / (union + 1e-6)
 
 
 
 
 
 
 
 
360
 
361
  def set_match_threshold(self, threshold: float):
362
  """Update matching threshold"""
363
  self.match_threshold = max(0.1, min(0.8, threshold))
364
  print(f"Tracking threshold updated to: {self.match_threshold:.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
 
367
- # Use the enhanced tracker as default
368
  SimpleTracker = RobustTracker
 
1
  """
2
+ tracking.py - Production-ready tracking with comprehensive error handling
3
+ Includes all bug fixes and defensive programming
4
  """
5
  import numpy as np
6
+ from typing import List, Optional, Tuple, Dict
7
  from scipy.optimize import linear_sum_assignment
8
  from collections import deque
9
  import uuid
10
  from detection import Detection
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
+
14
 
15
  class Track:
16
+ """Enhanced track with robust state management"""
17
 
18
  def __init__(self, detection: Detection, track_id: Optional[int] = None):
19
  """Initialize track from first detection"""
20
  self.track_id = track_id if track_id else self._generate_id()
21
+ self.bbox = detection.bbox.copy() if hasattr(detection, 'bbox') else [0, 0, 100, 100]
22
  self.detections = [detection]
23
+ self.confidence = detection.confidence if hasattr(detection, 'confidence') else 0.5
24
 
25
  # Track state
26
  self.age = 1
27
  self.time_since_update = 0
28
  self.state = 'tentative'
29
  self.hits = 1
30
+ self.consecutive_misses = 0
31
 
32
  # Store center points for trajectory
33
  cx = (self.bbox[0] + self.bbox[2]) / 2
 
46
 
47
  # Size tracking for scale changes
48
  self.sizes = deque(maxlen=10)
49
+ width = max(1, self.bbox[2] - self.bbox[0])
50
+ height = max(1, self.bbox[3] - self.bbox[1])
51
  self.sizes.append((width, height))
52
 
53
+ # Track quality metrics
54
+ self.avg_confidence = self.confidence
55
+ self.max_confidence = self.confidence
56
+
57
  def _generate_id(self) -> int:
58
  """Generate unique track ID"""
59
  return int(uuid.uuid4().int % 100000)
60
 
61
  def predict(self):
62
+ """Enhanced motion prediction with safety checks"""
63
  self.age += 1
64
  self.time_since_update += 1
65
+ self.consecutive_misses += 1
66
 
67
+ try:
68
+ if len(self.trajectory) >= 3:
69
+ # Calculate velocity and acceleration from recent positions
70
+ positions = np.array(list(self.trajectory))[-3:]
71
+
72
+ # Velocity from last two positions
73
+ self.velocity = positions[-1] - positions[-2]
74
+
75
+ # Limit velocity to reasonable values
76
+ max_velocity = 50 # pixels per frame
77
+ velocity_magnitude = np.linalg.norm(self.velocity)
78
+ if velocity_magnitude > max_velocity:
79
+ self.velocity = self.velocity / velocity_magnitude * max_velocity
80
+
81
+ # Acceleration from velocity change
82
+ if len(positions) == 3:
83
+ prev_velocity = positions[-2] - positions[-3]
84
+ self.acceleration = (self.velocity - prev_velocity) * 0.3
85
+
86
+ # Predict next position with damping
87
+ predicted_pos = positions[-1] + self.velocity * 0.7 + self.acceleration * 0.1
88
+
89
+ # Get average recent size for stable bbox
90
+ if self.sizes:
91
+ avg_width = np.mean([s[0] for s in self.sizes])
92
+ avg_height = np.mean([s[1] for s in self.sizes])
93
+ else:
94
+ avg_width = max(10, self.bbox[2] - self.bbox[0])
95
+ avg_height = max(10, self.bbox[3] - self.bbox[1])
96
+
97
+ # Update bbox with predicted center and smoothed size
98
+ self.bbox = [
99
+ predicted_pos[0] - avg_width/2,
100
+ predicted_pos[1] - avg_height/2,
101
+ predicted_pos[0] + avg_width/2,
102
+ predicted_pos[1] + avg_height/2
103
+ ]
104
+ except Exception as e:
105
+ # Fallback: Keep current bbox
106
+ print(f"Track prediction error: {e}")
107
+ pass
108
 
109
  def update(self, detection: Detection):
110
  """Update track with new detection"""
111
+ try:
112
+ # Update bbox
113
+ if hasattr(detection, 'bbox'):
114
+ self.bbox = detection.bbox.copy()
115
+
116
+ self.detections.append(detection)
117
+
118
+ # Update confidence
119
+ if hasattr(detection, 'confidence'):
120
+ self.confidence = detection.confidence
121
+ self.avg_confidence = (self.avg_confidence * 0.9 + self.confidence * 0.1)
122
+ self.max_confidence = max(self.max_confidence, self.confidence)
123
+
124
+ self.hits += 1
125
+ self.time_since_update = 0
126
+ self.consecutive_misses = 0
127
+
128
+ # Update trajectory
129
+ cx = (self.bbox[0] + self.bbox[2]) / 2
130
+ cy = (self.bbox[1] + self.bbox[3]) / 2
131
+ self.trajectory.append((cx, cy))
 
 
 
 
132
 
133
+ # Update size history
134
+ width = max(1, self.bbox[2] - self.bbox[0])
135
+ height = max(1, self.bbox[3] - self.bbox[1])
136
+ self.sizes.append((width, height))
137
+
138
+ # Store appearance features if available
139
+ if hasattr(detection, 'features'):
140
+ self.appearance_features.append(detection.features)
141
+ if len(self.appearance_features) > 5:
142
+ self.appearance_features = self.appearance_features[-5:]
143
+
144
+ # Confirm track after 2 hits
145
+ if self.state == 'tentative' and self.hits >= 2:
146
+ self.state = 'confirmed'
147
+
148
+ # Keep only recent detections to save memory
149
+ if len(self.detections) > 5:
150
+ # Clear old detection images to save memory
151
+ for old_det in self.detections[:-5]:
152
+ if hasattr(old_det, 'image_crop'):
153
+ old_det.image_crop = None
154
+ self.detections = self.detections[-5:]
155
+
156
+ except Exception as e:
157
+ print(f"Track update error: {e}")
158
 
159
  def mark_missed(self):
160
  """Mark track as missed in current frame"""
161
+ if self.state == 'confirmed':
162
+ # More lenient deletion criteria
163
+ if self.consecutive_misses > 15:
164
+ self.state = 'deleted'
165
+ elif self.time_since_update > 30:
166
+ self.state = 'deleted'
167
+ elif self.state == 'tentative':
168
+ if self.consecutive_misses > 3:
169
+ self.state = 'deleted'
170
 
171
 
172
  class RobustTracker:
173
  """
174
+ Production-ready tracker with comprehensive error handling
175
  """
176
 
177
  def __init__(self,
178
+ match_threshold: float = 0.35,
179
  track_buffer: int = 30,
180
+ min_iou_for_match: float = 0.15,
181
+ use_appearance: bool = False):
182
  """
183
+ Initialize tracker with safe defaults
184
 
185
  Args:
186
+ match_threshold: IoU threshold for matching (0.35 is balanced)
187
  track_buffer: Frames to keep lost tracks
188
  min_iou_for_match: Minimum IoU to consider a match
189
+ use_appearance: Whether to use appearance features (set False for speed)
190
  """
191
  self.match_threshold = match_threshold
192
  self.track_buffer = track_buffer
 
197
  self.track_id_count = 1
198
 
199
  # Enhanced parameters
200
+ self.max_center_distance = 150 # pixels (reduced for stricter matching)
201
+ self.min_size_similarity = 0.4 # Size change threshold
202
+
203
+ # Debug mode
204
+ self.debug = False
205
 
206
  def update(self, detections: List[Detection]) -> List[Track]:
207
  """
208
+ Update tracks with robust error handling
209
  """
210
+ if not detections:
211
+ # No detections - just predict existing tracks
212
+ for track in self.tracks:
213
+ track.predict()
214
+ track.mark_missed()
215
 
216
+ # Remove deleted tracks
217
+ self.tracks = [t for t in self.tracks if t.state != 'deleted']
218
+ return [t for t in self.tracks if t.state == 'confirmed']
 
 
 
 
 
 
 
219
 
220
+ try:
221
+ # Predict existing tracks
222
+ for track in self.tracks:
223
+ track.predict()
 
 
 
 
 
224
 
225
+ # Split tracks by state
226
+ confirmed_tracks = [t for t in self.tracks if t.state == 'confirmed']
227
+ tentative_tracks = [t for t in self.tracks if t.state == 'tentative']
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ # Initialize matched indices
230
+ matched_track_indices = set()
231
+ matched_det_indices = set()
232
 
233
+ # Stage 1: Match confirmed tracks with all detections
234
+ if confirmed_tracks and detections:
235
+ matched_track_indices, matched_det_indices = self._associate_tracks(
236
+ confirmed_tracks, detections,
237
+ matched_track_indices, matched_det_indices,
238
+ threshold_mult=1.0
239
+ )
240
+
241
+ # Stage 2: Match tentative tracks with unmatched detections
242
+ if tentative_tracks:
243
+ unmatched_dets = [detections[i] for i in range(len(detections))
244
+ if i not in matched_det_indices]
245
 
246
+ if unmatched_dets:
247
+ # Create temporary mapping
248
+ temp_det_mapping = [i for i in range(len(detections))
249
+ if i not in matched_det_indices]
250
+
251
+ tent_matched_tracks, tent_matched_dets = self._associate_tracks(
252
+ tentative_tracks, unmatched_dets,
253
+ set(), set(),
254
+ threshold_mult=0.7
255
+ )
256
+
257
+ # Map back to original detection indices
258
+ for det_idx in tent_matched_dets:
259
+ matched_det_indices.add(temp_det_mapping[det_idx])
 
 
 
 
260
 
261
+ # Mark unmatched tracks as missed
262
+ for i, track in enumerate(confirmed_tracks):
263
+ if i not in matched_track_indices:
264
+ track.mark_missed()
265
+
266
+ for track in tentative_tracks:
267
+ if track.time_since_update > 0:
268
+ track.mark_missed()
269
 
270
+ # Create new tracks for unmatched detections
271
+ for det_idx in range(len(detections)):
272
+ if det_idx not in matched_det_indices:
273
+ detection = detections[det_idx]
274
+
275
+ # Check if detection is too close to existing tracks
276
+ if self._is_new_track(detection):
277
+ new_track = Track(detection, self.track_id_count)
278
+ self.track_id_count += 1
279
+ self.tracks.append(new_track)
 
 
280
 
281
+ # Remove deleted tracks
282
+ self.tracks = [t for t in self.tracks if t.state != 'deleted']
283
+
284
+ # Return only confirmed tracks
285
+ return [t for t in self.tracks if t.state == 'confirmed']
286
+
287
+ except Exception as e:
288
+ print(f"Tracker update error: {e}")
289
+ # Return existing confirmed tracks as fallback
290
+ return [t for t in self.tracks if t.state == 'confirmed']
291
+
292
+ def _associate_tracks(self, tracks: List[Track], detections: List[Detection],
293
+ existing_matched_tracks: set, existing_matched_dets: set,
294
+ threshold_mult: float = 1.0) -> Tuple[set, set]:
295
+ """
296
+ Safe track-detection association
297
 
298
+ Returns:
299
+ (matched_track_indices, matched_det_indices)
300
+ """
301
+ if not tracks or not detections:
302
+ return existing_matched_tracks, existing_matched_dets
303
+
304
+ try:
305
+ # Calculate cost matrix
306
+ cost_matrix = self._calculate_enhanced_cost_matrix(tracks, detections)
307
+
308
+ if cost_matrix.size == 0:
309
+ return existing_matched_tracks, existing_matched_dets
310
 
311
+ # Hungarian matching
312
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
313
+
314
+ matched_tracks = existing_matched_tracks.copy()
315
+ matched_dets = existing_matched_dets.copy()
316
+
317
+ # Process matches
318
+ for r, c in zip(row_ind, col_ind):
319
+ # Check bounds
320
+ if r >= len(tracks) or c >= len(detections):
321
+ continue
322
+
323
+ # Check cost threshold
324
+ threshold = (1 - self.match_threshold * threshold_mult)
325
+ if cost_matrix[r, c] < threshold:
326
+ tracks[r].update(detections[c])
327
+ matched_tracks.add(r)
328
+ matched_dets.add(c)
329
+
330
+ return matched_tracks, matched_dets
331
+
332
+ except Exception as e:
333
+ print(f"Association error: {e}")
334
+ return existing_matched_tracks, existing_matched_dets
335
+
336
+ def _is_new_track(self, detection: Detection) -> bool:
337
+ """Check if detection represents a new track"""
338
+ try:
339
  det_center = self._get_center(detection.bbox)
340
 
341
  for track in self.tracks:
342
+ if track.state == 'deleted':
343
+ continue
 
344
 
345
+ track_center = self._get_center(track.bbox)
346
+ dist = np.linalg.norm(np.array(det_center) - np.array(track_center))
347
+
348
+ # Very close to existing track - likely same object
349
+ if dist < 30:
350
+ return False
351
+
352
+ return True
353
+
354
+ except Exception as e:
355
+ print(f"New track check error: {e}")
356
+ return True # Default to creating new track
 
 
357
 
358
  def _calculate_enhanced_cost_matrix(self, tracks: List[Track],
359
  detections: List[Detection]) -> np.ndarray:
360
+ """Calculate cost matrix with error handling"""
361
+ try:
362
+ if not tracks or not detections:
363
+ return np.array([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
+ n_tracks = len(tracks)
366
+ n_dets = len(detections)
367
+ cost_matrix = np.ones((n_tracks, n_dets))
368
+
369
+ for t_idx, track in enumerate(tracks):
370
+ if not hasattr(track, 'bbox') or len(track.bbox) != 4:
371
+ continue
372
+
373
+ track_center = np.array(self._get_center(track.bbox))
374
+ track_size = np.array([
375
+ max(1, track.bbox[2] - track.bbox[0]),
376
+ max(1, track.bbox[3] - track.bbox[1])
377
  ])
 
 
378
 
379
+ for d_idx, detection in enumerate(detections):
380
+ if not hasattr(detection, 'bbox') or len(detection.bbox) != 4:
381
+ continue
382
+
383
+ # IoU cost
384
+ iou = self._iou(track.bbox, detection.bbox)
385
 
386
+ # Center distance cost
387
+ det_center = np.array(self._get_center(detection.bbox))
388
+ distance = np.linalg.norm(track_center - det_center)
 
389
 
390
+ # Size similarity cost
391
+ det_size = np.array([
392
+ max(1, detection.bbox[2] - detection.bbox[0]),
393
+ max(1, detection.bbox[3] - detection.bbox[1])
394
+ ])
 
 
 
 
 
395
 
396
+ # Prevent division by zero
397
+ size_ratio = np.minimum(track_size, det_size) / (np.maximum(track_size, det_size) + 1e-6)
398
+ size_cost = 1 - np.mean(size_ratio)
399
 
400
+ # Check basic constraints
401
+ if iou >= self.min_iou_for_match and distance < self.max_center_distance:
402
+ iou_cost = 1 - iou
403
+ dist_cost = distance / self.max_center_distance
404
+
405
+ # Weighted combination (IoU is most important)
406
+ total_cost = (0.6 * iou_cost +
407
+ 0.25 * dist_cost +
408
+ 0.15 * size_cost)
409
+
410
+ # Add appearance cost if available and enabled
411
+ if (self.use_appearance and
412
+ hasattr(track, 'appearance_features') and
413
+ track.appearance_features and
414
+ hasattr(detection, 'features')):
415
+ try:
416
+ track_feat = np.mean(track.appearance_features, axis=0)
417
+ det_feat = detection.features
418
+
419
+ # Cosine similarity
420
+ feat_norm = np.linalg.norm(track_feat) * np.linalg.norm(det_feat)
421
+ if feat_norm > 0:
422
+ app_sim = np.dot(track_feat, det_feat) / feat_norm
423
+ app_cost = 1 - max(0, min(1, app_sim))
424
+ total_cost = (0.5 * iou_cost + 0.2 * dist_cost +
425
+ 0.15 * size_cost + 0.15 * app_cost)
426
+ except:
427
+ pass # Use cost without appearance
428
+
429
+ cost_matrix[t_idx, d_idx] = total_cost
430
+ else:
431
+ cost_matrix[t_idx, d_idx] = 1.0
432
+
433
+ return cost_matrix
434
+
435
+ except Exception as e:
436
+ print(f"Cost matrix calculation error: {e}")
437
+ # Return high cost matrix as fallback
438
+ return np.ones((len(tracks), len(detections)))
439
 
440
  def _get_center(self, bbox: List[float]) -> Tuple[float, float]:
441
+ """Get bbox center with validation"""
442
+ try:
443
+ if len(bbox) >= 4:
444
+ return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
445
+ return (0, 0)
446
+ except:
447
+ return (0, 0)
448
 
449
  def _iou(self, bbox1: List[float], bbox2: List[float]) -> float:
450
+ """Calculate IoU with validation"""
451
+ try:
452
+ if len(bbox1) < 4 or len(bbox2) < 4:
453
+ return 0.0
454
+
455
+ x1 = max(bbox1[0], bbox2[0])
456
+ y1 = max(bbox1[1], bbox2[1])
457
+ x2 = min(bbox1[2], bbox2[2])
458
+ y2 = min(bbox1[3], bbox2[3])
459
 
460
+ if x2 < x1 or y2 < y1:
461
+ return 0.0
462
+
463
+ intersection = (x2 - x1) * (y2 - y1)
464
+
465
+ area1 = max(1, (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]))
466
+ area2 = max(1, (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]))
467
+ union = area1 + area2 - intersection
468
+
469
+ return max(0, min(1, intersection / (union + 1e-6)))
470
+
471
+ except Exception as e:
472
+ print(f"IoU calculation error: {e}")
473
+ return 0.0
474
 
475
  def set_match_threshold(self, threshold: float):
476
  """Update matching threshold"""
477
  self.match_threshold = max(0.1, min(0.8, threshold))
478
  print(f"Tracking threshold updated to: {self.match_threshold:.2f}")
479
+
480
+ def reset(self):
481
+ """Reset tracker state"""
482
+ self.tracks.clear()
483
+ self.track_id_count = 1
484
+ print("Tracker reset")
485
+
486
+ def get_statistics(self) -> Dict:
487
+ """Get tracker statistics"""
488
+ confirmed = len([t for t in self.tracks if t.state == 'confirmed'])
489
+ tentative = len([t for t in self.tracks if t.state == 'tentative'])
490
+
491
+ return {
492
+ 'total_tracks': len(self.tracks),
493
+ 'confirmed_tracks': confirmed,
494
+ 'tentative_tracks': tentative,
495
+ 'next_id': self.track_id_count
496
+ }
497
 
498
 
499
+ # Compatibility alias
500
  SimpleTracker = RobustTracker