SUM010's picture
Upload 2120 files
7b7527a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import numpy as np
from collections import defaultdict
from ..matching import jde_matching as matching
from ..motion import KalmanFilter
from .base_jde_tracker import TrackState, STrack
from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
__all__ = ['JDETracker']
class JDETracker(object):
__shared__ = ['num_classes']
"""
JDE tracker, support single class and multi classes
Args:
use_byte (bool): Whether use ByteTracker, default False
num_classes (int): the number of classes
det_thresh (float): threshold of detection score
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results. If set <= 0 means no need to filter bboxes,usually set
1.6 for pedestrian tracking.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
r_tracked_thresh (float): linear assignment threshold of
tracked stracks and unmatched detections
unconfirmed_thresh (float): linear assignment threshold of
unconfirmed stracks and unmatched detections
conf_thres (float): confidence threshold for tracking, also used in
ByteTracker as higher confidence threshold
match_thres (float): linear assignment threshold of tracked
stracks and detections in ByteTracker
low_conf_thres (float): lower confidence threshold for tracking in
ByteTracker
input_size (list): input feature map size to reid model, [h, w] format,
[64, 192] as default.
motion (str): motion model, KalmanFilter as default
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
"""
def __init__(self,
use_byte=False,
num_classes=1,
det_thresh=0.3,
track_buffer=30,
min_box_area=0,
vertical_ratio=0,
tracked_thresh=0.7,
r_tracked_thresh=0.5,
unconfirmed_thresh=0.7,
conf_thres=0,
match_thres=0.8,
low_conf_thres=0.2,
input_size=[64, 192],
motion='KalmanFilter',
metric_type='euclidean'):
self.use_byte = use_byte
self.num_classes = num_classes
self.det_thresh = det_thresh if not use_byte else conf_thres + 0.1
self.track_buffer = track_buffer
self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio
self.tracked_thresh = tracked_thresh
self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh
self.conf_thres = conf_thres
self.match_thres = match_thres
self.low_conf_thres = low_conf_thres
self.input_size = input_size
if motion == 'KalmanFilter':
self.motion = KalmanFilter()
self.metric_type = metric_type
self.frame_id = 0
self.tracked_tracks_dict = defaultdict(list) # dict(list[STrack])
self.lost_tracks_dict = defaultdict(list) # dict(list[STrack])
self.removed_tracks_dict = defaultdict(list) # dict(list[STrack])
self.max_time_lost = 0
# max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
def update(self, pred_dets, pred_embs=None):
"""
Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles
lost, removed, refound and active tracklets.
Args:
pred_dets (np.array): Detection results of the image, the shape is
[N, 6], means 'cls_id, score, x0, y0, x1, y1'.
pred_embs (np.array): Embedding results of the image, the shape is
[N, 128] or [N, 512].
Return:
output_stracks_dict (dict(list)): The list contains information
regarding the online_tracklets for the received image tensor.
"""
self.frame_id += 1
if self.frame_id == 1:
STrack.init_count(self.num_classes)
activated_tracks_dict = defaultdict(list)
refined_tracks_dict = defaultdict(list)
lost_tracks_dict = defaultdict(list)
removed_tracks_dict = defaultdict(list)
output_tracks_dict = defaultdict(list)
pred_dets_dict = defaultdict(list)
pred_embs_dict = defaultdict(list)
# unify single and multi classes detection and embedding results
for cls_id in range(self.num_classes):
cls_idx = (pred_dets[:, 0:1] == cls_id).squeeze(-1)
pred_dets_dict[cls_id] = pred_dets[cls_idx]
if pred_embs is not None:
pred_embs_dict[cls_id] = pred_embs[cls_idx]
else:
pred_embs_dict[cls_id] = None
for cls_id in range(self.num_classes):
""" Step 1: Get detections by class"""
pred_dets_cls = pred_dets_dict[cls_id]
pred_embs_cls = pred_embs_dict[cls_id]
remain_inds = (pred_dets_cls[:, 1:2] > self.conf_thres).squeeze(-1)
if remain_inds.sum() > 0:
pred_dets_cls = pred_dets_cls[remain_inds]
if pred_embs_cls is None:
# in original ByteTrack
detections = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]),
tlbrs[1],
cls_id,
30,
temp_feat=None) for tlbrs in pred_dets_cls
]
else:
pred_embs_cls = pred_embs_cls[remain_inds]
detections = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id,
30, temp_feat) for (tlbrs, temp_feat) in
zip(pred_dets_cls, pred_embs_cls)
]
else:
detections = []
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed_dict = defaultdict(list)
tracked_tracks_dict = defaultdict(list)
for track in self.tracked_tracks_dict[cls_id]:
if not track.is_activated:
# previous tracks which are not active in the current frame are added in unconfirmed list
unconfirmed_dict[cls_id].append(track)
else:
# Active tracks are added to the local list 'tracked_stracks'
tracked_tracks_dict[cls_id].append(track)
""" Step 2: First association, with embedding"""
# building tracking pool for the current frame
track_pool_dict = defaultdict(list)
track_pool_dict[cls_id] = joint_stracks(
tracked_tracks_dict[cls_id], self.lost_tracks_dict[cls_id])
# Predict the current location with KalmanFilter
STrack.multi_predict(track_pool_dict[cls_id], self.motion)
if pred_embs_cls is None:
# in original ByteTrack
dists = matching.iou_distance(track_pool_dict[cls_id],
detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.match_thres) # not self.tracked_thresh
else:
dists = matching.embedding_distance(
track_pool_dict[cls_id],
detections,
metric=self.metric_type)
dists = matching.fuse_motion(
self.motion, dists, track_pool_dict[cls_id], detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.tracked_thresh)
for i_tracked, idet in matches:
# i_tracked is the id of the track and idet is the detection
track = track_pool_dict[cls_id][i_tracked]
det = detections[idet]
if track.state == TrackState.Tracked:
# If the track is active, add the detection to the track
track.update(detections[idet], self.frame_id)
activated_tracks_dict[cls_id].append(track)
else:
# We have obtained a detection from a track which is not active,
# hence put the track in refind_stracks list
track.re_activate(det, self.frame_id, new_id=False)
refined_tracks_dict[cls_id].append(track)
# None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU"""
if self.use_byte:
inds_low = pred_dets_dict[cls_id][:, 1:2] > self.low_conf_thres
inds_high = pred_dets_dict[cls_id][:, 1:2] < self.conf_thres
inds_second = np.logical_and(inds_low, inds_high).squeeze(-1)
pred_dets_cls_second = pred_dets_dict[cls_id][inds_second]
# association the untrack to the low score detections
if len(pred_dets_cls_second) > 0:
if pred_embs_dict[cls_id] is None:
# in original ByteTrack
detections_second = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]),
tlbrs[1],
cls_id,
30,
temp_feat=None)
for tlbrs in pred_dets_cls_second
]
else:
pred_embs_cls_second = pred_embs_dict[cls_id][
inds_second]
detections_second = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1],
cls_id, 30, temp_feat) for (tlbrs, temp_feat) in
zip(pred_dets_cls_second, pred_embs_cls_second)
]
else:
detections_second = []
r_tracked_stracks = [
track_pool_dict[cls_id][i] for i in u_track
if track_pool_dict[cls_id][i].state == TrackState.Tracked
]
dists = matching.iou_distance(r_tracked_stracks,
detections_second)
matches, u_track, u_detection_second = matching.linear_assignment(
dists, thresh=0.4) # not r_tracked_thresh
else:
detections = [detections[i] for i in u_detection]
r_tracked_stracks = []
for i in u_track:
if track_pool_dict[cls_id][i].state == TrackState.Tracked:
r_tracked_stracks.append(track_pool_dict[cls_id][i])
dists = matching.iou_distance(r_tracked_stracks, detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.r_tracked_thresh)
for i_tracked, idet in matches:
track = r_tracked_stracks[i_tracked]
det = detections[
idet] if not self.use_byte else detections_second[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_tracks_dict[cls_id].append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refined_tracks_dict[cls_id].append(track)
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
lost_tracks_dict[cls_id].append(track)
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection]
dists = matching.iou_distance(unconfirmed_dict[cls_id], detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(
dists, thresh=self.unconfirmed_thresh)
for i_tracked, idet in matches:
unconfirmed_dict[cls_id][i_tracked].update(detections[idet],
self.frame_id)
activated_tracks_dict[cls_id].append(unconfirmed_dict[cls_id][
i_tracked])
for it in u_unconfirmed:
track = unconfirmed_dict[cls_id][it]
track.mark_removed()
removed_tracks_dict[cls_id].append(track)
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.det_thresh:
continue
track.activate(self.motion, self.frame_id)
activated_tracks_dict[cls_id].append(track)
""" Step 5: Update state"""
for track in self.lost_tracks_dict[cls_id]:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
removed_tracks_dict[cls_id].append(track)
self.tracked_tracks_dict[cls_id] = [
t for t in self.tracked_tracks_dict[cls_id]
if t.state == TrackState.Tracked
]
self.tracked_tracks_dict[cls_id] = joint_stracks(
self.tracked_tracks_dict[cls_id], activated_tracks_dict[cls_id])
self.tracked_tracks_dict[cls_id] = joint_stracks(
self.tracked_tracks_dict[cls_id], refined_tracks_dict[cls_id])
self.lost_tracks_dict[cls_id] = sub_stracks(
self.lost_tracks_dict[cls_id], self.tracked_tracks_dict[cls_id])
self.lost_tracks_dict[cls_id].extend(lost_tracks_dict[cls_id])
self.lost_tracks_dict[cls_id] = sub_stracks(
self.lost_tracks_dict[cls_id], self.removed_tracks_dict[cls_id])
self.removed_tracks_dict[cls_id].extend(removed_tracks_dict[cls_id])
self.tracked_tracks_dict[cls_id], self.lost_tracks_dict[
cls_id] = remove_duplicate_stracks(
self.tracked_tracks_dict[cls_id],
self.lost_tracks_dict[cls_id])
# get scores of lost tracks
output_tracks_dict[cls_id] = [
track for track in self.tracked_tracks_dict[cls_id]
if track.is_activated
]
return output_tracks_dict