Spaces:
Runtime error
Runtime error
abtsousa
Enhance file handling tools: save_file, read_file, analyze_csv, and extract_text_from_image functions
2665628
| import os | |
| from typing import Literal | |
| from typing_extensions import TypedDict | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.prebuilt import tools_condition | |
| from agent.nodes import call_model, tool_node | |
| from langgraph.graph import MessagesState | |
| from langchain_core.messages import AIMessage, HumanMessage, AIMessageChunk | |
| from langgraph.checkpoint.memory import InMemorySaver | |
| from agent.config import create_agent_config | |
| from termcolor import colored, cprint | |
| class OracleBot: | |
| def __init__(self): | |
| print("Initializing OracleBot") | |
| self.name = "OracleBot" | |
| self.thread_id = 1 #TODO fix | |
| self.config = create_agent_config(self.name, self.thread_id) | |
| self.graph = self._build_agent(self.name) | |
| def answer_question(self, question: str, file_path: str | None = None): | |
| """ | |
| Answer a question using the LangGraph agent. | |
| Args: | |
| question: The question to answer | |
| file_path: Optional path to a file associated with this question | |
| """ | |
| # Enhance question with file context if available | |
| if file_path and os.path.exists(file_path): | |
| question = f"{question}\n\nNote: There is an associated file named {os.path.basename(file_path)}\nYou can use the file management tools to read and analyze this file." | |
| messages = [HumanMessage(content=question)] | |
| for mode, chunk in self.graph.stream({"messages": messages}, config=self.config, stream_mode=["messages", "updates"]): # type: ignore | |
| if mode == "messages": | |
| if isinstance(chunk, tuple) and len(chunk) > 0: | |
| message = chunk[0] | |
| if isinstance(message, (AIMessageChunk, AIMessage)): | |
| # Only print chunks that have actual content (skip tool call chunks) | |
| if hasattr(message, 'content') and message.content and not (hasattr(message, 'tool_calls') and message.tool_calls): | |
| cprint(message.content, color="light_grey", attrs=["dark"], end="", flush=True) | |
| # Handle case where chunk is directly the message | |
| elif isinstance(chunk, (AIMessageChunk, AIMessage)): | |
| # Only print chunks that have actual content (skip tool call chunks) | |
| if hasattr(chunk, 'content') and chunk.content and not (hasattr(chunk, 'tool_calls') and chunk.tool_calls): | |
| cprint(chunk.content, color="light_grey", attrs=["dark"], end="", flush=True) | |
| elif mode == "updates": | |
| # Look for complete tool calls in updates | |
| if isinstance(chunk, dict) and 'agent' in chunk: | |
| agent_update = chunk['agent'] | |
| if 'messages' in agent_update and agent_update['messages']: | |
| for message in agent_update['messages']: | |
| if hasattr(message, 'tool_calls') and message.tool_calls: | |
| for tool_call in message.tool_calls: | |
| cprint(f"\n🔧 Using tool: {tool_call['name']} with args: {tool_call['args']}\n", color="yellow") | |
| # Handle final answer messages (no tool calls) | |
| elif hasattr(message, 'content') and message.content: | |
| cprint(f"\n{message.content}\n", color="black", on_color="on_white", attrs=["bold"]) | |
| return message.content # Return final answer | |
| # Look for tool outputs in updates | |
| elif isinstance(chunk, dict) and 'tools' in chunk: | |
| tools_update = chunk['tools'] | |
| if 'messages' in tools_update and tools_update['messages']: | |
| for message in tools_update['messages']: | |
| if hasattr(message, 'content') and message.content: | |
| cprint(f"\n📤 Tool output:\n{message.content}\n", color="green") | |
| async def answer_question_async(self, question: str, file_path: str | None = None) -> str: | |
| """ | |
| Answer a question using the LangGraph agent asynchronously. | |
| Args: | |
| question: The question to answer | |
| file_path: Optional path to a file associated with this question | |
| Returns the final answer as a string. | |
| """ | |
| from langchain_core.runnables import RunnableConfig | |
| from typing import cast | |
| # Enhance question with file context if available | |
| if file_path and os.path.exists(file_path): | |
| question = f"{question}\n\nNote: There is an associated file at: {file_path}\nYou can use the file management tools to read and analyze this file." | |
| messages = [HumanMessage(content=question)] | |
| # Use LangGraph's built-in ainvoke method | |
| result = await self.graph.ainvoke({"messages": messages}, config=cast(RunnableConfig, self.config)) # type: ignore | |
| # Extract the content from the last message | |
| if "messages" in result and result["messages"]: | |
| last_message = result["messages"][-1] | |
| if hasattr(last_message, 'content'): | |
| return last_message.content or "" | |
| return "" | |
| def _build_agent(self, name: str): | |
| """ | |
| Get our LangGraph agent with the given model and tools. | |
| """ | |
| class GraphConfig(TypedDict): | |
| name: str; | |
| thread_id: int; | |
| graph = StateGraph(state_schema=MessagesState, context_schema=GraphConfig) | |
| # Add nodes | |
| graph.add_node("agent", call_model) | |
| graph.add_node("tools", tool_node) | |
| # Add edges | |
| graph.add_edge(START, "agent") | |
| graph.add_conditional_edges("agent", tools_condition) | |
| graph.add_edge("tools", "agent") | |
| return graph.compile() | |
| # test | |
| if __name__ == "__main__": | |
| import os | |
| question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia." | |
| try: | |
| from config import start_phoenix | |
| start_phoenix() | |
| bot = OracleBot() | |
| bot.answer_question(question, None) | |
| except Exception as e: | |
| print(f"Error running agent: {e}") |