File size: 6,362 Bytes
03f4295
f37e95b
 
 
 
 
7cdcb1a
692b974
7cdcb1a
 
692b974
f37e95b
7cdcb1a
 
 
 
 
 
 
f37e95b
03f4295
7cdcb1a
 
03f4295
 
 
 
7cdcb1a
03f4295
 
2665628
 
7cdcb1a
692b974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03f4295
 
692b974
 
 
 
 
 
 
f37e95b
03f4295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cdcb1a
 
 
 
f37e95b
7cdcb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03f4295
f37e95b
 
 
 
 
f439125
f37e95b
 
692b974
7cdcb1a
 
03f4295
7cdcb1a
f37e95b
692b974
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
from typing import Literal
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import tools_condition
from agent.nodes import call_model, tool_node
from langgraph.graph import MessagesState
from langchain_core.messages import AIMessage, HumanMessage, AIMessageChunk
from langgraph.checkpoint.memory import InMemorySaver
from agent.config import create_agent_config
from termcolor import colored, cprint

class OracleBot:
    def __init__(self):
        print("Initializing OracleBot")
        self.name = "OracleBot"
        self.thread_id = 1 #TODO fix
        self.config = create_agent_config(self.name, self.thread_id)
        self.graph = self._build_agent(self.name)

    def answer_question(self, question: str, file_path: str | None = None):
        """
        Answer a question using the LangGraph agent.
        
        Args:
            question: The question to answer
            file_path: Optional path to a file associated with this question
        """
        # Enhance question with file context if available
        if file_path and os.path.exists(file_path):
            question = f"{question}\n\nNote: There is an associated file named {os.path.basename(file_path)}\nYou can use the file management tools to read and analyze this file."

        messages = [HumanMessage(content=question)]

        for mode, chunk in self.graph.stream({"messages": messages}, config=self.config, stream_mode=["messages", "updates"]): # type: ignore
            if mode == "messages":
                if isinstance(chunk, tuple) and len(chunk) > 0:
                    message = chunk[0]
                    if isinstance(message, (AIMessageChunk, AIMessage)):
                        # Only print chunks that have actual content (skip tool call chunks)
                        if hasattr(message, 'content') and message.content and not (hasattr(message, 'tool_calls') and message.tool_calls):
                            cprint(message.content, color="light_grey", attrs=["dark"], end="", flush=True)
                # Handle case where chunk is directly the message
                elif isinstance(chunk, (AIMessageChunk, AIMessage)):
                    # Only print chunks that have actual content (skip tool call chunks)
                    if hasattr(chunk, 'content') and chunk.content and not (hasattr(chunk, 'tool_calls') and chunk.tool_calls):
                        cprint(chunk.content, color="light_grey", attrs=["dark"], end="", flush=True)
            elif mode == "updates":
                # Look for complete tool calls in updates
                if isinstance(chunk, dict) and 'agent' in chunk:
                    agent_update = chunk['agent']
                    if 'messages' in agent_update and agent_update['messages']:
                        for message in agent_update['messages']:
                            if hasattr(message, 'tool_calls') and message.tool_calls:
                                for tool_call in message.tool_calls:
                                    cprint(f"\n🔧 Using tool: {tool_call['name']} with args: {tool_call['args']}\n", color="yellow")
                            # Handle final answer messages (no tool calls)
                            elif hasattr(message, 'content') and message.content:
                                cprint(f"\n{message.content}\n", color="black", on_color="on_white", attrs=["bold"])
                                return message.content # Return final answer

                # Look for tool outputs in updates
                elif isinstance(chunk, dict) and 'tools' in chunk:
                    tools_update = chunk['tools']
                    if 'messages' in tools_update and tools_update['messages']:
                        for message in tools_update['messages']:
                            if hasattr(message, 'content') and message.content:
                                cprint(f"\n📤 Tool output:\n{message.content}\n", color="green")

    async def answer_question_async(self, question: str, file_path: str | None = None) -> str:
        """
        Answer a question using the LangGraph agent asynchronously.
        
        Args:
            question: The question to answer
            file_path: Optional path to a file associated with this question
            
        Returns the final answer as a string.
        """
        from langchain_core.runnables import RunnableConfig
        from typing import cast
        
        # Enhance question with file context if available
        if file_path and os.path.exists(file_path):
            question = f"{question}\n\nNote: There is an associated file at: {file_path}\nYou can use the file management tools to read and analyze this file."
        
        messages = [HumanMessage(content=question)]

        # Use LangGraph's built-in ainvoke method
        result = await self.graph.ainvoke({"messages": messages}, config=cast(RunnableConfig, self.config)) # type: ignore
        
        # Extract the content from the last message
        if "messages" in result and result["messages"]:
            last_message = result["messages"][-1]
            if hasattr(last_message, 'content'):
                return last_message.content or ""
        
        return ""

    def _build_agent(self, name: str):
        """
        Get our LangGraph agent with the given model and tools.
        """
    
        class GraphConfig(TypedDict):
            name: str;
            thread_id: int;

        graph = StateGraph(state_schema=MessagesState, context_schema=GraphConfig)

        # Add nodes
        graph.add_node("agent", call_model)
        graph.add_node("tools", tool_node)

        # Add edges
        graph.add_edge(START, "agent")
        graph.add_conditional_edges("agent", tools_condition)
        graph.add_edge("tools", "agent")

        return graph.compile()

# test
if __name__ == "__main__":
    import os
    
    question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
    
    try:
        from config import start_phoenix
        start_phoenix()
        bot = OracleBot()
        bot.answer_question(question, None)

    except Exception as e:
        print(f"Error running agent: {e}")