Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import requests | |
| import json | |
| import traceback | |
| import datetime | |
| import subprocess | |
| import tempfile | |
| import time | |
| from typing import TypedDict, List, Dict, Any, Optional, Union | |
| from langchain_core import tools | |
| from langgraph.graph import StateGraph, START, END | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage | |
| from langchain_core.tools import tool | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from ddgs import DDGS | |
| from dotenv import load_dotenv | |
| from groq import Groq | |
| from langchain_groq import ChatGroq | |
| from langchain_community.document_loaders.image import UnstructuredImageLoader | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| try: | |
| import cv2 | |
| except ImportError: | |
| cv2 = None | |
| # os.environ["USER_AGENT"] = "gaia-agent/1.0" | |
| whisper_model = None | |
| def get_whisper(): | |
| global whisper_model | |
| if whisper_model is None: | |
| import whisper | |
| # Lazy load the smallest, fastest model | |
| whisper_model = whisper.load_model("base") | |
| return whisper_model | |
| load_dotenv(override=True) | |
| # Base Hugging Face LLM used by the chat wrapper | |
| # base_llm = HuggingFaceEndpoint( | |
| # repo_id="openai/gpt-oss-20b:hyperbolic", | |
| # # deepseek-ai/DeepSeek-OCR:novita | |
| # task="text-generation", | |
| # temperature=0.0, | |
| # huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), | |
| # ) | |
| # Model initializations moved to smart_invoke for lazy loading to prevent import errors if keys are missing. | |
| def smart_invoke(msgs, use_tools=False, start_tier=0): | |
| """ | |
| Tiered fallback: OpenRouter -> Gemini -> Groq -> NVIDIA -> Vercel. | |
| Retries next tier if a 429 (rate limit), 402 (credits), or 404 (model found) error occurs. | |
| """ | |
| # Adaptive Gemini names verified via list_models (REST API) | |
| gemini_alternatives = ["gemini-2.5-flash", "gemini-2.0-flash", "gemini-flash-latest", "gemini-pro-latest"] | |
| tiers_config = [ | |
| {"name": "Qwen3-Next-80B", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "qwen/qwen3-next-80b-a3b-instruct:free", "base_url": "https://openrouter.ai/api/v1"}, | |
| {"name": "Gemma-3-27B", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "google/gemma-3-27b-it:free", "base_url": "https://openrouter.ai/api/v1"}, | |
| {"name": "NVIDIA-Nemotron-Super", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "nvidia/nemotron-3-super-120b-a12b:free", "base_url": "https://openrouter.ai/api/v1"}, | |
| {"name": "OpenRouter-FreeRouter", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "openrouter/free", "base_url": "https://openrouter.ai/api/v1"}, | |
| {"name": "DeepSeek-R1", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "deepseek/deepseek-r1:free", "base_url": "https://openrouter.ai/api/v1"}, | |
| {"name": "Gemini-Flash", "key": "GOOGLE_API_KEY", "provider": "google", "model_name": "gemini-2.0-flash", "alternatives": gemini_alternatives}, | |
| {"name": "Groq", "key": "GROQ_API_KEY", "provider": "groq", "model_name": "llama-3.3-70b-versatile"}, | |
| ] | |
| last_exception = None | |
| for i in range(start_tier, len(tiers_config)): | |
| tier = tiers_config[i] | |
| api_key = os.getenv(tier["key"]) | |
| if not api_key: | |
| continue | |
| def create_model_instance(m_name, provider, b_url=None): | |
| if provider == "openai": | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI(model=m_name, openai_api_key=api_key, openai_api_base=b_url, temperature=0) | |
| elif provider == "google": | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| return ChatGoogleGenerativeAI(model=m_name, temperature=0) | |
| elif provider == "groq": | |
| from langchain_groq import ChatGroq | |
| return ChatGroq(model=m_name, temperature=0, max_retries=2) | |
| return None | |
| primary_model = create_model_instance(tier["model_name"], tier["provider"], tier.get("base_url")) | |
| if use_tools: | |
| primary_model = primary_model.bind_tools(tools) | |
| models_to_try = [primary_model] | |
| if "alternatives" in tier: | |
| for alt_name in tier["alternatives"]: | |
| alt_model = create_model_instance(alt_name, tier["provider"], tier.get("base_url")) | |
| if use_tools: | |
| alt_model = alt_model.bind_tools(tools) | |
| models_to_try.append(alt_model) | |
| for current_model in models_to_try: | |
| try: | |
| model_name = getattr(current_model, "model", tier["name"]) | |
| print(f"--- Calling {tier['name']} ({model_name}) ---") | |
| return current_model.invoke(msgs), i | |
| except Exception as e: | |
| err_str = str(e).lower() | |
| # If it's a 404 (not found) and we have more alternatives, continue to the next alternative | |
| if any(x in err_str for x in ["not_found", "404"]) and current_model != models_to_try[-1]: | |
| print(f"--- {tier['name']} model {model_name} not found. Trying alternative... ---") | |
| continue | |
| # Catch other fallback triggers | |
| if any(x in err_str for x in ["rate_limit", "429", "500", "503", "overloaded", "not_found", "404", "402", "credits", "decommissioned", "invalid_request_error"]): | |
| print(f"--- {tier['name']} Error: {e}. Trying next model/tier... ---") | |
| last_exception = e | |
| # If this tier has more alternatives, continue to the next one | |
| if current_model != models_to_try[-1]: | |
| continue | |
| break # Move to next tier | |
| raise e | |
| if last_exception: | |
| print("CRITICAL: All fallback tiers failed.") | |
| raise last_exception | |
| return None, 0 | |
| def web_search(keywords: str) -> str: | |
| """ | |
| Uses duckduckgo to search the top 5 result on web | |
| Use cases: | |
| - Identify personal information | |
| - Information search | |
| - Finding organisation information | |
| - Obtain the latest news | |
| Args: | |
| keywords: keywords used to search the web | |
| Returns: | |
| Search result (Header + body + url) | |
| """ | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| with DDGS() as ddgs: | |
| output = "" | |
| results = ddgs.text(keywords, max_results = 5) | |
| for result in results: | |
| output += f"Results: {result['title']}\n{result['body']}\n{result['href']}\n\n" | |
| return output | |
| except Exception as e: | |
| if attempt < max_retries - 1: | |
| time.sleep(2 ** attempt) | |
| continue | |
| return f"Search failed after {max_retries} attempts: {str(e)}" | |
| def wiki_search(query: str) -> str: | |
| """ | |
| Search Wikipedia for a query and return up to 3 results. | |
| Use cases: | |
| When the question requires the use of information from wikipedia | |
| Args: | |
| query: The search query | |
| """ | |
| search_docs = WikipediaLoader(query=query, load_max_docs=3, doc_content_chars_max=15000).load() | |
| if not search_docs: | |
| return "No Wikipedia results found." | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("title", "Unknown Title")}"/>\n{doc.page_content}\n</Document>' | |
| for doc in search_docs | |
| ]) | |
| return formatted_search_docs | |
| def get_vision_models(): | |
| """Returns a list of vision models to try, in order of preference.""" | |
| configs = [ | |
| {"name": "OpenRouter-Qwen3-VL", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "qwen/qwen3-vl-235b-thinking:free", "base_url": "https://openrouter.ai/api/v1"}, | |
| {"name": "NVIDIA-Nemotron-VL", "key": "NVIDIA_API_KEY", "provider": "openai", "model_name": "nvidia/nemotron-nano-2-vl:free", "base_url": "https://integrate.api.nvidia.com/v1"}, | |
| {"name": "OpenRouter-Gemma-3-27b-it", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "google/gemma-3-27b-it:free", "base_url": "https://openrouter.ai/api/v1"}, | |
| {"name": "Google-Gemini-2.0-Flash", "key": "GOOGLE_API_KEY", "provider": "google", "model_name": "gemini-2.0-flash"}, | |
| {"name": "Google-Gemini-Flash-Latest", "key": "GOOGLE_API_KEY", "provider": "google", "model_name": "gemini-flash-latest"}, | |
| ] | |
| models = [] | |
| for cfg in configs: | |
| api_key = os.getenv(cfg["key"]) | |
| if not api_key: | |
| continue | |
| if cfg["provider"] == "openai": | |
| from langchain_openai import ChatOpenAI | |
| m = ChatOpenAI(model=cfg["model_name"], openai_api_key=api_key, openai_api_base=cfg.get("base_url"), temperature=0) | |
| elif cfg["provider"] == "google": | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| m = ChatGoogleGenerativeAI(model=cfg["model_name"], temperature=0) | |
| elif cfg["provider"] == "groq": | |
| from langchain_groq import ChatGroq | |
| m = ChatGroq(model=cfg["model_name"], temperature=0) | |
| models.append({"name": cfg["name"], "model": m}) | |
| return models | |
| def analyze_image(image_path: str, question: str) -> str: | |
| """ | |
| EXTERNAL SIGHT API: Sends an image path to a Vision Model to answer a specific question. | |
| YOU MUST CALL THIS TOOL ANY TIME an image (.png, .jpg, .jpeg) is attached to the prompt. | |
| NEVER claim you cannot see images. Use this tool instead. | |
| Args: | |
| image_path: The local path or URL to the image file. | |
| question: Specific question describing what you want the vision model to look for. | |
| """ | |
| try: | |
| if not os.path.exists(image_path): | |
| return f"Error: Image file not found at {image_path}" | |
| # If it's a local file, we encode it to base64 | |
| with open(image_path, "rb") as image_file: | |
| encoded_image = base64.b64encode(image_file.read()).decode('utf-8') | |
| message = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": question}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}, | |
| }, | |
| ] | |
| ) | |
| vision_models = get_vision_models() | |
| if not vision_models: | |
| return "Error: No vision models configured (missing API keys)." | |
| last_err = None | |
| for item in vision_models: | |
| try: | |
| m_name = getattr(item['model'], 'model', 'unknown') | |
| print(f"--- Calling Vision Model: {item['name']} ({m_name}) ---") | |
| response = item['model'].invoke([message]) | |
| return extract_text_from_content(response.content) | |
| except Exception as e: | |
| print(f"Vision Model {item['name']} failed.") | |
| traceback.print_exc() | |
| last_err = e | |
| return f"Error analyzing image: All vision models failed. Last error: {str(last_err)}" | |
| except Exception as e: | |
| traceback.print_exc() | |
| return f"Error reading/processing image: {str(e)}" | |
| def analyze_audio(audio_path: str, question: str) -> str: | |
| """ | |
| Transcribes an audio file (.mp3, .wav, .m4a) to answer questions about what is spoken. | |
| Args: | |
| audio_path: The local path to the audio file. | |
| question: The specific question to ask. | |
| """ | |
| try: | |
| model = get_whisper() | |
| result = model.transcribe(audio_path) | |
| transcript = result["text"] | |
| return f"Audio Transcript:\n{transcript}" | |
| except Exception as e: | |
| return f"Error analyzing audio: {str(e)}. Tip: You requires 'ffmpeg' installed on your system." | |
| def analyze_video(video_path: str, question: str) -> str: | |
| """ | |
| EXTERNAL SIGHT/HEARING API: Sends a video file to an external Vision/Audio model. | |
| YOU MUST CALL THIS TOOL ANY TIME a video (.mp4, .avi) is attached to the prompt. | |
| NEVER claim you cannot analyze videos. Use this tool instead. | |
| Args: | |
| video_path: The local path to the video file. | |
| question: Specific question describing what you want to extract from the video. | |
| """ | |
| if cv2 is None: | |
| return "Error: cv2 is not installed. Please install opencv-python." | |
| temp_dir = tempfile.gettempdir() | |
| downloaded_video = None | |
| try: | |
| # Check if video_path is a URL | |
| if video_path.startswith("http"): | |
| print(f"Downloading video from URL: {video_path}") | |
| downloaded_video = os.path.join(temp_dir, f"video_{int(time.time())}.mp4") | |
| try: | |
| # Use yt-dlp to download the video | |
| # Note: --ffmpeg-location could be used if we knew where it was, but we assume it's in path or missing | |
| subprocess.run(["yt-dlp", "-f", "best[ext=mp4]/mp4", "-o", downloaded_video, video_path], check=True, timeout=120) | |
| video_path = downloaded_video | |
| except Exception as e: | |
| return f"Error downloading video from URL: {str(e)}. Tip: Check if yt-dlp is installed and the URL is valid." | |
| # 1. Extract frames evenly spaced throughout the video | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames == 0: | |
| return "Error: Could not read video frames." | |
| # Take 5 frames as a summary | |
| frame_indices = [int(i * total_frames / 5) for i in range(5)] | |
| extracted_descriptions = [] | |
| vision_models = get_vision_models() | |
| # Ensure Groq-Llama is at the front for video if preferred, but we'll use the default order for now. | |
| for idx_num, frame_idx in enumerate(frame_indices): | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| # Convert frame to base64 | |
| _, buffer = cv2.imencode('.jpg', frame) | |
| encoded_image = base64.b64encode(buffer).decode('utf-8') | |
| # Ask a vision model to describe the frame (with fallback) | |
| msg = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": f"Describe what is happening in this video frame concisely. Focus on aspects related to: {question}"}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}, | |
| ] | |
| ) | |
| desc = "No description available." | |
| for item in vision_models: | |
| try: | |
| print(f"--- Calling Vision Model for Frame {idx_num+1}: {item['name']} ---") | |
| desc = item['model'].invoke([msg]).content | |
| break | |
| except Exception as e: | |
| print(f"Vision Model {item['name']} failed for frame: {e}") | |
| continue | |
| extracted_descriptions.append(f"Frame {idx_num + 1}: {desc}") | |
| cap.release() | |
| # 2. Compile the context for the agent | |
| video_context = "\n".join(extracted_descriptions) | |
| # 3. Transcribe audio if possible | |
| try: | |
| whisper_mod = get_whisper() | |
| trans_result = whisper_mod.transcribe(video_path) | |
| transcript = trans_result.get("text", "") | |
| if transcript.strip(): | |
| video_context += f"\n\nVideo Audio Transcript:\n{transcript}" | |
| except Exception as e: | |
| video_context += f"\n\n(No audio transcript generated: {e})" | |
| return f"Video Summary based on extracted frames and audio:\n{video_context}" | |
| except Exception as e: | |
| err_msg = str(e) | |
| if "No address associated with hostname" in err_msg or "Failed to resolve" in err_msg: | |
| return f"Error: The environment cannot access the internet (DNS failure). Please use 'web_search' or 'wiki_search' to find information about this video content instead of trying to download it." | |
| return f"Error analyzing video: {err_msg}" | |
| finally: | |
| if downloaded_video and os.path.exists(downloaded_video): | |
| try: | |
| os.remove(downloaded_video) | |
| except: | |
| pass | |
| def read_url(url: str) -> str: | |
| """ | |
| Reads and extracts text from a specific webpage URL. | |
| Use this if a web search snippet doesn't contain enough detail. | |
| """ | |
| try: | |
| loader = WebBaseLoader(url) | |
| docs = loader.load() | |
| # Truncate to first 15000 characters to fit context | |
| if not docs: | |
| return "No content could be extracted from this URL." | |
| return docs[0].page_content[:15000] | |
| except Exception as e: | |
| return f"Error reading URL: {e}" | |
| def run_python_script(code: str) -> str: | |
| """ | |
| Executes a Python script locally and returns the stdout and stderr. | |
| Use this to perform complex math, data analysis (e.g. pandas), or file processing. | |
| When given a file path, you can write python code to read and analyze it. | |
| """ | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: | |
| f.write(code) | |
| temp_file_name = f.name | |
| try: | |
| result = subprocess.run( | |
| ["python", temp_file_name], | |
| capture_output=True, | |
| text=True, | |
| timeout=60 | |
| ) | |
| os.remove(temp_file_name) | |
| output = result.stdout | |
| if result.stderr: | |
| output += f"\nErrors:\n{result.stderr}" | |
| return (output or "Script executed successfully with no output.")[:15000] | |
| except subprocess.TimeoutExpired: | |
| os.remove(temp_file_name) | |
| return "Script execution timed out after 60 seconds." | |
| except Exception as e: | |
| if os.path.exists(temp_file_name): | |
| os.remove(temp_file_name) | |
| return f"Failed to execute script: {str(e)}" | |
| def read_document(file_path: str) -> str: | |
| """ | |
| Reads the text contents of a local document (.txt, .csv, .json, .md). | |
| For binary files like .xlsx or .pdf, use run_python_script to process them instead. | |
| """ | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| if len(content) > 15000: | |
| return content[:15000] + "... (truncated)" | |
| return content | |
| except Exception as e: | |
| return f"Error reading document: {str(e)}. Tip: You can try running a python script to read it!" | |
| system_prompt = """ | |
| You are a helpful assistant tasked with answering questions using a set of tools. | |
| Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: | |
| FINAL ANSWER: [YOUR FINAL ANSWER]. | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
| Your answer should only start with "FINAL ANSWER: ", then follows with the answer. | |
| """ | |
| class AgentState(TypedDict): | |
| messages: List[Union[HumanMessage, AIMessage, SystemMessage]] | |
| def read_message(state: AgentState) -> AgentState: | |
| messages = state["messages"] | |
| print(f"Processing question: {messages[-1].content if messages else ''}") | |
| # Just pass the messages through to the next node | |
| return {"messages": messages} | |
| def restart_required(state: AgentState) -> AgentState: | |
| messages = state["messages"] | |
| print(f"Processing question: {messages[-1].content if messages else ''}") | |
| # Just pass the messages through to the next node | |
| return {"messages": messages} | |
| # def tool_message(state: AgentState) -> AgentState: | |
| # messages = state["messages"] | |
| # prompt = f""" | |
| # You are a GAIA question answering expert. | |
| # Your task is to decide whether to use a tool or not. | |
| # If you need to use a tool, answer ONLY: | |
| # CALL_TOOL: <your tool name> | |
| # If you do not need to use a tool, answer ONLY: | |
| # NO_TOOL | |
| # Here is the question: | |
| # {messages} | |
| # """ | |
| # return {"messages": messages} | |
| # response = model_with_tools.invoke(prompt) | |
| # return {"messages": messages + [response]} | |
| # Augment the LLM with tools | |
| tools = [web_search, wiki_search, analyze_image, analyze_audio, analyze_video, read_url, run_python_script, read_document] | |
| tools_by_name = {tool.name: tool for tool in tools} | |
| def extract_text_from_content(content: Any) -> str: | |
| """Extracts a simple string from various possible AIMessage content formats.""" | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| text_parts = [] | |
| for part in content: | |
| if isinstance(part, str): | |
| text_parts.append(part) | |
| elif isinstance(part, dict) and "text" in part: | |
| text_parts.append(part["text"]) | |
| elif isinstance(part, dict) and "type" in part and part["type"] == "text": | |
| text_parts.append(part.get("text", "")) | |
| return "".join(text_parts) | |
| return str(content) | |
| def answer_message(state: AgentState) -> AgentState: | |
| messages = state["messages"] | |
| current_date = datetime.datetime.now().strftime("%Y-%m-%d") | |
| prompt = [SystemMessage(f""" | |
| You are a master of the GAIA benchmark, a general AI assistant designed to solve complex multi-step tasks. | |
| Think carefully and logically. Use your tools effectively. Use your internal monologue to plan your steps. | |
| TODAY'S EXACT DATE is {current_date}. Keep this in mind for all time-sensitive queries. | |
| CRITICAL RULES: | |
| 1. If you see a path like `[Attached File Local Path: ...]` followed by an image, video, or audio file, YOU MUST USE THE CORRESPONDING TOOL (analyze_image, analyze_video, analyze_audio) IMMEDIATELY in your next step. | |
| 2. Plan your steps ahead. 12 steps is your LIMIT for the reasoning loop, so make every step count. | |
| 3. If a tool fails (e.g., 429 or 402), the system will automatically try another model for you, so just keep going! | |
| 4. Be concise and accurate. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list. | |
| 5. CHAIN-OF-THOUGHT: For complex questions, show your reasoning step by step before giving the final answer. | |
| 6. USE TOOLS AGGRESSIVELY: If a question requires computation, file reading, or web search, use the appropriate tools - don't try to answer from memory. | |
| 7. VERIFY YOUR ANSWER: Double-check calculations and facts using tools when uncertain. | |
| """)] | |
| messages = prompt + messages | |
| # Force tool usage if image path is detected | |
| for msg in state["messages"]: | |
| if isinstance(msg, HumanMessage) and "[Attached File Local Path:" in msg.content: | |
| messages.append(HumanMessage(content="IMPORTANT: I see an image path in the message. I MUST call the analyze_image tool IMMEDIATELY in my next step to see it.")) | |
| # Multi-step ReAct Loop (Up to 12 reasoning steps) | |
| max_steps = 12 | |
| draft_response = None | |
| current_tier = 0 | |
| for step in range(max_steps): | |
| if step > 0: | |
| time.sleep(3) | |
| print(f"--- ReAct Step {step + 1} ---") | |
| # Max history truncation to avoid 413 Request Too Large errors | |
| safe_messages = messages[:2] + messages[-6:] if len(messages) > 10 else messages | |
| ai_msg, current_tier = smart_invoke(safe_messages, use_tools=True, start_tier=current_tier) | |
| messages.append(ai_msg) | |
| # Check if the model requested tools | |
| tool_calls = getattr(ai_msg, "tool_calls", None) or [] | |
| if not tool_calls: | |
| # Model decided it has enough info to answer | |
| draft_response = ai_msg | |
| print(f"Model found answer or stopped tools: {ai_msg.content}") | |
| break | |
| # Execute requested tools and append their text output into the conversation | |
| for tool_call in tool_calls: | |
| name = tool_call["name"] | |
| args = tool_call["args"] | |
| tool_call_id = tool_call.get("id") | |
| print(f"Calling tool: {name} with args: {args}") | |
| try: | |
| tool = tools_by_name[name] | |
| tool_result = tool.invoke(args) | |
| except Exception as e: | |
| tool_result = f"Error executing tool {name}: {str(e)}" | |
| # Using ToolMessage allows the model to map the result back perfectly to its request | |
| messages.append(ToolMessage(content=str(tool_result), tool_call_id=tool_call_id, name=name)) | |
| # If we exhausted all steps without an answer, force a draft response | |
| if draft_response is None: | |
| print("Max reasoning steps reached. Forcing answer extraction.") | |
| forced_msg = HumanMessage(content="You have reached the maximum reasoning steps. Please provide your best final answer based on the current context without any more tool calls.") | |
| messages.append(forced_msg) | |
| draft_response, _ = smart_invoke(messages, use_tools=False) | |
| # Third pass: strict GAIA formatting extraction | |
| formatting_sys = SystemMessage( | |
| content=( | |
| "You are a strict output formatter for the GAIA benchmark. " | |
| "Given a verbose draft answer, extract ONLY the final exact answer required. " | |
| "Return nothing else. DO NOT include prefixes like 'The answer is'. " | |
| "Strip trailing whitespace only. " | |
| "If the answer is a number, just return the number. " | |
| "If the answer is a list or set of elements, return them as a COMMA-SEPARATED list (e.g., 'a, b, c'). " | |
| "Preserve necessary punctuation within answers (e.g., 'Dr. Smith' should keep the period)." | |
| ) | |
| ) | |
| final_response, _ = smart_invoke([formatting_sys, HumanMessage(content=extract_text_from_content(draft_response.content))], use_tools=False, start_tier=current_tier) | |
| print(f"Draft response: {draft_response.content}") | |
| print(f"Strict Final response: {final_response.content}") | |
| # Return messages including the final AIMessage so BasicAgent reads .content | |
| # Ensure final_response has string content for basic agents | |
| if not isinstance(final_response.content, str): | |
| final_response.content = extract_text_from_content(final_response.content) | |
| messages.append(draft_response) | |
| messages.append(final_response) | |
| return {"messages": messages} | |
| def build_graph(): | |
| agent_graph = StateGraph(AgentState) | |
| # Add nodes | |
| agent_graph.add_node("read_message", read_message) | |
| agent_graph.add_node("answer_message", answer_message) | |
| # Add edges | |
| agent_graph.add_edge(START, "read_message") | |
| agent_graph.add_edge("read_message", "answer_message") | |
| # Final edge | |
| agent_graph.add_edge("answer_message", END) | |
| # Compile and return the executable graph for use in app.py | |
| compiled_graph = agent_graph.compile() | |
| return compiled_graph | |