Paperbag's picture
claude fix
21be703
import os
import base64
import requests
import json
import traceback
import datetime
import subprocess
import tempfile
import time
from typing import TypedDict, List, Dict, Any, Optional, Union
from langchain_core import tools
from langgraph.graph import StateGraph, START, END
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from langchain_community.document_loaders import WikipediaLoader
from ddgs import DDGS
from dotenv import load_dotenv
from groq import Groq
from langchain_groq import ChatGroq
from langchain_community.document_loaders.image import UnstructuredImageLoader
from langchain_community.document_loaders import WebBaseLoader
from langchain_google_genai import ChatGoogleGenerativeAI
try:
import cv2
except ImportError:
cv2 = None
# os.environ["USER_AGENT"] = "gaia-agent/1.0"
whisper_model = None
def get_whisper():
global whisper_model
if whisper_model is None:
import whisper
# Lazy load the smallest, fastest model
whisper_model = whisper.load_model("base")
return whisper_model
load_dotenv(override=True)
# Base Hugging Face LLM used by the chat wrapper
# base_llm = HuggingFaceEndpoint(
# repo_id="openai/gpt-oss-20b:hyperbolic",
# # deepseek-ai/DeepSeek-OCR:novita
# task="text-generation",
# temperature=0.0,
# huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
# )
# Model initializations moved to smart_invoke for lazy loading to prevent import errors if keys are missing.
def smart_invoke(msgs, use_tools=False, start_tier=0):
"""
Tiered fallback: OpenRouter -> Gemini -> Groq -> NVIDIA -> Vercel.
Retries next tier if a 429 (rate limit), 402 (credits), or 404 (model found) error occurs.
"""
# Adaptive Gemini names verified via list_models (REST API)
gemini_alternatives = ["gemini-2.5-flash", "gemini-2.0-flash", "gemini-flash-latest", "gemini-pro-latest"]
tiers_config = [
{"name": "Qwen3-Next-80B", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "qwen/qwen3-next-80b-a3b-instruct:free", "base_url": "https://openrouter.ai/api/v1"},
{"name": "Gemma-3-27B", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "google/gemma-3-27b-it:free", "base_url": "https://openrouter.ai/api/v1"},
{"name": "NVIDIA-Nemotron-Super", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "nvidia/nemotron-3-super-120b-a12b:free", "base_url": "https://openrouter.ai/api/v1"},
{"name": "OpenRouter-FreeRouter", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "openrouter/free", "base_url": "https://openrouter.ai/api/v1"},
{"name": "DeepSeek-R1", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "deepseek/deepseek-r1:free", "base_url": "https://openrouter.ai/api/v1"},
{"name": "Gemini-Flash", "key": "GOOGLE_API_KEY", "provider": "google", "model_name": "gemini-2.0-flash", "alternatives": gemini_alternatives},
{"name": "Groq", "key": "GROQ_API_KEY", "provider": "groq", "model_name": "llama-3.3-70b-versatile"},
]
last_exception = None
for i in range(start_tier, len(tiers_config)):
tier = tiers_config[i]
api_key = os.getenv(tier["key"])
if not api_key:
continue
def create_model_instance(m_name, provider, b_url=None):
if provider == "openai":
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=m_name, openai_api_key=api_key, openai_api_base=b_url, temperature=0)
elif provider == "google":
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=m_name, temperature=0)
elif provider == "groq":
from langchain_groq import ChatGroq
return ChatGroq(model=m_name, temperature=0, max_retries=2)
return None
primary_model = create_model_instance(tier["model_name"], tier["provider"], tier.get("base_url"))
if use_tools:
primary_model = primary_model.bind_tools(tools)
models_to_try = [primary_model]
if "alternatives" in tier:
for alt_name in tier["alternatives"]:
alt_model = create_model_instance(alt_name, tier["provider"], tier.get("base_url"))
if use_tools:
alt_model = alt_model.bind_tools(tools)
models_to_try.append(alt_model)
for current_model in models_to_try:
try:
model_name = getattr(current_model, "model", tier["name"])
print(f"--- Calling {tier['name']} ({model_name}) ---")
return current_model.invoke(msgs), i
except Exception as e:
err_str = str(e).lower()
# If it's a 404 (not found) and we have more alternatives, continue to the next alternative
if any(x in err_str for x in ["not_found", "404"]) and current_model != models_to_try[-1]:
print(f"--- {tier['name']} model {model_name} not found. Trying alternative... ---")
continue
# Catch other fallback triggers
if any(x in err_str for x in ["rate_limit", "429", "500", "503", "overloaded", "not_found", "404", "402", "credits", "decommissioned", "invalid_request_error"]):
print(f"--- {tier['name']} Error: {e}. Trying next model/tier... ---")
last_exception = e
# If this tier has more alternatives, continue to the next one
if current_model != models_to_try[-1]:
continue
break # Move to next tier
raise e
if last_exception:
print("CRITICAL: All fallback tiers failed.")
raise last_exception
return None, 0
@tool
def web_search(keywords: str) -> str:
"""
Uses duckduckgo to search the top 5 result on web
Use cases:
- Identify personal information
- Information search
- Finding organisation information
- Obtain the latest news
Args:
keywords: keywords used to search the web
Returns:
Search result (Header + body + url)
"""
max_retries = 3
for attempt in range(max_retries):
try:
with DDGS() as ddgs:
output = ""
results = ddgs.text(keywords, max_results = 5)
for result in results:
output += f"Results: {result['title']}\n{result['body']}\n{result['href']}\n\n"
return output
except Exception as e:
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return f"Search failed after {max_retries} attempts: {str(e)}"
@tool
def wiki_search(query: str) -> str:
"""
Search Wikipedia for a query and return up to 3 results.
Use cases:
When the question requires the use of information from wikipedia
Args:
query: The search query
"""
search_docs = WikipediaLoader(query=query, load_max_docs=3, doc_content_chars_max=15000).load()
if not search_docs:
return "No Wikipedia results found."
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("title", "Unknown Title")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return formatted_search_docs
def get_vision_models():
"""Returns a list of vision models to try, in order of preference."""
configs = [
{"name": "OpenRouter-Qwen3-VL", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "qwen/qwen3-vl-235b-thinking:free", "base_url": "https://openrouter.ai/api/v1"},
{"name": "NVIDIA-Nemotron-VL", "key": "NVIDIA_API_KEY", "provider": "openai", "model_name": "nvidia/nemotron-nano-2-vl:free", "base_url": "https://integrate.api.nvidia.com/v1"},
{"name": "OpenRouter-Gemma-3-27b-it", "key": "OPENROUTER_API_KEY", "provider": "openai", "model_name": "google/gemma-3-27b-it:free", "base_url": "https://openrouter.ai/api/v1"},
{"name": "Google-Gemini-2.0-Flash", "key": "GOOGLE_API_KEY", "provider": "google", "model_name": "gemini-2.0-flash"},
{"name": "Google-Gemini-Flash-Latest", "key": "GOOGLE_API_KEY", "provider": "google", "model_name": "gemini-flash-latest"},
]
models = []
for cfg in configs:
api_key = os.getenv(cfg["key"])
if not api_key:
continue
if cfg["provider"] == "openai":
from langchain_openai import ChatOpenAI
m = ChatOpenAI(model=cfg["model_name"], openai_api_key=api_key, openai_api_base=cfg.get("base_url"), temperature=0)
elif cfg["provider"] == "google":
from langchain_google_genai import ChatGoogleGenerativeAI
m = ChatGoogleGenerativeAI(model=cfg["model_name"], temperature=0)
elif cfg["provider"] == "groq":
from langchain_groq import ChatGroq
m = ChatGroq(model=cfg["model_name"], temperature=0)
models.append({"name": cfg["name"], "model": m})
return models
@tool
def analyze_image(image_path: str, question: str) -> str:
"""
EXTERNAL SIGHT API: Sends an image path to a Vision Model to answer a specific question.
YOU MUST CALL THIS TOOL ANY TIME an image (.png, .jpg, .jpeg) is attached to the prompt.
NEVER claim you cannot see images. Use this tool instead.
Args:
image_path: The local path or URL to the image file.
question: Specific question describing what you want the vision model to look for.
"""
try:
if not os.path.exists(image_path):
return f"Error: Image file not found at {image_path}"
# If it's a local file, we encode it to base64
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
message = HumanMessage(
content=[
{"type": "text", "text": question},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
]
)
vision_models = get_vision_models()
if not vision_models:
return "Error: No vision models configured (missing API keys)."
last_err = None
for item in vision_models:
try:
m_name = getattr(item['model'], 'model', 'unknown')
print(f"--- Calling Vision Model: {item['name']} ({m_name}) ---")
response = item['model'].invoke([message])
return extract_text_from_content(response.content)
except Exception as e:
print(f"Vision Model {item['name']} failed.")
traceback.print_exc()
last_err = e
return f"Error analyzing image: All vision models failed. Last error: {str(last_err)}"
except Exception as e:
traceback.print_exc()
return f"Error reading/processing image: {str(e)}"
@tool
def analyze_audio(audio_path: str, question: str) -> str:
"""
Transcribes an audio file (.mp3, .wav, .m4a) to answer questions about what is spoken.
Args:
audio_path: The local path to the audio file.
question: The specific question to ask.
"""
try:
model = get_whisper()
result = model.transcribe(audio_path)
transcript = result["text"]
return f"Audio Transcript:\n{transcript}"
except Exception as e:
return f"Error analyzing audio: {str(e)}. Tip: You requires 'ffmpeg' installed on your system."
@tool
def analyze_video(video_path: str, question: str) -> str:
"""
EXTERNAL SIGHT/HEARING API: Sends a video file to an external Vision/Audio model.
YOU MUST CALL THIS TOOL ANY TIME a video (.mp4, .avi) is attached to the prompt.
NEVER claim you cannot analyze videos. Use this tool instead.
Args:
video_path: The local path to the video file.
question: Specific question describing what you want to extract from the video.
"""
if cv2 is None:
return "Error: cv2 is not installed. Please install opencv-python."
temp_dir = tempfile.gettempdir()
downloaded_video = None
try:
# Check if video_path is a URL
if video_path.startswith("http"):
print(f"Downloading video from URL: {video_path}")
downloaded_video = os.path.join(temp_dir, f"video_{int(time.time())}.mp4")
try:
# Use yt-dlp to download the video
# Note: --ffmpeg-location could be used if we knew where it was, but we assume it's in path or missing
subprocess.run(["yt-dlp", "-f", "best[ext=mp4]/mp4", "-o", downloaded_video, video_path], check=True, timeout=120)
video_path = downloaded_video
except Exception as e:
return f"Error downloading video from URL: {str(e)}. Tip: Check if yt-dlp is installed and the URL is valid."
# 1. Extract frames evenly spaced throughout the video
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
return "Error: Could not read video frames."
# Take 5 frames as a summary
frame_indices = [int(i * total_frames / 5) for i in range(5)]
extracted_descriptions = []
vision_models = get_vision_models()
# Ensure Groq-Llama is at the front for video if preferred, but we'll use the default order for now.
for idx_num, frame_idx in enumerate(frame_indices):
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if ret:
# Convert frame to base64
_, buffer = cv2.imencode('.jpg', frame)
encoded_image = base64.b64encode(buffer).decode('utf-8')
# Ask a vision model to describe the frame (with fallback)
msg = HumanMessage(
content=[
{"type": "text", "text": f"Describe what is happening in this video frame concisely. Focus on aspects related to: {question}"},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
]
)
desc = "No description available."
for item in vision_models:
try:
print(f"--- Calling Vision Model for Frame {idx_num+1}: {item['name']} ---")
desc = item['model'].invoke([msg]).content
break
except Exception as e:
print(f"Vision Model {item['name']} failed for frame: {e}")
continue
extracted_descriptions.append(f"Frame {idx_num + 1}: {desc}")
cap.release()
# 2. Compile the context for the agent
video_context = "\n".join(extracted_descriptions)
# 3. Transcribe audio if possible
try:
whisper_mod = get_whisper()
trans_result = whisper_mod.transcribe(video_path)
transcript = trans_result.get("text", "")
if transcript.strip():
video_context += f"\n\nVideo Audio Transcript:\n{transcript}"
except Exception as e:
video_context += f"\n\n(No audio transcript generated: {e})"
return f"Video Summary based on extracted frames and audio:\n{video_context}"
except Exception as e:
err_msg = str(e)
if "No address associated with hostname" in err_msg or "Failed to resolve" in err_msg:
return f"Error: The environment cannot access the internet (DNS failure). Please use 'web_search' or 'wiki_search' to find information about this video content instead of trying to download it."
return f"Error analyzing video: {err_msg}"
finally:
if downloaded_video and os.path.exists(downloaded_video):
try:
os.remove(downloaded_video)
except:
pass
@tool
def read_url(url: str) -> str:
"""
Reads and extracts text from a specific webpage URL.
Use this if a web search snippet doesn't contain enough detail.
"""
try:
loader = WebBaseLoader(url)
docs = loader.load()
# Truncate to first 15000 characters to fit context
if not docs:
return "No content could be extracted from this URL."
return docs[0].page_content[:15000]
except Exception as e:
return f"Error reading URL: {e}"
@tool
def run_python_script(code: str) -> str:
"""
Executes a Python script locally and returns the stdout and stderr.
Use this to perform complex math, data analysis (e.g. pandas), or file processing.
When given a file path, you can write python code to read and analyze it.
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write(code)
temp_file_name = f.name
try:
result = subprocess.run(
["python", temp_file_name],
capture_output=True,
text=True,
timeout=60
)
os.remove(temp_file_name)
output = result.stdout
if result.stderr:
output += f"\nErrors:\n{result.stderr}"
return (output or "Script executed successfully with no output.")[:15000]
except subprocess.TimeoutExpired:
os.remove(temp_file_name)
return "Script execution timed out after 60 seconds."
except Exception as e:
if os.path.exists(temp_file_name):
os.remove(temp_file_name)
return f"Failed to execute script: {str(e)}"
@tool
def read_document(file_path: str) -> str:
"""
Reads the text contents of a local document (.txt, .csv, .json, .md).
For binary files like .xlsx or .pdf, use run_python_script to process them instead.
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
if len(content) > 15000:
return content[:15000] + "... (truncated)"
return content
except Exception as e:
return f"Error reading document: {str(e)}. Tip: You can try running a python script to read it!"
system_prompt = """
You are a helpful assistant tasked with answering questions using a set of tools.
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
"""
class AgentState(TypedDict):
messages: List[Union[HumanMessage, AIMessage, SystemMessage]]
def read_message(state: AgentState) -> AgentState:
messages = state["messages"]
print(f"Processing question: {messages[-1].content if messages else ''}")
# Just pass the messages through to the next node
return {"messages": messages}
def restart_required(state: AgentState) -> AgentState:
messages = state["messages"]
print(f"Processing question: {messages[-1].content if messages else ''}")
# Just pass the messages through to the next node
return {"messages": messages}
# def tool_message(state: AgentState) -> AgentState:
# messages = state["messages"]
# prompt = f"""
# You are a GAIA question answering expert.
# Your task is to decide whether to use a tool or not.
# If you need to use a tool, answer ONLY:
# CALL_TOOL: <your tool name>
# If you do not need to use a tool, answer ONLY:
# NO_TOOL
# Here is the question:
# {messages}
# """
# return {"messages": messages}
# response = model_with_tools.invoke(prompt)
# return {"messages": messages + [response]}
# Augment the LLM with tools
tools = [web_search, wiki_search, analyze_image, analyze_audio, analyze_video, read_url, run_python_script, read_document]
tools_by_name = {tool.name: tool for tool in tools}
def extract_text_from_content(content: Any) -> str:
"""Extracts a simple string from various possible AIMessage content formats."""
if isinstance(content, str):
return content
if isinstance(content, list):
text_parts = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict) and "text" in part:
text_parts.append(part["text"])
elif isinstance(part, dict) and "type" in part and part["type"] == "text":
text_parts.append(part.get("text", ""))
return "".join(text_parts)
return str(content)
def answer_message(state: AgentState) -> AgentState:
messages = state["messages"]
current_date = datetime.datetime.now().strftime("%Y-%m-%d")
prompt = [SystemMessage(f"""
You are a master of the GAIA benchmark, a general AI assistant designed to solve complex multi-step tasks.
Think carefully and logically. Use your tools effectively. Use your internal monologue to plan your steps.
TODAY'S EXACT DATE is {current_date}. Keep this in mind for all time-sensitive queries.
CRITICAL RULES:
1. If you see a path like `[Attached File Local Path: ...]` followed by an image, video, or audio file, YOU MUST USE THE CORRESPONDING TOOL (analyze_image, analyze_video, analyze_audio) IMMEDIATELY in your next step.
2. Plan your steps ahead. 12 steps is your LIMIT for the reasoning loop, so make every step count.
3. If a tool fails (e.g., 429 or 402), the system will automatically try another model for you, so just keep going!
4. Be concise and accurate. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list.
5. CHAIN-OF-THOUGHT: For complex questions, show your reasoning step by step before giving the final answer.
6. USE TOOLS AGGRESSIVELY: If a question requires computation, file reading, or web search, use the appropriate tools - don't try to answer from memory.
7. VERIFY YOUR ANSWER: Double-check calculations and facts using tools when uncertain.
""")]
messages = prompt + messages
# Force tool usage if image path is detected
for msg in state["messages"]:
if isinstance(msg, HumanMessage) and "[Attached File Local Path:" in msg.content:
messages.append(HumanMessage(content="IMPORTANT: I see an image path in the message. I MUST call the analyze_image tool IMMEDIATELY in my next step to see it."))
# Multi-step ReAct Loop (Up to 12 reasoning steps)
max_steps = 12
draft_response = None
current_tier = 0
for step in range(max_steps):
if step > 0:
time.sleep(3)
print(f"--- ReAct Step {step + 1} ---")
# Max history truncation to avoid 413 Request Too Large errors
safe_messages = messages[:2] + messages[-6:] if len(messages) > 10 else messages
ai_msg, current_tier = smart_invoke(safe_messages, use_tools=True, start_tier=current_tier)
messages.append(ai_msg)
# Check if the model requested tools
tool_calls = getattr(ai_msg, "tool_calls", None) or []
if not tool_calls:
# Model decided it has enough info to answer
draft_response = ai_msg
print(f"Model found answer or stopped tools: {ai_msg.content}")
break
# Execute requested tools and append their text output into the conversation
for tool_call in tool_calls:
name = tool_call["name"]
args = tool_call["args"]
tool_call_id = tool_call.get("id")
print(f"Calling tool: {name} with args: {args}")
try:
tool = tools_by_name[name]
tool_result = tool.invoke(args)
except Exception as e:
tool_result = f"Error executing tool {name}: {str(e)}"
# Using ToolMessage allows the model to map the result back perfectly to its request
messages.append(ToolMessage(content=str(tool_result), tool_call_id=tool_call_id, name=name))
# If we exhausted all steps without an answer, force a draft response
if draft_response is None:
print("Max reasoning steps reached. Forcing answer extraction.")
forced_msg = HumanMessage(content="You have reached the maximum reasoning steps. Please provide your best final answer based on the current context without any more tool calls.")
messages.append(forced_msg)
draft_response, _ = smart_invoke(messages, use_tools=False)
# Third pass: strict GAIA formatting extraction
formatting_sys = SystemMessage(
content=(
"You are a strict output formatter for the GAIA benchmark. "
"Given a verbose draft answer, extract ONLY the final exact answer required. "
"Return nothing else. DO NOT include prefixes like 'The answer is'. "
"Strip trailing whitespace only. "
"If the answer is a number, just return the number. "
"If the answer is a list or set of elements, return them as a COMMA-SEPARATED list (e.g., 'a, b, c'). "
"Preserve necessary punctuation within answers (e.g., 'Dr. Smith' should keep the period)."
)
)
final_response, _ = smart_invoke([formatting_sys, HumanMessage(content=extract_text_from_content(draft_response.content))], use_tools=False, start_tier=current_tier)
print(f"Draft response: {draft_response.content}")
print(f"Strict Final response: {final_response.content}")
# Return messages including the final AIMessage so BasicAgent reads .content
# Ensure final_response has string content for basic agents
if not isinstance(final_response.content, str):
final_response.content = extract_text_from_content(final_response.content)
messages.append(draft_response)
messages.append(final_response)
return {"messages": messages}
def build_graph():
agent_graph = StateGraph(AgentState)
# Add nodes
agent_graph.add_node("read_message", read_message)
agent_graph.add_node("answer_message", answer_message)
# Add edges
agent_graph.add_edge(START, "read_message")
agent_graph.add_edge("read_message", "answer_message")
# Final edge
agent_graph.add_edge("answer_message", END)
# Compile and return the executable graph for use in app.py
compiled_graph = agent_graph.compile()
return compiled_graph