Agent / app.py
BiGuan's picture
Update app.py
f7f36de verified
Raw
History Blame Contribute Delete
24 kB
"""
app.py —— 整个项目的"主程序"
这个文件把所有零件组装起来,主要包含四大块:
1. 系统提示词 SYSTEM_PROMPT:写给大模型的"工作守则",告诉它怎么答题、答案要什么格式。
2. GAIA Agent 类:真正的"答题机器人",一个会思考的大模型 + 8 个工具(搜索、看图、读文件…)。
它按"思考→调用工具→再思考→…→给出答案"的循环工作。
3. 提交相关函数:把答案 POST 给评分服务器,并处理服务器偶尔出错时的重试。
4. Gradio 界面:网页上的几个按钮和表格,方便点一下就跑全流程、看结果。
"""
from __future__ import annotations
import os
import time
import tempfile
import gradio as gr
import requests
import pandas as pd
# LangChain 里三种"消息"类型:系统消息(给AI定规则)、人类消息(用户的话)、AI消息(AI的回复)。
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
# 导入"创建 ReAct agent"的函数。
try:
from langgraph.prebuilt import create_react_agent
except ImportError: # newer langchain/langgraph layouts
from langchain.agents import create_agent as create_react_agent
# 导入一个特定错误类型:agent 思考步数超上限时会抛它。库里没有就自己定义一个占位的。
try:
from langgraph.errors import GraphRecursionError
except ImportError:
class GraphRecursionError(Exception):
pass
# 导入我们自己写的 8 个工具(每个都在 tools/ 文件夹里)。
from tools.web_search import web_search
from tools.wikipedia_search import wikipedia_search
from tools.visit_webpage import visit_webpage
from tools.read_file import read_file
from tools.transcribe_audio import transcribe_audio
from tools.visual_qa import visual_qa
from tools.youtube_transcript import youtube_transcript
from tools.python_repl import python_repl
# 导入配置(模型地址/密钥/名字)和标准答案表/题目分类器。
from config import LLM_BASE_URL, LLM_API_KEY, LLM_MODEL_ID
from answer_key import REFERENCE_ANSWERS, classify_question
# 课程评分服务器的网址(提供取题、下载附件、提交答案三个接口)。
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
# 系统提示词:这是发给大模型的"工作守则",相当于给它的岗前培训。它直接决定答题质量,
# 是反复调试出来的成果。大意是:一步步推理、必须先用工具查证再回答、
# 绝不能没查就瞎猜或说"不知道"、最后必须用固定格式 "FINAL ANSWER: ..." 给出极简答案。
SYSTEM_PROMPT = (
"You are a general AI assistant answering questions from the GAIA benchmark. "
"Reason step by step and use your tools to gather and verify facts. "
"Available tools: web_search, wikipedia_search, visit_webpage, read_file "
"(spreadsheets/PDF/Word/code/text), transcribe_audio, visual_qa (images), "
"youtube_transcript, and python_repl (run Python for any maths, string or data work).\n"
"Tool guidance:\n"
"- For ANY question needing a fact, name, number, date, list, or file content you MUST "
"call at least one tool (web_search / wikipedia_search / a file tool) before answering. "
"NEVER output 'unknown' or a guess without having searched first.\n"
"- Plan multi-hop lookups: search, open the most relevant result with visit_webpage, and "
"read it carefully to extract the EXACT value (full names, exact spelling). Cross-check "
"when sources conflict.\n"
"- A question that contains a YouTube URL: call `youtube_transcript` on that URL.\n"
"- A question that says a file is attached: a local path is given in the message — open it "
"with read_file / transcribe_audio / visual_qa. Never ask the user to upload anything.\n"
"- If a tool fails or returns nothing, try a different tool or a reworded query rather than "
"repeating the same call.\n"
"- Only after genuinely trying the tools may you, as a last resort, give your single best "
"guess. Always commit to a concrete answer — never say you cannot answer.\n\n"
"Finish with one line in exactly this template:\n"
"FINAL ANSWER: [YOUR FINAL ANSWER]\n"
"YOUR FINAL ANSWER must be a number OR as few words as possible OR a comma separated "
"list of numbers and/or strings. Do not add anything after it.\n"
"- For a number: no thousands separators and no units ($, %, ...) unless asked.\n"
"- For a string: no articles, no abbreviations, digits written in plain text unless asked.\n"
"- For a list: apply these rules to each element."
)
def build_model():
"""创建并返回"驱动 agent 的大模型对象"。参数都来自 config.py。这个模型必须支持"调用工具"。"""
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=LLM_MODEL_ID,
base_url=LLM_BASE_URL,
api_key=LLM_API_KEY,
# 答题任务需要确定性,所以设为 0。
temperature=float(os.getenv("AGENT_TEMPERATURE", "0")),
# 单次回复最多生成多少字,防止回答过长。
max_tokens=int(os.getenv("AGENT_MAX_TOKENS", "4096")),
)
def clean_answer(content) -> str:
"""把大模型那段啰嗦的最终回复,"提纯"成评分服务器要的、干干净净的标准答案。"""
# 有时回复是分段的列表,这里先把它们拼成一整段文字。
if isinstance(content, list):
content = " ".join(
part.get("text", "") if isinstance(part, dict) else str(part) for part in content
)
text = str(content).strip()
# 只保留 "FINAL ANSWER:" 后面的那部分(这是我们要求模型给的最终答案标记)。
if "FINAL ANSWER:" in text:
text = text.split("FINAL ANSWER:")[-1].strip()
# 只取第一行,丢掉模型可能在后面多写的解释。
text = text.splitlines()[0].strip() if text else text
# 如果答案被引号包住,去掉首尾的引号。
if len(text) >= 2 and text[0] == text[-1] and text[0] in ("'", '"'):
text = text[1:-1].strip()
# GAIA 的标准答案末尾没有句号,所以去掉结尾的句号和空格。
return text.rstrip(". ").strip()
class GAIAAgent:
"""答题机器人:一个 LangGraph ReAct agent,外加上网、读文件、听音频、看图等全套工具。"""
def __init__(self, api_url: str = DEFAULT_API_URL):
self.api_url = api_url
tools = [
web_search,
wikipedia_search,
visit_webpage,
read_file,
transcribe_audio,
visual_qa,
youtube_transcript,
python_repl,
]
self.model = build_model()
self.agent = create_react_agent(self.model, tools) # 把模型和工具组装成会用工具的 agent
# 思考步数上限:防止 agent 陷入死循环无限调用工具。默认最多 40 步。
self.recursion_limit = int(os.getenv("AGENT_RECURSION_LIMIT", "40"))
print("GAIAAgent initialized.")
def _download_file(self, task_id: str, file_name: str) -> str | None:
"""从 GAIA 数据集官方地址下载附件(需 HF_TOKEN 认证)。"""
if not file_name:
return None
# 构造 GAIA 数据集的原始文件 URL
file_url = f"https://huggingface.co/datasets/gaia-benchmark/GAIA/resolve/main/2023/validation/{file_name}"
try:
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
response = requests.get(file_url, headers=headers, timeout=30)
response.raise_for_status()
except Exception as e:
print(f"Could not download file {file_name} for task {task_id}: {e}")
return None
path = os.path.join(tempfile.mkdtemp(), file_name)
with open(path, "wb") as f:
f.write(response.content)
return path
@staticmethod
def _collect_tools(history: list) -> list:
"""统计这次答题中 agent 实际用过哪些工具(按首次使用的先后顺序),用于结果表格展示。"""
used = []
for m in history: # 遍历对话历史里的每条消息
for tc in (getattr(m, "tool_calls", None) or []): # 看这条消息有没有"调用工具"的记录
name = tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None)
if name and name not in used: # 没记过的工具名才加进去(去重)
used.append(name)
return used
@staticmethod
def _file_hint(path: str) -> str:
"""根据附件后缀名,明确告诉模型"这个文件该用哪个工具"。
(因为模型有时会选错工具——比如对着 mp3 录音却用 read_file 去读,所以这里给个明确提示。)"""
ext = os.path.splitext(path)[1].lower()
if ext in (".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"):
return "It is an AUDIO file: call `transcribe_audio` on this path, then answer from the transcript."
if ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"):
return "It is an IMAGE: call `visual_qa` on this path with a precise question to read it."
return "Call `read_file` on this path to read its contents, then answer."
@staticmethod
def _is_giveup(text: str) -> bool:
"""判断模型这次是不是"摆烂了"——给了空答案,或说"不知道/做不到"之类的放弃性回答。"""
low = str(text).strip().lower()
return (
not low
or low in ("unknown", "nan", "none")
or "unable to" in low
or "cannot" in low
or "i don't" in low
)
@staticmethod
def _looks_incomplete(text: str) -> bool:
"""判断答案是不是"没答完/答歪了"——比如出现"需要更多步骤""我看不到图""请上传文件"等字样。"""
low = str(text).strip().lower()
if not low:
return True
markers = (
"need more steps",
"i cannot see",
"i can't see",
"please upload",
"i don't see any",
"i do not see any",
"unable to access",
"as an ai",
)
return any(m in low for m in markers)
def _force_final(self, history: list) -> str:
"""最后的兜底手段:不给任何工具,逼模型"就用目前已经查到的信息,立刻给一个确定答案",
这样它就没法再用"需要更多步骤"来拖延了。"""
# 先把历史里那些"需要更多步骤"的废话清掉,免得干扰。
clean_history = [
m for m in history
if not (isinstance(m, AIMessage) and "need more steps" in str(m.content).lower())
]
# 追加一句强硬要求:必须现在就给出 FINAL ANSWER,实在不行就猜,但不准说答不了。
clean_history.append(
HumanMessage(
content=(
"You are out of tool budget. Using only the information already gathered "
"above, give your single best answer now. You MUST output exactly one line "
"'FINAL ANSWER: [answer]' with a concrete value — guess if you must, and "
"never say you cannot answer."
)
)
)
# 这里直接调模型(不带工具),拿到它的最终回复。
return self.model.invoke(clean_history).content
def _run(self, messages: list, task_id):
"""跑一次完整的答题流程,返回 (答案, 用过的工具列表)。"""
try:
# 让 agent 开跑:它会自己循环"思考→用工具→再思考",直到给出答案或达到步数上限。
result = self.agent.invoke(
{"messages": messages}, config={"recursion_limit": self.recursion_limit}
)
history = result["messages"] # 全过程的对话记录
tools_used = self._collect_tools(history) # 统计用过哪些工具
answer = history[-1].content # 最后一条消息就是最终答案
# 如果答案看起来"没答完/答歪",就用兜底手段再逼它给个确定答案。
if self._looks_incomplete(answer):
answer = self._force_final(history)
return clean_answer(answer), tools_used
except GraphRecursionError:
# 思考步数超了上限:也走兜底,硬要一个答案。
try:
return clean_answer(self._force_final(messages)), self._collect_tools(messages)
except Exception as e:
print(f"Forced-answer failed on task {task_id}: {e}")
return "unknown", []
except Exception as e:
# 其它任何意外错误:返回 "unknown",保证整批题不会因一道题崩掉而中断。
print(f"Agent error on task {task_id}: {e}")
return "unknown", []
def __call__(self, question: str, task_id: str | None = None, file_name: str | None = None):
"""让这个机器人对象能像函数一样被"调用"来答一道题。返回 (答案, 用过的工具)。"""
user_content = question
# 如果这道题带附件:先下载,再把本地路径和"该用哪个工具"的提示一起拼进给模型的消息里。
if file_name:
path = self._download_file(task_id, file_name)
if path:
user_content += f"\n\nA file is attached at local path: {path}\n{self._file_hint(path)}"
else:
# 下载失败就告诉模型"附件下不下来,尽量只凭题目文字作答"。
user_content += (
"\n\n(Note: the attached file could not be downloaded. Answer as best you "
"can from the question text alone.)"
)
# 组装成两条消息:系统守则 + 用户的问题,交给 _run 去跑。
messages = [SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content=user_content)]
answer, tools_used = self._run(messages, task_id)
# 补救机制:如果模型"一个工具都没用"就摆烂了,强制它再答一次,并要求这次必须先搜索。
if not tools_used and self._is_giveup(answer):
retry = messages + [
HumanMessage(
content=(
"You answered without using any tool, which is not allowed. Call "
"web_search or wikipedia_search now, read the results, and then give "
"the FINAL ANSWER."
)
)
]
answer2, tools2 = self._run(retry, task_id)
# 只有当重试确实用了工具、或给出了非放弃的答案时,才采用重试结果。
if tools2 or not self._is_giveup(answer2):
answer, tools_used = answer2, tools2
return answer, tools_used
# 一个"缓存":存住最近一次算好的答案。这样如果提交失败(评分服务器偶尔返回 500 错误),
# 可以直接重新提交缓存里的答案,而不必让又慢又花钱的 agent 重跑一遍。
LAST_SUBMISSION: dict = {}
def _submit_with_retry(payload: dict, retries: int = 4):
"""把答案 POST 提交到 /submit 接口。遇到 5xx 服务器错误或网络问题会自动重试;
遇到 4xx(我们这边请求有问题)则把具体原因返回,方便人去修。"""
submit_url = f"{DEFAULT_API_URL}/submit"
last = None
for attempt in range(retries):
try:
resp = requests.post(submit_url, json=payload, timeout=120)
if resp.status_code >= 500: # 5xx = 服务器自己出毛病了,值得重试
last = f"{resp.status_code} server error: {resp.text[:200]}"
print(f"submit attempt {attempt + 1}: {last}")
time.sleep(4 * (attempt + 1)) # 等一会儿再试,等待时间逐次拉长
continue
resp.raise_for_status()
return True, resp.json() # 成功:返回 (True, 服务器给的结果)
except requests.exceptions.HTTPError as e:
# 4xx 错误:是我们的请求有问题,重试也没用,直接把原因返回。
detail = e.response.text[:300]
try:
detail = e.response.json().get("detail", detail)
except Exception:
pass
return False, f"HTTP {e.response.status_code}: {detail}"
except Exception as e:
# 网络异常等:记下错误,等一会儿继续重试。
last = str(e)
print(f"submit attempt {attempt + 1} failed: {last}")
time.sleep(4 * (attempt + 1))
return False, f"all {retries} attempts failed (last error: {last})"
def _format_result(result_data: dict) -> str:
"""把评分服务器返回的结果,整理成一段人类易读的文字(用户名、总分、对了几道、附言)。"""
return (
f"Submission Successful!\n"
f"User: {result_data.get('username')}\n"
f"Overall Score: {result_data.get('score', 'N/A')}% "
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
f"Message: {result_data.get('message', 'No message received.')}"
)
def run_and_submit_all(profile: gr.OAuthProfile | None):
"""【一键全流程】取回所有题目 → 逐题让 agent 作答 → 缓存答案 → 提交。这是主按钮触发的函数。"""
space_id = os.getenv("SPACE_ID") # 当前 Hugging Face Space 的标识
# 第一步:必须先登录 Hugging Face(评分要记到你账号名下)。
if profile:
username = f"{profile.username}"
print(f"User logged in: {username}")
else:
return "Please Login to Hugging Face with the button.", None
# 第二步:检查 SPACE_ID。评分服务器会校验提交者的 Space 代码链接,缺了它常导致 500 错误。
if not space_id:
return (
"SPACE_ID not found. Run this on your public Hugging Face Space — the scoring "
"server validates the agent_code link and a missing/invalid Space can cause a 500.",
None,
)
# 拼出本项目代码的公开地址,提交时要一并交上去(证明这答案是这套代码产出的)。
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
print(agent_code)
# 第三步:从服务器取回全部题目。
try:
response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15)
response.raise_for_status()
questions_data = response.json()
if not questions_data:
return "Fetched questions list is empty or invalid format.", None
print(f"Fetched {len(questions_data)} questions.")
except Exception as e:
return f"Error fetching questions: {e}", None
# 第四步:创建答题机器人。
try:
agent = GAIAAgent(api_url=DEFAULT_API_URL)
except Exception as e:
return f"Error initializing agent: {e}", None
# 第五步:逐题作答。results_log 存给人看的结果表格;answers_payload 存要提交给服务器的答案。
results_log = []
answers_payload = []
for item in questions_data:
task_id = item.get("task_id")
question_text = item.get("question")
file_name = item.get("file_name") or ""
if not task_id or question_text is None: # 题目数据不完整就跳过
continue
answer, tools_used = agent(question_text, task_id=task_id, file_name=file_name)
answer = (str(answer).strip() or "unknown") # 答案不能为空(空答案会让服务器 500)
answers_payload.append({"task_id": task_id, "submitted_answer": answer})
results_log.append(
{
"Task ID": task_id,
"Type": classify_question(question_text, file_name), # 题型标签
"Question": question_text,
"Reference Answer": REFERENCE_ANSWERS.get(task_id, ""), # 标准答案(仅供对照)
"Submitted Answer": answer, # 我们提交的答案
"Tools Used": ", ".join(tools_used) if tools_used else "(none)",
}
)
if not answers_payload:
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
# 第六步:把答案打包,存进缓存,然后提交。
payload = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
LAST_SUBMISSION.update(payload=payload, results_log=results_log)
print(f"Submitting {len(answers_payload)} answers for user '{username}'...")
df = pd.DataFrame(results_log) # 把结果整理成表格,显示在网页上
ok, data = _submit_with_retry(payload)
if ok:
return _format_result(data), df
# 提交失败时:答案已缓存,提示用户修好问题(最常见是把 Space 设为公开)后点"重新提交"即可。
return (
f"Submission Failed: {data}\n\n"
"Your answers are cached — fix the issue (most often: make the Space Public) and click "
"'Re-submit last answers' to retry WITHOUT re-running the agent.",
df,
)
def submit_only(profile: gr.OAuthProfile | None):
"""【重新提交】只把缓存里的答案再交一次,不重新跑 agent(省时省钱)。"""
if not LAST_SUBMISSION.get("payload"):
return "No cached answers yet — run the evaluation first.", None
df = pd.DataFrame(LAST_SUBMISSION.get("results_log", []))
ok, data = _submit_with_retry(LAST_SUBMISSION["payload"])
if ok:
return _format_result(data), df
return f"Submission Failed again: {data}", df
# --- Gradio 网页界面 ---
# 下面用 Gradio 搭一个简单网页:一段说明 + 登录按钮 + 两个操作按钮 + 状态框 + 结果表格。
with gr.Blocks() as demo:
gr.Markdown("# GAIA Agent Evaluation Runner (LangGraph)")
gr.Markdown(
"""
1. Log in to your Hugging Face account with the button below.
2. Click 'Run Evaluation & Submit All Answers' to fetch the questions, run the agent
and submit the answers. This can take several minutes.
3. If submission fails (the scoring server sometimes returns 500), click
'Re-submit last answers' to retry without re-running the agent.
The model endpoint is preconfigured in `config.py`, so no secrets are required.
Make sure this Space is **Public**, otherwise the scoring server can reject the
submission with a 500.
"""
)
gr.LoginButton() # Hugging Face 登录按钮
with gr.Row(): # 把两个按钮排在同一行
run_button = gr.Button("Run Evaluation & Submit All Answers", variant="primary")
resubmit_button = gr.Button("Re-submit last answers")
status_output = gr.Textbox(label="Run Status / Submission Result", lines=6, interactive=False)
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
# 把"按钮点击"和"要运行的函数"绑定起来:点主按钮跑全流程,点另一个只重新提交。
run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])
resubmit_button.click(fn=submit_only, outputs=[status_output, results_table])
# 只有"直接运行这个文件"时才启动网页(被别处导入时不会启动)。
if __name__ == "__main__":
demo.launch(debug=True, share=False)