File size: 4,029 Bytes
62ed5af
6bae69e
 
 
 
 
 
62ed5af
 
 
6bae69e
 
 
da66358
 
62ed5af
 
6bae69e
 
 
 
 
62ed5af
6bae69e
 
 
 
 
 
da66358
 
6bae69e
62ed5af
6bae69e
 
 
 
62ed5af
 
6bae69e
 
4ad672b
6bae69e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ad672b
 
6bae69e
 
 
4ad672b
 
 
 
 
 
6bae69e
 
 
 
 
 
 
 
 
 
 
62ed5af
6bae69e
62ed5af
6bae69e
 
62ed5af
6bae69e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
from typing import Any, List, TypedDict
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_openai import ChatOpenAI
from tools import (
    describe_image_tool,
    parse_excel_tool,
    webpage_extraction_tool,
    brave_web_search,
    python_code_interpreter_tool,
    audio_file_transcriber,
    get_youtube_transcript
)

class AgentState(TypedDict):
    messages: List[Any]
    question: str
    file_path: str
    final_answer: str

tools = [
    describe_image_tool,
    parse_excel_tool,
    webpage_extraction_tool,
    brave_web_search,
    python_code_interpreter_tool,
    audio_file_transcriber,
    get_youtube_transcript
]

rate_limiter = InMemoryRateLimiter(
    requests_per_second=0.1,  # <-- Can only make a request once every 10 seconds!!
    check_every_n_seconds=0.1,  # Wake up every 100 ms to check whether allowed to make a request,
    max_bucket_size=10,  # Controls the maximum burst size.
)

class LangGraphAgent:
    def __init__(self, model_name: str = "gpt-4o",):
        self.llm = ChatOpenAI(model=model_name, max_tokens=2000, temperature=0, rate_limiter=rate_limiter)
        self.llm_with_tools = self.llm.bind_tools(tools)
        self.tool_node = ToolNode(tools)
        self.graph = self.create_graph().compile()

    def create_graph(self) -> StateGraph:
        """Creates a state graph for the agent's workflow."""
        # Define the tools and their respective states
        graph = StateGraph(AgentState)
        graph.add_node("agent", self.agent_node)
        graph.add_node("tools", self.tool_node)

        graph.set_entry_point("agent")

        graph.add_conditional_edges("agent", tools_condition)
        graph.add_edge("tools", "agent") 
         
        return graph

    def agent_node(self, state: AgentState):
        """Creates a node for the agent that uses the model to respond to user queries."""
        messages = state['messages']

        if len(messages) == 1:
            system_prompt = ("You are a helpful assistant that can answer questions using various tools. "
            "You must answer the given question using as few words as possible, or the given format, if any."
            "If the answer is a number, you must return the number only, do not include symbols or commas."
            "If you need to search the web for information and aren't given a URL, always use a search tool before using a wepbage extraction tool so you always have a legit website."
            "If given a Python file, execute it with the code interpreter tool (riza_exec_python)")
            messages = [SystemMessage(system_prompt)] + messages
        
        MAX_HISTORY = 3  # tune as needed

        # Keep system + last N messages
        if len(messages) > MAX_HISTORY:
            messages = [m for m in messages if isinstance(m, SystemMessage)] + messages[-MAX_HISTORY:]

        while True:
            response = self.llm_with_tools.invoke(messages)
            messages.append(response)

            if not response.tool_calls:
                break

            tool_outputs = []
            for call in response.tool_calls:
                tool_output = self.tool_node.invoke({"messages": [response], "tool_call": call})
                tool_outputs.extend(tool_output["messages"])

            messages.extend(tool_outputs)

        state["messages"] = messages
        return {"messages": messages}

    def run(self, question: str) -> str:
        state = AgentState(messages=[HumanMessage(content=question)], question=question, final_answer=None)
        result = self.graph.invoke(state)
        
        final_message = result["messages"][-1]
        if hasattr(final_message, 'content'):
            result['final_answer'] = final_message.content
        else:
            result['final_answer'] = str(final_message)
        return result['final_answer']