|
|
from glob import glob |
|
|
from tqdm import tqdm, trange |
|
|
import numpy as np |
|
|
import mmcv |
|
|
import pickle |
|
|
from sklearn.cluster import KMeans |
|
|
|
|
|
def load_and_process_traj(ann_file): |
|
|
|
|
|
data_infos = mmcv.load(ann_file, file_format='pkl') |
|
|
ego_trajs = [] |
|
|
map_locs = [] |
|
|
|
|
|
for data in tqdm(data_infos): |
|
|
if np.sum(np.array(data["gt_fut_bbox_sdc_mask"][0, :8],dtype=np.float32))==8: |
|
|
traj = data["gt_fut_bbox_sdc_lidar"][0, :8, :2] |
|
|
ego_trajs.append(traj) |
|
|
map_locs.append(data['map_location']) |
|
|
|
|
|
|
|
|
with open('/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/full_traj.pkl', 'wb') as writer: |
|
|
pickle.dump([ego_trajs, map_locs], writer) |
|
|
|
|
|
def process_kmeans_vocab(traj_file, anchors=4096, loc=''): |
|
|
with open(traj_file, 'rb') as reader: |
|
|
[traj_data, locs] = pickle.load(reader) |
|
|
print(traj_data[0].shape) |
|
|
if loc=='': |
|
|
end_p = np.array([traj[-1,:2] for traj in traj_data]) |
|
|
else: |
|
|
end_p = [] |
|
|
for traj,l in zip(traj_data, locs): |
|
|
if l==loc: |
|
|
end_p.append(traj[-1,:2]) |
|
|
end_p = np.array(end_p) |
|
|
|
|
|
print(end_p.shape) |
|
|
|
|
|
kmeans = KMeans(n_clusters=anchors, verbose=True) |
|
|
kmeans.fit(end_p) |
|
|
print('fit end') |
|
|
centroids = kmeans.cluster_centers_ |
|
|
|
|
|
with open(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_{anchors}.pkl', 'wb') as writer: |
|
|
pickle.dump(centroids, writer) |
|
|
|
|
|
print('processing the representative trajs...') |
|
|
rep_traj = [] |
|
|
rep_loc = [] |
|
|
for i in trange(centroids.shape[0]): |
|
|
centroid = centroids[i] |
|
|
dist_arg = np.argmin(np.linalg.norm(end_p - centroid[np.newaxis, :2], axis=1)) |
|
|
rep_traj.append(traj_data[dist_arg]) |
|
|
rep_loc.append(locs[dist_arg]) |
|
|
|
|
|
with open(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_{anchors}.pkl', 'wb') as writer: |
|
|
pickle.dump([rep_traj, rep_loc], writer) |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
def visualization(traj_file, kmeans_files, suffix='4096'): |
|
|
with open(traj_file, 'rb') as reader: |
|
|
[traj_data, loc_data] = pickle.load(reader) |
|
|
|
|
|
with open(kmeans_files, 'rb') as reader: |
|
|
k_means_data = pickle.load(reader) |
|
|
|
|
|
sing_mask = np.array(loc_data)=='singapore' |
|
|
plt.figure() |
|
|
plt.scatter(k_means_data[:, 0], k_means_data[:, 1], c='orange', |
|
|
marker='*', s=5, zorder=3) |
|
|
plt.scatter(k_means_data[sing_mask, 0], k_means_data[sing_mask, 1], c='red', |
|
|
marker='*', s=5, zorder=4) |
|
|
|
|
|
for traj in tqdm(traj_data): |
|
|
plt.plot(traj[:, 0], traj[:, 1], color='navy', |
|
|
alpha=0.6, zorder=1) |
|
|
|
|
|
plt.savefig(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_{suffix}.png') |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
visualization( |
|
|
'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_8192.pkl', |
|
|
'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_8192.pkl', |
|
|
'8192' |
|
|
) |
|
|
|
|
|
visualization( |
|
|
'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_4096.pkl', |
|
|
'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_4096.pkl', |
|
|
'4096' |
|
|
) |
|
|
|
|
|
|
|
|
|