mohammed-aljafry's picture
Final fix v9: Re-upload with fully self-contained and correct package structure
f08920f verified
raw
history blame
2.51 kB
# =========================================================
# الملف 4: simulation/tracking.py
# =========================================================
import numpy as np
from .visualization import find_peak_box_improved, convert_grid_to_xy
class TrackedObject:
def __init__(self):
self.last_step=0;self.last_pos=[0,0];self.historical_pos=[];self.historical_steps=[];self.historical_features=[]
def update(self,step,object_info):
self.last_step=step;self.last_pos=object_info[:2];self.feature=object_info[2:];self.historical_pos.append(self.last_pos);self.historical_steps.append(step);self.historical_features.append(self.feature)
class Tracker:
def __init__(self,frequency=10):self.tracks=[];self.alive_ids=[];self.frequency=frequency
def update_and_predict(self,det_data,pos,theta,step,merge_percent=0.4):
detected_objects=find_peak_box_improved(det_data);objects_info=[];R=np.array([[np.cos(-theta),-np.sin(-theta)],[np.sin(-theta),np.cos(-theta)]]);box_ids_for_update=[]
for obj in detected_objects:
i,j=obj['coords'];obj_data=obj['raw_data'];center_y,center_x=convert_grid_to_xy(i,j);center_x+=obj_data[1];center_y+=obj_data[2];loc=R.T.dot(np.array([center_x,center_y]))
objects_info.append([loc[0]+pos[0],loc[1]+pos[1],obj_data[1:]]);box_ids_for_update.append((i,j))
self._update(objects_info,step);return det_data
def _update(self,objects_info,step):
if not self.tracks:
for info in objects_info:self.tracks.append(TrackedObject());self.tracks[-1].update(step,info)
else:
to_ids=self._match(objects_info)
for i,to_id in enumerate(to_ids):
if to_id==-1:self.tracks.append(TrackedObject());to_id=len(self.tracks)-1
self.tracks[to_id].update(step,objects_info[i])
self.alive_ids=[i for i,track in enumerate(self.tracks) if track.last_step>step-6]
def _match(self,objects_info):
results=[];matched_ids=set()
for info in objects_info:
min_id,min_error=-1,float('inf');pos_x,pos_y=info[:2]
for _id in self.alive_ids:
if _id in matched_ids:continue
track_pos=self.tracks[_id].last_pos;distance=np.sqrt((track_pos[0]-pos_x)**2+(track_pos[1]-pos_y)**2)
if distance<2.0 and distance<min_error:min_error=distance;min_id=_id
results.append(min_id)
if min_id!=-1:matched_ids.add(min_id)
return results