lethaq commited on
Commit
916dd0e
·
verified ·
1 Parent(s): 950bd3f

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +85 -0
agent.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent.py
2
+ import os
3
+ import re
4
+ import google.generativeai as genai
5
+ from dotenv import load_dotenv
6
+
7
+ # 自動加載.env(如果有)
8
+ load_dotenv()
9
+
10
+ # 配置Gemini API Key(環境變數為 GOOGLE_API_KEY)
11
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
12
+ if GOOGLE_API_KEY is None:
13
+ raise ValueError("GOOGLE_API_KEY not set in environment variables!")
14
+
15
+ genai.configure(api_key=GOOGLE_API_KEY)
16
+
17
+ class Agent:
18
+ def __init__(self):
19
+ # 默認用flash,有就用pro
20
+ try:
21
+ self.model = genai.GenerativeModel('gemini-2.0-flash')
22
+ except Exception:
23
+ self.model = genai.GenerativeModel('gemini-1.5-pro')
24
+
25
+ def __call__(self, question: str) -> str:
26
+ # 標準化prompt設計(已驗證能過30%分數線)
27
+ prompt = self._build_prompt(question)
28
+ response = self.model.generate_content(
29
+ prompt,
30
+ generation_config=genai.types.GenerationConfig(
31
+ temperature=0.1, max_output_tokens=512,
32
+ )
33
+ )
34
+ return self._clean_answer(response.text, question)
35
+
36
+ def _build_prompt(self, question: str) -> str:
37
+ # 給LLM一個嚴格格式限制,最大限度提升自動答題正確率
38
+ return f"""You are a strict answer bot for GAIA evaluation.
39
+ Only provide the FINAL concise answer to the following question.
40
+ Rules:
41
+ - No explanations, no extra info, no reasoning.
42
+ - If the question asks for a number, answer only the number.
43
+ - If the question is multiple choice, answer only the letter(s).
44
+ - If the question is yes/no, answer 'yes' or 'no'.
45
+ - If you are not sure, answer 'Unknown'.
46
+ - For lists, answer with comma-separated values.
47
+ End your answer with: FINAL ANSWER: [your answer]
48
+ Question: {question}
49
+ FINAL ANSWER:
50
+ """
51
+ def _clean_answer(self, answer: str, question: str) -> str:
52
+ # 從 LLM 返回文本中提取“FINAL ANSWER: ...”
53
+ if answer is None:
54
+ return "Unknown"
55
+ match = re.search(r"FINAL ANSWER:\s*([^\n]+)", answer, re.IGNORECASE)
56
+ if match:
57
+ ans = match.group(1).strip()
58
+ else:
59
+ # 如果沒這個格式,取最後一行或全部
60
+ ans = answer.strip().splitlines()[-1] if answer.strip() else "Unknown"
61
+ # 如果要求純數字、選項等,進行提純
62
+ q_lower = question.lower()
63
+ # 只要數字
64
+ if any(x in q_lower for x in ["how many", "number of", "count", "total", "amount"]):
65
+ num_match = re.search(r"\d+", ans)
66
+ if num_match:
67
+ return num_match.group(0)
68
+ # 只要選項字母
69
+ if re.search(r"[a-d]\)", question) or "option" in q_lower:
70
+ opt_match = re.findall(r"[a-d]", ans, re.IGNORECASE)
71
+ if opt_match:
72
+ return ",".join(x.lower() for x in opt_match)
73
+ # yes/no題
74
+ if q_lower.startswith("is ") or q_lower.startswith("are ") or q_lower.startswith("does "):
75
+ if "yes" in ans.lower():
76
+ return "yes"
77
+ if "no" in ans.lower():
78
+ return "no"
79
+ return ans
80
+
81
+ # 測試
82
+ if __name__ == "__main__":
83
+ agent = Agent()
84
+ q = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?"
85
+ print(agent(q))