AI_Safety_Demo5 / tracker.py
neerajkalyank's picture
Update tracker.py
6251904 verified
import numpy as np
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment
import logging
logger = logging.getLogger(__name__)
class BYTETracker:
def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
self.track_thresh = track_thresh
self.track_buffer = track_buffer
self.match_thresh = match_thresh
self.frame_rate = frame_rate
self.next_id = 1
self.tracks = {}
self.kalman_filters = {}
self.worker_history = {}
self.last_positions = {}
self.recently_removed = {}
logger.info("BYTETracker initialized with track_thresh=%.2f, track_buffer=%d, match_thresh=%.2f",
track_thresh, track_buffer, match_thresh)
def _init_kalman(self):
kf = KalmanFilter(dim_x=4, dim_z=2)
kf.F = np.array([[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]])
kf.H = np.array([[1, 0, 0, 0], [0, 1, 0, 0]])
kf.P *= 1000.
kf.R = np.array([[10, 0], [0, 10]])
kf.Q = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
return kf
def update(self, dets, scores, cls, current_time):
tracks = []
# Prune stale tracks
for track_id in list(self.tracks.keys()):
if current_time - self.tracks[track_id]['last_seen'] > self.track_buffer / self.frame_rate:
self.recently_removed[track_id] = {
'bbox': self.tracks[track_id]['bbox'],
'last_seen': current_time,
'last_position': self.last_positions.get(track_id, [0, 0])
}
del self.tracks[track_id]
del self.kalman_filters[track_id]
if track_id in self.worker_history:
del self.worker_history[track_id]
if track_id in self.last_positions:
del self.last_positions[track_id]
# Predict with Kalman filter
for track_id, kf in self.kalman_filters.items():
kf.predict()
self.tracks[track_id]['bbox'][:2] = kf.x[:2].flatten()
# Match detections to tracks
if len(dets) == 0:
return tracks
cost_matrix = np.zeros((len(dets), len(self.tracks)))
track_ids = list(self.tracks.keys())
for i, det in enumerate(dets):
for j, track_id in enumerate(track_ids):
cost_matrix[i, j] = 1 - self._calculate_iou(det, self.tracks[track_id]['bbox'])
row_indices, col_indices = linear_sum_assignment(cost_matrix)
matched = set()
for i, j in zip(row_indices, col_indices):
if cost_matrix[i, j] < 1 - self.match_thresh:
track_id = track_ids[j]
self.tracks[track_id].update({
'bbox': dets[i],
'score': scores[i],
'cls': cls[i],
'last_seen': current_time
})
self.kalman_filters[track_id].update(dets[i][:2])
self.worker_history[track_id].append(dets[i][:2])
self.last_positions[track_id] = dets[i][:2]
tracks.append({'id': track_id, 'bbox': dets[i], 'score': scores[i], 'cls': cls[i]})
matched.add(i)
# Handle unmatched detections
for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
if i in matched or score < self.track_thresh:
continue
reidentified = False
for track_id, info in list(self.recently_removed.items()):
if self._is_same_worker(det[:2], info['last_position']):
self.tracks[track_id] = {
'bbox': det,
'score': score,
'cls': cl,
'last_seen': current_time
}
self.kalman_filters[track_id] = self._init_kalman()
self.kalman_filters[track_id].x[:2] = det[:2].reshape(2, 1)
self.worker_history[track_id] = [det[:2]]
self.last_positions[track_id] = det[:2]
tracks.append({'id': track_id, 'bbox': det, 'score': score, 'cls': cl})
del self.recently_removed[track_id]
reidentified = True
break
if not reidentified:
self.tracks[self.next_id] = {
'bbox': det,
'score': score,
'cls': cl,
'last_seen': current_time
}
self.kalman_filters[self.next_id] = self._init_kalman()
self.kalman_filters[self.next_id].x[:2] = det[:2].reshape(2, 1)
self.worker_history[self.next_id] = [det[:2]]
self.last_positions[self.next_id] = det[:2]
tracks.append({'id': self.next_id, 'bbox': det, 'score': score, 'cls': cl})
self.next_id += 1
return tracks
def _calculate_iou(self, box1, box2):
x1, y1, w1, h1 = box1
x2, y2, w2, h2 = box2
x_left = max(x1 - w1/2, x2 - w2/2)
y_top = max(y1 - h1/2, y2 - h2/2)
x_right = min(x1 + w1/2, x2 + w2/2)
y_bottom = min(y1 + h1/2, y2 + h2/2)
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection = (x_right - x_left) * (y_bottom - y_top)
union = w1 * h1 + w2 * h2 - intersection
return intersection / union
def _is_same_worker(self, pos1, pos2, threshold=150):
return np.sqrt(np.sum((np.array(pos1) - np.array(pos2))**2)) < threshold