Spaces:
Runtime error
Runtime error
| # agent.py | |
| import os | |
| import time | |
| import functools | |
| import pandas as pd | |
| from typing import Dict, Any, List | |
| import re | |
| from langgraph.graph import StateGraph, START, END, MessagesState | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_core.tools import tool | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.utilities.wikipedia import WikipediaAPIWrapper | |
| try: | |
| from langchain_experimental.tools.python.tool import PythonAstREPLTool | |
| except ImportError: | |
| from langchain.tools.python.tool import PythonAstREPLTool | |
| # --------------------------------------------------------------------- | |
| # LangSmith optional | |
| # --------------------------------------------------------------------- | |
| if os.getenv("LANGCHAIN_API_KEY"): | |
| os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
| os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" | |
| os.environ.setdefault("LANGCHAIN_PROJECT", "gaia-agent") | |
| print("📱 LangSmith tracing enabled.") | |
| # --------------------------------------------------------------------- | |
| # Fehler-Wrapper | |
| # --------------------------------------------------------------------- | |
| def error_guard(fn): | |
| def wrapper(*args, **kw): | |
| try: | |
| return fn(*args, **kw) | |
| except Exception as e: | |
| return f"ERROR: {e}" | |
| return wrapper | |
| # --------------------------------------------------------------------- | |
| # Eigene Tools | |
| # --------------------------------------------------------------------- | |
| def parse_csv(file_path: str, query: str = "") -> str: | |
| df = pd.read_csv(file_path) | |
| if not query: | |
| return f"Rows={len(df)}, Cols={list(df.columns)}" | |
| return df.query(query).to_markdown(index=False) | |
| def parse_excel(file_path: str, sheet: str | int | None = None, query: str = "") -> str: | |
| sheet_arg = int(sheet) if isinstance(sheet, str) and sheet.isdigit() else sheet or 0 | |
| df = pd.read_excel(file_path, sheet_name=sheet_arg) | |
| if not query: | |
| return f"Rows={len(df)}, Cols={list(df.columns)}" | |
| return df.query(query).to_markdown(index=False) | |
| def web_search(query: str, max_results: int = 5) -> str: | |
| api_key = os.getenv("TAVILY_API_KEY") | |
| hits = TavilySearchResults(max_results=max_results, api_key=api_key).invoke(query) | |
| if not hits: | |
| return "No results." | |
| return "\n".join(f"{h['title']} – {h['url']}" for h in hits) | |
| def wiki_search(query: str, sentences: int = 3) -> str: | |
| wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=4000) | |
| res = wrapper.run(query) | |
| return "\n".join(res.split(". ")[:sentences]) if res else "No article found." | |
| # Python Tool | |
| python_repl = PythonAstREPLTool() | |
| # --------------------------------------------------------------------- | |
| # Gemini LLM | |
| # --------------------------------------------------------------------- | |
| gemini_llm = ChatGoogleGenerativeAI( | |
| google_api_key=os.getenv("GOOGLE_API_KEY"), | |
| model="gemini-2.0-flash", | |
| temperature=0, | |
| max_output_tokens=2048, | |
| ) | |
| SYSTEM_PROMPT = SystemMessage( | |
| content=( | |
| "You are a helpful assistant with access to tools.\n" | |
| "Use tools when appropriate using tool calls.\n" | |
| "If the answer is clear, return it directly without explanation." | |
| ) | |
| ) | |
| TOOLS = [web_search, wiki_search, parse_csv, parse_excel, python_repl] | |
| # --------------------------------------------------------------------- | |
| # LangGraph Nodes | |
| # --------------------------------------------------------------------- | |
| def planner(state: MessagesState): | |
| messages = state["messages"] | |
| if not any(m.type == "system" for m in messages): | |
| messages = [SYSTEM_PROMPT] + messages | |
| resp = gemini_llm.invoke(messages) | |
| return {"messages": messages + [resp]} | |
| def should_end(state: MessagesState) -> bool: | |
| last = state["messages"][-1] | |
| return not getattr(last, "tool_calls", None) | |
| # --------------------------------------------------------------------- | |
| # Build Graph | |
| # --------------------------------------------------------------------- | |
| graph = StateGraph(MessagesState) | |
| graph.add_node("planner", planner) | |
| graph.add_node("tools", ToolNode(TOOLS)) | |
| graph.add_edge(START, "planner") | |
| graph.add_conditional_edges( | |
| "planner", | |
| lambda state: "END" if should_end(state) else "tools", | |
| {"tools": "tools", "END": END}, | |
| ) | |
| graph.add_edge("tools", "planner") | |
| agent_executor = graph.compile() | |
| # --------------------------------------------------------------------- | |
| # Öffentliche Klasse | |
| # --------------------------------------------------------------------- | |
| class GaiaAgent: | |
| def __init__(self): | |
| print("✅ GaiaAgent initialised (LangGraph)") | |
| def __call__(self, task_id: str, question: str) -> str: | |
| state = {"messages": [HumanMessage(content=question)]} | |
| final = agent_executor.invoke(state) | |
| return final["messages"][-1].content.strip() |