Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import re | |
| from pathlib import Path | |
| from typing import Any | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langgraph.graph import END, StateGraph | |
| from gaia_agent.answer import normalize_answer | |
| from gaia_agent.config import settings | |
| from gaia_agent.llms import create_chat_model | |
| from gaia_agent.observability import traced_step | |
| from gaia_agent.prompts import ( | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| GAIA_QUERY_PROMPT, | |
| GAIA_VERIFY_PROMPT, | |
| ) | |
| from gaia_agent.state import GaiaState | |
| from gaia_agent.tools.files import ( | |
| download_task_file, | |
| read_text_file, | |
| summarize_spreadsheet, | |
| ) | |
| from gaia_agent.tools.media import image_data_url, transcribe_audio_file | |
| from gaia_agent.tools.python_repl import run_python_file | |
| from gaia_agent.tools.web import ( | |
| extract_urls, | |
| fetch_url, | |
| get_youtube_transcript, | |
| web_search, | |
| ) | |
| MAX_EVIDENCE_CHARS = 36_000 | |
| MAX_WEB_PAGES = 4 | |
| def build_graph(trace=None, llm=None): | |
| graph = StateGraph(GaiaState) | |
| chat_model = llm or create_chat_model() | |
| def ingest_task(state: GaiaState) -> dict[str, Any]: | |
| def run() -> dict[str, Any]: | |
| evidence = list(state.get("evidence", [])) | |
| output: dict[str, Any] = { | |
| "evidence": evidence, | |
| "tool_outputs": list(state.get("tool_outputs", [])), | |
| } | |
| if state.get("file_path") or not state.get("file_name"): | |
| return output | |
| try: | |
| path = download_task_file( | |
| settings.gaia_api_url, | |
| state["task_id"], | |
| state.get("file_name"), | |
| ) | |
| output["file_path"] = str(path) | |
| evidence.append(f"Downloaded attached file to {path}.") | |
| except Exception as exc: | |
| evidence.append(f"Could not download attached file: {exc}") | |
| output["error"] = str(exc) | |
| return output | |
| return traced_step(trace, "ingest_task", run) | |
| def classify_task(state: GaiaState) -> dict[str, str]: | |
| def run() -> dict[str, str]: | |
| question = state["question"].lower() | |
| file_name = state.get("file_name", "").lower() | |
| if file_name.endswith((".xlsx", ".xls", ".csv")): | |
| task_type = "spreadsheet" | |
| elif file_name.endswith(".py"): | |
| task_type = "python_file" | |
| elif file_name.endswith((".mp3", ".wav", ".m4a", ".ogg", ".flac")): | |
| task_type = "audio" | |
| elif file_name.endswith((".png", ".jpg", ".jpeg", ".webp")): | |
| task_type = "image" | |
| elif "youtube.com" in question or "youtu.be" in question: | |
| task_type = "youtube" | |
| elif _looks_like_computation(question): | |
| task_type = "compute" | |
| elif _looks_like_direct(question): | |
| task_type = "direct" | |
| else: | |
| task_type = "web" | |
| return {"task_type": task_type} | |
| return traced_step(trace, "classify_task", run) | |
| def solve_direct(state: GaiaState) -> dict[str, Any]: | |
| def run() -> dict[str, Any]: | |
| answer = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| f"Question:\n{state['question']}", | |
| ) | |
| return {"draft_answer": answer} | |
| return traced_step(trace, "solve_direct", run) | |
| def solve_compute(state: GaiaState) -> dict[str, Any]: | |
| def run() -> dict[str, Any]: | |
| answer = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| ( | |
| "Solve this question carefully. If it includes a table or " | |
| "formal rule, compute the requested value exactly.\n\n" | |
| f"Question:\n{state['question']}" | |
| ), | |
| ) | |
| return {"draft_answer": answer} | |
| return traced_step(trace, "solve_compute", run) | |
| def solve_spreadsheet(state: GaiaState) -> dict[str, Any]: | |
| def run() -> dict[str, Any]: | |
| evidence = list(state.get("evidence", [])) | |
| path = state.get("file_path") | |
| if not path: | |
| evidence.append("Attached spreadsheet is unavailable.") | |
| answer = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| _question_with_evidence(state["question"], evidence), | |
| ) | |
| return {"evidence": evidence, "draft_answer": answer} | |
| summary = summarize_spreadsheet(path) | |
| evidence.append(f"Spreadsheet summary:\n{summary}") | |
| answer = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| _question_with_evidence(state["question"], evidence), | |
| ) | |
| return {"evidence": evidence, "draft_answer": answer} | |
| return traced_step(trace, "solve_spreadsheet", run) | |
| def solve_python_file(state: GaiaState) -> dict[str, Any]: | |
| def run() -> dict[str, Any]: | |
| evidence = list(state.get("evidence", [])) | |
| path = state.get("file_path") | |
| if not path: | |
| evidence.append("Attached Python file is unavailable.") | |
| answer = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| _question_with_evidence(state["question"], evidence), | |
| ) | |
| return {"evidence": evidence, "draft_answer": answer} | |
| source = read_text_file(path, max_chars=30_000) | |
| result = run_python_file(path) | |
| evidence.append(f"Attached Python source:\n{source}") | |
| evidence.append( | |
| "Python execution result:\n" | |
| f"exit_code={result['exit_code']}\n" | |
| f"stdout:\n{result['stdout']}\n" | |
| f"stderr:\n{result['stderr']}" | |
| ) | |
| stdout = str(result.get("stdout", "")).strip() | |
| if stdout and not str(result.get("stderr", "")).strip(): | |
| draft = stdout.splitlines()[-1] | |
| verified = draft | |
| else: | |
| draft = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| _question_with_evidence(state["question"], evidence), | |
| ) | |
| verified = "" | |
| output = {"evidence": evidence, "draft_answer": draft} | |
| if verified: | |
| output["verified_answer"] = verified | |
| return output | |
| return traced_step(trace, "solve_python_file", run) | |
| def solve_audio(state: GaiaState) -> dict[str, Any]: | |
| def run() -> dict[str, Any]: | |
| evidence = list(state.get("evidence", [])) | |
| path = state.get("file_path") | |
| if not path: | |
| evidence.append("Attached audio file is unavailable.") | |
| else: | |
| try: | |
| transcript = transcribe_audio_file(path) | |
| evidence.append(f"Audio transcript:\n{transcript}") | |
| except Exception as exc: | |
| evidence.append(f"Audio transcription failed: {exc}") | |
| answer = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| _question_with_evidence(state["question"], evidence), | |
| ) | |
| return {"evidence": evidence, "draft_answer": answer} | |
| return traced_step(trace, "solve_audio", run) | |
| def solve_image(state: GaiaState) -> dict[str, Any]: | |
| def run() -> dict[str, Any]: | |
| evidence = list(state.get("evidence", [])) | |
| path = state.get("file_path") | |
| if path: | |
| try: | |
| answer = _invoke_image(chat_model, state["question"], path) | |
| evidence.append(f"Image analyzed from {path}.") | |
| except Exception as exc: | |
| evidence.append(f"Image analysis failed: {exc}") | |
| answer = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| _question_with_evidence(state["question"], evidence), | |
| ) | |
| else: | |
| evidence.append("Attached image file is unavailable.") | |
| answer = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| _question_with_evidence(state["question"], evidence), | |
| ) | |
| return {"evidence": evidence, "draft_answer": answer} | |
| return traced_step(trace, "solve_image", run) | |
| def solve_youtube(state: GaiaState) -> dict[str, Any]: | |
| def run() -> dict[str, Any]: | |
| evidence = list(state.get("evidence", [])) | |
| urls = extract_urls(state["question"]) | |
| for url in urls: | |
| if "youtube.com" not in url and "youtu.be" not in url: | |
| continue | |
| try: | |
| transcript = get_youtube_transcript(url) | |
| evidence.append(f"YouTube transcript for {url}:\n{transcript}") | |
| except Exception as exc: | |
| evidence.append(f"YouTube transcript failed for {url}: {exc}") | |
| answer = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| _question_with_evidence(state["question"], evidence), | |
| ) | |
| return {"evidence": evidence, "draft_answer": answer} | |
| return traced_step(trace, "solve_youtube", run) | |
| def solve_web(state: GaiaState) -> dict[str, Any]: | |
| def run() -> dict[str, Any]: | |
| evidence = list(state.get("evidence", [])) | |
| queries = _build_search_queries(chat_model, state["question"]) | |
| seen_urls: set[str] = set() | |
| for query in queries: | |
| try: | |
| results = web_search(query, max_results=5) | |
| except Exception as exc: | |
| evidence.append(f"Search failed for {query!r}: {exc}") | |
| continue | |
| if results: | |
| evidence.append( | |
| "Search results for " | |
| f"{query!r}:\n" | |
| + "\n".join(f"- {item.title}: {item.url}" for item in results) | |
| ) | |
| for result in results: | |
| if len(seen_urls) >= MAX_WEB_PAGES: | |
| break | |
| if result.url in seen_urls: | |
| continue | |
| seen_urls.add(result.url) | |
| try: | |
| page_text = fetch_url(result.url, max_chars=12_000) | |
| except Exception as exc: | |
| evidence.append(f"Fetch failed for {result.url}: {exc}") | |
| continue | |
| evidence.append(f"Page: {result.title}\nURL: {result.url}\n{page_text}") | |
| answer = _invoke_text( | |
| chat_model, | |
| GAIA_AGENT_SYSTEM_PROMPT, | |
| _question_with_evidence(state["question"], evidence), | |
| ) | |
| return {"evidence": evidence, "draft_answer": answer} | |
| return traced_step(trace, "solve_web", run) | |
| def verify_answer(state: GaiaState) -> dict[str, str]: | |
| def run() -> dict[str, str]: | |
| if state.get("verified_answer"): | |
| return {"verified_answer": state["verified_answer"]} | |
| evidence = _trim_evidence(state.get("evidence", [])) | |
| verified = _invoke_text( | |
| chat_model, | |
| GAIA_VERIFY_PROMPT, | |
| ( | |
| f"Question:\n{state['question']}\n\n" | |
| f"Evidence:\n{evidence}\n\n" | |
| f"Draft answer:\n{state.get('draft_answer', '')}" | |
| ), | |
| ) | |
| return {"verified_answer": verified} | |
| return traced_step(trace, "verify_answer", run) | |
| def normalize_final_answer(state: GaiaState) -> dict[str, str]: | |
| def run() -> dict[str, str]: | |
| answer = state.get("verified_answer") or state.get("draft_answer", "") | |
| return {"final_answer": normalize_answer(answer)} | |
| return traced_step(trace, "normalize_final_answer", run) | |
| graph.add_node("ingest_task", ingest_task) | |
| graph.add_node("classify_task", classify_task) | |
| graph.add_node("solve_direct", solve_direct) | |
| graph.add_node("solve_compute", solve_compute) | |
| graph.add_node("solve_spreadsheet", solve_spreadsheet) | |
| graph.add_node("solve_python_file", solve_python_file) | |
| graph.add_node("solve_audio", solve_audio) | |
| graph.add_node("solve_image", solve_image) | |
| graph.add_node("solve_youtube", solve_youtube) | |
| graph.add_node("solve_web", solve_web) | |
| graph.add_node("verify_answer", verify_answer) | |
| graph.add_node("normalize_final_answer", normalize_final_answer) | |
| graph.set_entry_point("ingest_task") | |
| graph.add_edge("ingest_task", "classify_task") | |
| graph.add_conditional_edges( | |
| "classify_task", | |
| lambda state: state.get("task_type", "web"), | |
| { | |
| "direct": "solve_direct", | |
| "compute": "solve_compute", | |
| "spreadsheet": "solve_spreadsheet", | |
| "python_file": "solve_python_file", | |
| "audio": "solve_audio", | |
| "image": "solve_image", | |
| "youtube": "solve_youtube", | |
| "web": "solve_web", | |
| }, | |
| ) | |
| for node in ( | |
| "solve_direct", | |
| "solve_compute", | |
| "solve_spreadsheet", | |
| "solve_python_file", | |
| "solve_audio", | |
| "solve_image", | |
| "solve_youtube", | |
| "solve_web", | |
| ): | |
| graph.add_edge(node, "verify_answer") | |
| graph.add_edge("verify_answer", "normalize_final_answer") | |
| graph.add_edge("normalize_final_answer", END) | |
| return graph.compile() | |
| def _invoke_text(chat_model, system_prompt: str, user_prompt: str) -> str: | |
| response = chat_model.invoke( | |
| [ | |
| ("system", system_prompt), | |
| ("user", user_prompt), | |
| ] | |
| ) | |
| return str(response.content) | |
| def _invoke_image(chat_model, question: str, path: str | Path) -> str: | |
| response = chat_model.invoke( | |
| [ | |
| SystemMessage(content=GAIA_AGENT_SYSTEM_PROMPT), | |
| HumanMessage( | |
| content=[ | |
| {"type": "text", "text": question}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": image_data_url(path)}, | |
| }, | |
| ] | |
| ), | |
| ] | |
| ) | |
| return str(response.content) | |
| def _build_search_queries(chat_model, question: str) -> list[str]: | |
| raw_queries = _invoke_text( | |
| chat_model, | |
| GAIA_QUERY_PROMPT, | |
| f"Question:\n{question}", | |
| ) | |
| queries = [ | |
| re.sub(r"^\s*[-*\d.)]+\s*", "", line).strip() | |
| for line in raw_queries.splitlines() | |
| if line.strip() | |
| ] | |
| queries = [query.strip("\"'") for query in queries if len(query.strip("\"'")) > 3] | |
| if question not in queries: | |
| queries.append(question) | |
| return queries[:3] | |
| def _question_with_evidence(question: str, evidence: list[str]) -> str: | |
| return f"Question:\n{question}\n\nEvidence:\n{_trim_evidence(evidence)}" | |
| def _trim_evidence(evidence: list[str]) -> str: | |
| text = "\n\n---\n\n".join(evidence) | |
| if len(text) <= MAX_EVIDENCE_CHARS: | |
| return text | |
| return f"{text[:MAX_EVIDENCE_CHARS]}\n\n[trimmed after {MAX_EVIDENCE_CHARS} chars]" | |
| def _looks_like_computation(question: str) -> bool: | |
| markers = ( | |
| "given this table", | |
| "provide the subset", | |
| "counter-examples", | |
| "not commutative", | |
| "calculate", | |
| "numeric output", | |
| ) | |
| return any(marker in question for marker in markers) | |
| def _looks_like_direct(question: str) -> bool: | |
| if question.count(" ") <= 8: | |
| return True | |
| if _looks_reversed(question): | |
| return True | |
| direct_markers = ( | |
| "grocery list", | |
| "categorizing things", | |
| "write the opposite", | |
| ) | |
| return any(marker in question for marker in direct_markers) | |
| def _looks_reversed(question: str) -> bool: | |
| words = re.findall(r"[a-z]{4,}", question) | |
| if len(words) < 3: | |
| return False | |
| reversed_common = {"rewsna", "drow", "etirw", "ecnetnes", "dnatsrednu"} | |
| return len(reversed_common.intersection(words)) >= 2 | |