File size: 3,411 Bytes
663494c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import sys
import mmcv
import scipy
import cv2
import pickle
import yaml
import numpy as np
import pandas as pd
from copy import deepcopy
import lzma
class LogFilter(object):
def __init__(
self,
log_name,
template_yaml,
save_path,
keep_col=True,
keep_off_route=False,
ttc=False,
sort_case=True,
keep_num=-1,
keep_ratio=-1,
):
self.keep_col = keep_col
self.keep_off_route = keep_off_route
self.ttc = ttc
self.sort_case = sort_case
self.keep_num = keep_num
self.template_yaml = template_yaml
self.save_path = save_path
self.keep_ratio = keep_ratio
self.log_data = pd.read_csv(log_name)
self.log_data = self.log_data[self.log_data['valid']==True]
print(f'Loaded from {log_name}')
def filter_cases(self,prefix='', metric_filter=True):
length = self.log_data.values.shape[0]
if metric_filter:
val_mask = np.zeros((length-1,),dtype=np.bool)
if self.keep_col:
case = self.log_data['no_at_fault_collisions'][:-1]==0
val_mask[case] = True
if self.keep_off_route:
case = self.log_data['drivable_area_compliance'][:-1]==0
val_mask[case] = True
if self.ttc:
case = self.log_data['time_to_collision_within_bound'][:-1]==0
val_mask[case] = True
val_mask = np.concatenate((val_mask, [False]),axis=0)
filtered_data = self.log_data[val_mask]
else:
filtered_data = self.log_data
if self.sort_case:
ep = filtered_data['ego_progress']
ttc = filtered_data['time_to_collision_within_bound']
comfort = filtered_data['comfort']
dac = filtered_data['drivable_area_compliance']
non_col_pdm = dac + (5*ep + 5*ttc + 2*comfort) / 12
filtered_data['non_col_pdm'] = non_col_pdm
if self.keep_ratio != -1:
len = filtered_data['score'].shape[0]
sorted_ind = np.argsort(filtered_data['score'])[:int(len*self.keep_ratio)]
print(int(len*self.keep_ratio))
else:
sorted_ind = np.argsort(non_col_pdm)
if self.keep_num > 0:
k = min(self.keep_num, len(sorted_ind))
sorted_ind = sorted_ind[:k]
filtered_data = filtered_data.iloc[sorted_ind]
print(f'filter {length-1} cases to {filtered_data.values.shape[0]} cases')
split_name = self.template_yaml.split('/')[-1].split('.')[0]
filtered_data.to_csv(self.save_path+f'filtered_{split_name}_{prefix}.csv')
with open(self.template_yaml, 'r') as f:
temp_log = yaml.safe_load(f)
temp_log['tokens'] = filtered_data['token'].tolist()
with open(self.save_path+f'filtered_{split_name}_{prefix}.yaml', 'w') as yaml_file:
yaml.dump(temp_log, yaml_file, default_flow_style=False, sort_keys=False)
if __name__=='__main__':
filters = LogFilter(
log_name='/xxx/split_1.csv',
template_yaml='/xxx/split_1.yaml',
save_path='/xxx/',
sort_case=True,keep_ratio=0.01
)
filters.filter_cases(metric_filter=False,prefix='per1')
|