| """GAIA exact-match ์ฑ์ ์ ๋ง์ถ ๋ต๋ณ ํฌ๋งท ํ์ฒ๋ฆฌ. |
| |
| ๋ ๋จ๊ณ๋ก ๊ตฌ์ฑ: |
| 1. final_format_pass(question, raw): LLM ํ ๋ฒ ๋ ํธ์ถํด์ GAIA ํฌ๋งท์ผ๋ก๋ง ๋ณํ. |
| B ์นดํ
๊ณ ๋ฆฌ(๋ด์ฉ ๋ง๊ณ ํ์ ์๋ฐ) ํ๋ณต์ฉ. ์งง์ reformat ์ ์ฉ ์์คํ
ํ๋กฌํํธ. |
| 2. coerce_answer(question, ans): ๊ฒฐ์ ์ regex ํ์ฒ๋ฆฌ. yes/no, ์ซ์, ํตํ ๋ฑ |
| ํ์คํ ํจํด๋ง ๊ฐ์ . ๋งค์นญ ์คํจ ์ ์๋ณธ ์ ์ง(์๋ชป ๊ฐ์ ํ๋ฉด ๋ ๋ง์นจ). |
| |
| ์์: __call__์์ raw โ strip prefixes/quotes โ final_format_pass โ coerce_answer. |
| """ |
| import re |
| import unicodedata |
|
|
|
|
| |
| |
| _YES_NO_STARTS = ( |
| "is ", "are ", "was ", "were ", "do ", "does ", "did ", |
| "has ", "have ", "had ", "can ", "could ", "should ", |
| "will ", "would ", "may ", "might ", |
| ) |
|
|
|
|
| def _looks_yes_no(question: str) -> bool: |
| q = question.strip().lower() |
| if "yes or no" in q or "yes/no" in q: |
| return True |
| if not q.endswith("?"): |
| return False |
| return any(q.startswith(s) for s in _YES_NO_STARTS) |
|
|
|
|
| def _looks_numeric(question: str) -> bool: |
| q = question.lower() |
| return ( |
| "how many" in q |
| or "what number" in q |
| or "what is the number of" in q |
| |
| ) |
|
|
|
|
| def coerce_answer(question: str, answer: str) -> str: |
| """์ง๋ฌธ ํ์ ํํธ์ ๋ง์ถฐ LLM ๋ต์ ๋ณด์ . ํํธ๊ฐ ์๊ฑฐ๋ ๋งค์นญ ์คํจ ์ ์๋ณธ ๋ฐํ.""" |
| a = answer.strip() |
| if not a: |
| return a |
|
|
| |
| if _looks_yes_no(question): |
| first = a.split(None, 1)[0].rstrip(",.").lower() if a.split() else "" |
| if first == "yes": |
| return "Yes" |
| if first == "no": |
| return "No" |
| |
| return a |
|
|
| |
| if _looks_numeric(question): |
| m = re.search(r"-?\d+(?:\.\d+)?", a.replace(",", "")) |
| if m: |
| num = m.group(0) |
| try: |
| f = float(num) |
| if f.is_integer(): |
| return str(int(f)) |
| return num |
| except ValueError: |
| pass |
| return a |
|
|
| |
| |
| if re.fullmatch(r"\s*[\$โฌยฃยฅ]?\s*-?[\d,]+(?:\.\d+)?\s*", a): |
| cleaned = re.sub(r"[\$โฌยฃยฅ,\s]", "", a) |
| if cleaned: |
| return cleaned |
|
|
| return a |
|
|
|
|
| |
| _FORMAT_SYSTEM_PROMPT = """You reformat agent answers to match the GAIA benchmark |
| exact-match grading rules. You receive a question and a draft answer, and output the |
| final answer string ONLY (no explanation, no preamble). |
| |
| Rules: |
| - Numbers: plain digits, no commas, no currency/units unless the question asks for them. |
| - Strings: minimal exact form. No articles ("the", "a"), no abbreviations unless |
| abbreviation is the expected form. No surrounding quotes. |
| - Lists: comma + single space ("apple, banana, cherry"), in the order requested. |
| - Yes/no questions: exactly "Yes" or "No". |
| - "Give only the first name" โ output only the first name, no surname. |
| - "Give only the city name" โ only the city, no country/state. |
| - If the draft already matches all applicable rules, output it unchanged. |
| - If the draft is "UNKNOWN" or admits inability, output "UNKNOWN". |
| |
| Output only the answer string, nothing else. |
| """ |
|
|
|
|
| def final_format_pass( |
| question: str, |
| raw_answer: str, |
| model_id: str = "Qwen/Qwen2.5-72B-Instruct", |
| ) -> str: |
| """LLM ํ ๋ฒ ๋ ํธ์ถํด raw ๋ต์ GAIA ํฌ๋งท์ผ๋ก๋ง ๋ณํ. |
| |
| ํธ์ถ ์คํจ(rate-limit, ํ์์์ ๋ฑ) ์ raw_answer๋ฅผ ๊ทธ๋๋ก ๋ฐํ โ graceful |
| degrade. coerce_answer๊ฐ ๋ง์ง๋ง ์์ ๋ง์ด๋ฏ๋ก ์ด ๋จ๊ณ๊ฐ ์คํจํด๋ ํฐ ์ํด๋ ์์. |
| |
| ์ ๋์ฝ๋ ์ ๊ทํ(NFC)๋ ๊ฐ์ด ์ํํด์ ๋ณด์ด์ง ์๋ ๋ณํ ๊ธ์(์: ๊ฒฐํฉ ๊ธ์ |
| ๋ถํด๋ ํํ)๋ก ์ธํ mismatch ๋ฐฉ์ง. |
| |
| Args: |
| question: ์๋ณธ ์ง๋ฌธ ๋ณธ๋ฌธ. |
| raw_answer: ์์ด์ ํธ๊ฐ final_answer๋ก ๋๊ธด raw ๋ต. |
| model_id: ํฌ๋งท ๋ณํ์ ์ธ ๋ชจ๋ธ (๊ธฐ๋ณธ์ ๋ฉ์ธ ๋ชจ๋ธ๊ณผ ๋์ผ). |
| |
| Returns: |
| ํฌ๋งท ์ ๋ฆฌ๋ ๋ต ๋๋ raw_answer (ํธ์ถ ์คํจ ์). |
| """ |
| if not raw_answer or raw_answer.strip().upper() == "UNKNOWN": |
| return raw_answer |
| try: |
| from huggingface_hub import InferenceClient |
| client = InferenceClient(provider="auto") |
| resp = client.chat_completion( |
| model=model_id, |
| messages=[ |
| {"role": "system", "content": _FORMAT_SYSTEM_PROMPT}, |
| { |
| "role": "user", |
| "content": f"Question: {question}\n\nDraft answer: {raw_answer}\n\nFinal answer:", |
| }, |
| ], |
| max_tokens=200, |
| ) |
| formatted = (resp.choices[0].message.content or "").strip() |
| if not formatted: |
| return raw_answer |
| |
| if len(formatted) >= 2 and ( |
| (formatted[0] == '"' and formatted[-1] == '"') |
| or (formatted[0] == "'" and formatted[-1] == "'") |
| ): |
| formatted = formatted[1:-1].strip() |
| |
| formatted = unicodedata.normalize("NFC", formatted) |
| return formatted |
| except Exception as e: |
| print(f"final_format_pass failed (using raw): {e}") |
| return raw_answer |
|
|