File size: 3,110 Bytes
b0c0df0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import random
import re
from typing import Dict, List, Tuple

import numpy as np

assertion_prompt = """Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A or B."""

mcq_prompt = """Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, or D."""


def csbench_mcq_doc_to_text(doc: Dict, lmms_eval_specific_kwargs: Dict) -> str:
    q = doc["Question"]
    a = doc["A"]
    b = doc["B"]
    c = doc["C"]
    d = doc["D"]
    question = f"{assertion_prompt}\nQuestion: {q}\nA: {a}\nB: {b}\nC: {c}\nD: {d}\n"
    return question


def csbench_assertion_doc_to_text(doc: Dict, lmms_eval_specific_kwargs: Dict) -> str:
    q = doc["Question"]
    question = f"{assertion_prompt}\nQuestion: {q}\n A: True\n B: False\n"
    return question


def csbench_doc_to_target(doc: Dict) -> str:
    if doc["Format"].strip() == "Multiple-choice":
        return doc["Answer"].strip().upper()
    else:
        return "A" if doc["Answer"].strip() == "True" else "B"


def csbench_doc_to_choice(doc: Dict) -> List[str]:
    if doc["Format"].strip() == "Multiple-choice":
        return ["A", "B", "C", "D"]
    else:
        return ["A", "B"]


def parse_multi_choice_response(response, all_choices):
    """
    Parse the prediction from the generated response.
    Return the predicted choice letter e.g., A, B, C, D.
    """
    # Clean response of unwanted characters
    for char in [",", ".", "!", "?", ";", ":", "'"]:
        response = response.strip(char)
    response = " " + response + " "  # Add space to avoid partial match

    candidates = []
    # Look for choices with parentheses, e.g., (A)
    for choice in all_choices:
        if f"({choice})" in response:
            candidates.append(choice)

    # Look for simple choices, e.g., A, B, C
    if len(candidates) == 0:
        for choice in all_choices:
            if f" {choice} " in response:
                candidates.append(choice)

    # Look for choices with periods, e.g., A., B., C.
    if len(candidates) == 0:
        for choice in all_choices:
            if f"{choice}." in response:
                candidates.append(choice)

    # If no candidates, randomly choose one
    if len(candidates) == 0:
        pred_index = random.choice(all_choices)
    elif len(candidates) > 1:
        # If more than one candidate, choose the last one found
        start_indexes = [response.rfind(f" {can} ") for can in candidates]
        pred_index = candidates[np.argmax(start_indexes)]
    else:
        # If only one candidate, use it
        pred_index = candidates[0]

    return pred_index


def csbench_process_results(doc: Dict, result: List[str]) -> Dict[str, float]:
    pred = parse_multi_choice_response(result[0], csbench_doc_to_choice(doc))
    gt = csbench_doc_to_target(doc)
    score = 1.0 if pred == gt else 0.0
    return {"accuracy": score}