| import os |
| import json |
| from util.json import load_from_json |
|
|
| def first_non_space_char(s): |
| for char in s: |
| if char not in (' ', '\n'): |
| return char |
| return None |
|
|
| def check_multi_choose(json_file_path): |
| """ |
| 检查多选题的准确率。 |
| |
| 参数: |
| - json_file_path: JSON文件路径 |
| |
| 返回: |
| - 准确率 (float) |
| """ |
| result = load_from_json(json_file_path) |
| acc = 0.0 |
| for entry in result: |
| pred = entry['pred'] |
| right_option = entry['right_option'] |
| |
| if first_non_space_char(pred) == right_option: |
| acc += 1 |
| return acc / len(result) |
|
|
| def read_json_files(root_folder, file_name): |
| """ |
| 读取指定目录下的所有JSON文件,并计算每个文件的准确率。 |
| |
| 参数: |
| - root_folder: 根目录路径 |
| - file_name: JSON文件名 |
| |
| 返回: |
| - 准确率字典 (dict) |
| """ |
| acc = {} |
| for entry in os.listdir(root_folder): |
| print(entry) |
| entry_path = os.path.join(root_folder, entry) |
| json_file_path = os.path.join(entry_path, file_name) |
| if os.path.isfile(json_file_path): |
| sub_acc = check_multi_choose(json_file_path) |
| acc[entry] = sub_acc |
| return acc |
|
|
| if __name__ == "__main__": |
| root_folder = './result/eval/mvbench_qa' |
|
|
| from argparse import ArgumentParser |
| parser = ArgumentParser() |
| parser.add_argument('file_name', type=str, help='要获取的实验结果文件名') |
| args = parser.parse_args() |
|
|
| file_name = args.file_name |
|
|
| acc_list = read_json_files(root_folder, file_name) |
| just_acc = [acc for acc in acc_list.values()] |
| average_acc = sum(just_acc) / len(just_acc) |
| acc_list['average'] = average_acc |
| print("--------------------------Result--------------------------") |
| |
| formatted_acc_list = {k: f"{v:.3f}" for k, v in acc_list.items()} |
| print(formatted_acc_list) |
|
|