Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from scipy.optimize import linear_sum_assignment | |
| import scipy.linalg | |
| class KalmanFilter: | |
| """ | |
| A simple Kalman Filter for tracking bounding boxes in image space. | |
| The 8-dimensional state space is (x, y, a, h, vx, vy, va, vh), where | |
| x, y is the center position, a is the aspect ratio, and h is the height. | |
| """ | |
| def __init__(self): | |
| ndim, dt = 4, 1.0 | |
| # Create Kalman filter model matrices. | |
| self._motion_mat = np.eye(2 * ndim, 2 * ndim) | |
| for i in range(ndim): | |
| self._motion_mat[i, ndim + i] = dt | |
| self._update_mat = np.eye(ndim, 2 * ndim) | |
| # Motion and observation uncertainty are chosen relative to the current | |
| # state estimate. These weights control the amount of uncertainty in | |
| # the model. This is a bit heuristic. | |
| self._std_weight_position = 1.0 / 20 | |
| self._std_weight_velocity = 1.0 / 160 | |
| def initiate(self, measurement): | |
| """Create track from unassociated measurement. | |
| Parameters | |
| ---------- | |
| measurement : dbo | |
| Bounding box coordinates (x1, y1, x2, y2) with confidence score. | |
| Returns | |
| ------- | |
| (mean, covariance) | |
| Returns the mean vector (8 dimensional) and covariance matrix (8x8) | |
| of the new track. | |
| """ | |
| mean_pos = self._xyah_from_xyxy(measurement) | |
| mean = np.r_[mean_pos, np.zeros_like(mean_pos)] | |
| std = [ | |
| 2 * self._std_weight_position * mean_pos[3], | |
| 2 * self._std_weight_position * mean_pos[3], | |
| 1e-2, | |
| 2 * self._std_weight_position * mean_pos[3], | |
| 10 * self._std_weight_velocity * mean_pos[3], | |
| 10 * self._std_weight_velocity * mean_pos[3], | |
| 1e-5, | |
| 10 * self._std_weight_velocity * mean_pos[3], | |
| ] | |
| covariance = np.diag(np.square(std)) | |
| return mean, covariance | |
| def predict(self, mean, covariance): | |
| """Run Kalman filter prediction step. | |
| Parameters | |
| ---------- | |
| mean : ndarray | |
| The 8 dimensional mean vector of the object state at the previous | |
| time step. | |
| covariance : ndarray | |
| The 8x8 dimensional covariance matrix of the object state at the | |
| previous time step. | |
| Returns | |
| ------- | |
| (mean, covariance) | |
| Returns the mean vector and covariance matrix of the predicted | |
| state. | |
| """ | |
| std_pos = [ | |
| self._std_weight_position * mean[3], | |
| self._std_weight_position * mean[3], | |
| 1e-2, | |
| self._std_weight_position * mean[3], | |
| ] | |
| std_vel = [ | |
| self._std_weight_velocity * mean[3], | |
| self._std_weight_velocity * mean[3], | |
| 1e-5, | |
| self._std_weight_velocity * mean[3], | |
| ] | |
| motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) | |
| mean = np.dot(self._motion_mat, mean) | |
| covariance = ( | |
| np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) | |
| + motion_cov | |
| ) | |
| return mean, covariance | |
| def project(self, mean, covariance): | |
| """Project state distribution to measurement space. | |
| Parameters | |
| ---------- | |
| mean : ndarray | |
| The state's mean vector (8 dimensional). | |
| covariance : ndarray | |
| The state's covariance matrix (8x8 dimensional). | |
| Returns | |
| ------- | |
| (mean, covariance) | |
| Returns the projected mean and covariance matrix of the given state | |
| estimate. | |
| """ | |
| std = [ | |
| self._std_weight_position * mean[3], | |
| self._std_weight_position * mean[3], | |
| 1e-1, | |
| self._std_weight_position * mean[3], | |
| ] | |
| innovation_cov = np.diag(np.square(std)) | |
| mean = np.dot(self._update_mat, mean) | |
| covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) | |
| return mean, covariance + innovation_cov | |
| def update(self, mean, covariance, measurement): | |
| """Run Kalman filter correction step. | |
| Parameters | |
| ---------- | |
| mean : ndarray | |
| The predicted state's mean vector (8 dimensional). | |
| covariance : ndarray | |
| The state's covariance matrix (8x8 dimensional). | |
| measurement : ndarray | |
| The 4 dimensional measurement vector (x, y, a, h), where (x, y) | |
| is the center position, a the aspect ratio, and h the height. | |
| Returns | |
| ------- | |
| (mean, covariance) | |
| Returns the measurement-corrected state distribution. | |
| """ | |
| projected_mean, projected_cov = self.project(mean, covariance) | |
| chol_factor, lower = scipy.linalg.cho_factor( | |
| projected_cov, lower=True, check_finite=False | |
| ) | |
| kalman_gain = scipy.linalg.cho_solve( | |
| (chol_factor, lower), | |
| np.dot(covariance, self._update_mat.T).T, | |
| check_finite=False, | |
| ).T | |
| innovation = measurement - projected_mean | |
| new_mean = mean + np.dot(innovation, kalman_gain.T) | |
| new_covariance = covariance - np.linalg.multi_dot( | |
| (kalman_gain, projected_cov, kalman_gain.T) | |
| ) | |
| return new_mean, new_covariance | |
| def gating_distance(self, mean, covariance, measurements, only_position=False, metric="mahalanobis"): | |
| """Compute gating distance between state distribution and measurements.""" | |
| mean, covariance = self.project(mean, covariance) | |
| if only_position: | |
| mean, covariance = mean[:2], covariance[:2, :2] | |
| measurements = measurements[:, :2] | |
| d = measurements - mean | |
| if metric == "gaussian": | |
| return np.sum(d * d, axis=1) | |
| elif metric == "mahalanobis": | |
| cholesky_factor = np.linalg.cholesky(covariance) | |
| z = scipy.linalg.solve_triangular( | |
| cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True | |
| ) | |
| squared_maha = np.sum(z * z, axis=0) | |
| return squared_maha | |
| else: | |
| raise ValueError("invalid distance metric") | |
| def _xyah_from_xyxy(self, xyxy): | |
| """Convert bounding box to format `(center x, center y, aspect ratio, | |
| height)`, where the aspect ratio is `width / height`. | |
| """ | |
| bbox = np.asarray(xyxy).copy() | |
| cx = (bbox[0] + bbox[2]) / 2.0 | |
| cy = (bbox[1] + bbox[3]) / 2.0 | |
| w = bbox[2] - bbox[0] | |
| h = bbox[3] - bbox[1] | |
| ret = np.zeros(4, dtype=bbox.dtype) | |
| ret[0] = cx | |
| ret[1] = cy | |
| ret[2] = w / h | |
| ret[3] = h | |
| return ret | |
| class STrack: | |
| """ | |
| Single object track. Wrapper around KalmanFilter state. | |
| """ | |
| def __init__(self, tlwh, score, label): | |
| # wait, input is xyxy usually in our pipeline | |
| # ByteTrack usually uses tlwh internally. | |
| # Let's standardize to input xyxy. | |
| self._tlwh = np.asarray(self._tlwh_from_xyxy(tlwh), dtype=np.float32) | |
| self.is_activated = False | |
| self.track_id = 0 | |
| self.state = 1 # 1: New, 2: Tracked, 3: Lost, 4: Removed | |
| self.score = score | |
| self.label = label | |
| self.start_frame = 0 | |
| self.frame_id = 0 | |
| self.time_since_update = 0 | |
| # Multi-frame history | |
| self.history = [] | |
| # Kalman Filter | |
| self.kalman_filter = None | |
| self.mean = None | |
| self.covariance = None | |
| # GPT attributes (persistent) | |
| self.gpt_data = {} | |
| def _tlwh_from_xyxy(self, xyxy): | |
| """Convert xyxy to tlwh.""" | |
| w = xyxy[2] - xyxy[0] | |
| h = xyxy[3] - xyxy[1] | |
| return [xyxy[0], xyxy[1], w, h] | |
| def _xyxy_from_tlwh(self, tlwh): | |
| """Convert tlwh to xyxy.""" | |
| x1 = tlwh[0] | |
| y1 = tlwh[1] | |
| x2 = x1 + tlwh[2] | |
| y2 = y1 + tlwh[3] | |
| return [x1, y1, x2, y2] | |
| def tlwh(self): | |
| """Get current position in bounding box format `(top left x, top left y, | |
| width, height)`. | |
| """ | |
| if self.mean is None: | |
| return self._tlwh.copy() | |
| ret = self.mean[:4].copy() | |
| ret[2] *= ret[3] | |
| ret[:2] -= ret[2:] / 2 | |
| return ret | |
| def tlbr(self): | |
| """Get current position in bounding box format `(min x, min y, max x, | |
| max y)`. | |
| """ | |
| ret = self.tlwh.copy() | |
| ret[2:] += ret[:2] | |
| return ret | |
| def activate(self, kalman_filter, frame_id): | |
| """Start a new track.""" | |
| self.kalman_filter = kalman_filter | |
| self.track_id = self.next_id() | |
| self.mean, self.covariance = self.kalman_filter.initiate(self.tlbr) # Initiate needs xyxy | |
| self.state = 2 # Tracked | |
| self.frame_id = frame_id | |
| self.start_frame = frame_id | |
| self.is_activated = True | |
| def re_activate(self, new_track, frame_id, new_id=False): | |
| """Reactivate a lost track with a new detection.""" | |
| self.mean, self.covariance = self.kalman_filter.update( | |
| self.mean, self.covariance, self._xyah_from_xyxy(new_track.tlbr) | |
| ) | |
| self.time_since_update = 0 | |
| self.state = 2 # Tracked | |
| self.frame_id = frame_id | |
| self.score = new_track.score | |
| if new_id: | |
| self.track_id = self.next_id() | |
| def update(self, new_track, frame_id): | |
| """Update a tracked object with a new detection.""" | |
| self.frame_id = frame_id | |
| self.time_since_update = 0 | |
| self.score = new_track.score | |
| self.mean, self.covariance = self.kalman_filter.update( | |
| self.mean, self.covariance, self._xyah_from_xyxy(new_track.tlbr) | |
| ) | |
| self.state = 2 # Tracked | |
| self.is_activated = True | |
| def predict(self): | |
| """Propagate tracking state distribution one time step forward.""" | |
| if self.mean is None: return | |
| if self.state != 2: # Only predict if tracked? ByteTrack predicts always? | |
| # Standard implementation predicts for all active/lost tracks | |
| pass | |
| self.mean, self.covariance = self.kalman_filter.predict(self.mean, self.covariance) | |
| def _xyah_from_xyxy(self, xyxy): | |
| """Internal helper for measurement conversion.""" | |
| bbox = np.asarray(xyxy).copy() | |
| cx = (bbox[0] + bbox[2]) / 2.0 | |
| cy = (bbox[1] + bbox[3]) / 2.0 | |
| w = bbox[2] - bbox[0] | |
| h = bbox[3] - bbox[1] | |
| ret = np.zeros(4, dtype=bbox.dtype) | |
| ret[0] = cx | |
| ret[1] = cy | |
| ret[2] = w / h | |
| ret[3] = h | |
| return ret | |
| def next_id(): | |
| # Global counter | |
| if not hasattr(STrack, "_count"): | |
| STrack._count = 0 | |
| STrack._count += 1 | |
| return STrack._count | |
| class ByteTracker: | |
| def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30): | |
| self.track_thresh = track_thresh | |
| self.track_buffer = track_buffer | |
| self.match_thresh = match_thresh | |
| self.frame_id = 0 | |
| self.tracked_stracks = [] # Type: List[STrack] | |
| self.lost_stracks = [] # Type: List[STrack] | |
| self.removed_stracks = [] # Type: List[STrack] | |
| self.kalman_filter = KalmanFilter() | |
| def update(self, detections_list): | |
| """ | |
| Update the tracker with a list of detections. | |
| Args: | |
| detections_list: List of dicts, each having: | |
| - bbox: [x1, y1, x2, y2] | |
| - score: float | |
| - label: str | |
| - (optional) other keys preserved | |
| Returns: | |
| List of dicts with 'track_id' added/updated. | |
| """ | |
| self.frame_id += 1 | |
| # 0. STrack Conversion using generic interface | |
| activated_stracks = [] | |
| refind_stracks = [] | |
| lost_stracks = [] | |
| removed_stracks = [] | |
| scores = [d['score'] for d in detections_list] | |
| bboxes = [d['bbox'] for d in detections_list] | |
| # Split into high and low confidence | |
| detections = [] | |
| detections_second = [] | |
| # Need to keep mapping to original dict to populate results later | |
| # We wrap original dict in STrack | |
| for d in detections_list: | |
| bbox = d['bbox'] | |
| score = d['score'] | |
| label = d['label'] | |
| t = STrack(bbox, score, label) | |
| t.original_data = d # Link back | |
| if score >= self.track_thresh: | |
| detections.append(t) | |
| else: | |
| detections_second.append(t) | |
| # 1. Prediction | |
| unconfirmed = [] | |
| tracked_stracks = [] # Type: List[STrack] | |
| for track in self.tracked_stracks: | |
| if not track.is_activated: | |
| unconfirmed.append(track) | |
| else: | |
| tracked_stracks.append(track) | |
| strack_pool = join_stracks(tracked_stracks, self.lost_stracks) | |
| # Predict the current location with KF | |
| STrack.multi_predict(strack_pool, self.kalman_filter) | |
| # 2. First association (High score) | |
| dists = iou_distance(strack_pool, detections) | |
| dists = fuse_score(dists, detections) # Optional? ByteTrack uses it | |
| matches, u_track, u_detection = linear_assignment(dists, thresh=self.match_thresh) | |
| for itracked, idet in matches: | |
| track = strack_pool[itracked] | |
| det = detections[idet] | |
| if track.state == 2: | |
| track.update(det, self.frame_id) | |
| activated_stracks.append(track) | |
| else: | |
| track.re_activate(det, self.frame_id, new_id=False) | |
| refind_stracks.append(track) | |
| # Persist data | |
| self._sync_data(track, det) | |
| # 3. Second association (Low score) | |
| # Match unmatched tracks to low score detections | |
| r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == 2] | |
| dists = iou_distance(r_tracked_stracks, detections_second) | |
| matches, u_track, u_detection_second = linear_assignment(dists, thresh=0.5) | |
| for itracked, idet in matches: | |
| track = r_tracked_stracks[itracked] | |
| det = detections_second[idet] | |
| if track.state == 2: | |
| track.update(det, self.frame_id) | |
| activated_stracks.append(track) | |
| else: | |
| track.re_activate(det, self.frame_id, new_id=False) | |
| refind_stracks.append(track) | |
| self._sync_data(track, det) | |
| for it in u_track: | |
| track = r_tracked_stracks[it] | |
| if not track.state == 3: # If not already lost | |
| track.state = 3 # Lost | |
| lost_stracks.append(track) | |
| # 4. Init new tracks from unmatched high score detections | |
| # Note: Unmatched low score detections are ignored (noise) | |
| unmatched_dets = [detections[i] for i in u_detection] | |
| for track in unmatched_dets: | |
| if track.score < self.track_thresh: | |
| continue | |
| track.activate(self.kalman_filter, self.frame_id) | |
| activated_stracks.append(track) | |
| self._sync_data(track, track) # Sync self | |
| # 5. Update state | |
| self.tracked_stracks = [t for t in self.tracked_stracks if t.state == 2] | |
| self.tracked_stracks = join_stracks(self.tracked_stracks, activated_stracks) | |
| self.tracked_stracks = join_stracks(self.tracked_stracks, refind_stracks) | |
| self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) | |
| self.lost_stracks.extend(lost_stracks) | |
| self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) | |
| self.removed_stracks.extend(removed_stracks) | |
| self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) | |
| # 6. Age out lost tracks | |
| for track in self.lost_stracks: | |
| if self.frame_id - track.frame_id > self.track_buffer: | |
| self.removed_stracks.append(track) | |
| self.lost_stracks = [t for t in self.lost_stracks if self.frame_id - t.frame_id <= self.track_buffer] | |
| # 7. Final Output Construction | |
| # We need to update the original dictionaries in detections_list IN PLACE, | |
| # or return a new list. The logic in inference.py expects us to modify detections dicts | |
| # or we might want to return the tracked ones. | |
| # But wait, we iterate `detections_list` at start. | |
| # We want to return ONLY the currently tracked/active objects? | |
| # Usually inference pipeline draws ALL detections, but standard tracking ONLY output active tracks. | |
| # If we only output active tracks, we might suppress valid high-confidence detections that just started? | |
| # No, activated_stracks includes new ones. | |
| # Let's collect all active tracks | |
| output_stracks = [t for t in self.tracked_stracks if t.is_activated] | |
| results = [] | |
| for track in output_stracks: | |
| # Reconstruct dictionary | |
| # Get latest bbox from Kalman State for smoothness, or original? | |
| # Usually we use the detection box if matched, or predicted if lost (but logic above separates them). | |
| # If matched, we have updated KF. | |
| d_out = track.original_data.copy() if hasattr(track, 'original_data') else {} | |
| # Update bbox to tracked bbox? Or keep raw? | |
| # Keeping raw is safer for simple visualizer, but tracked bbox is smoother. | |
| # Let's use tracked bbox (tlbr). | |
| tracked_bbox = track.tlbr | |
| d_out['bbox'] = [float(x) for x in tracked_bbox] | |
| d_out['track_id'] = f"T{str(track.track_id).zfill(2)}" | |
| # Restore GPT data if track has it and current detection didn't | |
| for k, v in track.gpt_data.items(): | |
| if k not in d_out: | |
| d_out[k] = v | |
| # Update history | |
| if 'history' not in track.gpt_data: | |
| track.gpt_data['history'] = [] | |
| track.gpt_data['history'].append(d_out['bbox']) | |
| if len(track.gpt_data['history']) > 30: | |
| track.gpt_data['history'].pop(0) | |
| d_out['history'] = track.gpt_data['history'] | |
| results.append(d_out) | |
| return results | |
| def _sync_data(self, track, det_source): | |
| """Propagate attributes like GPT data between track and detection.""" | |
| # 1. From Source to Track (Update) | |
| source_data = det_source.original_data if hasattr(det_source, 'original_data') else {} | |
| for k in ['gpt_distance_m', 'gpt_direction', 'gpt_description']: | |
| if k in source_data: | |
| track.gpt_data[k] = source_data[k] | |
| # 2. From Track to Source (Forward fill logic handled in output construction) | |
| # --- Helper Functions --- | |
| def linear_assignment(cost_matrix, thresh): | |
| """Linear assignment with threshold using scipy.""" | |
| if cost_matrix.size == 0: | |
| return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) | |
| matches, unmatched_a, unmatched_b = [], [], [] | |
| # Scipy linear_sum_assignment finds min cost | |
| row_ind, col_ind = linear_sum_assignment(cost_matrix) | |
| for r, c in zip(row_ind, col_ind): | |
| if cost_matrix[r, c] <= thresh: | |
| matches.append((r, c)) | |
| else: | |
| unmatched_a.append(r) | |
| unmatched_b.append(c) | |
| # Add accumulation of indices that weren't selected | |
| # (scipy returns perfect matching for square, but partial for rectangular) | |
| # Actually scipy matches rows to cols. Any row not in row_ind is unmatched? | |
| # No, row_ind covers all rows if N < M. | |
| if cost_matrix.shape[0] > cost_matrix.shape[1]: # More rows than cols | |
| unmatched_a += list(set(range(cost_matrix.shape[0])) - set(row_ind)) | |
| elif cost_matrix.shape[0] < cost_matrix.shape[1]: # More cols than rows | |
| unmatched_b += list(set(range(cost_matrix.shape[1])) - set(col_ind)) | |
| # Also filter out threshold failures | |
| for r, c in zip(row_ind, col_ind): | |
| if cost_matrix[r, c] > thresh: | |
| if r not in unmatched_a: unmatched_a.append(r) | |
| if c not in unmatched_b: unmatched_b.append(c) | |
| # Clean up | |
| matches = np.array(matches) if len(matches) > 0 else np.empty((0, 2), dtype=int) | |
| return matches, unmatched_a, unmatched_b | |
| def iou_distance(atracks, btracks): | |
| """Compute IOU cost matrix between tracks and detections.""" | |
| if (len(atracks) == 0 and len(btracks) == 0) or len(atracks) == 0 or len(btracks) == 0: | |
| return np.zeros((len(atracks), len(btracks)), dtype=float) | |
| atlbrs = [track.tlbr for track in atracks] | |
| btlbrs = [track.tlbr for track in btracks] | |
| _ious = bbox_ious(np.array(atlbrs), np.array(btlbrs)) | |
| cost_matrix = 1 - _ious | |
| return cost_matrix | |
| def bbox_ious(boxes1, boxes2): | |
| """IOU matrix.""" | |
| b1_x1, b1_y1, b1_x2, b1_y2 = boxes1[:, 0], boxes1[:, 1], boxes1[:, 2], boxes1[:, 3] | |
| b2_x1, b2_y1, b2_x2, b2_y2 = boxes2[:, 0], boxes2[:, 1], boxes2[:, 2], boxes2[:, 3] | |
| inter_rect_x1 = np.maximum(b1_x1[:, None], b2_x1) | |
| inter_rect_y1 = np.maximum(b1_y1[:, None], b2_y1) | |
| inter_rect_x2 = np.minimum(b1_x2[:, None], b2_x2) | |
| inter_rect_y2 = np.minimum(b1_y2[:, None], b2_y2) | |
| inter_area = np.maximum(inter_rect_x2 - inter_rect_x1, 0) * np.maximum(inter_rect_y2 - inter_rect_y1, 0) | |
| b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) | |
| b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) | |
| iou = inter_area / (b1_area[:, None] + b2_area - inter_area + 1e-6) | |
| return iou | |
| def fuse_score(cost_matrix, detections): | |
| """Refine cost matrix with detection scores.""" | |
| if cost_matrix.size == 0: return cost_matrix | |
| iou_sim = 1 - cost_matrix | |
| det_scores = np.array([d.score for d in detections]) | |
| det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) | |
| fuse_sim = iou_sim * det_scores | |
| fuse_cost = 1 - fuse_sim | |
| return fuse_cost | |
| # STrack collection helpers | |
| def join_stracks(tlist_a, tlist_b): | |
| exists = {} | |
| res = [] | |
| for t in tlist_a: | |
| exists[t.track_id] = 1 | |
| res.append(t) | |
| for t in tlist_b: | |
| tid = t.track_id | |
| if not exists.get(tid, 0): | |
| exists[tid] = 1 | |
| res.append(t) | |
| return res | |
| def sub_stracks(tlist_a, tlist_b): | |
| stracks = {} | |
| for t in tlist_a: | |
| stracks[t.track_id] = t | |
| for t in tlist_b: | |
| tid = t.track_id | |
| if stracks.get(tid, 0): | |
| del stracks[tid] | |
| return list(stracks.values()) | |
| def remove_duplicate_stracks(stracksa, stracksb): | |
| pdist = iou_distance(stracksa, stracksb) | |
| pairs = np.where(pdist < 0.15) | |
| dupa, dupb = list(pairs[0]), list(pairs[1]) | |
| for a, b in zip(dupa, dupb): | |
| time_a = stracksa[a].frame_id - stracksa[a].start_frame | |
| time_b = stracksb[b].frame_id - stracksb[b].start_frame | |
| if time_a > time_b: | |
| dupb.append(b) # Bug in orig ByteTrack? It assumes removing from list. | |
| # We mark for removal. | |
| else: | |
| dupa.append(a) | |
| res_a = [t for i, t in enumerate(stracksa) if not i in dupa] | |
| res_b = [t for i, t in enumerate(stracksb) if not i in dupb] | |
| return res_a, res_b | |
| # Monkey patch for multi_predict since STrack is not in a module | |
| def multi_predict(stracks, kalman_filter): | |
| for t in stracks: | |
| if t.state != 2: | |
| t.mean[7] = 0 # reset velocity h if lost | |
| t.mean, t.covariance = kalman_filter.predict(t.mean, t.covariance) | |
| STrack.multi_predict = static_method_multi_predict = multi_predict | |