Spaces:
Sleeping
Sleeping
File size: 5,947 Bytes
c2446d5 4a5f5e9 c2446d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """GAIA exact-match ์ฑ์ ์ ๋ง์ถ ๋ต๋ณ ํฌ๋งท ํ์ฒ๋ฆฌ.
๋ ๋จ๊ณ๋ก ๊ตฌ์ฑ:
1. final_format_pass(question, raw): LLM ํ ๋ฒ ๋ ํธ์ถํด์ GAIA ํฌ๋งท์ผ๋ก๋ง ๋ณํ.
B ์นดํ
๊ณ ๋ฆฌ(๋ด์ฉ ๋ง๊ณ ํ์ ์๋ฐ) ํ๋ณต์ฉ. ์งง์ reformat ์ ์ฉ ์์คํ
ํ๋กฌํํธ.
2. coerce_answer(question, ans): ๊ฒฐ์ ์ regex ํ์ฒ๋ฆฌ. yes/no, ์ซ์, ํตํ ๋ฑ
ํ์คํ ํจํด๋ง ๊ฐ์ . ๋งค์นญ ์คํจ ์ ์๋ณธ ์ ์ง(์๋ชป ๊ฐ์ ํ๋ฉด ๋ ๋ง์นจ).
์์: __call__์์ raw โ strip prefixes/quotes โ final_format_pass โ coerce_answer.
"""
import re
import unicodedata
# yes/no ์ง๋ฌธ ์์ ํ๋ณด ํค์๋. ์์ด ์๋ฌธ๋ฌธ์ด ์ด ๋ณด์กฐ๋์ฌ๋ก ์์ํ๊ณ ?๋ก ๋๋๋ฉด
# ๋๊ฐ yes/no ๋ต์ ๊ธฐ๋ํ๋ ํํ.
_YES_NO_STARTS = (
"is ", "are ", "was ", "were ", "do ", "does ", "did ",
"has ", "have ", "had ", "can ", "could ", "should ",
"will ", "would ", "may ", "might ",
)
def _looks_yes_no(question: str) -> bool:
q = question.strip().lower()
if "yes or no" in q or "yes/no" in q:
return True
if not q.endswith("?"):
return False
return any(q.startswith(s) for s in _YES_NO_STARTS)
def _looks_numeric(question: str) -> bool:
q = question.lower()
return (
"how many" in q
or "what number" in q
or "what is the number of" in q
# "how much" ๋ ๋จ์ ํฌํจ ๋ต์ ์ํ ์๋ ์์ด ์ ์ธ(์: "how much money" โ "$1.5M").
)
def coerce_answer(question: str, answer: str) -> str:
"""์ง๋ฌธ ํ์ ํํธ์ ๋ง์ถฐ LLM ๋ต์ ๋ณด์ . ํํธ๊ฐ ์๊ฑฐ๋ ๋งค์นญ ์คํจ ์ ์๋ณธ ๋ฐํ."""
a = answer.strip()
if not a:
return a
# 1) Yes/No ์ง๋ฌธ โ ์ฒซ ๋จ์ด๋ก ๊ฒฐ์ .
if _looks_yes_no(question):
first = a.split(None, 1)[0].rstrip(",.").lower() if a.split() else ""
if first == "yes":
return "Yes"
if first == "no":
return "No"
# ๋งค์นญ ์คํจ ์ ์๋ณธ ์ ์ง(์๋ชป ๊ฐ์ ํ๋ฉด ๋ ๋ง์นจ).
return a
# 2) ์์ ์ซ์ ์ง๋ฌธ โ ๋ต ์์ ์ฒซ ์ ์/์ค์๋ง ์ถ์ถ.
if _looks_numeric(question):
m = re.search(r"-?\d+(?:\.\d+)?", a.replace(",", ""))
if m:
num = m.group(0)
try:
f = float(num)
if f.is_integer():
return str(int(f))
return num
except ValueError:
pass
return a
# 3) ๋ต์ด ํตํ๊ธฐํธ+์ซ์ ํจํด์ด๋ฉด ๊ธฐํธ/์ฝค๋ง/๊ณต๋ฐฑ๋ง ์ ๊ฑฐ.
# "$1,234" โ "1234", "1,234.5" โ "1234.5"
if re.fullmatch(r"\s*[\$โฌยฃยฅ]?\s*-?[\d,]+(?:\.\d+)?\s*", a):
cleaned = re.sub(r"[\$โฌยฃยฅ,\s]", "", a)
if cleaned:
return cleaned
return a
# Final-answer formatter pass์ฉ ์์คํ
ํ๋กฌํํธ. ์งง๊ณ ๋ถ์ ํ ์ต์ํ.
_FORMAT_SYSTEM_PROMPT = """You reformat agent answers to match the GAIA benchmark
exact-match grading rules. You receive a question and a draft answer, and output the
final answer string ONLY (no explanation, no preamble).
Rules:
- Numbers: plain digits, no commas, no currency/units unless the question asks for them.
- Strings: minimal exact form. No articles ("the", "a"), no abbreviations unless
abbreviation is the expected form. No surrounding quotes.
- Lists: comma + single space ("apple, banana, cherry"), in the order requested.
- Yes/no questions: exactly "Yes" or "No".
- "Give only the first name" โ output only the first name, no surname.
- "Give only the city name" โ only the city, no country/state.
- If the draft already matches all applicable rules, output it unchanged.
- If the draft is "UNKNOWN" or admits inability, output "UNKNOWN".
Output only the answer string, nothing else.
"""
def final_format_pass(
question: str,
raw_answer: str,
model_id: str = "Qwen/Qwen2.5-72B-Instruct",
) -> str:
"""LLM ํ ๋ฒ ๋ ํธ์ถํด raw ๋ต์ GAIA ํฌ๋งท์ผ๋ก๋ง ๋ณํ.
ํธ์ถ ์คํจ(rate-limit, ํ์์์ ๋ฑ) ์ raw_answer๋ฅผ ๊ทธ๋๋ก ๋ฐํ โ graceful
degrade. coerce_answer๊ฐ ๋ง์ง๋ง ์์ ๋ง์ด๋ฏ๋ก ์ด ๋จ๊ณ๊ฐ ์คํจํด๋ ํฐ ์ํด๋ ์์.
์ ๋์ฝ๋ ์ ๊ทํ(NFC)๋ ๊ฐ์ด ์ํํด์ ๋ณด์ด์ง ์๋ ๋ณํ ๊ธ์(์: ๊ฒฐํฉ ๊ธ์
๋ถํด๋ ํํ)๋ก ์ธํ mismatch ๋ฐฉ์ง.
Args:
question: ์๋ณธ ์ง๋ฌธ ๋ณธ๋ฌธ.
raw_answer: ์์ด์ ํธ๊ฐ final_answer๋ก ๋๊ธด raw ๋ต.
model_id: ํฌ๋งท ๋ณํ์ ์ธ ๋ชจ๋ธ (๊ธฐ๋ณธ์ ๋ฉ์ธ ๋ชจ๋ธ๊ณผ ๋์ผ).
Returns:
ํฌ๋งท ์ ๋ฆฌ๋ ๋ต ๋๋ raw_answer (ํธ์ถ ์คํจ ์).
"""
if not raw_answer or raw_answer.strip().upper() == "UNKNOWN":
return raw_answer
try:
from huggingface_hub import InferenceClient
client = InferenceClient(provider="auto")
resp = client.chat_completion(
model=model_id,
messages=[
{"role": "system", "content": _FORMAT_SYSTEM_PROMPT},
{
"role": "user",
"content": f"Question: {question}\n\nDraft answer: {raw_answer}\n\nFinal answer:",
},
],
max_tokens=200, # ๋ต๋ณ ์์ฒด๋ ์งง์
)
formatted = (resp.choices[0].message.content or "").strip()
if not formatted:
return raw_answer
# ์๋ ๋ฐ์ดํ ํ ์ ์ ๊ฑฐ (๋ชจ๋ธ์ด ์ข
์ข
"X" ํํ๋ก ๋๋ฌ์)
if len(formatted) >= 2 and (
(formatted[0] == '"' and formatted[-1] == '"')
or (formatted[0] == "'" and formatted[-1] == "'")
):
formatted = formatted[1:-1].strip()
# NFC ์ ๊ทํ: ๊ฒฐํฉ ๊ธ์(์: ล, รฉ) ๋ณํ ํต์ผ
formatted = unicodedata.normalize("NFC", formatted)
return formatted
except Exception as e:
print(f"final_format_pass failed (using raw): {e}")
return raw_answer
|