R2SE_model / data_yaml /split_navtrain_logs_pct.py
unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
import yaml
import random
import jsonlines
import copy
if __name__=='__main__':
with open("xxx/scene_filter/navtrain.yaml", 'r') as file:
navtrain_filter = yaml.safe_load(file)
with jsonlines.open("/xxx/nuplan_log_infos.jsonl", 'r') as reader:
log_name2lidar_pc_token = {item['log_name']: item for item in reader}
log_navtrain_tokens = {}
navtrain_tokens = set(navtrain_filter['tokens'])
token_mapping = {}
for log_name in log_name2lidar_pc_token:
log_navtrain_tokens[log_name] = []
for token in log_name2lidar_pc_token[log_name]['lidar_pc_tokens']:
if token in navtrain_tokens:
log_navtrain_tokens[log_name].append(token)
log_names = navtrain_filter['log_names'].copy()
random.shuffle(log_names)
num_logs = len(log_names)
for percentage in [50, 60, 70, 80, 90]:
navtrain_npct_logs = sorted(log_names[:int(num_logs * percentage / 100)])
navtrain_npct_tokens = []
for log_name in navtrain_npct_logs:
navtrain_npct_tokens.extend(log_navtrain_tokens[log_name])
print(percentage, len(navtrain_npct_tokens))
navtrain_npct_filter = copy.deepcopy(navtrain_filter)
navtrain_npct_filter['log_names'] = navtrain_npct_logs
navtrain_npct_filter['tokens'] = navtrain_npct_tokens
with open(f"data_loop/navtrain_split/navtrain_{percentage}pct.yaml", "w") as f:
yaml.dump(navtrain_npct_filter, f)