File size: 3,581 Bytes
686a009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

from utils.model_loader import ModelLoader
from prompt_library.prompt import SYSTEM_PROMPT
from langgraph.graph import StateGraph, MessagesState, END, START
from langgraph.graph.message import add_messages # Is reducer function that just aggregrates the messages.
from langgraph.prebuilt import ToolNode, tools_condition
from tools.weather_info_tool import WeatherInfoTool
from tools.place_search_tool import PlaceSearchTool
from tools.expense_calculator_tool import CalculatorTool
from tools.currency_conversion_tool import CurrencyConverterTool
from typing import TypedDict, Sequence, Optional, Annotated
from langchain_core.messages import BaseMessage, HumanMessage
from datetime import datetime

class AgentState(TypedDict):
    # messages is a list (or other sequence) of BaseMessage objects, and it has extra metadata attached to it: add_messages
    ## Sequence is a type hint that represents any ordered, iterable collection of items — like lists, tuples, or strings 
    ## Annotated type lets you attach metadata to a type. It's not used by Python itself, but frameworks (like LangGraph, Pydantic, FastAPI) can use it.
    messages: Annotated[Sequence[BaseMessage], add_messages]

class GraphBuilder():
    def __init__(self,model_provider: str = "openai"):
        self.model_loader = ModelLoader(model_provider=model_provider)
        self.llm = self.model_loader.load_llm()
        
        self.tools = []
        
        self.weather_tools = WeatherInfoTool()
        self.place_search_tools = PlaceSearchTool()
        self.calculator_tools = CalculatorTool()
        self.currency_converter_tools = CurrencyConverterTool()
        
        self.tools.extend([* self.weather_tools.weather_tool_list, 
                           * self.place_search_tools.place_search_tool_list,
                           * self.calculator_tools.calculator_tool_list,
                           * self.currency_converter_tools.currency_converter_tool_list])
        
        self.llm_with_tools = self.llm.bind_tools(tools=self.tools)
        
        self.graph = None
        
        self.system_prompt = SYSTEM_PROMPT
    
    
    def agent_function(self,state: MessagesState):
        """Main agent function"""
        user_question = state["messages"]
        input_question = [self.system_prompt] + user_question
        # with mlflow.start_run(run_name="genai-trip-planner-llm-call", nested=True):
        response = self.llm_with_tools.invoke(input_question)
        return {"messages": [response]}
    
    def should_continue(self, state: AgentState) -> AgentState:
        messages = state['messages']
        last_message = messages[-1]
        if not last_message.tool_calls:
            return "end"
        else:
            return "continue"
    
    def build_graph(self):
        
        graph_builder=StateGraph(MessagesState)
        graph_builder.add_node("agent", self.agent_function)
        graph_builder.add_node("tools", ToolNode(tools=self.tools))
        graph_builder.add_edge(START,"agent")
        # graph_builder.add_conditional_edges("agent",tools_condition)
        graph_builder.add_conditional_edges(
            source="agent",
            path=self.should_continue,
            path_map={
                # Edge: Node
                "end": END,
                "continue": "tools"
            }
        )
        graph_builder.add_edge("tools","agent")
        # graph_builder.add_edge("agent",END)
        self.graph = graph_builder.compile()
        return self.graph
        
    def __call__(self):
        return self.build_graph()