File size: 5,535 Bytes
b0c0df0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
from datasets import Dataset
from openai import OpenAI
from tqdm import tqdm
from lmms_eval.tasks.charxiv.constant import REASONING_RESP_INST
from lmms_eval.tasks.charxiv.descriptive_utils import (
build_descriptive_grading_queries,
descriptive_query_helper,
get_descriptive_result_gpt,
postprocess_descriptive_grading_queries,
)
from lmms_eval.tasks.charxiv.reasoning_utils import (
build_reasoning_grading_queries,
get_number_instruction,
get_reasoning_result_gpt,
)
# get environment else return dummy values, in a single line
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY")
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "YOUR_OPENAI_BASE_URL")
MODEL_VERSION = os.getenv("MODEL_VERSION", "YOUR_MODEL_VERSION")
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
def charxiv_reasoning_doc_to_text_cot(doc, lmms_eval_specific_kwargs=None):
inst_category = doc["reasoning_q_source"]
if inst_category in [1, 2, 3]:
question = REASONING_RESP_INST[inst_category].format(doc["reasoning_q"])
# 4: number-in-general -> need to specify the number of decimal places
elif inst_category == 4:
question = REASONING_RESP_INST[inst_category].format(doc["reasoning_q"], get_number_instruction(doc["reasoning_a"]))
return question
def charxiv_descriptive_process_docs(dataset: Dataset) -> Dataset:
# Four descriptive questions for each reasoning question
dataset = dataset.repeat(4)
def _process_row(example, indice):
q_number = indice % 4 + 1
descriptive_q = example[f"descriptive_q{q_number}"]
qid = descriptive_q
subplot_loc = example["subplot_loc"]
if subplot_loc is None:
subplot_row = example["subplot_row"]
subplot_col = example["subplot_col"]
subplot_loc = [subplot_row, subplot_col]
descriptive_q = descriptive_query_helper(descriptive_q, subplot_loc)
example[f"descriptive_q"] = descriptive_q
example[f"descriptive_a"] = example[f"descriptive_a{q_number}"]
return {"qid": qid, **example}
dataset = dataset.map(_process_row, with_indices=True, num_proc=1)
return dataset
def charxiv_descriptive_doc_to_text_cot(doc, lmms_eval_specific_kwargs=None):
return doc["descriptive_q"]
def charxiv_doc_to_visual(doc):
return [doc["image"].convert("RGB")]
def charxiv_descriptive_doc_to_messages_cot(doc, lmms_eval_specific_kwargs=None):
question = charxiv_descriptive_doc_to_text_cot(doc, lmms_eval_specific_kwargs)
visuals = charxiv_doc_to_visual(doc)
messages = [{"role": "user", "content": []}]
messages[0]["content"].append({"type": "image", "url": visuals[0]})
messages[0]["content"].append({"type": "text", "text": question.strip()})
return messages
def charxiv_reasoning_doc_to_messages_cot(doc, lmms_eval_specific_kwargs=None):
question = charxiv_reasoning_doc_to_text_cot(doc, lmms_eval_specific_kwargs)
visuals = charxiv_doc_to_visual(doc)
messages = [{"role": "user", "content": []}]
messages[0]["content"].append({"type": "image", "url": visuals[0]})
messages[0]["content"].append({"type": "text", "text": question.strip()})
return messages
def charxiv_descriptive_process_results(doc, results):
figure_id = doc["figure_path"]
qid = doc["qid"]
resp_key = f"{figure_id}_{qid}"
response = results[0].strip()
answer = doc["descriptive_a"]
return {"descriptive_acc": {"group": (resp_key, response, answer), "qid": qid}}
def charxiv_descriptive_aggregate_results(results):
groups = {i: [] for i in range(1, 20)}
for result in results:
groups[result["qid"]].append(result["group"])
queries = build_descriptive_grading_queries(groups)
combined_queries = []
for query in tqdm(queries):
result = get_descriptive_result_gpt(client, query["grading_query"], len(query["resp_keys"]), model=MODEL_VERSION)
# query contains resp_keys, grading_query, extract_answer and score
combined_queries.append({**query, **result})
queries = combined_queries
# flatten the queries and only keep the necessary fields
queries = postprocess_descriptive_grading_queries(queries)
# Return the average score
scores = [query["score"] for query in queries.values()]
return sum(scores) / len(scores)
def charxiv_reasoning_process_results(doc, results):
figure_id = doc["figure_path"]
inst_category = doc["reasoning_q_source"]
response = results[0].strip()
answer = doc["reasoning_a"]
data = {}
data["inst_category"] = inst_category
data["figure_id"] = figure_id
data["answer"] = answer
resp_value = {"raw_question": doc["reasoning_q"], "response": response}
return {"reasoning_acc": {"resp_value": resp_value, "resp_key": figure_id, "data": data}}
def charxiv_reasoning_aggregate_results(results):
data = {}
resps = {}
for i, result in enumerate(results):
data[i] = result["data"]
resps[result["resp_key"]] = result["resp_value"]
queries = build_reasoning_grading_queries(data, resps)
for figure_id, query in tqdm(queries.items()):
ext, scr = get_reasoning_result_gpt(client, query["grading_query"])
queries[figure_id]["extracted_answer"] = ext
queries[figure_id]["score"] = scr
queries[figure_id].pop("grading_query")
# Return the average score
scores = [query["score"] for query in queries.values()]
return sum(scores) / len(scores)
|