|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
from .visualization import find_peak_box_improved, convert_grid_to_xy |
|
|
|
|
|
class TrackedObject: |
|
|
def __init__(self): |
|
|
self.last_step=0;self.last_pos=[0,0];self.historical_pos=[];self.historical_steps=[];self.historical_features=[] |
|
|
def update(self,step,object_info): |
|
|
self.last_step=step;self.last_pos=object_info[:2];self.feature=object_info[2:];self.historical_pos.append(self.last_pos);self.historical_steps.append(step);self.historical_features.append(self.feature) |
|
|
|
|
|
class Tracker: |
|
|
def __init__(self,frequency=10):self.tracks=[];self.alive_ids=[];self.frequency=frequency |
|
|
def update_and_predict(self,det_data,pos,theta,step,merge_percent=0.4): |
|
|
detected_objects=find_peak_box_improved(det_data);objects_info=[];R=np.array([[np.cos(-theta),-np.sin(-theta)],[np.sin(-theta),np.cos(-theta)]]);box_ids_for_update=[] |
|
|
for obj in detected_objects: |
|
|
i,j=obj['coords'];obj_data=obj['raw_data'];center_y,center_x=convert_grid_to_xy(i,j);center_x+=obj_data[1];center_y+=obj_data[2];loc=R.T.dot(np.array([center_x,center_y])) |
|
|
objects_info.append([loc[0]+pos[0],loc[1]+pos[1],obj_data[1:]]);box_ids_for_update.append((i,j)) |
|
|
self._update(objects_info,step);return det_data |
|
|
def _update(self,objects_info,step): |
|
|
if not self.tracks: |
|
|
for info in objects_info:self.tracks.append(TrackedObject());self.tracks[-1].update(step,info) |
|
|
else: |
|
|
to_ids=self._match(objects_info) |
|
|
for i,to_id in enumerate(to_ids): |
|
|
if to_id==-1:self.tracks.append(TrackedObject());to_id=len(self.tracks)-1 |
|
|
self.tracks[to_id].update(step,objects_info[i]) |
|
|
self.alive_ids=[i for i,track in enumerate(self.tracks) if track.last_step>step-6] |
|
|
def _match(self,objects_info): |
|
|
results=[];matched_ids=set() |
|
|
for info in objects_info: |
|
|
min_id,min_error=-1,float('inf');pos_x,pos_y=info[:2] |
|
|
for _id in self.alive_ids: |
|
|
if _id in matched_ids:continue |
|
|
track_pos=self.tracks[_id].last_pos;distance=np.sqrt((track_pos[0]-pos_x)**2+(track_pos[1]-pos_y)**2) |
|
|
if distance<2.0 and distance<min_error:min_error=distance;min_id=_id |
|
|
results.append(min_id) |
|
|
if min_id!=-1:matched_ids.add(min_id) |
|
|
return results |
|
|
|