"""GAIA exact-match 채점에 맞춘 답변 포맷 후처리. 두 단계로 구성: 1. final_format_pass(question, raw): LLM 한 번 더 호출해서 GAIA 포맷으로만 변환. B 카테고리(내용 맞고 형식 위반) 회복용. 짧은 reformat 전용 시스템 프롬프트. 2. coerce_answer(question, ans): 결정적 regex 후처리. yes/no, 숫자, 통화 등 확실한 패턴만 강제. 매칭 실패 시 원본 유지(잘못 강제하면 더 망침). 순서: __call__에서 raw → strip prefixes/quotes → final_format_pass → coerce_answer. """ import re import unicodedata # yes/no 질문 시작 후보 키워드. 영어 의문문이 이 보조동사로 시작하고 ?로 끝나면 # 대개 yes/no 답을 기대하는 형태. _YES_NO_STARTS = ( "is ", "are ", "was ", "were ", "do ", "does ", "did ", "has ", "have ", "had ", "can ", "could ", "should ", "will ", "would ", "may ", "might ", ) def _looks_yes_no(question: str) -> bool: q = question.strip().lower() if "yes or no" in q or "yes/no" in q: return True if not q.endswith("?"): return False return any(q.startswith(s) for s in _YES_NO_STARTS) def _looks_numeric(question: str) -> bool: q = question.lower() return ( "how many" in q or "what number" in q or "what is the number of" in q # "how much" 는 단위 포함 답을 원할 수도 있어 제외(예: "how much money" → "$1.5M"). ) def coerce_answer(question: str, answer: str) -> str: """질문 형식 힌트에 맞춰 LLM 답을 보정. 힌트가 없거나 매칭 실패 시 원본 반환.""" a = answer.strip() if not a: return a # 1) Yes/No 질문 — 첫 단어로 결정. if _looks_yes_no(question): first = a.split(None, 1)[0].rstrip(",.").lower() if a.split() else "" if first == "yes": return "Yes" if first == "no": return "No" # 매칭 실패 시 원본 유지(잘못 강제하면 더 망침). return a # 2) 순수 숫자 질문 — 답 안의 첫 정수/실수만 추출. if _looks_numeric(question): m = re.search(r"-?\d+(?:\.\d+)?", a.replace(",", "")) if m: num = m.group(0) try: f = float(num) if f.is_integer(): return str(int(f)) return num except ValueError: pass return a # 3) 답이 통화기호+숫자 패턴이면 기호/콤마/공백만 제거. # "$1,234" → "1234", "1,234.5" → "1234.5" if re.fullmatch(r"\s*[\$€£¥]?\s*-?[\d,]+(?:\.\d+)?\s*", a): cleaned = re.sub(r"[\$€£¥,\s]", "", a) if cleaned: return cleaned return a # Final-answer formatter pass용 시스템 프롬프트. 짧고 부정형 최소화. _FORMAT_SYSTEM_PROMPT = """You reformat agent answers to match the GAIA benchmark exact-match grading rules. You receive a question and a draft answer, and output the final answer string ONLY (no explanation, no preamble). Rules: - Numbers: plain digits, no commas, no currency/units unless the question asks for them. - Strings: minimal exact form. No articles ("the", "a"), no abbreviations unless abbreviation is the expected form. No surrounding quotes. - Lists: comma + single space ("apple, banana, cherry"), in the order requested. - Yes/no questions: exactly "Yes" or "No". - "Give only the first name" → output only the first name, no surname. - "Give only the city name" → only the city, no country/state. - If the draft already matches all applicable rules, output it unchanged. - If the draft is "UNKNOWN" or admits inability, output "UNKNOWN". Output only the answer string, nothing else. """ def final_format_pass( question: str, raw_answer: str, model_id: str = "Qwen/Qwen2.5-72B-Instruct", ) -> str: """LLM 한 번 더 호출해 raw 답을 GAIA 포맷으로만 변환. 호출 실패(rate-limit, 타임아웃 등) 시 raw_answer를 그대로 반환 — graceful degrade. coerce_answer가 마지막 안전망이므로 이 단계가 실패해도 큰 손해는 없음. 유니코드 정규화(NFC)도 같이 수행해서 보이지 않는 변형 글자(예: 결합 글자 분해된 형태)로 인한 mismatch 방지. Args: question: 원본 질문 본문. raw_answer: 에이전트가 final_answer로 넘긴 raw 답. model_id: 포맷 변환에 쓸 모델 (기본은 메인 모델과 동일). Returns: 포맷 정리된 답 또는 raw_answer (호출 실패 시). """ if not raw_answer or raw_answer.strip().upper() == "UNKNOWN": return raw_answer try: from huggingface_hub import InferenceClient client = InferenceClient(provider="auto") resp = client.chat_completion( model=model_id, messages=[ {"role": "system", "content": _FORMAT_SYSTEM_PROMPT}, { "role": "user", "content": f"Question: {question}\n\nDraft answer: {raw_answer}\n\nFinal answer:", }, ], max_tokens=200, # 답변 자체는 짧음 ) formatted = (resp.choices[0].message.content or "").strip() if not formatted: return raw_answer # 양끝 따옴표 한 쌍 제거 (모델이 종종 "X" 형태로 둘러쌈) if len(formatted) >= 2 and ( (formatted[0] == '"' and formatted[-1] == '"') or (formatted[0] == "'" and formatted[-1] == "'") ): formatted = formatted[1:-1].strip() # NFC 정규화: 결합 글자(예: ł, é) 변형 통일 formatted = unicodedata.normalize("NFC", formatted) return formatted except Exception as e: print(f"final_format_pass failed (using raw): {e}") return raw_answer