File size: 1,958 Bytes
31616e2
 
 
6065fb1
31616e2
 
d55b613
6065fb1
 
31616e2
 
6065fb1
 
31616e2
 
 
 
 
6065fb1
31616e2
6065fb1
 
d55b613
31616e2
6065fb1
 
d55b613
31616e2
 
6065fb1
31616e2
d55b613
31616e2
 
 
 
 
6065fb1
31616e2
 
 
 
 
6065fb1
31616e2
 
 
 
d55b613
 
31616e2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from typing import TypedDict, Literal, Annotated
from langgraph.graph import StateGraph, START, END, add_messages
from langgraph.prebuilt import ToolNode
from langchain_ollama import ChatOllama
from src.tools import TOOLMAP
from src.prompts import REASON
from pathlib import Path

class AgentSchema(TypedDict):
  messages: Annotated[list, add_messages]
  niter: int

class ReActAgent():
  def __init__(self, modelid: str, verbose: bool = False, maxreason: int = 5):
    self.verbose, self.maxreason, self.brain = (
      verbose, maxreason, 
      ChatOllama(model=modelid, temperature=0.2, validate_model_on_init=True, reasoning=True).bind_tools(list(TOOLMAP.values()))
    )
    workflow = StateGraph(AgentSchema)
    toolnode = ToolNode(list(TOOLMAP.values()))

    # nodes #
    workflow.add_node("reason", self.reason)
    workflow.add_node("toolnode", toolnode)

    # edges #
    workflow.add_edge(START, "reason")
    workflow.add_conditional_edges("reason", self.next_step)
    workflow.add_edge("toolnode", "reason")

    # compile #
    self.workflow = workflow.compile()
    imagepath = Path("assets", "agent_graph.png")
    imagepath.parent.mkdir(exist_ok=True)
    if imagepath.exists(): imagepath.unlink()
    with open(imagepath, "wb") as imw: imw.write(self.workflow.get_graph().draw_mermaid_png())
    return
  
  def __call__(self, query: str): return self.workflow.invoke({
    "messages": [{'role': 'system', 'content': REASON}, {'role': 'user', 'content': query}],
    "niter": 0
  })

  # nodes #
  def reason(self, state: AgentSchema) -> AgentSchema: 
    response = self.brain.invoke(state.get("messages", []))
    if self.verbose: print(f"Reasoning:\n{response.additional_kwargs["reasoning_content"]}\n\nResponse:{response}")
    return {"messages": [response], "niter": int(state.get("niter"))+1}

  # edges #
  def next_step(self, state: AgentSchema): 
    if state.get("messages")[-1].tool_calls: return "toolnode"
    return END