File size: 5,947 Bytes
c2446d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a5f5e9
c2446d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""GAIA exact-match ์ฑ„์ ์— ๋งž์ถ˜ ๋‹ต๋ณ€ ํฌ๋งท ํ›„์ฒ˜๋ฆฌ.

๋‘ ๋‹จ๊ณ„๋กœ ๊ตฌ์„ฑ:
1. final_format_pass(question, raw): LLM ํ•œ ๋ฒˆ ๋” ํ˜ธ์ถœํ•ด์„œ GAIA ํฌ๋งท์œผ๋กœ๋งŒ ๋ณ€ํ™˜.
   B ์นดํ…Œ๊ณ ๋ฆฌ(๋‚ด์šฉ ๋งž๊ณ  ํ˜•์‹ ์œ„๋ฐ˜) ํšŒ๋ณต์šฉ. ์งง์€ reformat ์ „์šฉ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ.
2. coerce_answer(question, ans): ๊ฒฐ์ •์  regex ํ›„์ฒ˜๋ฆฌ. yes/no, ์ˆซ์ž, ํ†ตํ™” ๋“ฑ
   ํ™•์‹คํ•œ ํŒจํ„ด๋งŒ ๊ฐ•์ œ. ๋งค์นญ ์‹คํŒจ ์‹œ ์›๋ณธ ์œ ์ง€(์ž˜๋ชป ๊ฐ•์ œํ•˜๋ฉด ๋” ๋ง์นจ).

์ˆœ์„œ: __call__์—์„œ raw โ†’ strip prefixes/quotes โ†’ final_format_pass โ†’ coerce_answer.
"""
import re
import unicodedata


# yes/no ์งˆ๋ฌธ ์‹œ์ž‘ ํ›„๋ณด ํ‚ค์›Œ๋“œ. ์˜์–ด ์˜๋ฌธ๋ฌธ์ด ์ด ๋ณด์กฐ๋™์‚ฌ๋กœ ์‹œ์ž‘ํ•˜๊ณ  ?๋กœ ๋๋‚˜๋ฉด
# ๋Œ€๊ฐœ yes/no ๋‹ต์„ ๊ธฐ๋Œ€ํ•˜๋Š” ํ˜•ํƒœ.
_YES_NO_STARTS = (
    "is ", "are ", "was ", "were ", "do ", "does ", "did ",
    "has ", "have ", "had ", "can ", "could ", "should ",
    "will ", "would ", "may ", "might ",
)


def _looks_yes_no(question: str) -> bool:
    q = question.strip().lower()
    if "yes or no" in q or "yes/no" in q:
        return True
    if not q.endswith("?"):
        return False
    return any(q.startswith(s) for s in _YES_NO_STARTS)


def _looks_numeric(question: str) -> bool:
    q = question.lower()
    return (
        "how many" in q
        or "what number" in q
        or "what is the number of" in q
        # "how much" ๋Š” ๋‹จ์œ„ ํฌํ•จ ๋‹ต์„ ์›ํ•  ์ˆ˜๋„ ์žˆ์–ด ์ œ์™ธ(์˜ˆ: "how much money" โ†’ "$1.5M").
    )


def coerce_answer(question: str, answer: str) -> str:
    """์งˆ๋ฌธ ํ˜•์‹ ํžŒํŠธ์— ๋งž์ถฐ LLM ๋‹ต์„ ๋ณด์ •. ํžŒํŠธ๊ฐ€ ์—†๊ฑฐ๋‚˜ ๋งค์นญ ์‹คํŒจ ์‹œ ์›๋ณธ ๋ฐ˜ํ™˜."""
    a = answer.strip()
    if not a:
        return a

    # 1) Yes/No ์งˆ๋ฌธ โ€” ์ฒซ ๋‹จ์–ด๋กœ ๊ฒฐ์ •.
    if _looks_yes_no(question):
        first = a.split(None, 1)[0].rstrip(",.").lower() if a.split() else ""
        if first == "yes":
            return "Yes"
        if first == "no":
            return "No"
        # ๋งค์นญ ์‹คํŒจ ์‹œ ์›๋ณธ ์œ ์ง€(์ž˜๋ชป ๊ฐ•์ œํ•˜๋ฉด ๋” ๋ง์นจ).
        return a

    # 2) ์ˆœ์ˆ˜ ์ˆซ์ž ์งˆ๋ฌธ โ€” ๋‹ต ์•ˆ์˜ ์ฒซ ์ •์ˆ˜/์‹ค์ˆ˜๋งŒ ์ถ”์ถœ.
    if _looks_numeric(question):
        m = re.search(r"-?\d+(?:\.\d+)?", a.replace(",", ""))
        if m:
            num = m.group(0)
            try:
                f = float(num)
                if f.is_integer():
                    return str(int(f))
                return num
            except ValueError:
                pass
        return a

    # 3) ๋‹ต์ด ํ†ตํ™”๊ธฐํ˜ธ+์ˆซ์ž ํŒจํ„ด์ด๋ฉด ๊ธฐํ˜ธ/์ฝค๋งˆ/๊ณต๋ฐฑ๋งŒ ์ œ๊ฑฐ.
    # "$1,234" โ†’ "1234", "1,234.5" โ†’ "1234.5"
    if re.fullmatch(r"\s*[\$โ‚ฌยฃยฅ]?\s*-?[\d,]+(?:\.\d+)?\s*", a):
        cleaned = re.sub(r"[\$โ‚ฌยฃยฅ,\s]", "", a)
        if cleaned:
            return cleaned

    return a


# Final-answer formatter pass์šฉ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ. ์งง๊ณ  ๋ถ€์ •ํ˜• ์ตœ์†Œํ™”.
_FORMAT_SYSTEM_PROMPT = """You reformat agent answers to match the GAIA benchmark
exact-match grading rules. You receive a question and a draft answer, and output the
final answer string ONLY (no explanation, no preamble).

Rules:
- Numbers: plain digits, no commas, no currency/units unless the question asks for them.
- Strings: minimal exact form. No articles ("the", "a"), no abbreviations unless
  abbreviation is the expected form. No surrounding quotes.
- Lists: comma + single space ("apple, banana, cherry"), in the order requested.
- Yes/no questions: exactly "Yes" or "No".
- "Give only the first name" โ†’ output only the first name, no surname.
- "Give only the city name" โ†’ only the city, no country/state.
- If the draft already matches all applicable rules, output it unchanged.
- If the draft is "UNKNOWN" or admits inability, output "UNKNOWN".

Output only the answer string, nothing else.
"""


def final_format_pass(
    question: str,
    raw_answer: str,
    model_id: str = "Qwen/Qwen2.5-72B-Instruct",
) -> str:
    """LLM ํ•œ ๋ฒˆ ๋” ํ˜ธ์ถœํ•ด raw ๋‹ต์„ GAIA ํฌ๋งท์œผ๋กœ๋งŒ ๋ณ€ํ™˜.

    ํ˜ธ์ถœ ์‹คํŒจ(rate-limit, ํƒ€์ž„์•„์›ƒ ๋“ฑ) ์‹œ raw_answer๋ฅผ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜ โ€” graceful
    degrade. coerce_answer๊ฐ€ ๋งˆ์ง€๋ง‰ ์•ˆ์ „๋ง์ด๋ฏ€๋กœ ์ด ๋‹จ๊ณ„๊ฐ€ ์‹คํŒจํ•ด๋„ ํฐ ์†ํ•ด๋Š” ์—†์Œ.

    ์œ ๋‹ˆ์ฝ”๋“œ ์ •๊ทœํ™”(NFC)๋„ ๊ฐ™์ด ์ˆ˜ํ–‰ํ•ด์„œ ๋ณด์ด์ง€ ์•Š๋Š” ๋ณ€ํ˜• ๊ธ€์ž(์˜ˆ: ๊ฒฐํ•ฉ ๊ธ€์ž
    ๋ถ„ํ•ด๋œ ํ˜•ํƒœ)๋กœ ์ธํ•œ mismatch ๋ฐฉ์ง€.

    Args:
        question: ์›๋ณธ ์งˆ๋ฌธ ๋ณธ๋ฌธ.
        raw_answer: ์—์ด์ „ํŠธ๊ฐ€ final_answer๋กœ ๋„˜๊ธด raw ๋‹ต.
        model_id: ํฌ๋งท ๋ณ€ํ™˜์— ์“ธ ๋ชจ๋ธ (๊ธฐ๋ณธ์€ ๋ฉ”์ธ ๋ชจ๋ธ๊ณผ ๋™์ผ).

    Returns:
        ํฌ๋งท ์ •๋ฆฌ๋œ ๋‹ต ๋˜๋Š” raw_answer (ํ˜ธ์ถœ ์‹คํŒจ ์‹œ).
    """
    if not raw_answer or raw_answer.strip().upper() == "UNKNOWN":
        return raw_answer
    try:
        from huggingface_hub import InferenceClient
        client = InferenceClient(provider="auto")
        resp = client.chat_completion(
            model=model_id,
            messages=[
                {"role": "system", "content": _FORMAT_SYSTEM_PROMPT},
                {
                    "role": "user",
                    "content": f"Question: {question}\n\nDraft answer: {raw_answer}\n\nFinal answer:",
                },
            ],
            max_tokens=200,  # ๋‹ต๋ณ€ ์ž์ฒด๋Š” ์งง์Œ
        )
        formatted = (resp.choices[0].message.content or "").strip()
        if not formatted:
            return raw_answer
        # ์–‘๋ ๋”ฐ์˜ดํ‘œ ํ•œ ์Œ ์ œ๊ฑฐ (๋ชจ๋ธ์ด ์ข…์ข… "X" ํ˜•ํƒœ๋กœ ๋‘˜๋Ÿฌ์Œˆ)
        if len(formatted) >= 2 and (
            (formatted[0] == '"' and formatted[-1] == '"')
            or (formatted[0] == "'" and formatted[-1] == "'")
        ):
            formatted = formatted[1:-1].strip()
        # NFC ์ •๊ทœํ™”: ๊ฒฐํ•ฉ ๊ธ€์ž(์˜ˆ: ล‚, รฉ) ๋ณ€ํ˜• ํ†ต์ผ
        formatted = unicodedata.normalize("NFC", formatted)
        return formatted
    except Exception as e:
        print(f"final_format_pass failed (using raw): {e}")
        return raw_answer