|
|
import os |
|
|
import re |
|
|
import signal |
|
|
from collections import Counter |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import datasets |
|
|
|
|
|
from lmms_eval.utils import eval_logger |
|
|
|
|
|
if os.getenv("PROMPTSTEP") is not None: |
|
|
QUERY_TEMPLATE = "{Question}\n\nThink for up to " + os.getenv("PROMPTSTEP") + " steps." |
|
|
elif os.getenv("PROMPTTOKEN") is not None: |
|
|
QUERY_TEMPLATE = "{Question}\n\nThink for up to " + os.getenv("PROMPTTOKEN") + " tokens." |
|
|
elif os.getenv("PROMPTLONG") is not None: |
|
|
QUERY_TEMPLATE = "{Question}\n\nAnswer after a long amount of thinking. If you feel like you are finished early, spend the extra time trying to double-check your work until you are absolutely sure that you have the correct answer." |
|
|
elif os.getenv("PROMPTSHORT") is not None: |
|
|
QUERY_TEMPLATE = "{Question}\n\nAnswer after a short amount of thinking. Do not spend excessive time double-checking your work." |
|
|
else: |
|
|
QUERY_TEMPLATE = "{Question}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ANSWER_PATTERN = r"(?i)Answer\s*:\s*(.*)" |
|
|
|
|
|
EXTRACTION_TEMPLATE_IDX = r""" |
|
|
Look at the following attempt by a student and extract the student's answer. If it is equivalent (ignoring trivial simplifications) to any of the provided options, return the index of that option starting from 1. Else, return -1. |
|
|
|
|
|
Examples: |
|
|
|
|
|
Options: ['2x+4', '2x', '4x'] |
|
|
Attempt: The answer is 3+2x. |
|
|
|
|
|
-1 |
|
|
(the student's answer is not among the options) |
|
|
|
|
|
Options: ['72,000'] |
|
|
Attempt: 72000 \text{ cents}. |
|
|
|
|
|
1 |
|
|
(always give benefit of the doubt to units and ignore formatting which makes the 1st option match) |
|
|
|
|
|
Options: ['2/(-3)', '2/3'] |
|
|
Attempt: -1 * 2/3 |
|
|
|
|
|
1 |
|
|
(the 1st option matches after trivial simplifications which are fine) |
|
|
|
|
|
Options: ['x=5'] |
|
|
Attempt: 5 |
|
|
|
|
|
1 |
|
|
|
|
|
Options: ['\dfrac{33}{100}'] |
|
|
Attempt: 0.33 |
|
|
|
|
|
1 |
|
|
|
|
|
Options: ['75^\circ'] |
|
|
Attempt: ...various calculations and explanations...hence the answer is $\boxed{x in 75}$. |
|
|
|
|
|
1 |
|
|
|
|
|
Options: ['(1,-3)', '(1,-1)', '(1,0)', '(1,-2)'] |
|
|
Attempt: -2, 1 |
|
|
|
|
|
4 |
|
|
(ignore whitespace and other formatting which makes the 4th option match) |
|
|
|
|
|
Options: ['-2,1'] |
|
|
Attempt: 1, -2 |
|
|
|
|
|
1 |
|
|
(likely a problem where multiple solutions are possible thus ignore order) |
|
|
|
|
|
Options: ['11', '100', '50', '-5', '12', '10'] |
|
|
Attempt: ...$\boxed{12^{\mathrm{th}}}$. |
|
|
|
|
|
5 |
|
|
|
|
|
Options: ['2516_8'] |
|
|
Attempt: 2516 |
|
|
|
|
|
1 |
|
|
(give benefit of the doubt for different bases) |
|
|
|
|
|
Options: ['11\sqrt2'] |
|
|
Attempt: 11\sqrt{2} |
|
|
|
|
|
1 |
|
|
|
|
|
Options: ['11,\! 111,\! 111,\! 100'] |
|
|
Attempt: 11111111100 |
|
|
|
|
|
1 |
|
|
|
|
|
Options: ['\text{Navin}'] |
|
|
Attempt: ...it is navin. |
|
|
|
|
|
1 |
|
|
|
|
|
--- |
|
|
|
|
|
YOUR TASK |
|
|
|
|
|
|
|
|
Respond with only the index of the matching option starting from 1 or -1 if there is absolutely no reasonable match. Do not include a rationale. |
|
|
|
|
|
Options: %(expression1)s |
|
|
Attempt: %(expression2)s |
|
|
""".strip() |
|
|
|
|
|
|
|
|
|
|
|
def extract_answer_idx(sampler, options: List[str], attempt: str): |
|
|
prompt = EXTRACTION_TEMPLATE_IDX % {"expression1": options, "expression2": attempt} |
|
|
response = sampler([dict(content=prompt, role="user")]) |
|
|
return response |
|
|
|
|
|
|
|
|
import time |
|
|
from typing import Any |
|
|
|
|
|
import openai |
|
|
from openai import OpenAI |
|
|
|
|
|
|
|
|
class ChatCompletionSampler: |
|
|
""" |
|
|
Sample from OpenAI's chat completion API |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: str = "gpt-4o-mini", |
|
|
system_message: str | None = None, |
|
|
temperature: float = 0.5, |
|
|
max_tokens: int = 1024, |
|
|
): |
|
|
self.api_key_name = "OPENAI_API_KEY" |
|
|
self.client = OpenAI() |
|
|
|
|
|
self.model = model |
|
|
self.system_message = system_message |
|
|
self.temperature = temperature |
|
|
self.max_tokens = max_tokens |
|
|
self.image_format = "url" |
|
|
|
|
|
def _handle_image(self, image: str, encoding: str = "base64", format: str = "png", fovea: int = 768): |
|
|
new_image = { |
|
|
"type": "image_url", |
|
|
"image_url": { |
|
|
"url": f"data:image/{format};{encoding},{image}", |
|
|
}, |
|
|
} |
|
|
return new_image |
|
|
|
|
|
def _handle_text(self, text: str): |
|
|
return {"type": "text", "text": text} |
|
|
|
|
|
def _pack_message(self, role: str, content): |
|
|
return {"role": str(role), "content": content} |
|
|
|
|
|
def __call__(self, message_list) -> str: |
|
|
if self.system_message: |
|
|
message_list = [self._pack_message("system", self.system_message)] + message_list |
|
|
trial = 0 |
|
|
while True: |
|
|
try: |
|
|
response = self.client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=message_list, |
|
|
temperature=self.temperature, |
|
|
max_tokens=self.max_tokens, |
|
|
) |
|
|
return response.choices[0].message.content |
|
|
|
|
|
except openai.BadRequestError as e: |
|
|
print("Bad Request Error", e) |
|
|
return "" |
|
|
except Exception as e: |
|
|
exception_backoff = 2**trial |
|
|
print( |
|
|
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", |
|
|
e, |
|
|
) |
|
|
time.sleep(exception_backoff) |
|
|
trial += 1 |
|
|
|
|
|
|
|
|
|
|
|
def doc_to_text(doc: dict) -> str: |
|
|
return QUERY_TEMPLATE.format(Question=doc.get("problem", doc.get("question"))) |
|
|
|
|
|
|
|
|
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: |
|
|
def _process_doc(doc: dict) -> dict: |
|
|
solution = doc.get("solution", doc.get("orig_solution", doc.get("orig_orig_solution"))) |
|
|
problem = doc.get("problem", doc.get("question")) |
|
|
answer = doc.get("answer", doc.get("orig_answer", doc.get("orig_orig_answer"))) |
|
|
if solution is None: |
|
|
print("Warning: No solution found; DOC:", doc) |
|
|
out_doc = { |
|
|
"problem": problem, |
|
|
"solution": solution, |
|
|
"answer": answer, |
|
|
} |
|
|
if getattr(doc, "few_shot", None) is not None: |
|
|
out_doc["few_shot"] = True |
|
|
return out_doc |
|
|
|
|
|
return dataset.map(_process_doc) |
|
|
|
|
|
|
|
|
def process_results(doc: dict, results: List[str]) -> Dict[str, int]: |
|
|
metrics = {"exact_match": None, "extracted_answers": []} |
|
|
|
|
|
|
|
|
if len(results) > 1: |
|
|
n_res = len(results) |
|
|
n_res_list = [2**i for i in range(1, int(n_res.bit_length()))] |
|
|
metrics = { |
|
|
**metrics, |
|
|
"exact_matches": [], |
|
|
**{f"cov@{n}": -1 for n in n_res_list}, |
|
|
**{f"maj@{n}": -1 for n in n_res_list}, |
|
|
**{f"avg@{n}": -1 for n in n_res_list}, |
|
|
} |
|
|
else: |
|
|
n_res_list = [] |
|
|
metrics["exact_matches"] = [] |
|
|
|
|
|
if os.getenv("PROCESSOR", "") == "gpt-4o-mini": |
|
|
sampler = ChatCompletionSampler(model="gpt-4o-mini") |
|
|
else: |
|
|
print(f"Unknown processor: {os.getenv('PROCESSOR')}; set 'PROCESSOR=gpt-4o-mini' and 'OPENAI_API_KEY=YOUR_KEY' for best results.") |
|
|
sampler = None |
|
|
|
|
|
if isinstance(doc["answer"], str) and doc["answer"].isdigit(): |
|
|
gt = str(int(doc["answer"])) |
|
|
else: |
|
|
gt = str(doc["answer"]) |
|
|
split_tokens = ["<|im_start|>answer\n", "<|im_start|>"] |
|
|
|
|
|
for i, a in enumerate(results, start=1): |
|
|
if split_tokens[0] in a: |
|
|
a = a.split(split_tokens[0])[-1] |
|
|
elif split_tokens[1] in a: |
|
|
a = a.split(split_tokens[1])[-1] |
|
|
if "\n" in a: |
|
|
a = "\n".join(a.split("\n")[1:]) |
|
|
|
|
|
if (box := last_boxed_only_string(a)) is not None: |
|
|
a = remove_boxed(box) |
|
|
|
|
|
elif (matches := re.findall(ANSWER_PATTERN, a, re.DOTALL)) != []: |
|
|
a = matches[-1] |
|
|
|
|
|
|
|
|
if (a.isdigit()) and (gt.isdigit()): |
|
|
a = str(int(a)) |
|
|
elif sampler is not None: |
|
|
options = [gt] + list(set(metrics["extracted_answers"]) - {gt}) |
|
|
if len(options) > 7: |
|
|
|
|
|
|
|
|
print("Warning: Lots of options which may harm indexing performance:", options) |
|
|
|
|
|
options_str = "[" + ", ".join(["'" + str(o) + "'" for o in options]) + "]" |
|
|
|
|
|
idx = extract_answer_idx(sampler, options_str, a) |
|
|
if idx != "-1": |
|
|
if idx.isdigit(): |
|
|
idx = int(idx) - 1 |
|
|
if len(options) > idx >= 0: |
|
|
a = options[idx] |
|
|
else: |
|
|
print("Warning: Index out of bounds; leaving answer unchanged\n", a, "\noptions", options_str, "\ndoc['answer']", gt, "\nidx", idx) |
|
|
else: |
|
|
print("Warning: Processing did not produce integer index\na", a, "\noptions", options_str, "\ndoc['answer']", gt, "\nidx", idx) |
|
|
else: |
|
|
pass |
|
|
|
|
|
metrics["extracted_answers"].append(a) |
|
|
a = int(a == gt) |
|
|
if not (a): |
|
|
print("Marked incorrect\na " + metrics["extracted_answers"][-1] + "\ndoc['answer'] " + gt) |
|
|
if i == 1: |
|
|
metrics["exact_match"] = a |
|
|
if "exact_matches" in metrics: |
|
|
metrics["exact_matches"].append(a) |
|
|
elif i > 1: |
|
|
metrics["exact_matches"].append(a) |
|
|
if i in n_res_list: |
|
|
metrics[f"cov@{i}"] = int(1 in metrics["exact_matches"]) |
|
|
metrics[f"avg@{i}"] = sum(metrics["exact_matches"]) / i |
|
|
metrics[f"maj@{i}"] = int(gt == Counter(metrics["extracted_answers"]).most_common(1)[0][0]) |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def last_boxed_only_string(string: str) -> Optional[str]: |
|
|
idx = string.rfind("\\boxed") |
|
|
if "\\boxed " in string: |
|
|
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] |
|
|
if idx < 0: |
|
|
idx = string.rfind("\\fbox") |
|
|
if idx < 0: |
|
|
return None |
|
|
|
|
|
i = idx |
|
|
right_brace_idx = None |
|
|
num_left_braces_open = 0 |
|
|
while i < len(string): |
|
|
if string[i] == "{": |
|
|
num_left_braces_open += 1 |
|
|
if string[i] == "}": |
|
|
num_left_braces_open -= 1 |
|
|
if num_left_braces_open == 0: |
|
|
right_brace_idx = i |
|
|
break |
|
|
i += 1 |
|
|
|
|
|
if right_brace_idx is None: |
|
|
retval = None |
|
|
else: |
|
|
retval = string[idx : right_brace_idx + 1] |
|
|
|
|
|
return retval |
|
|
|
|
|
|
|
|
def remove_boxed(s: str) -> str: |
|
|
if "\\boxed " in s: |
|
|
left = "\\boxed " |
|
|
assert s[: len(left)] == left |
|
|
return s[len(left) :] |
|
|
|
|
|
left = "\\boxed{" |
|
|
|
|
|
assert s[: len(left)] == left |
|
|
assert s[-1] == "}" |
|
|
|
|
|
return s[len(left) : -1] |
|
|
|