| """ |
| app.py —— 整个项目的"主程序" |
| |
| 这个文件把所有零件组装起来,主要包含四大块: |
| 1. 系统提示词 SYSTEM_PROMPT:写给大模型的"工作守则",告诉它怎么答题、答案要什么格式。 |
| 2. GAIA Agent 类:真正的"答题机器人",一个会思考的大模型 + 8 个工具(搜索、看图、读文件…)。 |
| 它按"思考→调用工具→再思考→…→给出答案"的循环工作。 |
| 3. 提交相关函数:把答案 POST 给评分服务器,并处理服务器偶尔出错时的重试。 |
| 4. Gradio 界面:网页上的几个按钮和表格,方便点一下就跑全流程、看结果。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import time |
| import tempfile |
|
|
| import gradio as gr |
| import requests |
| import pandas as pd |
|
|
| |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
|
|
| |
| try: |
| from langgraph.prebuilt import create_react_agent |
| except ImportError: |
| from langchain.agents import create_agent as create_react_agent |
|
|
| |
| try: |
| from langgraph.errors import GraphRecursionError |
| except ImportError: |
| class GraphRecursionError(Exception): |
| pass |
|
|
| |
| from tools.web_search import web_search |
| from tools.wikipedia_search import wikipedia_search |
| from tools.visit_webpage import visit_webpage |
| from tools.read_file import read_file |
| from tools.transcribe_audio import transcribe_audio |
| from tools.visual_qa import visual_qa |
| from tools.youtube_transcript import youtube_transcript |
| from tools.python_repl import python_repl |
|
|
| |
| from config import LLM_BASE_URL, LLM_API_KEY, LLM_MODEL_ID |
| from answer_key import REFERENCE_ANSWERS, classify_question |
|
|
| |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
|
| |
| |
| |
| SYSTEM_PROMPT = ( |
| "You are a general AI assistant answering questions from the GAIA benchmark. " |
| "Reason step by step and use your tools to gather and verify facts. " |
| "Available tools: web_search, wikipedia_search, visit_webpage, read_file " |
| "(spreadsheets/PDF/Word/code/text), transcribe_audio, visual_qa (images), " |
| "youtube_transcript, and python_repl (run Python for any maths, string or data work).\n" |
| "Tool guidance:\n" |
| "- For ANY question needing a fact, name, number, date, list, or file content you MUST " |
| "call at least one tool (web_search / wikipedia_search / a file tool) before answering. " |
| "NEVER output 'unknown' or a guess without having searched first.\n" |
| "- Plan multi-hop lookups: search, open the most relevant result with visit_webpage, and " |
| "read it carefully to extract the EXACT value (full names, exact spelling). Cross-check " |
| "when sources conflict.\n" |
| "- A question that contains a YouTube URL: call `youtube_transcript` on that URL.\n" |
| "- A question that says a file is attached: a local path is given in the message — open it " |
| "with read_file / transcribe_audio / visual_qa. Never ask the user to upload anything.\n" |
| "- If a tool fails or returns nothing, try a different tool or a reworded query rather than " |
| "repeating the same call.\n" |
| "- Only after genuinely trying the tools may you, as a last resort, give your single best " |
| "guess. Always commit to a concrete answer — never say you cannot answer.\n\n" |
| "Finish with one line in exactly this template:\n" |
| "FINAL ANSWER: [YOUR FINAL ANSWER]\n" |
| "YOUR FINAL ANSWER must be a number OR as few words as possible OR a comma separated " |
| "list of numbers and/or strings. Do not add anything after it.\n" |
| "- For a number: no thousands separators and no units ($, %, ...) unless asked.\n" |
| "- For a string: no articles, no abbreviations, digits written in plain text unless asked.\n" |
| "- For a list: apply these rules to each element." |
| ) |
|
|
|
|
| def build_model(): |
| """创建并返回"驱动 agent 的大模型对象"。参数都来自 config.py。这个模型必须支持"调用工具"。""" |
| from langchain_openai import ChatOpenAI |
|
|
| return ChatOpenAI( |
| model=LLM_MODEL_ID, |
| base_url=LLM_BASE_URL, |
| api_key=LLM_API_KEY, |
| |
| temperature=float(os.getenv("AGENT_TEMPERATURE", "0")), |
| |
| max_tokens=int(os.getenv("AGENT_MAX_TOKENS", "4096")), |
| ) |
|
|
|
|
| def clean_answer(content) -> str: |
| """把大模型那段啰嗦的最终回复,"提纯"成评分服务器要的、干干净净的标准答案。""" |
| |
| if isinstance(content, list): |
| content = " ".join( |
| part.get("text", "") if isinstance(part, dict) else str(part) for part in content |
| ) |
| text = str(content).strip() |
| |
| if "FINAL ANSWER:" in text: |
| text = text.split("FINAL ANSWER:")[-1].strip() |
| |
| text = text.splitlines()[0].strip() if text else text |
| |
| if len(text) >= 2 and text[0] == text[-1] and text[0] in ("'", '"'): |
| text = text[1:-1].strip() |
| |
| return text.rstrip(". ").strip() |
|
|
|
|
| class GAIAAgent: |
| """答题机器人:一个 LangGraph ReAct agent,外加上网、读文件、听音频、看图等全套工具。""" |
|
|
| def __init__(self, api_url: str = DEFAULT_API_URL): |
| self.api_url = api_url |
| tools = [ |
| web_search, |
| wikipedia_search, |
| visit_webpage, |
| read_file, |
| transcribe_audio, |
| visual_qa, |
| youtube_transcript, |
| python_repl, |
| ] |
| self.model = build_model() |
| self.agent = create_react_agent(self.model, tools) |
| |
| self.recursion_limit = int(os.getenv("AGENT_RECURSION_LIMIT", "40")) |
| print("GAIAAgent initialized.") |
|
|
| def _download_file(self, task_id: str, file_name: str) -> str | None: |
| """从 GAIA 数据集官方地址下载附件(需 HF_TOKEN 认证)。""" |
| if not file_name: |
| return None |
| |
| file_url = f"https://huggingface.co/datasets/gaia-benchmark/GAIA/resolve/main/2023/validation/{file_name}" |
| try: |
| headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"} |
| response = requests.get(file_url, headers=headers, timeout=30) |
| response.raise_for_status() |
| except Exception as e: |
| print(f"Could not download file {file_name} for task {task_id}: {e}") |
| return None |
| path = os.path.join(tempfile.mkdtemp(), file_name) |
| with open(path, "wb") as f: |
| f.write(response.content) |
| return path |
|
|
| @staticmethod |
| def _collect_tools(history: list) -> list: |
| """统计这次答题中 agent 实际用过哪些工具(按首次使用的先后顺序),用于结果表格展示。""" |
| used = [] |
| for m in history: |
| for tc in (getattr(m, "tool_calls", None) or []): |
| name = tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None) |
| if name and name not in used: |
| used.append(name) |
| return used |
|
|
| @staticmethod |
| def _file_hint(path: str) -> str: |
| """根据附件后缀名,明确告诉模型"这个文件该用哪个工具"。 |
| (因为模型有时会选错工具——比如对着 mp3 录音却用 read_file 去读,所以这里给个明确提示。)""" |
| ext = os.path.splitext(path)[1].lower() |
| if ext in (".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"): |
| return "It is an AUDIO file: call `transcribe_audio` on this path, then answer from the transcript." |
| if ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"): |
| return "It is an IMAGE: call `visual_qa` on this path with a precise question to read it." |
| return "Call `read_file` on this path to read its contents, then answer." |
|
|
| @staticmethod |
| def _is_giveup(text: str) -> bool: |
| """判断模型这次是不是"摆烂了"——给了空答案,或说"不知道/做不到"之类的放弃性回答。""" |
| low = str(text).strip().lower() |
| return ( |
| not low |
| or low in ("unknown", "nan", "none") |
| or "unable to" in low |
| or "cannot" in low |
| or "i don't" in low |
| ) |
|
|
| @staticmethod |
| def _looks_incomplete(text: str) -> bool: |
| """判断答案是不是"没答完/答歪了"——比如出现"需要更多步骤""我看不到图""请上传文件"等字样。""" |
| low = str(text).strip().lower() |
| if not low: |
| return True |
| markers = ( |
| "need more steps", |
| "i cannot see", |
| "i can't see", |
| "please upload", |
| "i don't see any", |
| "i do not see any", |
| "unable to access", |
| "as an ai", |
| ) |
| return any(m in low for m in markers) |
|
|
| def _force_final(self, history: list) -> str: |
| """最后的兜底手段:不给任何工具,逼模型"就用目前已经查到的信息,立刻给一个确定答案", |
| 这样它就没法再用"需要更多步骤"来拖延了。""" |
| |
| clean_history = [ |
| m for m in history |
| if not (isinstance(m, AIMessage) and "need more steps" in str(m.content).lower()) |
| ] |
| |
| clean_history.append( |
| HumanMessage( |
| content=( |
| "You are out of tool budget. Using only the information already gathered " |
| "above, give your single best answer now. You MUST output exactly one line " |
| "'FINAL ANSWER: [answer]' with a concrete value — guess if you must, and " |
| "never say you cannot answer." |
| ) |
| ) |
| ) |
| |
| return self.model.invoke(clean_history).content |
|
|
| def _run(self, messages: list, task_id): |
| """跑一次完整的答题流程,返回 (答案, 用过的工具列表)。""" |
| try: |
| |
| result = self.agent.invoke( |
| {"messages": messages}, config={"recursion_limit": self.recursion_limit} |
| ) |
| history = result["messages"] |
| tools_used = self._collect_tools(history) |
| answer = history[-1].content |
| |
| if self._looks_incomplete(answer): |
| answer = self._force_final(history) |
| return clean_answer(answer), tools_used |
| except GraphRecursionError: |
| |
| try: |
| return clean_answer(self._force_final(messages)), self._collect_tools(messages) |
| except Exception as e: |
| print(f"Forced-answer failed on task {task_id}: {e}") |
| return "unknown", [] |
| except Exception as e: |
| |
| print(f"Agent error on task {task_id}: {e}") |
| return "unknown", [] |
|
|
| def __call__(self, question: str, task_id: str | None = None, file_name: str | None = None): |
| """让这个机器人对象能像函数一样被"调用"来答一道题。返回 (答案, 用过的工具)。""" |
| user_content = question |
| |
| if file_name: |
| path = self._download_file(task_id, file_name) |
| if path: |
| user_content += f"\n\nA file is attached at local path: {path}\n{self._file_hint(path)}" |
| else: |
| |
| user_content += ( |
| "\n\n(Note: the attached file could not be downloaded. Answer as best you " |
| "can from the question text alone.)" |
| ) |
| |
| messages = [SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content=user_content)] |
| answer, tools_used = self._run(messages, task_id) |
|
|
| |
| if not tools_used and self._is_giveup(answer): |
| retry = messages + [ |
| HumanMessage( |
| content=( |
| "You answered without using any tool, which is not allowed. Call " |
| "web_search or wikipedia_search now, read the results, and then give " |
| "the FINAL ANSWER." |
| ) |
| ) |
| ] |
| answer2, tools2 = self._run(retry, task_id) |
| |
| if tools2 or not self._is_giveup(answer2): |
| answer, tools_used = answer2, tools2 |
| return answer, tools_used |
|
|
|
|
| |
| |
| LAST_SUBMISSION: dict = {} |
|
|
|
|
| def _submit_with_retry(payload: dict, retries: int = 4): |
| """把答案 POST 提交到 /submit 接口。遇到 5xx 服务器错误或网络问题会自动重试; |
| 遇到 4xx(我们这边请求有问题)则把具体原因返回,方便人去修。""" |
| submit_url = f"{DEFAULT_API_URL}/submit" |
| last = None |
| for attempt in range(retries): |
| try: |
| resp = requests.post(submit_url, json=payload, timeout=120) |
| if resp.status_code >= 500: |
| last = f"{resp.status_code} server error: {resp.text[:200]}" |
| print(f"submit attempt {attempt + 1}: {last}") |
| time.sleep(4 * (attempt + 1)) |
| continue |
| resp.raise_for_status() |
| return True, resp.json() |
| except requests.exceptions.HTTPError as e: |
| |
| detail = e.response.text[:300] |
| try: |
| detail = e.response.json().get("detail", detail) |
| except Exception: |
| pass |
| return False, f"HTTP {e.response.status_code}: {detail}" |
| except Exception as e: |
| |
| last = str(e) |
| print(f"submit attempt {attempt + 1} failed: {last}") |
| time.sleep(4 * (attempt + 1)) |
| return False, f"all {retries} attempts failed (last error: {last})" |
|
|
|
|
| def _format_result(result_data: dict) -> str: |
| """把评分服务器返回的结果,整理成一段人类易读的文字(用户名、总分、对了几道、附言)。""" |
| return ( |
| f"Submission Successful!\n" |
| f"User: {result_data.get('username')}\n" |
| f"Overall Score: {result_data.get('score', 'N/A')}% " |
| f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" |
| f"Message: {result_data.get('message', 'No message received.')}" |
| ) |
|
|
|
|
| def run_and_submit_all(profile: gr.OAuthProfile | None): |
| """【一键全流程】取回所有题目 → 逐题让 agent 作答 → 缓存答案 → 提交。这是主按钮触发的函数。""" |
| space_id = os.getenv("SPACE_ID") |
|
|
| |
| if profile: |
| username = f"{profile.username}" |
| print(f"User logged in: {username}") |
| else: |
| return "Please Login to Hugging Face with the button.", None |
|
|
| |
| if not space_id: |
| return ( |
| "SPACE_ID not found. Run this on your public Hugging Face Space — the scoring " |
| "server validates the agent_code link and a missing/invalid Space can cause a 500.", |
| None, |
| ) |
| |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" |
| print(agent_code) |
|
|
| |
| try: |
| response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15) |
| response.raise_for_status() |
| questions_data = response.json() |
| if not questions_data: |
| return "Fetched questions list is empty or invalid format.", None |
| print(f"Fetched {len(questions_data)} questions.") |
| except Exception as e: |
| return f"Error fetching questions: {e}", None |
|
|
| |
| try: |
| agent = GAIAAgent(api_url=DEFAULT_API_URL) |
| except Exception as e: |
| return f"Error initializing agent: {e}", None |
|
|
| |
| results_log = [] |
| answers_payload = [] |
| for item in questions_data: |
| task_id = item.get("task_id") |
| question_text = item.get("question") |
| file_name = item.get("file_name") or "" |
| if not task_id or question_text is None: |
| continue |
| answer, tools_used = agent(question_text, task_id=task_id, file_name=file_name) |
| answer = (str(answer).strip() or "unknown") |
| answers_payload.append({"task_id": task_id, "submitted_answer": answer}) |
| results_log.append( |
| { |
| "Task ID": task_id, |
| "Type": classify_question(question_text, file_name), |
| "Question": question_text, |
| "Reference Answer": REFERENCE_ANSWERS.get(task_id, ""), |
| "Submitted Answer": answer, |
| "Tools Used": ", ".join(tools_used) if tools_used else "(none)", |
| } |
| ) |
|
|
| if not answers_payload: |
| return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) |
|
|
| |
| payload = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} |
| LAST_SUBMISSION.update(payload=payload, results_log=results_log) |
| print(f"Submitting {len(answers_payload)} answers for user '{username}'...") |
|
|
| df = pd.DataFrame(results_log) |
| ok, data = _submit_with_retry(payload) |
| if ok: |
| return _format_result(data), df |
| |
| return ( |
| f"Submission Failed: {data}\n\n" |
| "Your answers are cached — fix the issue (most often: make the Space Public) and click " |
| "'Re-submit last answers' to retry WITHOUT re-running the agent.", |
| df, |
| ) |
|
|
|
|
| def submit_only(profile: gr.OAuthProfile | None): |
| """【重新提交】只把缓存里的答案再交一次,不重新跑 agent(省时省钱)。""" |
| if not LAST_SUBMISSION.get("payload"): |
| return "No cached answers yet — run the evaluation first.", None |
| df = pd.DataFrame(LAST_SUBMISSION.get("results_log", [])) |
| ok, data = _submit_with_retry(LAST_SUBMISSION["payload"]) |
| if ok: |
| return _format_result(data), df |
| return f"Submission Failed again: {data}", df |
|
|
|
|
| |
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# GAIA Agent Evaluation Runner (LangGraph)") |
| gr.Markdown( |
| """ |
| 1. Log in to your Hugging Face account with the button below. |
| 2. Click 'Run Evaluation & Submit All Answers' to fetch the questions, run the agent |
| and submit the answers. This can take several minutes. |
| 3. If submission fails (the scoring server sometimes returns 500), click |
| 'Re-submit last answers' to retry without re-running the agent. |
| |
| The model endpoint is preconfigured in `config.py`, so no secrets are required. |
| Make sure this Space is **Public**, otherwise the scoring server can reject the |
| submission with a 500. |
| """ |
| ) |
|
|
| gr.LoginButton() |
| with gr.Row(): |
| run_button = gr.Button("Run Evaluation & Submit All Answers", variant="primary") |
| resubmit_button = gr.Button("Re-submit last answers") |
| status_output = gr.Textbox(label="Run Status / Submission Result", lines=6, interactive=False) |
| results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) |
|
|
| |
| run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table]) |
| resubmit_button.click(fn=submit_only, outputs=[status_output, results_table]) |
|
|
|
|
| |
| if __name__ == "__main__": |
| demo.launch(debug=True, share=False) |
|
|