Spaces:
Runtime error
Runtime error
| # agent.py – Gemini 2.0 Flash · LangGraph · Mehrere Tools | |
| # ========================================================= | |
| import os, asyncio, base64, mimetypes, tempfile, functools, json | |
| from typing import Dict, Any, List, Optional | |
| from langgraph.graph import START, StateGraph, MessagesState, END | |
| from langgraph.prebuilt import tools_condition, ToolNode | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_core.tools import tool | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| # --------------------------------------------------------------------- | |
| # Konstanten / API-Keys | |
| # --------------------------------------------------------------------- | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| TAVILY_KEY = os.getenv("TAVILY_API_KEY") | |
| # --------------------------------------------------------------------- | |
| # Fehler-Wrapper – behält Doc-String dank wraps | |
| # --------------------------------------------------------------------- | |
| import functools | |
| def error_guard(fn): | |
| def wrapper(*args, **kwargs): | |
| try: | |
| return fn(*args, **kwargs) | |
| except Exception as e: | |
| return f"ERROR: {e}" | |
| return wrapper | |
| # --------------------------------------------------------------------- | |
| # 1) fetch_gaia_file – Datei vom GAIA-Server holen | |
| # --------------------------------------------------------------------- | |
| GAIA_FILE_ENDPOINT = "https://agents-course-unit4-scoring.hf.space/file" | |
| def fetch_gaia_file(task_id: str) -> str: | |
| """Download the attachment for the given GAIA task_id and return local path.""" | |
| url = f"{GAIA_FILE_ENDPOINT}/{task_id}" | |
| try: | |
| response = requests.get(url, timeout=30) | |
| response.raise_for_status() | |
| file_name = response.headers.get("x-gaia-filename", f"{task_id}") | |
| tmp_path = tempfile.gettempdir() + "/" + file_name | |
| with open(tmp_path, "wb") as f: | |
| f.write(response.content) | |
| return tmp_path | |
| except Exception as e: | |
| return f"ERROR: could not fetch file – {e}" | |
| # --------------------------------------------------------------------- | |
| # 2) CSV-Parser | |
| # --------------------------------------------------------------------- | |
| import pandas as pd | |
| def parse_csv(file_path: str, query: str = "") -> str: | |
| """Load a CSV file and answer a quick pandas query (optional).""" | |
| df = pd.read_csv(file_path) | |
| if not query: | |
| return f"Loaded CSV with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}" | |
| try: | |
| result = df.query(query) | |
| return result.to_markdown() | |
| except Exception as e: | |
| return f"ERROR in pandas query: {e}" | |
| # --------------------------------------------------------------------- | |
| # 3) Excel-Parser | |
| # --------------------------------------------------------------------- | |
| def parse_excel(file_path: str, query: str = "") -> str: | |
| """Load an Excel file (first sheet) and answer a pandas query (optional).""" | |
| df = pd.read_excel(file_path) | |
| if not query: | |
| return f"Loaded Excel with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}" | |
| try: | |
| result = df.query(query) | |
| return result.to_markdown() | |
| except Exception as e: | |
| return f"ERROR in pandas query: {e}" | |
| # --------------------------------------------------------------------- | |
| # 4) Gemini-Audio-Transkription | |
| # --------------------------------------------------------------------- | |
| def gemini_transcribe_audio(file_path: str, prompt: str = "Transcribe the audio.") -> str: | |
| """Use Gemini to transcribe an audio file.""" | |
| with open(file_path, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode() | |
| mime = mimetypes.guess_type(file_path)[0] or "audio/mpeg" | |
| message = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": prompt}, | |
| {"type": "media", "data": b64, "mime_type": mime}, | |
| ] | |
| ) | |
| resp = asyncio.run(safe_invoke([message])) | |
| return resp.content if hasattr(resp, "content") else str(resp) | |
| # --------------------------------------------------------------------- | |
| # 5) Bild-Beschreibung | |
| # --------------------------------------------------------------------- | |
| def describe_image(file_path: str, prompt: str = "Describe this image.") -> str: | |
| """Gemini vision – Bild beschreiben.""" | |
| from PIL import Image | |
| img = Image.open(file_path) | |
| message = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": prompt}, | |
| img, # langchain übernimmt Encoding | |
| ] | |
| ) | |
| resp = asyncio.run(safe_invoke([message])) | |
| return resp.content | |
| # --------------------------------------------------------------------- | |
| # 6) OCR-Tool | |
| # --------------------------------------------------------------------- | |
| def ocr_image(file_path: str, lang: str = "eng") -> str: | |
| """Extract text from an image via pytesseract.""" | |
| try: | |
| import pytesseract | |
| from PIL import Image | |
| text = pytesseract.image_to_string(Image.open(file_path), lang=lang) | |
| return text.strip() or "No text found." | |
| except Exception as e: | |
| return f"ERROR: {e}" | |
| # --------------------------------------------------------------------- | |
| # 7) Tavily-Web-Suche | |
| # --------------------------------------------------------------------- | |
| def web_search(query: str, max_results: int = 5) -> str: | |
| """Search the web via Tavily and return a markdown list of results.""" | |
| hits = TavilySearchResults(max_results=max_results, api_key=TAVILY_KEY).invoke(query) | |
| if not hits: | |
| return "No results." | |
| return "\n\n".join(f"{h['title']} – {h['url']}" for h in hits) | |
| # --------------------------------------------------------------------- | |
| # 8) Kleiner Rechner | |
| # --------------------------------------------------------------------- | |
| def simple_calculator(operation: str, a: float, b: float) -> float: | |
| """Basic maths (add, subtract, multiply, divide).""" | |
| ops = { | |
| "add": a + b, | |
| "subtract": a - b, | |
| "multiply": a * b, | |
| "divide": a / b if b else float("inf"), | |
| } | |
| return ops.get(operation, f"ERROR: unknown op '{operation}'") | |
| # --------------------------------------------------------------------- | |
| # LLM + Semaphore-Throttle (Gemini 2.0 Flash) | |
| # --------------------------------------------------------------------- | |
| gemini_llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash", | |
| google_api_key=GOOGLE_API_KEY, | |
| temperature=0, | |
| max_output_tokens=2048, | |
| ).bind_tools([ | |
| fetch_gaia_file, parse_csv, parse_excel, | |
| gemini_transcribe_audio, describe_image, ocr_image, | |
| web_search, simple_calculator,] ,return_named_tools=True) | |
| LLM_SEMA = asyncio.Semaphore(2) # 3 gleichz. Anfragen ≈ < 15/min | |
| # safe_invoke neu (ersetzt die alte Funktion) | |
| async def safe_invoke(msgs, tries: int = 4): | |
| """Gemini-Aufruf mit Semaphor + Exponential-Back-off bei 429 / Netzfehlern.""" | |
| delay = 4 | |
| for t in range(tries): | |
| async with LLM_SEMA: | |
| try: | |
| return await gemini_llm.ainvoke(msgs) | |
| except Exception as e: | |
| # nur bei Rate-Limit oder Netzwerk erneut versuchen | |
| if ("429" in str(e) or "RateLimit" in str(e)) and t < tries - 1: | |
| await asyncio.sleep(delay) | |
| delay *= 2 # 4 s, 8 s, 16 s … | |
| continue | |
| raise | |
| # --------------------------------------------------------------------- | |
| # System-Prompt | |
| # --------------------------------------------------------------------- | |
| system_prompt = SystemMessage(content=""" | |
| You are a helpful assistant tasked with answering questions using a set of tools. | |
| Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: | |
| FINAL ANSWER: [YOUR FINAL ANSWER]. | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
| Your answer should only start with "FINAL ANSWER: ", then follows with the answer. | |
| """) | |
| # --------------------------------------------------------------------- | |
| # LangGraph – Assistant-Node | |
| # --------------------------------------------------------------------- | |
| def assistant(state: MessagesState): | |
| msgs = state["messages"] | |
| if msgs[0].type != "system": | |
| msgs = [system_prompt] + msgs | |
| resp = asyncio.run(safe_invoke(msgs)) | |
| finished = resp.content.lower().lstrip().startswith("final answer") or not resp.tool_calls | |
| return {"messages": [resp], "should_end": finished} | |
| def route(state): | |
| return "END" if state["should_end"] else "tools" | |
| # --------------------------------------------------------------------- | |
| # Tools-Liste & Graph | |
| # --------------------------------------------------------------------- | |
| tools = [ | |
| fetch_gaia_file, parse_csv, parse_excel, | |
| gemini_transcribe_audio, describe_image, ocr_image, | |
| web_search, simple_calculator, | |
| ] | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges("assistant", route, {"tools": "tools", "END": END}) | |
| # Compile | |
| agent_executor = builder.compile() |