R2SE_model / data_yaml /format_rare_log.py
unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
import sys
import mmcv
import scipy
import cv2
import pickle
import yaml
import numpy as np
import pandas as pd
from copy import deepcopy
import lzma
class LogFilter(object):
def __init__(
self,
log_name,
template_yaml,
save_path,
keep_col=True,
keep_off_route=False,
ttc=False,
sort_case=True,
keep_num=-1,
keep_ratio=-1,
):
self.keep_col = keep_col
self.keep_off_route = keep_off_route
self.ttc = ttc
self.sort_case = sort_case
self.keep_num = keep_num
self.template_yaml = template_yaml
self.save_path = save_path
self.keep_ratio = keep_ratio
self.log_data = pd.read_csv(log_name)
self.log_data = self.log_data[self.log_data['valid']==True]
print(f'Loaded from {log_name}')
def filter_cases(self,prefix='', metric_filter=True):
length = self.log_data.values.shape[0]
if metric_filter:
val_mask = np.zeros((length-1,),dtype=np.bool)
if self.keep_col:
case = self.log_data['no_at_fault_collisions'][:-1]==0
val_mask[case] = True
if self.keep_off_route:
case = self.log_data['drivable_area_compliance'][:-1]==0
val_mask[case] = True
if self.ttc:
case = self.log_data['time_to_collision_within_bound'][:-1]==0
val_mask[case] = True
val_mask = np.concatenate((val_mask, [False]),axis=0)
filtered_data = self.log_data[val_mask]
else:
filtered_data = self.log_data
if self.sort_case:
ep = filtered_data['ego_progress']
ttc = filtered_data['time_to_collision_within_bound']
comfort = filtered_data['comfort']
dac = filtered_data['drivable_area_compliance']
non_col_pdm = dac + (5*ep + 5*ttc + 2*comfort) / 12
filtered_data['non_col_pdm'] = non_col_pdm
if self.keep_ratio != -1:
len = filtered_data['score'].shape[0]
sorted_ind = np.argsort(filtered_data['score'])[:int(len*self.keep_ratio)]
print(int(len*self.keep_ratio))
else:
sorted_ind = np.argsort(non_col_pdm)
if self.keep_num > 0:
k = min(self.keep_num, len(sorted_ind))
sorted_ind = sorted_ind[:k]
filtered_data = filtered_data.iloc[sorted_ind]
print(f'filter {length-1} cases to {filtered_data.values.shape[0]} cases')
split_name = self.template_yaml.split('/')[-1].split('.')[0]
filtered_data.to_csv(self.save_path+f'filtered_{split_name}_{prefix}.csv')
with open(self.template_yaml, 'r') as f:
temp_log = yaml.safe_load(f)
temp_log['tokens'] = filtered_data['token'].tolist()
with open(self.save_path+f'filtered_{split_name}_{prefix}.yaml', 'w') as yaml_file:
yaml.dump(temp_log, yaml_file, default_flow_style=False, sort_keys=False)
if __name__=='__main__':
filters = LogFilter(
log_name='/xxx/split_1.csv',
template_yaml='/xxx/split_1.yaml',
save_path='/xxx/',
sort_case=True,keep_ratio=0.01
)
filters.filter_cases(metric_filter=False,prefix='per1')