agentic-rl-main / data_utils /chart /evaluator.py
Jack04810's picture
Add files using upload-large-folder tool
f5ab8aa verified
Raw
History Blame Contribute Delete
9.41 kB
import re
import string
from typing import Optional
def _normalize_string(s):
if (s.startswith('"') and s.endswith('"')) or (
s.startswith("'") and s.endswith("'")
):
return s[1:-1]
return s
def _remove_end_punctuation(unnormalized_string: str) -> str:
while (
unnormalized_string
and (
unnormalized_string[-1] in string.punctuation
or unnormalized_string[-1].isspace()
)
and unnormalized_string[-1] != "%"
):
unnormalized_string = unnormalized_string[:-1]
return unnormalized_string
class RelaxedCorrectness:
"""Relaxed correctness metrics.
The correctness tolerates certain error ratio defined by max_relative_change.
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
"Following Methani et al. (2020), we use a relaxed accuracy measure for the
numeric answers to allow a minor inaccuracy that may result from the automatic
data extraction process. We consider an answer to be correct if it is within
5% of the gold answer. For non-numeric answers, we still need an exact match
to consider an answer to be correct."
"""
def _relaxed_correctness(
self, prediction: str, targets: list[str], max_relative_change: float = 0.05
) -> float:
def _to_float(text: str) -> tuple[float | None, bool]:
text = text.strip()
is_percent = text.endswith("%")
try:
value = float(text.rstrip("%"))
return value, is_percent
except ValueError:
return None, False
def _is_letter(text: str) -> bool:
return text.isalpha() and len(text) == 1
def _preprocess_text(text: str) -> str:
if not any(char.isdigit() for char in text):
return _normalize_string(text)
else:
return _remove_end_punctuation(text).replace(",", "").replace("$", "")
def calculate_relative_change(prediction: float, target: float) -> float:
return abs(prediction - target) / max(abs(target), 1e-10)
def _compare_numeric_values(
prediction: float, target: float, max_relative_change: float
) -> float:
relative_change = calculate_relative_change(prediction, target)
return 1.0 if relative_change <= max_relative_change else 0.0
def _compare_text_values(prediction: str, target: str) -> float:
while prediction and prediction[-1] in string.punctuation:
prediction = prediction[:-1]
return 1.0 if prediction.lower() == target.lower() else 0.0
def _to_decimal(value: float, is_percent: bool) -> float:
return value / 100 if is_percent else value
def _compare_numeric_with_percent(
prediction: float,
prediction_is_percent: bool,
target: float,
target_is_percent: bool,
max_relative_change: float,
) -> float:
# Compare as-is
value = _compare_numeric_values(prediction, target, max_relative_change)
# If not equal and one is percent, try other comparisons
if value != 1.0 and (prediction_is_percent or target_is_percent):
value = max(
value,
_compare_numeric_values(
_to_decimal(prediction, prediction_is_percent),
target,
max_relative_change,
),
_compare_numeric_values(
prediction,
_to_decimal(target, target_is_percent),
max_relative_change,
),
)
return value
prediction = _preprocess_text(prediction)
prediction_float, prediction_is_percent = _to_float(prediction)
value_list = []
for target in targets:
target = _preprocess_text(target)
target_float, target_is_percent = _to_float(target)
if prediction_float is not None and target_float is not None:
# Compare as numeric values
value = _compare_numeric_with_percent(
prediction_float,
prediction_is_percent,
target_float,
target_is_percent,
max_relative_change,
)
elif _is_letter(target) and len(prediction) > 0:
# Compare as multiple choice options: take first letter from prediction
value = 1.0 if prediction[0].lower() == target.lower() else 0.0
else:
# Compare as text values
value = _compare_text_values(prediction, target)
value_list.append(value)
return max(value_list)
def score(self, model_answer: str, reference_answer: str | list[str], max_relative_change=0.05) -> float:
reference_answer = (
reference_answer
if isinstance(reference_answer, list)
else [reference_answer]
)
return self._relaxed_correctness(model_answer, reference_answer, max_relative_change)
class ExplicitPromptRelaxedCorrectness(RelaxedCorrectness):
"""Relaxed correctness for explicit prompt."""
@property
def name(self) -> str:
return "explicit_prompt_relaxed_correctness"
def _get_final_answer(self, generation: str) -> str:
def _find_last_occurrence(pattern: str, string: str):
return string.rfind(pattern)
# Strip extraneous markdown around the answer:
generation = re.sub(r"([aA]nswer)\**:\**", "\\1:", generation)
final_answer_index = _find_last_occurrence("answer:", generation.lower())
if final_answer_index != -1:
# Find the start of the answer (after "final answer:")
start_index = final_answer_index + len("answer:")
# Split the remaining text into lines
lines = generation[start_index:].split("\n")
# Find the first non-empty line
final_answer = next((line.strip() for line in lines if line.strip()), "")
# Remove any markdown formatting
final_answer = re.sub(r"[*_\[\]\(\)]", "", final_answer)
return final_answer
else:
return ""
def score(self, model_answer: str, reference_answer: str | list[str], max_relative_change=0.05) -> float:
parsed_model_answer = self._get_final_answer(model_answer)
if not parsed_model_answer:
# Parsing failed.
return 0.0
return super().score(parsed_model_answer, reference_answer, max_relative_change)
def relaxed_correctness(target: str,
prediction: str,
max_relative_change: float = 0.05) -> bool:
"""Calculates relaxed correctness.
The correctness tolerates certain error ratio defined by max_relative_change.
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
“Following Methani et al. (2020), we use a relaxed accuracy measure for the
numeric answers to allow a minor inaccuracy that may result from the automatic
data extraction process. We consider an answer to be correct if it is within
5% of the gold answer. For non-numeric answers, we still need an exact match
to consider an answer to be correct.”
Args:
target: Target string.
prediction: Predicted string.
max_relative_change: Maximum relative change.
Returns:
Whether the prediction was correct given the specified tolerance.
"""
def _to_float(text: str) -> Optional[float]:
try:
if text.endswith('%'):
# Convert percentages to floats.
return float(text.rstrip('%')) / 100.0
else:
return float(text)
except ValueError:
return None
prediction = str(prediction)
target = str(target)
prediction_float = _to_float(prediction)
target_float = _to_float(target)
if prediction_float is not None and target_float:
relative_change = abs(prediction_float - target_float) / abs(target_float)
return relative_change <= max_relative_change
else:
return prediction.lower() == target.lower()
def eval_one_chart(
model_answer: str,
reference_answer: str | list[str],
max_relative_change: float = 0.05,
answer_flag = 'answer:'
) -> float:
model_answer = model_answer.strip()
reference_answer = reference_answer.strip()
reference_answer = reference_answer.lower().replace(answer_flag, '')
if answer_flag not in model_answer.lower():
# If the answer is not in the model answer, we can use the relaxed correctness.
return relaxed_correctness(model_answer, reference_answer, max_relative_change)
"""Evaluate one chart."""
# return relaxed_correctness(model_answer, reference_answer)
evaluator = ExplicitPromptRelaxedCorrectness()
return evaluator.score(model_answer, reference_answer, max_relative_change)
if __name__ == "__main__":
# Example usage
model_answer = "The reasoning above leads to the following answer: 0.009"
score = eval_one_chart('2009', '2010', 0.05)
print(f"Score: {score}")