File size: 2,506 Bytes
f08920f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# =========================================================
# الملف 4: simulation/tracking.py
# =========================================================
import numpy as np
from .visualization import find_peak_box_improved, convert_grid_to_xy

class TrackedObject:
    def __init__(self):
        self.last_step=0;self.last_pos=[0,0];self.historical_pos=[];self.historical_steps=[];self.historical_features=[]
    def update(self,step,object_info):
        self.last_step=step;self.last_pos=object_info[:2];self.feature=object_info[2:];self.historical_pos.append(self.last_pos);self.historical_steps.append(step);self.historical_features.append(self.feature)

class Tracker:
    def __init__(self,frequency=10):self.tracks=[];self.alive_ids=[];self.frequency=frequency
    def update_and_predict(self,det_data,pos,theta,step,merge_percent=0.4):
        detected_objects=find_peak_box_improved(det_data);objects_info=[];R=np.array([[np.cos(-theta),-np.sin(-theta)],[np.sin(-theta),np.cos(-theta)]]);box_ids_for_update=[]
        for obj in detected_objects:
            i,j=obj['coords'];obj_data=obj['raw_data'];center_y,center_x=convert_grid_to_xy(i,j);center_x+=obj_data[1];center_y+=obj_data[2];loc=R.T.dot(np.array([center_x,center_y]))
            objects_info.append([loc[0]+pos[0],loc[1]+pos[1],obj_data[1:]]);box_ids_for_update.append((i,j))
        self._update(objects_info,step);return det_data
    def _update(self,objects_info,step):
        if not self.tracks:
            for info in objects_info:self.tracks.append(TrackedObject());self.tracks[-1].update(step,info)
        else:
            to_ids=self._match(objects_info)
            for i,to_id in enumerate(to_ids):
                if to_id==-1:self.tracks.append(TrackedObject());to_id=len(self.tracks)-1
                self.tracks[to_id].update(step,objects_info[i])
        self.alive_ids=[i for i,track in enumerate(self.tracks) if track.last_step>step-6]
    def _match(self,objects_info):
        results=[];matched_ids=set()
        for info in objects_info:
            min_id,min_error=-1,float('inf');pos_x,pos_y=info[:2]
            for _id in self.alive_ids:
                if _id in matched_ids:continue
                track_pos=self.tracks[_id].last_pos;distance=np.sqrt((track_pos[0]-pos_x)**2+(track_pos[1]-pos_y)**2)
                if distance<2.0 and distance<min_error:min_error=distance;min_id=_id
            results.append(min_id)
            if min_id!=-1:matched_ids.add(min_id)
        return results