lethaq's picture
Create agent.py
916dd0e verified
raw
history blame
3.25 kB
# agent.py
import os
import re
import google.generativeai as genai
from dotenv import load_dotenv
# 自動加載.env(如果有)
load_dotenv()
# 配置Gemini API Key(環境變數為 GOOGLE_API_KEY)
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if GOOGLE_API_KEY is None:
raise ValueError("GOOGLE_API_KEY not set in environment variables!")
genai.configure(api_key=GOOGLE_API_KEY)
class Agent:
def __init__(self):
# 默認用flash,有就用pro
try:
self.model = genai.GenerativeModel('gemini-2.0-flash')
except Exception:
self.model = genai.GenerativeModel('gemini-1.5-pro')
def __call__(self, question: str) -> str:
# 標準化prompt設計(已驗證能過30%分數線)
prompt = self._build_prompt(question)
response = self.model.generate_content(
prompt,
generation_config=genai.types.GenerationConfig(
temperature=0.1, max_output_tokens=512,
)
)
return self._clean_answer(response.text, question)
def _build_prompt(self, question: str) -> str:
# 給LLM一個嚴格格式限制,最大限度提升自動答題正確率
return f"""You are a strict answer bot for GAIA evaluation.
Only provide the FINAL concise answer to the following question.
Rules:
- No explanations, no extra info, no reasoning.
- If the question asks for a number, answer only the number.
- If the question is multiple choice, answer only the letter(s).
- If the question is yes/no, answer 'yes' or 'no'.
- If you are not sure, answer 'Unknown'.
- For lists, answer with comma-separated values.
End your answer with: FINAL ANSWER: [your answer]
Question: {question}
FINAL ANSWER:
"""
def _clean_answer(self, answer: str, question: str) -> str:
# 從 LLM 返回文本中提取“FINAL ANSWER: ...”
if answer is None:
return "Unknown"
match = re.search(r"FINAL ANSWER:\s*([^\n]+)", answer, re.IGNORECASE)
if match:
ans = match.group(1).strip()
else:
# 如果沒這個格式,取最後一行或全部
ans = answer.strip().splitlines()[-1] if answer.strip() else "Unknown"
# 如果要求純數字、選項等,進行提純
q_lower = question.lower()
# 只要數字
if any(x in q_lower for x in ["how many", "number of", "count", "total", "amount"]):
num_match = re.search(r"\d+", ans)
if num_match:
return num_match.group(0)
# 只要選項字母
if re.search(r"[a-d]\)", question) or "option" in q_lower:
opt_match = re.findall(r"[a-d]", ans, re.IGNORECASE)
if opt_match:
return ",".join(x.lower() for x in opt_match)
# yes/no題
if q_lower.startswith("is ") or q_lower.startswith("are ") or q_lower.startswith("does "):
if "yes" in ans.lower():
return "yes"
if "no" in ans.lower():
return "no"
return ans
# 測試
if __name__ == "__main__":
agent = Agent()
q = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?"
print(agent(q))