Spaces:
Runtime error
Runtime error
File size: 6,362 Bytes
03f4295 f37e95b 7cdcb1a 692b974 7cdcb1a 692b974 f37e95b 7cdcb1a f37e95b 03f4295 7cdcb1a 03f4295 7cdcb1a 03f4295 2665628 7cdcb1a 692b974 03f4295 692b974 f37e95b 03f4295 7cdcb1a f37e95b 7cdcb1a 03f4295 f37e95b f439125 f37e95b 692b974 7cdcb1a 03f4295 7cdcb1a f37e95b 692b974 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
from typing import Literal
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import tools_condition
from agent.nodes import call_model, tool_node
from langgraph.graph import MessagesState
from langchain_core.messages import AIMessage, HumanMessage, AIMessageChunk
from langgraph.checkpoint.memory import InMemorySaver
from agent.config import create_agent_config
from termcolor import colored, cprint
class OracleBot:
def __init__(self):
print("Initializing OracleBot")
self.name = "OracleBot"
self.thread_id = 1 #TODO fix
self.config = create_agent_config(self.name, self.thread_id)
self.graph = self._build_agent(self.name)
def answer_question(self, question: str, file_path: str | None = None):
"""
Answer a question using the LangGraph agent.
Args:
question: The question to answer
file_path: Optional path to a file associated with this question
"""
# Enhance question with file context if available
if file_path and os.path.exists(file_path):
question = f"{question}\n\nNote: There is an associated file named {os.path.basename(file_path)}\nYou can use the file management tools to read and analyze this file."
messages = [HumanMessage(content=question)]
for mode, chunk in self.graph.stream({"messages": messages}, config=self.config, stream_mode=["messages", "updates"]): # type: ignore
if mode == "messages":
if isinstance(chunk, tuple) and len(chunk) > 0:
message = chunk[0]
if isinstance(message, (AIMessageChunk, AIMessage)):
# Only print chunks that have actual content (skip tool call chunks)
if hasattr(message, 'content') and message.content and not (hasattr(message, 'tool_calls') and message.tool_calls):
cprint(message.content, color="light_grey", attrs=["dark"], end="", flush=True)
# Handle case where chunk is directly the message
elif isinstance(chunk, (AIMessageChunk, AIMessage)):
# Only print chunks that have actual content (skip tool call chunks)
if hasattr(chunk, 'content') and chunk.content and not (hasattr(chunk, 'tool_calls') and chunk.tool_calls):
cprint(chunk.content, color="light_grey", attrs=["dark"], end="", flush=True)
elif mode == "updates":
# Look for complete tool calls in updates
if isinstance(chunk, dict) and 'agent' in chunk:
agent_update = chunk['agent']
if 'messages' in agent_update and agent_update['messages']:
for message in agent_update['messages']:
if hasattr(message, 'tool_calls') and message.tool_calls:
for tool_call in message.tool_calls:
cprint(f"\n🔧 Using tool: {tool_call['name']} with args: {tool_call['args']}\n", color="yellow")
# Handle final answer messages (no tool calls)
elif hasattr(message, 'content') and message.content:
cprint(f"\n{message.content}\n", color="black", on_color="on_white", attrs=["bold"])
return message.content # Return final answer
# Look for tool outputs in updates
elif isinstance(chunk, dict) and 'tools' in chunk:
tools_update = chunk['tools']
if 'messages' in tools_update and tools_update['messages']:
for message in tools_update['messages']:
if hasattr(message, 'content') and message.content:
cprint(f"\n📤 Tool output:\n{message.content}\n", color="green")
async def answer_question_async(self, question: str, file_path: str | None = None) -> str:
"""
Answer a question using the LangGraph agent asynchronously.
Args:
question: The question to answer
file_path: Optional path to a file associated with this question
Returns the final answer as a string.
"""
from langchain_core.runnables import RunnableConfig
from typing import cast
# Enhance question with file context if available
if file_path and os.path.exists(file_path):
question = f"{question}\n\nNote: There is an associated file at: {file_path}\nYou can use the file management tools to read and analyze this file."
messages = [HumanMessage(content=question)]
# Use LangGraph's built-in ainvoke method
result = await self.graph.ainvoke({"messages": messages}, config=cast(RunnableConfig, self.config)) # type: ignore
# Extract the content from the last message
if "messages" in result and result["messages"]:
last_message = result["messages"][-1]
if hasattr(last_message, 'content'):
return last_message.content or ""
return ""
def _build_agent(self, name: str):
"""
Get our LangGraph agent with the given model and tools.
"""
class GraphConfig(TypedDict):
name: str;
thread_id: int;
graph = StateGraph(state_schema=MessagesState, context_schema=GraphConfig)
# Add nodes
graph.add_node("agent", call_model)
graph.add_node("tools", tool_node)
# Add edges
graph.add_edge(START, "agent")
graph.add_conditional_edges("agent", tools_condition)
graph.add_edge("tools", "agent")
return graph.compile()
# test
if __name__ == "__main__":
import os
question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
try:
from config import start_phoenix
start_phoenix()
bot = OracleBot()
bot.answer_question(question, None)
except Exception as e:
print(f"Error running agent: {e}") |