File size: 3,528 Bytes
663494c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from glob import glob 
from tqdm import tqdm, trange
import numpy as np 
import mmcv 
import pickle
from sklearn.cluster import KMeans

def load_and_process_traj(ann_file):

    data_infos = mmcv.load(ann_file, file_format='pkl')
    ego_trajs = []
    map_locs = []
    
    for data in tqdm(data_infos):
        if np.sum(np.array(data["gt_fut_bbox_sdc_mask"][0, :8],dtype=np.float32))==8:
            traj = data["gt_fut_bbox_sdc_lidar"][0, :8, :2]
            ego_trajs.append(traj)
            map_locs.append(data['map_location'])

    
    with open('/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/full_traj.pkl', 'wb') as writer:
        pickle.dump([ego_trajs, map_locs], writer)

def process_kmeans_vocab(traj_file, anchors=4096, loc=''):
    with open(traj_file, 'rb') as reader:
        [traj_data, locs] = pickle.load(reader)
    print(traj_data[0].shape)
    if loc=='':
        end_p = np.array([traj[-1,:2] for traj in traj_data])
    else:
        end_p = []
        for traj,l in zip(traj_data, locs):
            if l==loc:
               end_p.append(traj[-1,:2]) 
        end_p = np.array(end_p)

    print(end_p.shape)

    kmeans = KMeans(n_clusters=anchors, verbose=True)
    kmeans.fit(end_p)
    print('fit end')
    centroids = kmeans.cluster_centers_
    
    with open(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_{anchors}.pkl', 'wb') as writer:
        pickle.dump(centroids, writer)
    
    print('processing the representative trajs...')
    rep_traj = []
    rep_loc = []
    for i in trange(centroids.shape[0]):
        centroid = centroids[i]
        dist_arg = np.argmin(np.linalg.norm(end_p - centroid[np.newaxis, :2], axis=1))
        rep_traj.append(traj_data[dist_arg])
        rep_loc.append(locs[dist_arg])
    
    with open(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_{anchors}.pkl', 'wb') as writer:
        pickle.dump([rep_traj, rep_loc], writer)

import matplotlib.pyplot as plt 
def visualization(traj_file, kmeans_files, suffix='4096'):
    with open(traj_file, 'rb') as reader:
        [traj_data, loc_data] = pickle.load(reader)
    
    with open(kmeans_files, 'rb') as reader:
        k_means_data = pickle.load(reader)
    
    sing_mask = np.array(loc_data)=='singapore'
    plt.figure()
    plt.scatter(k_means_data[:, 0], k_means_data[:, 1], c='orange',
        marker='*', s=5, zorder=3)
    plt.scatter(k_means_data[sing_mask, 0], k_means_data[sing_mask, 1], c='red',
        marker='*', s=5, zorder=4)
    
    for traj in tqdm(traj_data):
        plt.plot(traj[:, 0], traj[:, 1], color='navy',
            alpha=0.6, zorder=1)
    
    plt.savefig(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_{suffix}.png')



if __name__ == '__main__':

    # ann_file = '/nas/shared/opendrivelab/litianyu/paradrive/data/navsim_infos/nuplan_navsim_train.pkl'
    # load_and_process_traj(ann_file)

    # traj_file = '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/full_traj.pkl'
    # process_kmeans_vocab(traj_file, anchors=4096)
    # process_kmeans_vocab(traj_file, anchors=8192)

    visualization(
        '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_8192.pkl',
        '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_8192.pkl',
        '8192'
    )

    visualization(
        '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_4096.pkl',
        '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_4096.pkl',
        '4096'
    )