BiGuan commited on
Commit
951a5f7
·
verified ·
1 Parent(s): 72161cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -18
app.py CHANGED
@@ -35,6 +35,7 @@ QWEN_MODEL = "qwen3.5-35b-a3b"
35
  # 进度监控器
36
  # =============================================================================
37
  class ProgressMonitor:
 
38
  def __init__(self):
39
  self.current = 0
40
  self.total = 0
@@ -81,6 +82,7 @@ class ProgressMonitor:
81
  # Qwen LLM 封装
82
  # =============================================================================
83
  class QwenLLM:
 
84
  def __init__(self, model=QWEN_MODEL):
85
  self.model = model
86
  self.api_key = AGICTO_API_KEY
@@ -185,7 +187,7 @@ class QwenLLM:
185
  return formatted
186
 
187
  # =============================================================================
188
- # 工具定义(所有工具均附带 description)
189
  # =============================================================================
190
  api_url_tasks = DEFAULT_API_URL
191
 
@@ -195,6 +197,7 @@ def _get_api_base():
195
  base = base[:-3]
196
  return base
197
 
 
198
  @tool(description="搜索互联网信息,返回相关摘要。")
199
  def web_search(query: str) -> str:
200
  try:
@@ -312,10 +315,54 @@ def download_file_for_task(task_id: str) -> str:
312
  os.unlink(temp_path)
313
  return result
314
  else:
 
315
  return resp.text[:4000]
316
  except Exception as e:
317
  return f"文件下载失败: {e}"
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  # =============================================================================
320
  # LangGraph 状态与节点
321
  # =============================================================================
@@ -323,9 +370,20 @@ class AgentState(TypedDict):
323
  messages: Annotated[Sequence[BaseMessage], operator.add]
324
  final_answer: str
325
  task_id: str
326
- tool_attempts: int # 已执行工具调用次数
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- tools = [web_search, web_scraper, calculator, analyze_image, transcribe_audio, get_youtube_transcript, download_file_for_task]
329
  tool_node = ToolNode(tools)
330
  llm = QwenLLM()
331
  functions = [convert_to_openai_function(t) for t in tools]
@@ -334,9 +392,13 @@ llm_with_tools = llm.bind_functions(functions)
334
  def agent_node(state: AgentState) -> dict:
335
  messages = state["messages"]
336
  task_id = state.get("task_id", "")
337
- sys_prompt = f"""You are a helpful assistant answering GAIA Level 1 questions.
338
- IMPORTANT: You MUST use at least one tool (e.g., web_search, web_scraper, download_file_for_task) to verify or retrieve information, even if you think you already know the answer.
339
- When you have the final answer, output only the answer string, without any extra text or "FINAL ANSWER:".
 
 
 
 
340
  Current task ID: {task_id}. If the question requires a file, use download_file_for_task with task_id="{task_id}"."""
341
  full = [SystemMessage(content=sys_prompt)] + list(messages)
342
  response = llm_with_tools.invoke(full)
@@ -346,28 +408,28 @@ def should_continue(state: AgentState) -> str:
346
  messages = state["messages"]
347
  last = messages[-1]
348
  tool_attempts = state.get("tool_attempts", 0)
349
- MAX_TOOL_CALLS = 5
350
 
351
- # 超过最大调用次数,强制结束
352
  if tool_attempts >= MAX_TOOL_CALLS:
353
  return "finish"
354
 
355
- # 如果 LLM 请求了工具调用,允许执行
356
  if hasattr(last, "additional_kwargs") and "function_call" in last.additional_kwargs:
357
  return "tools"
358
 
359
- # 尚未调用过任何工具?强制要求使用工具
360
  tool_msg_count = sum(1 for m in messages if isinstance(m, ToolMessage))
361
  if tool_msg_count == 0:
362
  return "force_tool"
363
 
364
- # 已经用过工具可以结束
 
 
 
 
365
  return "finish"
366
 
367
  def force_tool_node(state: AgentState) -> dict:
368
  new_msg = HumanMessage(
369
- content="You have not used any tools yet. Please use at least one tool to find or verify the answer. "
370
- "Search the web, download a file, or analyze an image if provided."
371
  )
372
  return {"messages": [new_msg]}
373
 
@@ -381,17 +443,15 @@ def finish_node(state: AgentState) -> dict:
381
  if "FINAL ANSWER:" in answer:
382
  answer = answer.split("FINAL ANSWER:")[-1].strip()
383
 
384
- # 若答案仍为空,尝试从历史消息中提取最后一条有内容的 AI 消息
385
  if not answer:
386
  for m in reversed(state["messages"]):
387
  if isinstance(m, AIMessage) and m.content.strip():
388
  answer = m.content.strip().split("\n")[-1].strip()
389
  break
390
 
391
- # 依然无答案时,输出原因
392
  if not answer:
393
- if state.get("tool_attempts", 0) >= 5:
394
- answer = "Unable to determine answer: reached maximum tool calls without conclusion."
395
  else:
396
  answer = "Unable to determine answer: insufficient information."
397
 
@@ -521,7 +581,8 @@ with gr.Blocks(title="GAIA Agent") as demo:
521
  gr.Markdown("""
522
  # 🤖 GAIA Level 1 Agent (LangGraph + Qwen)
523
  **模型:** Qwen3.5-35B-A3B | **API:** agicto.com
524
- 点击按钮获取题目,Agent 自动调用工具并回答,最后提交评分。
 
525
  """)
526
  gr.LoginButton()
527
  run_btn = gr.Button("🚀 运行评测并提交", variant="primary")
 
35
  # 进度监控器
36
  # =============================================================================
37
  class ProgressMonitor:
38
+ # ... 保持不变 ...
39
  def __init__(self):
40
  self.current = 0
41
  self.total = 0
 
82
  # Qwen LLM 封装
83
  # =============================================================================
84
  class QwenLLM:
85
+ # ... 保持不变 ...
86
  def __init__(self, model=QWEN_MODEL):
87
  self.model = model
88
  self.api_key = AGICTO_API_KEY
 
187
  return formatted
188
 
189
  # =============================================================================
190
+ # 工具定义
191
  # =============================================================================
192
  api_url_tasks = DEFAULT_API_URL
193
 
 
197
  base = base[:-3]
198
  return base
199
 
200
+ # --- 原有工具 ---
201
  @tool(description="搜索互联网信息,返回相关摘要。")
202
  def web_search(query: str) -> str:
203
  try:
 
315
  os.unlink(temp_path)
316
  return result
317
  else:
318
+ # 对于文本文件(包括 .py, .txt 等),直接返回文本内容
319
  return resp.text[:4000]
320
  except Exception as e:
321
  return f"文件下载失败: {e}"
322
 
323
+ # --- 新增:维基百科搜索工具 ---
324
+ @tool(description="在维基百科中搜索关键词,返回页面摘要或详细信息。")
325
+ def search_wikipedia(query: str) -> str:
326
+ """
327
+ 使用维基百科 API 搜索关键词。
328
+ 首先尝试 opensearch 获取页面标题,然后用 extract 获取摘要。
329
+ """
330
+ try:
331
+ # 第一步:搜索相关页面标题
332
+ search_url = "https://en.wikipedia.org/w/api.php"
333
+ params = {
334
+ "action": "opensearch",
335
+ "search": query,
336
+ "limit": 1,
337
+ "format": "json"
338
+ }
339
+ resp = requests.get(search_url, params=params, timeout=10)
340
+ data = resp.json()
341
+ titles = data[1] # 标题列表
342
+ if not titles:
343
+ return "维基百科未找到相关页面。"
344
+ title = titles[0]
345
+ # 第二步:获取页面摘要
346
+ extract_params = {
347
+ "action": "query",
348
+ "prop": "extracts",
349
+ "exintro": True,
350
+ "explaintext": True,
351
+ "titles": title,
352
+ "format": "json"
353
+ }
354
+ resp2 = requests.get(search_url, params=extract_params, timeout=10)
355
+ data2 = resp2.json()
356
+ pages = data2.get("query", {}).get("pages", {})
357
+ for page_id, page_info in pages.items():
358
+ extract = page_info.get("extract", "")
359
+ if extract:
360
+ # 返回前2000字符,避免过长
361
+ return f"Wikipedia - {title}:\n{extract[:2000]}"
362
+ return f"维基百科页面 '{title}' 未提供摘要。"
363
+ except Exception as e:
364
+ return f"维基百科搜索失败: {e}"
365
+
366
  # =============================================================================
367
  # LangGraph 状态与节点
368
  # =============================================================================
 
370
  messages: Annotated[Sequence[BaseMessage], operator.add]
371
  final_answer: str
372
  task_id: str
373
+ tool_attempts: int
374
+
375
+ # 所有工具(包含新增的 search_wikipedia)
376
+ tools = [
377
+ search_wikipedia, # 优先搜索维基百科
378
+ web_search, # 备用网络搜索
379
+ web_scraper,
380
+ calculator,
381
+ analyze_image,
382
+ transcribe_audio,
383
+ get_youtube_transcript,
384
+ download_file_for_task
385
+ ]
386
 
 
387
  tool_node = ToolNode(tools)
388
  llm = QwenLLM()
389
  functions = [convert_to_openai_function(t) for t in tools]
 
392
  def agent_node(state: AgentState) -> dict:
393
  messages = state["messages"]
394
  task_id = state.get("task_id", "")
395
+ # 更新系统提示,强调维基百科、文件处理和 YouTube 工具的使用
396
+ sys_prompt = f"""You are a helpful assistant answering GAIA Level 1 questions.
397
+ IMPORTANT GUIDELINES:
398
+ - For fact-based questions, first try to find the answer using the `search_wikipedia` tool. Only if Wikipedia fails, use `web_search` or other tools.
399
+ - If the question provides a file (image, audio, or code), use `download_file_for_task` with the given task_id to retrieve it. The tool will automatically analyze images, transcribe audio, or return text for Python/text files.
400
+ - For YouTube links, use `get_youtube_transcript` to obtain the captions.
401
+ - When you have the final answer, output ONLY the answer string (a word, number, short phrase, or letter). Do NOT include any extra text, explanations, or "FINAL ANSWER:".
402
  Current task ID: {task_id}. If the question requires a file, use download_file_for_task with task_id="{task_id}"."""
403
  full = [SystemMessage(content=sys_prompt)] + list(messages)
404
  response = llm_with_tools.invoke(full)
 
408
  messages = state["messages"]
409
  last = messages[-1]
410
  tool_attempts = state.get("tool_attempts", 0)
411
+ MAX_TOOL_CALLS = 3 # 限制最多3次工具调用,避免循环
412
 
 
413
  if tool_attempts >= MAX_TOOL_CALLS:
414
  return "finish"
415
 
 
416
  if hasattr(last, "additional_kwargs") and "function_call" in last.additional_kwargs:
417
  return "tools"
418
 
 
419
  tool_msg_count = sum(1 for m in messages if isinstance(m, ToolMessage))
420
  if tool_msg_count == 0:
421
  return "force_tool"
422
 
423
+ # 如果 LLM 已经给出了一个简洁答案,结束
424
+ content = last.content
425
+ if "?" not in content and len(content.strip()) < 100:
426
+ return "finish"
427
+
428
  return "finish"
429
 
430
  def force_tool_node(state: AgentState) -> dict:
431
  new_msg = HumanMessage(
432
+ content="You haven't used any tool yet. Please use an appropriate tool (e.g., search_wikipedia, download_file_for_task) to find the answer."
 
433
  )
434
  return {"messages": [new_msg]}
435
 
 
443
  if "FINAL ANSWER:" in answer:
444
  answer = answer.split("FINAL ANSWER:")[-1].strip()
445
 
 
446
  if not answer:
447
  for m in reversed(state["messages"]):
448
  if isinstance(m, AIMessage) and m.content.strip():
449
  answer = m.content.strip().split("\n")[-1].strip()
450
  break
451
 
 
452
  if not answer:
453
+ if state.get("tool_attempts", 0) >= 3:
454
+ answer = "Unable to determine answer: max tool calls reached."
455
  else:
456
  answer = "Unable to determine answer: insufficient information."
457
 
 
581
  gr.Markdown("""
582
  # 🤖 GAIA Level 1 Agent (LangGraph + Qwen)
583
  **模型:** Qwen3.5-35B-A3B | **API:** agicto.com
584
+ 点击按钮获取题目,Agent 自动调用工具并回答,最后提交评分。
585
+ **新增维基百科搜索、文件处理(图片/音频/代码)、YouTube 字幕提取。**
586
  """)
587
  gr.LoginButton()
588
  run_btn = gr.Button("🚀 运行评测并提交", variant="primary")