File size: 1,405 Bytes
524e3cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""Post-process model output for GAIA exact-match submission."""

import re
from typing import Any, Union


_FINAL_ANSWER_RE = re.compile(
    r"^\s*(?:FINAL\s*ANSWER\s*[::]?\s*)",
    re.IGNORECASE,
)


def normalize_answer(raw: Union[str, int, float, None]) -> Union[str, int, float]:
    """
    Strip wrappers and forbidden prefixes. Prefer returning a string for API compatibility.
    """
    if raw is None:
        return ""
    if isinstance(raw, (int, float)) and not isinstance(raw, bool):
        return raw
    text = str(raw).strip()
    if not text:
        return ""
    text = _FINAL_ANSWER_RE.sub("", text, count=1).strip()
    # Strip common wrappers (single line)
    for prefix in ("The answer is", "Answer:", "ANSWER:", "```", "`"):
        if text.lower().startswith(prefix.lower()):
            text = text[len(prefix) :].strip()
    if text.startswith('"') and text.endswith('"') and len(text) >= 2:
        text = text[1:-1].strip()
    if text.startswith("```"):
        text = re.sub(r"^```\w*\s*", "", text)
        text = re.sub(r"\s*```$", "", text).strip()
    return text.strip()


def maybe_numeric(text: str) -> Union[str, int, float]:
    """If the prompt expects a plain number, allow int/float submission."""
    t = text.strip()
    if re.fullmatch(r"-?\d+", t):
        return int(t)
    if re.fullmatch(r"-?\d+\.\d+", t):
        return float(t)
    return text