File size: 17,191 Bytes
b0c0df0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 |
import json
import os
import random
import re
from collections import Counter, defaultdict
from loguru import logger as eval_logger
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
PROMPT = {
"task_instructions": [
"请回答以下多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。",
"请回答以下判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。",
"请回答以下填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。",
],
"multi_choice_example_format": ["问题:{}\n选项:\n{}\n正确答案:\n"],
"T/F_example_format": ["问题:{}\n正确答案:\n"],
"short_ans_example_format": ["问题:{}\n正确答案:\n"],
}
def construct_prompt(sample):
question = sample["question"]
task_instructions = PROMPT["task_instructions"]
if sample["type"] == "选择":
formatted_options = ""
start_chr = "A"
for i in range(1, 5):
formatted_options += f"({start_chr}) {sample[f'option{i}']}\n"
start_chr = chr(ord(start_chr) + 1)
current_example_template = PROMPT["multi_choice_example_format"][0]
current_example = current_example_template.format(question, formatted_options)
final_input_prompt = task_instructions[0] + "\n\n" + current_example
elif sample["type"] == "判断":
current_example_template = PROMPT["T/F_example_format"][0]
current_example = current_example_template.format(question)
final_input_prompt = task_instructions[1] + "\n\n" + current_example
else: # For fill in the blanks questions.
current_example_template = PROMPT["short_ans_example_format"][0]
current_example = current_example_template.format(question)
final_input_prompt = task_instructions[2] + "\n\n" + current_example
for i in range(1, 6):
final_input_prompt = final_input_prompt.replace(f'<img="{sample[f"image_{i}_filename"]}">', f"<图片 {i}>")
return final_input_prompt
def cmmmu_doc_to_text(doc):
return construct_prompt(doc)
def cmmmu_doc_to_visual(doc):
prompt = construct_prompt(doc)
image_tokens = re.findall(r"<图片 \d+>", prompt)
# Remove <> and swap space as _
image_tokens = [image_token.strip("<>").replace(" ", "_").replace("图片", "image") for image_token in image_tokens]
visual = [doc[image_token].convert("RGB") for image_token in image_tokens]
return visual
def cmmmu_process_results(doc, results):
pred = results[0]
if doc["type"] == "选择":
index2ans, all_choices = get_multi_choice_info([doc[f"option{i}"] for i in range(1, 5)])
parsed_pred = get_multi_choice_prediction(pred, all_choices, index2ans)
elif doc["type"] == "判断":
parsed_pred = get_TF_prediction(pred)
else:
parsed_pred = get_fill_blank_prediction(pred, doc["answer"])
return {"cmmmu_acc": {"id": doc["id"], "subdomain": doc["subcategory"], "question_type": doc["type"], "answer": doc["answer"], "parsed_pred": parsed_pred}}
def cmmmu_aggregate_results(results):
evaluation_result = {}
subset_to_eval_samples = defaultdict(list)
for result in results:
subset_to_eval_samples[result["subdomain"]].append(result)
for subset, sub_eval_samples in subset_to_eval_samples.items():
metric_dict = eval_cmmmu(sub_eval_samples)
evaluation_result[subset] = metric_dict
printable_results = {}
for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
in_domain_cat_results = {}
for cat_name in in_domain_cats:
if cat_name in evaluation_result.keys():
in_domain_cat_results[cat_name] = evaluation_result[cat_name]
else:
pass
in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
in_domain_data_num = sum([cat_results["entries_num"] for cat_results in in_domain_cat_results.values()])
printable_results["Overall-" + domain] = {
"num": int(in_domain_data_num),
"acc": round(in_domain_ins_acc, 3),
}
# add sub category
for cat_name, cat_results in in_domain_cat_results.items():
printable_results[cat_name] = {
"num": int(cat_results["entries_num"]),
"acc": round(cat_results["acc"], 3),
}
all_ins_acc = calculate_ins_level_acc(evaluation_result)
printable_results["Overall"] = {
"num": sum([cat_results["entries_num"] for cat_results in evaluation_result.values()]),
"acc": round(all_ins_acc, 3),
}
print(printable_results)
return printable_results["Overall"]["acc"]
def cmmmu_process_test_results_for_submission(doc, results):
response = results[0]
return {"submission": {"id": doc["id"], "type": doc["type"], "response": response}}
def cmmmu_test_aggregate_results_for_submission(results, args):
file = generate_submission_file("cmmmu_test_for_submission.jsonl", args)
with open(file, "w", encoding="utf8") as f:
for result in results:
json.dump(result, f, ensure_ascii=False)
f.write("\n")
eval_logger.info(f"Submission file saved to {file}")
##################
# Helper functions
##################
DOMAIN_CAT2SUB_CAT = {
"艺术与设计": ["艺术", "艺术理论", "设计", "音乐"],
"商业": ["会计", "经济", "金融", "管理", "营销"],
"科学": ["生物", "化学", "地理", "数学", "物理"],
"健康与医学": ["基础医学", "临床医学", "诊断学与实验室医学", "制药", "公共卫生"],
"人文社会科学": ["历史", "文献学", "社会学", "心理学"],
"技术与工程": ["农业", "建筑学", "计算机科学", "电子学", "能源和电力", "材料", "机械工程"],
}
def eval_cmmmu(entries):
correct_cnt = 0
for entry in entries:
parsed_pred = entry.get("parsed_pred", "")
correct = False
if entry.get("question_type") == "选择":
if parsed_pred == entry["answer"]:
correct_cnt += 1
correct = True
elif entry.get("question_type") == "填空":
norm_answers = normalize_str(entry["answer"], entry["answer"])
for pred in parsed_pred:
# already normalized
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
for norm_ans in norm_answers:
# only see if the string answer in the string pred
# print(norm_ans, pred)
if isinstance(norm_ans, str) and norm_ans in pred:
if not correct:
correct_cnt += 1
correct = True
break
else: # it's a number
if pred in norm_answers:
if not correct:
correct_cnt += 1
correct = True
break
else:
positive_keywords = ["正确", "对", "准确", "肯定", "对的"]
negative_keywords = ["不对", "错误", "不正确", "不准确", "不合适", "否定", "错的", "错"]
ambiguous_keywords = ["对错", "是否正确", "否正确", "或者", "是否", "正确性", "对不"]
def judge_similarity(pred_list, positive_keywords, negative_keywords):
positive_count = 0
negative_count = 0
for pred in pred_list:
if any(pos_word in pred for pos_word in positive_keywords):
positive_count += 1
elif any(neg_word in pred for neg_word in negative_keywords):
negative_count += 1
if positive_count > negative_count:
return "对"
elif negative_count > positive_count:
return "错"
else:
return random.choice(["对", "错"])
answer = entry["answer"]
parsed_pred = [word for word in parsed_pred if not any(ambiguous in word for ambiguous in ambiguous_keywords)]
result = judge_similarity(parsed_pred, positive_keywords, negative_keywords)
if result == answer:
correct_cnt += 1
correct = True
if correct:
entry["judge"] = "正确"
else:
entry["judge"] = "错误"
if len(entries) == 0:
print("entries_num == 0, please check your file")
results_count = {"correct_num": 0, "entries_num": 0, "acc": 0}
else:
results_count = {"correct_num": correct_cnt, "entries_num": len(entries), "acc": correct_cnt / len(entries)}
return results_count
def get_multi_choice_prediction(response, all_choices, index2ans):
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " " # add space to avoid partial match
candidates = []
for choice in all_choices: # (A) (B) (C) (D)
# Add the choice to candidates each time it appears in the response
candidates.extend([choice for _ in range(response.count(f"({choice})"))])
if len(candidates) == 0:
for choice in all_choices: # A B C D
# Similarly, add the choice for each occurrence
candidates.extend([choice for _ in range(response.count(f"{choice}"))])
if len(candidates) == 0 and len(response.split()) >= 1:
for index, ans in index2ans.items():
# Add index for each occurrence of ans in response
candidates.extend([index for _ in range(response.count(ans))])
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
if len(candidates) == 0 and len(response.split()) >= 1:
for index, ans in index2ans.items():
if ans in response:
candidates.append(index)
index_ans = False # it's content ans.
if len(candidates) == 0: # still not get answer, randomly choose one.
return random.choice(all_choices)
# return ''
else:
# Count the occurrence of each candidate
candidate_counts = Counter(candidates)
# Select the most frequent candidates
max_count = max(candidate_counts.values())
most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count]
# Combine the most frequent candidates in ABCD order
return "".join(most_frequent_candidates)
def extract_numbers(string):
# Pattern for numbers with Chinese commas
pattern_commas = r"-?\d{1,3}(?:,\d{3})+"
# Pattern for scientific notation
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
# Pattern for simple numbers without Chinese commas
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)"
# Extract numbers with Chinese commas
numbers_with_commas = re.findall(pattern_commas, string)
# Extract numbers in scientific notation
numbers_scientific = re.findall(pattern_scientific, string)
# Extract simple numbers without Chinese commas
numbers_simple = re.findall(pattern_simple, string)
# Combine all extracted numbers
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
return all_numbers
def check_is_number(string):
try:
float(string.replace(",", ""))
return True
except ValueError:
# check if there's comma inside
return False
def count_letters(string):
return sum(c.isalpha() and "a" <= c <= "z" or "A" <= c <= "Z" for c in string)
def normalize_str(string, answer):
# check if characters in the string
# if number, numerize it.
if string == None:
return [string]
string = string.strip()
is_number = check_is_number(string)
if is_number:
string = string.replace(",", "")
string = float(string)
# leave 2 decimal
string = round(string, 2)
return [string]
else: # it's likely to be a string
if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2:
return []
return [string]
def get_fill_blank_prediction(response, answer):
"""get the prediction from the generated response,
return a list of predicted strings or numbers"""
def get_key_subresponses(response):
key_responses = []
response = response.strip("。").strip()
sub_responses = re.split(r"。|\n", response)
indicators_of_keys = ["是", "为", "所以", "等于", "方案", "选择", "正确答案", "因此", "最后", "答案", "结果"]
key_responses = []
for index, resp in enumerate(sub_responses):
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
if index == len(sub_responses) - 1:
indicators_of_keys.extend(["="])
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
shortest_key_response = resp.split(indicator)[-1].strip()
else:
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
shortest_key_response = resp.split(indicator)[-1].strip()
if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
key_responses.append(shortest_key_response)
if len(key_responses) == 0: # did not found any
return [response]
return key_responses
key_responses = get_key_subresponses(response)
pred_list = key_responses.copy() # keep the original string response
for resp in key_responses:
pred_list.extend(extract_numbers(resp))
tmp_pred_list = []
for i in range(len(pred_list)):
tmp_pred_list.extend(normalize_str(pred_list[i], answer))
pred_list = tmp_pred_list
# remove duplicates
pred_list = list(set(pred_list))
return pred_list
def get_TF_prediction(response):
"""get the prediction from the generated response,
return a list of predicted strings or numbers"""
def get_key_subresponses(response):
key_responses = []
response = response.strip("。").strip()
sub_responses = re.split(r"。|\n", response)
indicators_of_keys = ["是", "为", "所以", "判断", "陈述", "说法", "表达", "答案", "结果"]
key_responses = []
for index, resp in enumerate(sub_responses):
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
shortest_key_response = resp.split(indicator)[-1].strip()
else:
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
shortest_key_response = resp.split(indicator)[-1].strip()
if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
key_responses.append(shortest_key_response)
if len(key_responses) == 0: # did not found any
return [response]
return key_responses
key_responses = get_key_subresponses(response)
pred_list = key_responses.copy() # keep the original string response
# remove duplicates
pred_list = list(set(pred_list))
return pred_list
def get_multi_choice_info(options):
start_chr = "A"
all_choices = []
index2ans = {}
for i, option in enumerate(options):
index2ans[chr(ord(start_chr) + i)] = option
all_choices.append(chr(ord(start_chr) + i))
return index2ans, all_choices
def calculate_ins_level_acc(results):
correct_sum = 0
entries_sum = 0
for cat_results in results.values():
correct_sum += cat_results["correct_num"]
entries_sum += cat_results["entries_num"]
if entries_sum == 0:
return 0
return correct_sum / entries_sum
|