import sys import mmcv import scipy import cv2 import pickle import yaml import numpy as np import pandas as pd from copy import deepcopy import lzma class LogFilter(object): def __init__( self, log_name, template_yaml, save_path, keep_col=True, keep_off_route=False, ttc=False, sort_case=True, keep_num=-1, keep_ratio=-1, ): self.keep_col = keep_col self.keep_off_route = keep_off_route self.ttc = ttc self.sort_case = sort_case self.keep_num = keep_num self.template_yaml = template_yaml self.save_path = save_path self.keep_ratio = keep_ratio self.log_data = pd.read_csv(log_name) self.log_data = self.log_data[self.log_data['valid']==True] print(f'Loaded from {log_name}') def filter_cases(self,prefix='', metric_filter=True): length = self.log_data.values.shape[0] if metric_filter: val_mask = np.zeros((length-1,),dtype=np.bool) if self.keep_col: case = self.log_data['no_at_fault_collisions'][:-1]==0 val_mask[case] = True if self.keep_off_route: case = self.log_data['drivable_area_compliance'][:-1]==0 val_mask[case] = True if self.ttc: case = self.log_data['time_to_collision_within_bound'][:-1]==0 val_mask[case] = True val_mask = np.concatenate((val_mask, [False]),axis=0) filtered_data = self.log_data[val_mask] else: filtered_data = self.log_data if self.sort_case: ep = filtered_data['ego_progress'] ttc = filtered_data['time_to_collision_within_bound'] comfort = filtered_data['comfort'] dac = filtered_data['drivable_area_compliance'] non_col_pdm = dac + (5*ep + 5*ttc + 2*comfort) / 12 filtered_data['non_col_pdm'] = non_col_pdm if self.keep_ratio != -1: len = filtered_data['score'].shape[0] sorted_ind = np.argsort(filtered_data['score'])[:int(len*self.keep_ratio)] print(int(len*self.keep_ratio)) else: sorted_ind = np.argsort(non_col_pdm) if self.keep_num > 0: k = min(self.keep_num, len(sorted_ind)) sorted_ind = sorted_ind[:k] filtered_data = filtered_data.iloc[sorted_ind] print(f'filter {length-1} cases to {filtered_data.values.shape[0]} cases') split_name = self.template_yaml.split('/')[-1].split('.')[0] filtered_data.to_csv(self.save_path+f'filtered_{split_name}_{prefix}.csv') with open(self.template_yaml, 'r') as f: temp_log = yaml.safe_load(f) temp_log['tokens'] = filtered_data['token'].tolist() with open(self.save_path+f'filtered_{split_name}_{prefix}.yaml', 'w') as yaml_file: yaml.dump(temp_log, yaml_file, default_flow_style=False, sort_keys=False) if __name__=='__main__': filters = LogFilter( log_name='/xxx/split_1.csv', template_yaml='/xxx/split_1.yaml', save_path='/xxx/', sort_case=True,keep_ratio=0.01 ) filters.filter_cases(metric_filter=False,prefix='per1')