Spaces:
Sleeping
Sleeping
| # agent.py | |
| import os | |
| import re | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| # 自動加載.env(如果有) | |
| load_dotenv() | |
| # 配置Gemini API Key(環境變數為 GOOGLE_API_KEY) | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| if GOOGLE_API_KEY is None: | |
| raise ValueError("GOOGLE_API_KEY not set in environment variables!") | |
| genai.configure(api_key=GOOGLE_API_KEY) | |
| class Agent: | |
| def __init__(self): | |
| # 默認用flash,有就用pro | |
| try: | |
| self.model = genai.GenerativeModel('gemini-2.0-flash') | |
| except Exception: | |
| self.model = genai.GenerativeModel('gemini-1.5-pro') | |
| def __call__(self, question: str) -> str: | |
| # 標準化prompt設計(已驗證能過30%分數線) | |
| prompt = self._build_prompt(question) | |
| response = self.model.generate_content( | |
| prompt, | |
| generation_config=genai.types.GenerationConfig( | |
| temperature=0.1, max_output_tokens=512, | |
| ) | |
| ) | |
| return self._clean_answer(response.text, question) | |
| def _build_prompt(self, question: str) -> str: | |
| # 給LLM一個嚴格格式限制,最大限度提升自動答題正確率 | |
| return f"""You are a strict answer bot for GAIA evaluation. | |
| Only provide the FINAL concise answer to the following question. | |
| Rules: | |
| - No explanations, no extra info, no reasoning. | |
| - If the question asks for a number, answer only the number. | |
| - If the question is multiple choice, answer only the letter(s). | |
| - If the question is yes/no, answer 'yes' or 'no'. | |
| - If you are not sure, answer 'Unknown'. | |
| - For lists, answer with comma-separated values. | |
| End your answer with: FINAL ANSWER: [your answer] | |
| Question: {question} | |
| FINAL ANSWER: | |
| """ | |
| def _clean_answer(self, answer: str, question: str) -> str: | |
| # 從 LLM 返回文本中提取“FINAL ANSWER: ...” | |
| if answer is None: | |
| return "Unknown" | |
| match = re.search(r"FINAL ANSWER:\s*([^\n]+)", answer, re.IGNORECASE) | |
| if match: | |
| ans = match.group(1).strip() | |
| else: | |
| # 如果沒這個格式,取最後一行或全部 | |
| ans = answer.strip().splitlines()[-1] if answer.strip() else "Unknown" | |
| # 如果要求純數字、選項等,進行提純 | |
| q_lower = question.lower() | |
| # 只要數字 | |
| if any(x in q_lower for x in ["how many", "number of", "count", "total", "amount"]): | |
| num_match = re.search(r"\d+", ans) | |
| if num_match: | |
| return num_match.group(0) | |
| # 只要選項字母 | |
| if re.search(r"[a-d]\)", question) or "option" in q_lower: | |
| opt_match = re.findall(r"[a-d]", ans, re.IGNORECASE) | |
| if opt_match: | |
| return ",".join(x.lower() for x in opt_match) | |
| # yes/no題 | |
| if q_lower.startswith("is ") or q_lower.startswith("are ") or q_lower.startswith("does "): | |
| if "yes" in ans.lower(): | |
| return "yes" | |
| if "no" in ans.lower(): | |
| return "no" | |
| return ans | |
| # 測試 | |
| if __name__ == "__main__": | |
| agent = Agent() | |
| q = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?" | |
| print(agent(q)) | |