File size: 5,535 Bytes
b0c0df0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os

from datasets import Dataset
from openai import OpenAI
from tqdm import tqdm

from lmms_eval.tasks.charxiv.constant import REASONING_RESP_INST
from lmms_eval.tasks.charxiv.descriptive_utils import (
    build_descriptive_grading_queries,
    descriptive_query_helper,
    get_descriptive_result_gpt,
    postprocess_descriptive_grading_queries,
)
from lmms_eval.tasks.charxiv.reasoning_utils import (
    build_reasoning_grading_queries,
    get_number_instruction,
    get_reasoning_result_gpt,
)

# get environment else return dummy values, in a single line
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY")
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "YOUR_OPENAI_BASE_URL")
MODEL_VERSION = os.getenv("MODEL_VERSION", "YOUR_MODEL_VERSION")

client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)


def charxiv_reasoning_doc_to_text_cot(doc, lmms_eval_specific_kwargs=None):
    inst_category = doc["reasoning_q_source"]
    if inst_category in [1, 2, 3]:
        question = REASONING_RESP_INST[inst_category].format(doc["reasoning_q"])
    # 4: number-in-general -> need to specify the number of decimal places
    elif inst_category == 4:
        question = REASONING_RESP_INST[inst_category].format(doc["reasoning_q"], get_number_instruction(doc["reasoning_a"]))
    return question


def charxiv_descriptive_process_docs(dataset: Dataset) -> Dataset:
    # Four descriptive questions for each reasoning question
    dataset = dataset.repeat(4)

    def _process_row(example, indice):
        q_number = indice % 4 + 1
        descriptive_q = example[f"descriptive_q{q_number}"]
        qid = descriptive_q
        subplot_loc = example["subplot_loc"]
        if subplot_loc is None:
            subplot_row = example["subplot_row"]
            subplot_col = example["subplot_col"]
            subplot_loc = [subplot_row, subplot_col]
        descriptive_q = descriptive_query_helper(descriptive_q, subplot_loc)
        example[f"descriptive_q"] = descriptive_q
        example[f"descriptive_a"] = example[f"descriptive_a{q_number}"]
        return {"qid": qid, **example}

    dataset = dataset.map(_process_row, with_indices=True, num_proc=1)
    return dataset


def charxiv_descriptive_doc_to_text_cot(doc, lmms_eval_specific_kwargs=None):
    return doc["descriptive_q"]


def charxiv_doc_to_visual(doc):
    return [doc["image"].convert("RGB")]


def charxiv_descriptive_doc_to_messages_cot(doc, lmms_eval_specific_kwargs=None):
    question = charxiv_descriptive_doc_to_text_cot(doc, lmms_eval_specific_kwargs)
    visuals = charxiv_doc_to_visual(doc)
    messages = [{"role": "user", "content": []}]
    messages[0]["content"].append({"type": "image", "url": visuals[0]})
    messages[0]["content"].append({"type": "text", "text": question.strip()})
    return messages


def charxiv_reasoning_doc_to_messages_cot(doc, lmms_eval_specific_kwargs=None):
    question = charxiv_reasoning_doc_to_text_cot(doc, lmms_eval_specific_kwargs)
    visuals = charxiv_doc_to_visual(doc)
    messages = [{"role": "user", "content": []}]
    messages[0]["content"].append({"type": "image", "url": visuals[0]})
    messages[0]["content"].append({"type": "text", "text": question.strip()})
    return messages


def charxiv_descriptive_process_results(doc, results):
    figure_id = doc["figure_path"]
    qid = doc["qid"]
    resp_key = f"{figure_id}_{qid}"
    response = results[0].strip()
    answer = doc["descriptive_a"]
    return {"descriptive_acc": {"group": (resp_key, response, answer), "qid": qid}}


def charxiv_descriptive_aggregate_results(results):
    groups = {i: [] for i in range(1, 20)}
    for result in results:
        groups[result["qid"]].append(result["group"])
    queries = build_descriptive_grading_queries(groups)
    combined_queries = []
    for query in tqdm(queries):
        result = get_descriptive_result_gpt(client, query["grading_query"], len(query["resp_keys"]), model=MODEL_VERSION)
        # query contains resp_keys, grading_query, extract_answer and score
        combined_queries.append({**query, **result})
    queries = combined_queries
    # flatten the queries and only keep the necessary fields
    queries = postprocess_descriptive_grading_queries(queries)
    # Return the average score
    scores = [query["score"] for query in queries.values()]
    return sum(scores) / len(scores)


def charxiv_reasoning_process_results(doc, results):
    figure_id = doc["figure_path"]
    inst_category = doc["reasoning_q_source"]
    response = results[0].strip()
    answer = doc["reasoning_a"]
    data = {}
    data["inst_category"] = inst_category
    data["figure_id"] = figure_id
    data["answer"] = answer
    resp_value = {"raw_question": doc["reasoning_q"], "response": response}
    return {"reasoning_acc": {"resp_value": resp_value, "resp_key": figure_id, "data": data}}


def charxiv_reasoning_aggregate_results(results):
    data = {}
    resps = {}
    for i, result in enumerate(results):
        data[i] = result["data"]
        resps[result["resp_key"]] = result["resp_value"]
    queries = build_reasoning_grading_queries(data, resps)
    for figure_id, query in tqdm(queries.items()):
        ext, scr = get_reasoning_result_gpt(client, query["grading_query"])
        queries[figure_id]["extracted_answer"] = ext
        queries[figure_id]["score"] = scr
        queries[figure_id].pop("grading_query")
    # Return the average score
    scores = [query["score"] for query in queries.values()]
    return sum(scores) / len(scores)