Spaces:
Building
Building
| from typing import TypedDict, Literal, Annotated | |
| from langgraph.graph import StateGraph, START, END, add_messages | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_ollama import ChatOllama | |
| from src.tools import TOOLMAP | |
| from src.prompts import REASON | |
| from pathlib import Path | |
| class AgentSchema(TypedDict): | |
| messages: Annotated[list, add_messages] | |
| niter: int | |
| class ReActAgent(): | |
| def __init__(self, modelid: str, verbose: bool = False, maxreason: int = 5): | |
| self.verbose, self.maxreason, self.brain = ( | |
| verbose, maxreason, | |
| ChatOllama(model=modelid, temperature=0.2, validate_model_on_init=True, reasoning=True).bind_tools(list(TOOLMAP.values())) | |
| ) | |
| workflow = StateGraph(AgentSchema) | |
| toolnode = ToolNode(list(TOOLMAP.values())) | |
| # nodes # | |
| workflow.add_node("reason", self.reason) | |
| workflow.add_node("toolnode", toolnode) | |
| # edges # | |
| workflow.add_edge(START, "reason") | |
| workflow.add_conditional_edges("reason", self.next_step) | |
| workflow.add_edge("toolnode", "reason") | |
| # compile # | |
| self.workflow = workflow.compile() | |
| imagepath = Path("assets", "agent_graph.png") | |
| imagepath.parent.mkdir(exist_ok=True) | |
| if imagepath.exists(): imagepath.unlink() | |
| with open(imagepath, "wb") as imw: imw.write(self.workflow.get_graph().draw_mermaid_png()) | |
| return | |
| def __call__(self, query: str): return self.workflow.invoke({ | |
| "messages": [{'role': 'system', 'content': REASON}, {'role': 'user', 'content': query}], | |
| "niter": 0 | |
| }) | |
| # nodes # | |
| def reason(self, state: AgentSchema) -> AgentSchema: | |
| response = self.brain.invoke(state.get("messages", [])) | |
| if self.verbose: print(f"Reasoning:\n{response.additional_kwargs["reasoning_content"]}\n\nResponse:{response}") | |
| return {"messages": [response], "niter": int(state.get("niter"))+1} | |
| # edges # | |
| def next_step(self, state: AgentSchema): | |
| if state.get("messages")[-1].tool_calls: return "toolnode" | |
| return END |