Spaces:
Runtime error
Runtime error
abtsousa
commited on
Commit
·
692b974
1
Parent(s):
2b9cce2
Enhance OracleBot: improve message handling and output formatting in answer_question method
Browse files- agent/agent.py +38 -6
agent/agent.py
CHANGED
|
@@ -4,9 +4,10 @@ from langgraph.graph import StateGraph, START, END
|
|
| 4 |
from langgraph.prebuilt import tools_condition
|
| 5 |
from agent.nodes import call_model, tool_node
|
| 6 |
from langgraph.graph import MessagesState
|
| 7 |
-
from langchain_core.messages import AIMessage, HumanMessage
|
| 8 |
from langgraph.checkpoint.memory import InMemorySaver
|
| 9 |
from agent.config import create_agent_config
|
|
|
|
| 10 |
|
| 11 |
class OracleBot:
|
| 12 |
def __init__(self):
|
|
@@ -21,8 +22,40 @@ class OracleBot:
|
|
| 21 |
Answer a question using the LangGraph agent.
|
| 22 |
"""
|
| 23 |
messages = [HumanMessage(content=question)]
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def _build_agent(self, name: str):
|
| 28 |
"""
|
|
@@ -56,11 +89,10 @@ if __name__ == "__main__":
|
|
| 56 |
question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
|
| 57 |
|
| 58 |
try:
|
| 59 |
-
from
|
| 60 |
start_phoenix()
|
| 61 |
bot = OracleBot()
|
| 62 |
bot.answer_question(question)
|
| 63 |
|
| 64 |
except Exception as e:
|
| 65 |
-
print(f"Error running agent: {e}")
|
| 66 |
-
print("Make sure your API key is correct and the service is accessible.")
|
|
|
|
| 4 |
from langgraph.prebuilt import tools_condition
|
| 5 |
from agent.nodes import call_model, tool_node
|
| 6 |
from langgraph.graph import MessagesState
|
| 7 |
+
from langchain_core.messages import AIMessage, HumanMessage, AIMessageChunk
|
| 8 |
from langgraph.checkpoint.memory import InMemorySaver
|
| 9 |
from agent.config import create_agent_config
|
| 10 |
+
from termcolor import colored, cprint
|
| 11 |
|
| 12 |
class OracleBot:
|
| 13 |
def __init__(self):
|
|
|
|
| 22 |
Answer a question using the LangGraph agent.
|
| 23 |
"""
|
| 24 |
messages = [HumanMessage(content=question)]
|
| 25 |
+
|
| 26 |
+
for mode, chunk in self.graph.stream({"messages": messages}, config=self.config, stream_mode=["messages", "updates"]): # type: ignore
|
| 27 |
+
if mode == "messages":
|
| 28 |
+
if isinstance(chunk, tuple) and len(chunk) > 0:
|
| 29 |
+
message = chunk[0]
|
| 30 |
+
if isinstance(message, (AIMessageChunk, AIMessage)):
|
| 31 |
+
# Only print chunks that have actual content (skip tool call chunks)
|
| 32 |
+
if hasattr(message, 'content') and message.content and not (hasattr(message, 'tool_calls') and message.tool_calls):
|
| 33 |
+
cprint(message.content, color="light_grey", attrs=["dark"], end="", flush=True)
|
| 34 |
+
# Handle case where chunk is directly the message
|
| 35 |
+
elif isinstance(chunk, (AIMessageChunk, AIMessage)):
|
| 36 |
+
# Only print chunks that have actual content (skip tool call chunks)
|
| 37 |
+
if hasattr(chunk, 'content') and chunk.content and not (hasattr(chunk, 'tool_calls') and chunk.tool_calls):
|
| 38 |
+
cprint(chunk.content, color="light_grey", attrs=["dark"], end="", flush=True)
|
| 39 |
+
elif mode == "updates":
|
| 40 |
+
# Look for complete tool calls in updates
|
| 41 |
+
if isinstance(chunk, dict) and 'agent' in chunk:
|
| 42 |
+
agent_update = chunk['agent']
|
| 43 |
+
if 'messages' in agent_update and agent_update['messages']:
|
| 44 |
+
for message in agent_update['messages']:
|
| 45 |
+
if hasattr(message, 'tool_calls') and message.tool_calls:
|
| 46 |
+
for tool_call in message.tool_calls:
|
| 47 |
+
cprint(f"\n🔧 Using tool: {tool_call['name']} with args: {tool_call['args']}\n", color="yellow")
|
| 48 |
+
# Handle final answer messages (no tool calls)
|
| 49 |
+
elif hasattr(message, 'content') and message.content:
|
| 50 |
+
cprint(f"\n{message.content}\n", color="black", on_color="on_white", attrs=["bold"])
|
| 51 |
+
|
| 52 |
+
# Look for tool outputs in updates
|
| 53 |
+
elif isinstance(chunk, dict) and 'tools' in chunk:
|
| 54 |
+
tools_update = chunk['tools']
|
| 55 |
+
if 'messages' in tools_update and tools_update['messages']:
|
| 56 |
+
for message in tools_update['messages']:
|
| 57 |
+
if hasattr(message, 'content') and message.content:
|
| 58 |
+
cprint(f"\n📤 Tool output:\n{message.content}\n", color="green")
|
| 59 |
|
| 60 |
def _build_agent(self, name: str):
|
| 61 |
"""
|
|
|
|
| 89 |
question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
|
| 90 |
|
| 91 |
try:
|
| 92 |
+
from config import start_phoenix
|
| 93 |
start_phoenix()
|
| 94 |
bot = OracleBot()
|
| 95 |
bot.answer_question(question)
|
| 96 |
|
| 97 |
except Exception as e:
|
| 98 |
+
print(f"Error running agent: {e}")
|
|
|