Spaces:
Runtime error
Runtime error
| import os | |
| import pandas as pd | |
| from langgraph.graph import StateGraph, START, MessagesState | |
| from langgraph.prebuilt import tools_condition, ToolNode | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.tools import tool | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
| from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| import requests | |
| import tempfile | |
| # Lade Umgebungsvariablen (Google API Key) | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| # === Tools definieren === | |
| GAIA_BASE_URL = "https://agents-course-unit4-scoring.hf.space" | |
| def fetch_gaia_file(task_id: str) -> str: | |
| """ | |
| Download the file attached to a GAIA task and return the local file-path. | |
| Args: | |
| task_id: The GAIA task_id (string in the JSON payload). | |
| Returns: | |
| Absolute path to the downloaded temp-file. | |
| """ | |
| try: | |
| url = f"{GAIA_BASE_URL}/files/{task_id}" | |
| response = requests.get(url, timeout=20) | |
| response.raise_for_status() | |
| # Server liefert den echten Dateinamen im Header – fallback auf "download" | |
| filename = ( | |
| response.headers.get("x-filename") or | |
| response.headers.get("content-disposition", "download").split("filename=")[-1].strip('"') | |
| ) | |
| if not filename: | |
| filename = f"{task_id}.bin" | |
| tmp_path = os.path.join(tempfile.gettempdir(), filename) | |
| with open(tmp_path, "wb") as f: | |
| f.write(response.content) | |
| return tmp_path | |
| except Exception as e: | |
| return f"ERROR: could not download file for task {task_id}: {e}" | |
| def parse_csv(file_path: str, query: str = "") -> str: | |
| """ | |
| Load a CSV file from `file_path` and optionally run a simple analysis query. | |
| Args: | |
| file_path: absolute path to a CSV file (from fetch_gaia_file) | |
| query: optional natural-language instruction, e.g. | |
| "sum of column Sales where Category != 'Drinks'" | |
| Returns: | |
| A concise string with the answer OR a preview of the dataframe | |
| if no query given. | |
| """ | |
| try: | |
| df = pd.read_csv(file_path) | |
| # Auto-preview if kein query | |
| if not query: | |
| preview = df.head(5).to_markdown(index=False) | |
| return f"CSV loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}" | |
| # Mini-query-engine (sehr simpel, reicht für Summen / Mittelwerte) | |
| query_lc = query.lower() | |
| if "sum" in query_lc: | |
| # ermitteln, welche Spalte summiert werden soll | |
| for col in df.columns: | |
| if col.lower() in query_lc: | |
| s = df[col] | |
| if "where" in query_lc: | |
| # naive Filter-Parsing: where <col> != 'Drinks' | |
| cond_part = query_lc.split("where", 1)[1].strip() | |
| # SEHR einfaches != oder == Parsing | |
| if "!=" in cond_part: | |
| key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")] | |
| s = df.loc[df[key] != val, col] | |
| elif "==" in cond_part: | |
| key, val = [x.strip().strip("'\"") for x in cond_part.split("==")] | |
| s = df.loc[df[key] == val, col] | |
| return str(round(s.sum(), 2)) | |
| # Fallback | |
| return "Query type not supported by parse_csv." | |
| except Exception as e: | |
| return f"ERROR parsing CSV: {e}" | |
| def parse_excel(file_path: str, query: str = "") -> str: | |
| """ | |
| Identisch zu parse_csv, nur für XLS/XLSX. | |
| """ | |
| try: | |
| df = pd.read_excel(file_path) | |
| if not query: | |
| preview = df.head(5).to_markdown(index=False) | |
| return f"Excel loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}" | |
| query_lc = query.lower() | |
| if "sum" in query_lc: | |
| for col in df.columns: | |
| if col.lower() in query_lc: | |
| s = df[col] | |
| if "where" in query_lc: | |
| cond_part = query_lc.split("where", 1)[1].strip() | |
| if "!=" in cond_part: | |
| key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")] | |
| s = df.loc[df[key] != val, col] | |
| elif "==" in cond_part: | |
| key, val = [x.strip().strip("'\"") for x in cond_part.split("==")] | |
| s = df.loc[df[key] == val, col] | |
| return str(round(s.sum(), 2)) | |
| return "Query type not supported by parse_excel." | |
| except Exception as e: | |
| return f"ERROR parsing Excel: {e}" | |
| def transcribe_audio(file_path: str, language: str = "en") -> str: | |
| """ | |
| Transcribe an audio file (MP3/WAV/etc.) using Faster-Whisper. | |
| Args: | |
| file_path: absolute path to an audio file (from fetch_gaia_file) | |
| language: ISO language code, default "en" | |
| Returns: | |
| Full transcription as plain text, or "ERROR …" | |
| """ | |
| try: | |
| from faster_whisper import WhisperModel | |
| # Tiny model reicht für kurze Sprachmemos, ~75 MB | |
| model = WhisperModel("tiny", device="cpu", compute_type="int8") | |
| segments, _ = model.transcribe(file_path, language=language) | |
| transcript = " ".join(segment.text.strip() for segment in segments).strip() | |
| if not transcript: | |
| return "ERROR: transcription empty." | |
| return transcript | |
| except Exception as e: | |
| return f"ERROR: audio transcription failed – {e}" | |
| def multiply(a: int, b: int) -> int: | |
| """Multiplies two numbers.""" | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| """Adds two numbers.""" | |
| return a + b | |
| def subtract(a: int, b: int) -> int: | |
| """Subtracts two numbers.""" | |
| return a - b | |
| def divide(a: int, b: int) -> float: | |
| """Divides two numbers.""" | |
| if b == 0: | |
| raise ValueError("Cannot divide by zero.") | |
| return a / b | |
| def modulo(a: int, b: int) -> int: | |
| """Returns the remainder of dividing two numbers.""" | |
| return a % b | |
| def wiki_search(query: str) -> str: | |
| """Search Wikipedia for a query and return the result.""" | |
| search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| return "\n\n".join(doc.page_content for doc in search_docs) | |
| def arxiv_search(query: str) -> str: | |
| """Search Arxiv for academic papers about a query.""" | |
| search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| return "\n\n".join(doc.page_content[:1000] for doc in search_docs) | |
| def web_search(query: str) -> str: | |
| """Perform a DuckDuckGo web search.""" | |
| wrapper = DuckDuckGoSearchAPIWrapper(max_results=5) | |
| results = wrapper.run(query) | |
| return results | |
| # === System Prompt definieren === | |
| system_prompt = SystemMessage(content=( | |
| system_prompt = SystemMessage( | |
| content=( | |
| "You are a focused, factual AI agent competing on the GAIA evaluation.\n" | |
| "\n" | |
| "GENERAL RULES\n" | |
| "-------------\n" | |
| "1. Always try to answer every question.\n" | |
| "2. If you are NOT 100 % certain, prefer using a TOOL.\n" | |
| "3. Never invent facts.\n" | |
| "\n" | |
| "TOOLS\n" | |
| "-----\n" | |
| "- fetch_gaia_file(task_id): downloads any attachment for the current task.\n" | |
| "- parse_csv(file_path, query): analyse CSV files.\n" | |
| "- parse_excel(file_path, query): analyse Excel files.\n" | |
| "- transcribe_audio(file_path): transcribe MP3 / WAV audio.\n" | |
| "- wiki_search(query): query English Wikipedia.\n" | |
| "- arxiv_search(query): query arXiv.\n" | |
| "- web_search(query): DuckDuckGo web search.\n" | |
| "- simple_calculator(operation,a,b): basic maths.\n" | |
| "\n" | |
| "WHEN TO USE WHICH TOOL\n" | |
| "----------------------\n" | |
| "・If the prompt or GAIA metadata mentions an *attached* file, FIRST call " | |
| "fetch_gaia_file with the given task_id. Then:\n" | |
| " • CSV → parse_csv\n" | |
| " • XLS/XLSX → parse_excel\n" | |
| " • MP3/WAV → transcribe_audio (language auto-detect is OK)\n" | |
| " • Image → (currently unsupported) answer that image processing is unavailable\n" | |
| "・If you need factual data (dates, numbers, names) → wiki_search or web_search.\n" | |
| "・If you need a scientific paper → arxiv_search.\n" | |
| "・If a numeric operation is required → simple_calculator.\n" | |
| "\n" | |
| "ERROR HANDLING\n" | |
| "--------------\n" | |
| "If a tool call returns a string that starts with \"ERROR:\", IMMEDIATELY think of " | |
| "an alternative strategy: retry with a different tool or modified parameters. " | |
| "Do not repeat the same failing call twice.\n" | |
| "\n" | |
| "OUTPUT FORMAT\n" | |
| "-------------\n" | |
| "Follow the exact format asked in the question (e.g. single word, CSV, comma-list). " | |
| "Do not add extra commentary.\n" | |
| ) | |
| ) | |
| )) | |
| # === LLM definieren === | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash", | |
| google_api_key=GOOGLE_API_KEY, | |
| temperature=0, | |
| max_output_tokens=2048, | |
| system_message=system_prompt, | |
| ) | |
| # === Tools in LLM einbinden === | |
| tools = [ | |
| fetch_gaia_file, | |
| parse_csv, | |
| parse_excel, | |
| transcribe_audio, | |
| wiki_search, | |
| arxiv_search, | |
| web_search, | |
| simple_calculator, | |
| ] | |
| llm_with_tools = llm.bind_tools(tools) | |
| def safe_llm_invoke(messages): | |
| """ | |
| Ruft LLM einmal auf. Wenn das Ergebnis mit ERROR beginnt, | |
| ruft es genau EIN weiteres Mal auf – jetzt weiß das LLM, | |
| dass der vorige Tool-Call fehlgeschlagen ist. | |
| """ | |
| max_attempts = 2 | |
| for attempt in range(max_attempts): | |
| result = llm_with_tools.invoke(messages) | |
| content = result.content if hasattr(result, "content") else "" | |
| if "ERROR:" not in content: | |
| return result | |
| # Fehler: füge eine System-Korrektur hinzu und versuche erneut | |
| messages.append( | |
| SystemMessage( | |
| content="Previous tool call returned an ERROR. " | |
| "Try a different tool or revise the input." | |
| ) | |
| ) | |
| # nach max_attempts immer noch Fehler → zurückgeben | |
| return result | |
| # === Nodes für LangGraph === | |
| def assistant(state: MessagesState): | |
| """ | |
| Assistant node mit eingebautem Retry bei Tool-Fehlern. | |
| """ | |
| result_msg = safe_llm_invoke(state["messages"]) | |
| return {"messages": [result_msg]} | |
| # === LangGraph bauen === | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges("assistant", tools_condition) | |
| builder.add_edge("tools", "assistant") | |
| # === Agent Executor === | |
| agent_executor = builder.compile() |