File size: 2,372 Bytes
b0c0df0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import json
import re
from collections import Counter

from loguru import logger

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

PROMPT = """Question: {}
(A) {}
(B) {}
(C) {}
(D) {}
(E) {}
(F) {}"""


def ii_bench_doc_to_text(doc, lmms_eval_specific_kwargs):
    question = PROMPT.format(doc["question"], doc["option1"], doc["option2"], doc["option3"], doc["option4"], doc["option5"], doc["option6"])
    pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
    post_prompt = lmms_eval_specific_kwargs["post_prompt"]
    return f"{pre_prompt}{question}{post_prompt}"


def ii_bench_doc_to_visual(doc):
    return [doc["image"].convert("RGB")]


def extract_option_labels(text, options=None):
    if isinstance(text, dict):
        return "error"
    pattern = r"\(([A-F])\)"
    matches = re.findall(pattern, text)

    if not matches:
        pattern = r"\b([A-F])\b"
        matches = re.findall(pattern, text)

    if matches:
        counter = Counter(matches)
        most_common = counter.most_common()
        max_count = most_common[0][1]
        candidates = [item for item in most_common if item[1] == max_count]
        return candidates[-1][0]
    else:
        if options:
            counter = Counter()
            for i, option in enumerate(options, start=1):
                label = chr(64 + i)
                option_stripped = option.strip()
                if option_stripped in text:
                    counter[label] += 1
                elif text in option:
                    counter[label] += 1
            if counter:
                most_common = counter.most_common()
                max_count = most_common[0][1]
                candidates = [item for item in most_common if item[1] == max_count]
                return candidates[-1][0]
    return None


def ii_bench_process_results(doc, results):
    response = results[0]
    predict = extract_option_labels(response, [doc["option1"], doc["option2"], doc["option3"], doc["option4"], doc["option5"], doc["option6"]])
    return {"submission": {"id": doc["id"], "predict_answer": predict, "response": response}}


def ii_bench_aggregate_submissions(results, args):
    file = generate_submission_file("ii_bench_test_for_submission.json", args)
    with open(file, "w") as f:
        json.dump(results, f, indent=4)
    logger.info(f"Results saved to {file}")