| import yaml | |
| import random | |
| import jsonlines | |
| import copy | |
| if __name__=='__main__': | |
| with open("xxx/scene_filter/navtrain.yaml", 'r') as file: | |
| navtrain_filter = yaml.safe_load(file) | |
| with jsonlines.open("/xxx/nuplan_log_infos.jsonl", 'r') as reader: | |
| log_name2lidar_pc_token = {item['log_name']: item for item in reader} | |
| log_navtrain_tokens = {} | |
| navtrain_tokens = set(navtrain_filter['tokens']) | |
| token_mapping = {} | |
| for log_name in log_name2lidar_pc_token: | |
| log_navtrain_tokens[log_name] = [] | |
| for token in log_name2lidar_pc_token[log_name]['lidar_pc_tokens']: | |
| if token in navtrain_tokens: | |
| log_navtrain_tokens[log_name].append(token) | |
| log_names = navtrain_filter['log_names'].copy() | |
| random.shuffle(log_names) | |
| num_logs = len(log_names) | |
| for percentage in [50, 60, 70, 80, 90]: | |
| navtrain_npct_logs = sorted(log_names[:int(num_logs * percentage / 100)]) | |
| navtrain_npct_tokens = [] | |
| for log_name in navtrain_npct_logs: | |
| navtrain_npct_tokens.extend(log_navtrain_tokens[log_name]) | |
| print(percentage, len(navtrain_npct_tokens)) | |
| navtrain_npct_filter = copy.deepcopy(navtrain_filter) | |
| navtrain_npct_filter['log_names'] = navtrain_npct_logs | |
| navtrain_npct_filter['tokens'] = navtrain_npct_tokens | |
| with open(f"data_loop/navtrain_split/navtrain_{percentage}pct.yaml", "w") as f: | |
| yaml.dump(navtrain_npct_filter, f) | |