File size: 3,528 Bytes
663494c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from glob import glob
from tqdm import tqdm, trange
import numpy as np
import mmcv
import pickle
from sklearn.cluster import KMeans
def load_and_process_traj(ann_file):
data_infos = mmcv.load(ann_file, file_format='pkl')
ego_trajs = []
map_locs = []
for data in tqdm(data_infos):
if np.sum(np.array(data["gt_fut_bbox_sdc_mask"][0, :8],dtype=np.float32))==8:
traj = data["gt_fut_bbox_sdc_lidar"][0, :8, :2]
ego_trajs.append(traj)
map_locs.append(data['map_location'])
with open('/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/full_traj.pkl', 'wb') as writer:
pickle.dump([ego_trajs, map_locs], writer)
def process_kmeans_vocab(traj_file, anchors=4096, loc=''):
with open(traj_file, 'rb') as reader:
[traj_data, locs] = pickle.load(reader)
print(traj_data[0].shape)
if loc=='':
end_p = np.array([traj[-1,:2] for traj in traj_data])
else:
end_p = []
for traj,l in zip(traj_data, locs):
if l==loc:
end_p.append(traj[-1,:2])
end_p = np.array(end_p)
print(end_p.shape)
kmeans = KMeans(n_clusters=anchors, verbose=True)
kmeans.fit(end_p)
print('fit end')
centroids = kmeans.cluster_centers_
with open(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_{anchors}.pkl', 'wb') as writer:
pickle.dump(centroids, writer)
print('processing the representative trajs...')
rep_traj = []
rep_loc = []
for i in trange(centroids.shape[0]):
centroid = centroids[i]
dist_arg = np.argmin(np.linalg.norm(end_p - centroid[np.newaxis, :2], axis=1))
rep_traj.append(traj_data[dist_arg])
rep_loc.append(locs[dist_arg])
with open(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_{anchors}.pkl', 'wb') as writer:
pickle.dump([rep_traj, rep_loc], writer)
import matplotlib.pyplot as plt
def visualization(traj_file, kmeans_files, suffix='4096'):
with open(traj_file, 'rb') as reader:
[traj_data, loc_data] = pickle.load(reader)
with open(kmeans_files, 'rb') as reader:
k_means_data = pickle.load(reader)
sing_mask = np.array(loc_data)=='singapore'
plt.figure()
plt.scatter(k_means_data[:, 0], k_means_data[:, 1], c='orange',
marker='*', s=5, zorder=3)
plt.scatter(k_means_data[sing_mask, 0], k_means_data[sing_mask, 1], c='red',
marker='*', s=5, zorder=4)
for traj in tqdm(traj_data):
plt.plot(traj[:, 0], traj[:, 1], color='navy',
alpha=0.6, zorder=1)
plt.savefig(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_{suffix}.png')
if __name__ == '__main__':
# ann_file = '/nas/shared/opendrivelab/litianyu/paradrive/data/navsim_infos/nuplan_navsim_train.pkl'
# load_and_process_traj(ann_file)
# traj_file = '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/full_traj.pkl'
# process_kmeans_vocab(traj_file, anchors=4096)
# process_kmeans_vocab(traj_file, anchors=8192)
visualization(
'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_8192.pkl',
'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_8192.pkl',
'8192'
)
visualization(
'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_4096.pkl',
'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_4096.pkl',
'4096'
)
|