from glob import glob from tqdm import tqdm, trange import numpy as np import mmcv import pickle from sklearn.cluster import KMeans def load_and_process_traj(ann_file): data_infos = mmcv.load(ann_file, file_format='pkl') ego_trajs = [] map_locs = [] for data in tqdm(data_infos): if np.sum(np.array(data["gt_fut_bbox_sdc_mask"][0, :8],dtype=np.float32))==8: traj = data["gt_fut_bbox_sdc_lidar"][0, :8, :2] ego_trajs.append(traj) map_locs.append(data['map_location']) with open('/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/full_traj.pkl', 'wb') as writer: pickle.dump([ego_trajs, map_locs], writer) def process_kmeans_vocab(traj_file, anchors=4096, loc=''): with open(traj_file, 'rb') as reader: [traj_data, locs] = pickle.load(reader) print(traj_data[0].shape) if loc=='': end_p = np.array([traj[-1,:2] for traj in traj_data]) else: end_p = [] for traj,l in zip(traj_data, locs): if l==loc: end_p.append(traj[-1,:2]) end_p = np.array(end_p) print(end_p.shape) kmeans = KMeans(n_clusters=anchors, verbose=True) kmeans.fit(end_p) print('fit end') centroids = kmeans.cluster_centers_ with open(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_{anchors}.pkl', 'wb') as writer: pickle.dump(centroids, writer) print('processing the representative trajs...') rep_traj = [] rep_loc = [] for i in trange(centroids.shape[0]): centroid = centroids[i] dist_arg = np.argmin(np.linalg.norm(end_p - centroid[np.newaxis, :2], axis=1)) rep_traj.append(traj_data[dist_arg]) rep_loc.append(locs[dist_arg]) with open(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_{anchors}.pkl', 'wb') as writer: pickle.dump([rep_traj, rep_loc], writer) import matplotlib.pyplot as plt def visualization(traj_file, kmeans_files, suffix='4096'): with open(traj_file, 'rb') as reader: [traj_data, loc_data] = pickle.load(reader) with open(kmeans_files, 'rb') as reader: k_means_data = pickle.load(reader) sing_mask = np.array(loc_data)=='singapore' plt.figure() plt.scatter(k_means_data[:, 0], k_means_data[:, 1], c='orange', marker='*', s=5, zorder=3) plt.scatter(k_means_data[sing_mask, 0], k_means_data[sing_mask, 1], c='red', marker='*', s=5, zorder=4) for traj in tqdm(traj_data): plt.plot(traj[:, 0], traj[:, 1], color='navy', alpha=0.6, zorder=1) plt.savefig(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_{suffix}.png') if __name__ == '__main__': # ann_file = '/nas/shared/opendrivelab/litianyu/paradrive/data/navsim_infos/nuplan_navsim_train.pkl' # load_and_process_traj(ann_file) # traj_file = '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/full_traj.pkl' # process_kmeans_vocab(traj_file, anchors=4096) # process_kmeans_vocab(traj_file, anchors=8192) visualization( '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_8192.pkl', '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_8192.pkl', '8192' ) visualization( '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_4096.pkl', '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_4096.pkl', '4096' )