import yaml import random import jsonlines import copy if __name__=='__main__': with open("xxx/scene_filter/navtrain.yaml", 'r') as file: navtrain_filter = yaml.safe_load(file) with jsonlines.open("/xxx/nuplan_log_infos.jsonl", 'r') as reader: log_name2lidar_pc_token = {item['log_name']: item for item in reader} log_navtrain_tokens = {} navtrain_tokens = set(navtrain_filter['tokens']) token_mapping = {} for log_name in log_name2lidar_pc_token: log_navtrain_tokens[log_name] = [] for token in log_name2lidar_pc_token[log_name]['lidar_pc_tokens']: if token in navtrain_tokens: log_navtrain_tokens[log_name].append(token) log_names = navtrain_filter['log_names'].copy() random.shuffle(log_names) num_logs = len(log_names) for percentage in [50, 60, 70, 80, 90]: navtrain_npct_logs = sorted(log_names[:int(num_logs * percentage / 100)]) navtrain_npct_tokens = [] for log_name in navtrain_npct_logs: navtrain_npct_tokens.extend(log_navtrain_tokens[log_name]) print(percentage, len(navtrain_npct_tokens)) navtrain_npct_filter = copy.deepcopy(navtrain_filter) navtrain_npct_filter['log_names'] = navtrain_npct_logs navtrain_npct_filter['tokens'] = navtrain_npct_tokens with open(f"data_loop/navtrain_split/navtrain_{percentage}pct.yaml", "w") as f: yaml.dump(navtrain_npct_filter, f)