File size: 3,411 Bytes
663494c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import sys
import mmcv
import scipy
import cv2
import pickle
import yaml
import numpy as np
import pandas as pd
from copy import deepcopy

import lzma

class LogFilter(object):
    def __init__(
        self,
        log_name,
        template_yaml,
        save_path,
        keep_col=True,
        keep_off_route=False,
        ttc=False,
        sort_case=True,
        keep_num=-1,
        keep_ratio=-1,
        ):

        self.keep_col = keep_col
        self.keep_off_route = keep_off_route
        self.ttc = ttc

        self.sort_case = sort_case
        self.keep_num = keep_num
        self.template_yaml = template_yaml
        self.save_path = save_path
        self.keep_ratio = keep_ratio
        
        self.log_data = pd.read_csv(log_name)
        self.log_data = self.log_data[self.log_data['valid']==True]
        print(f'Loaded from {log_name}')
    
    def filter_cases(self,prefix='', metric_filter=True):
        length = self.log_data.values.shape[0]
        if metric_filter:
            val_mask = np.zeros((length-1,),dtype=np.bool)
            if self.keep_col:
                case = self.log_data['no_at_fault_collisions'][:-1]==0
                val_mask[case] = True
            if self.keep_off_route:
                case = self.log_data['drivable_area_compliance'][:-1]==0
                val_mask[case] = True
            if self.ttc:
                case = self.log_data['time_to_collision_within_bound'][:-1]==0
                val_mask[case] = True


            val_mask = np.concatenate((val_mask, [False]),axis=0)

            filtered_data = self.log_data[val_mask]
        else:
            filtered_data = self.log_data
        
        if self.sort_case:
            ep = filtered_data['ego_progress']
            ttc = filtered_data['time_to_collision_within_bound']
            comfort = filtered_data['comfort']
            dac = filtered_data['drivable_area_compliance']

            non_col_pdm = dac + (5*ep + 5*ttc + 2*comfort) / 12
            filtered_data['non_col_pdm'] = non_col_pdm

            if self.keep_ratio != -1:
                len = filtered_data['score'].shape[0]
                sorted_ind = np.argsort(filtered_data['score'])[:int(len*self.keep_ratio)]
                print(int(len*self.keep_ratio))
            else:
                sorted_ind = np.argsort(non_col_pdm)
                if self.keep_num > 0:
                    k = min(self.keep_num, len(sorted_ind))
                    sorted_ind = sorted_ind[:k]
            filtered_data = filtered_data.iloc[sorted_ind]

        print(f'filter {length-1} cases to {filtered_data.values.shape[0]} cases')

        split_name = self.template_yaml.split('/')[-1].split('.')[0]
        filtered_data.to_csv(self.save_path+f'filtered_{split_name}_{prefix}.csv')

        with open(self.template_yaml, 'r') as f:
            temp_log = yaml.safe_load(f)

        temp_log['tokens'] = filtered_data['token'].tolist()
        with open(self.save_path+f'filtered_{split_name}_{prefix}.yaml', 'w') as yaml_file:
            yaml.dump(temp_log, yaml_file, default_flow_style=False, sort_keys=False)

if __name__=='__main__':


    filters = LogFilter(
        log_name='/xxx/split_1.csv',
        template_yaml='/xxx/split_1.yaml',
        save_path='/xxx/',
        sort_case=True,keep_ratio=0.01
    )
    filters.filter_cases(metric_filter=False,prefix='per1')