| import sys | |
| import mmcv | |
| import scipy | |
| import cv2 | |
| import pickle | |
| import yaml | |
| import numpy as np | |
| import pandas as pd | |
| from copy import deepcopy | |
| import lzma | |
| class LogFilter(object): | |
| def __init__( | |
| self, | |
| log_name, | |
| template_yaml, | |
| save_path, | |
| keep_col=True, | |
| keep_off_route=False, | |
| ttc=False, | |
| sort_case=True, | |
| keep_num=-1, | |
| keep_ratio=-1, | |
| ): | |
| self.keep_col = keep_col | |
| self.keep_off_route = keep_off_route | |
| self.ttc = ttc | |
| self.sort_case = sort_case | |
| self.keep_num = keep_num | |
| self.template_yaml = template_yaml | |
| self.save_path = save_path | |
| self.keep_ratio = keep_ratio | |
| self.log_data = pd.read_csv(log_name) | |
| self.log_data = self.log_data[self.log_data['valid']==True] | |
| print(f'Loaded from {log_name}') | |
| def filter_cases(self,prefix='', metric_filter=True): | |
| length = self.log_data.values.shape[0] | |
| if metric_filter: | |
| val_mask = np.zeros((length-1,),dtype=np.bool) | |
| if self.keep_col: | |
| case = self.log_data['no_at_fault_collisions'][:-1]==0 | |
| val_mask[case] = True | |
| if self.keep_off_route: | |
| case = self.log_data['drivable_area_compliance'][:-1]==0 | |
| val_mask[case] = True | |
| if self.ttc: | |
| case = self.log_data['time_to_collision_within_bound'][:-1]==0 | |
| val_mask[case] = True | |
| val_mask = np.concatenate((val_mask, [False]),axis=0) | |
| filtered_data = self.log_data[val_mask] | |
| else: | |
| filtered_data = self.log_data | |
| if self.sort_case: | |
| ep = filtered_data['ego_progress'] | |
| ttc = filtered_data['time_to_collision_within_bound'] | |
| comfort = filtered_data['comfort'] | |
| dac = filtered_data['drivable_area_compliance'] | |
| non_col_pdm = dac + (5*ep + 5*ttc + 2*comfort) / 12 | |
| filtered_data['non_col_pdm'] = non_col_pdm | |
| if self.keep_ratio != -1: | |
| len = filtered_data['score'].shape[0] | |
| sorted_ind = np.argsort(filtered_data['score'])[:int(len*self.keep_ratio)] | |
| print(int(len*self.keep_ratio)) | |
| else: | |
| sorted_ind = np.argsort(non_col_pdm) | |
| if self.keep_num > 0: | |
| k = min(self.keep_num, len(sorted_ind)) | |
| sorted_ind = sorted_ind[:k] | |
| filtered_data = filtered_data.iloc[sorted_ind] | |
| print(f'filter {length-1} cases to {filtered_data.values.shape[0]} cases') | |
| split_name = self.template_yaml.split('/')[-1].split('.')[0] | |
| filtered_data.to_csv(self.save_path+f'filtered_{split_name}_{prefix}.csv') | |
| with open(self.template_yaml, 'r') as f: | |
| temp_log = yaml.safe_load(f) | |
| temp_log['tokens'] = filtered_data['token'].tolist() | |
| with open(self.save_path+f'filtered_{split_name}_{prefix}.yaml', 'w') as yaml_file: | |
| yaml.dump(temp_log, yaml_file, default_flow_style=False, sort_keys=False) | |
| if __name__=='__main__': | |
| filters = LogFilter( | |
| log_name='/xxx/split_1.csv', | |
| template_yaml='/xxx/split_1.yaml', | |
| save_path='/xxx/', | |
| sort_case=True,keep_ratio=0.01 | |
| ) | |
| filters.filter_cases(metric_filter=False,prefix='per1') | |