| import re |
| import string |
| from typing import Optional |
|
|
|
|
| def _normalize_string(s): |
| if (s.startswith('"') and s.endswith('"')) or ( |
| s.startswith("'") and s.endswith("'") |
| ): |
| return s[1:-1] |
| return s |
|
|
|
|
| def _remove_end_punctuation(unnormalized_string: str) -> str: |
| while ( |
| unnormalized_string |
| and ( |
| unnormalized_string[-1] in string.punctuation |
| or unnormalized_string[-1].isspace() |
| ) |
| and unnormalized_string[-1] != "%" |
| ): |
| unnormalized_string = unnormalized_string[:-1] |
| return unnormalized_string |
|
|
|
|
| class RelaxedCorrectness: |
| """Relaxed correctness metrics. |
| |
| The correctness tolerates certain error ratio defined by max_relative_change. |
| See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: |
| "Following Methani et al. (2020), we use a relaxed accuracy measure for the |
| numeric answers to allow a minor inaccuracy that may result from the automatic |
| data extraction process. We consider an answer to be correct if it is within |
| 5% of the gold answer. For non-numeric answers, we still need an exact match |
| to consider an answer to be correct." |
| """ |
|
|
| def _relaxed_correctness( |
| self, prediction: str, targets: list[str], max_relative_change: float = 0.05 |
| ) -> float: |
| def _to_float(text: str) -> tuple[float | None, bool]: |
| text = text.strip() |
| is_percent = text.endswith("%") |
| try: |
| value = float(text.rstrip("%")) |
| return value, is_percent |
| except ValueError: |
| return None, False |
|
|
| def _is_letter(text: str) -> bool: |
| return text.isalpha() and len(text) == 1 |
|
|
| def _preprocess_text(text: str) -> str: |
| if not any(char.isdigit() for char in text): |
| return _normalize_string(text) |
| else: |
| return _remove_end_punctuation(text).replace(",", "").replace("$", "") |
|
|
| def calculate_relative_change(prediction: float, target: float) -> float: |
| return abs(prediction - target) / max(abs(target), 1e-10) |
|
|
| def _compare_numeric_values( |
| prediction: float, target: float, max_relative_change: float |
| ) -> float: |
| relative_change = calculate_relative_change(prediction, target) |
| return 1.0 if relative_change <= max_relative_change else 0.0 |
|
|
| def _compare_text_values(prediction: str, target: str) -> float: |
| while prediction and prediction[-1] in string.punctuation: |
| prediction = prediction[:-1] |
| return 1.0 if prediction.lower() == target.lower() else 0.0 |
|
|
| def _to_decimal(value: float, is_percent: bool) -> float: |
| return value / 100 if is_percent else value |
|
|
| def _compare_numeric_with_percent( |
| prediction: float, |
| prediction_is_percent: bool, |
| target: float, |
| target_is_percent: bool, |
| max_relative_change: float, |
| ) -> float: |
| |
| value = _compare_numeric_values(prediction, target, max_relative_change) |
|
|
| |
| if value != 1.0 and (prediction_is_percent or target_is_percent): |
| value = max( |
| value, |
| _compare_numeric_values( |
| _to_decimal(prediction, prediction_is_percent), |
| target, |
| max_relative_change, |
| ), |
| _compare_numeric_values( |
| prediction, |
| _to_decimal(target, target_is_percent), |
| max_relative_change, |
| ), |
| ) |
| return value |
|
|
| prediction = _preprocess_text(prediction) |
| prediction_float, prediction_is_percent = _to_float(prediction) |
|
|
| value_list = [] |
| for target in targets: |
| target = _preprocess_text(target) |
| target_float, target_is_percent = _to_float(target) |
|
|
| if prediction_float is not None and target_float is not None: |
| |
| value = _compare_numeric_with_percent( |
| prediction_float, |
| prediction_is_percent, |
| target_float, |
| target_is_percent, |
| max_relative_change, |
| ) |
| elif _is_letter(target) and len(prediction) > 0: |
| |
| value = 1.0 if prediction[0].lower() == target.lower() else 0.0 |
| else: |
| |
| value = _compare_text_values(prediction, target) |
|
|
| value_list.append(value) |
|
|
| return max(value_list) |
|
|
| def score(self, model_answer: str, reference_answer: str | list[str], max_relative_change=0.05) -> float: |
| reference_answer = ( |
| reference_answer |
| if isinstance(reference_answer, list) |
| else [reference_answer] |
| ) |
| return self._relaxed_correctness(model_answer, reference_answer, max_relative_change) |
|
|
|
|
| class ExplicitPromptRelaxedCorrectness(RelaxedCorrectness): |
| """Relaxed correctness for explicit prompt.""" |
|
|
| @property |
| def name(self) -> str: |
| return "explicit_prompt_relaxed_correctness" |
|
|
| def _get_final_answer(self, generation: str) -> str: |
| def _find_last_occurrence(pattern: str, string: str): |
| return string.rfind(pattern) |
|
|
| |
| generation = re.sub(r"([aA]nswer)\**:\**", "\\1:", generation) |
|
|
| final_answer_index = _find_last_occurrence("answer:", generation.lower()) |
|
|
| if final_answer_index != -1: |
| |
| start_index = final_answer_index + len("answer:") |
|
|
| |
| lines = generation[start_index:].split("\n") |
|
|
| |
| final_answer = next((line.strip() for line in lines if line.strip()), "") |
|
|
| |
| final_answer = re.sub(r"[*_\[\]\(\)]", "", final_answer) |
|
|
| return final_answer |
| else: |
| return "" |
|
|
| def score(self, model_answer: str, reference_answer: str | list[str], max_relative_change=0.05) -> float: |
| parsed_model_answer = self._get_final_answer(model_answer) |
| if not parsed_model_answer: |
| |
| return 0.0 |
| return super().score(parsed_model_answer, reference_answer, max_relative_change) |
|
|
| def relaxed_correctness(target: str, |
| prediction: str, |
| max_relative_change: float = 0.05) -> bool: |
| """Calculates relaxed correctness. |
| |
| The correctness tolerates certain error ratio defined by max_relative_change. |
| See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: |
| “Following Methani et al. (2020), we use a relaxed accuracy measure for the |
| numeric answers to allow a minor inaccuracy that may result from the automatic |
| data extraction process. We consider an answer to be correct if it is within |
| 5% of the gold answer. For non-numeric answers, we still need an exact match |
| to consider an answer to be correct.” |
| |
| Args: |
| target: Target string. |
| prediction: Predicted string. |
| max_relative_change: Maximum relative change. |
| |
| Returns: |
| Whether the prediction was correct given the specified tolerance. |
| """ |
|
|
| def _to_float(text: str) -> Optional[float]: |
| try: |
| if text.endswith('%'): |
| |
| return float(text.rstrip('%')) / 100.0 |
| else: |
| return float(text) |
| except ValueError: |
| return None |
| prediction = str(prediction) |
| target = str(target) |
| prediction_float = _to_float(prediction) |
| target_float = _to_float(target) |
| if prediction_float is not None and target_float: |
| relative_change = abs(prediction_float - target_float) / abs(target_float) |
| return relative_change <= max_relative_change |
| else: |
| return prediction.lower() == target.lower() |
|
|
| def eval_one_chart( |
| model_answer: str, |
| reference_answer: str | list[str], |
| max_relative_change: float = 0.05, |
| answer_flag = 'answer:' |
| ) -> float: |
| model_answer = model_answer.strip() |
| reference_answer = reference_answer.strip() |
| reference_answer = reference_answer.lower().replace(answer_flag, '') |
| if answer_flag not in model_answer.lower(): |
| |
| return relaxed_correctness(model_answer, reference_answer, max_relative_change) |
| """Evaluate one chart.""" |
| |
| evaluator = ExplicitPromptRelaxedCorrectness() |
| return evaluator.score(model_answer, reference_answer, max_relative_change) |
|
|
| if __name__ == "__main__": |
| |
| model_answer = "The reasoning above leads to the following answer: 0.009" |
| score = eval_one_chart('2009', '2010', 0.05) |
| print(f"Score: {score}") |