abtsousa
Enhance file handling tools: save_file, read_file, analyze_csv, and extract_text_from_image functions
2665628
import os
from typing import Literal
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import tools_condition
from agent.nodes import call_model, tool_node
from langgraph.graph import MessagesState
from langchain_core.messages import AIMessage, HumanMessage, AIMessageChunk
from langgraph.checkpoint.memory import InMemorySaver
from agent.config import create_agent_config
from termcolor import colored, cprint
class OracleBot:
def __init__(self):
print("Initializing OracleBot")
self.name = "OracleBot"
self.thread_id = 1 #TODO fix
self.config = create_agent_config(self.name, self.thread_id)
self.graph = self._build_agent(self.name)
def answer_question(self, question: str, file_path: str | None = None):
"""
Answer a question using the LangGraph agent.
Args:
question: The question to answer
file_path: Optional path to a file associated with this question
"""
# Enhance question with file context if available
if file_path and os.path.exists(file_path):
question = f"{question}\n\nNote: There is an associated file named {os.path.basename(file_path)}\nYou can use the file management tools to read and analyze this file."
messages = [HumanMessage(content=question)]
for mode, chunk in self.graph.stream({"messages": messages}, config=self.config, stream_mode=["messages", "updates"]): # type: ignore
if mode == "messages":
if isinstance(chunk, tuple) and len(chunk) > 0:
message = chunk[0]
if isinstance(message, (AIMessageChunk, AIMessage)):
# Only print chunks that have actual content (skip tool call chunks)
if hasattr(message, 'content') and message.content and not (hasattr(message, 'tool_calls') and message.tool_calls):
cprint(message.content, color="light_grey", attrs=["dark"], end="", flush=True)
# Handle case where chunk is directly the message
elif isinstance(chunk, (AIMessageChunk, AIMessage)):
# Only print chunks that have actual content (skip tool call chunks)
if hasattr(chunk, 'content') and chunk.content and not (hasattr(chunk, 'tool_calls') and chunk.tool_calls):
cprint(chunk.content, color="light_grey", attrs=["dark"], end="", flush=True)
elif mode == "updates":
# Look for complete tool calls in updates
if isinstance(chunk, dict) and 'agent' in chunk:
agent_update = chunk['agent']
if 'messages' in agent_update and agent_update['messages']:
for message in agent_update['messages']:
if hasattr(message, 'tool_calls') and message.tool_calls:
for tool_call in message.tool_calls:
cprint(f"\n🔧 Using tool: {tool_call['name']} with args: {tool_call['args']}\n", color="yellow")
# Handle final answer messages (no tool calls)
elif hasattr(message, 'content') and message.content:
cprint(f"\n{message.content}\n", color="black", on_color="on_white", attrs=["bold"])
return message.content # Return final answer
# Look for tool outputs in updates
elif isinstance(chunk, dict) and 'tools' in chunk:
tools_update = chunk['tools']
if 'messages' in tools_update and tools_update['messages']:
for message in tools_update['messages']:
if hasattr(message, 'content') and message.content:
cprint(f"\n📤 Tool output:\n{message.content}\n", color="green")
async def answer_question_async(self, question: str, file_path: str | None = None) -> str:
"""
Answer a question using the LangGraph agent asynchronously.
Args:
question: The question to answer
file_path: Optional path to a file associated with this question
Returns the final answer as a string.
"""
from langchain_core.runnables import RunnableConfig
from typing import cast
# Enhance question with file context if available
if file_path and os.path.exists(file_path):
question = f"{question}\n\nNote: There is an associated file at: {file_path}\nYou can use the file management tools to read and analyze this file."
messages = [HumanMessage(content=question)]
# Use LangGraph's built-in ainvoke method
result = await self.graph.ainvoke({"messages": messages}, config=cast(RunnableConfig, self.config)) # type: ignore
# Extract the content from the last message
if "messages" in result and result["messages"]:
last_message = result["messages"][-1]
if hasattr(last_message, 'content'):
return last_message.content or ""
return ""
def _build_agent(self, name: str):
"""
Get our LangGraph agent with the given model and tools.
"""
class GraphConfig(TypedDict):
name: str;
thread_id: int;
graph = StateGraph(state_schema=MessagesState, context_schema=GraphConfig)
# Add nodes
graph.add_node("agent", call_model)
graph.add_node("tools", tool_node)
# Add edges
graph.add_edge(START, "agent")
graph.add_conditional_edges("agent", tools_condition)
graph.add_edge("tools", "agent")
return graph.compile()
# test
if __name__ == "__main__":
import os
question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
try:
from config import start_phoenix
start_phoenix()
bot = OracleBot()
bot.answer_question(question, None)
except Exception as e:
print(f"Error running agent: {e}")