File size: 5,922 Bytes
3370983 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""
Base class for LangGraph-based agents that serves an interface for building,
compiling, and executing custom agent graphs.
Alternatively you can also use `create_agent`, which implements a ReAct agent by default.
It may be of particular interest, since it enables MiddleWare like `context summarization`
and `human in the loop`, `dynamic model selection` out of the box.
links:
- create_agent: https://docs.langchain.com/oss/python/langchain/agents
- middleware: https://docs.langchain.com/oss/python/langchain/middleware
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from langgraph.graph import StateGraph
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langchain_openrouter import ChatOpenRouter
from src.backend.core.configs.agent import AgentConfig
class BaseAgent(ABC):
"""Abstract base class for all LangGraph-based agents.
"""
def __init__(self, config: AgentConfig) -> None:
"""Initialize the agent with configuration.
"""
self.config = config
self.name = config.name
self.description = config.description
self._graph: Optional[StateGraph] = None
self._compiled_graph = None
# Initialize model (tools are bound optionally via bind_tools)
self.llm = self._init_model()
# ~~~ ABSTRACT METHODS ~~~
@abstractmethod
def build_graph(self) -> StateGraph:
"""Build the agent's LangGraph structure.
"""
pass
# ~~~ MODEL INITIALIZATION ~~~
def _init_model(self) -> ChatOpenAI:
"""Initialize LLM engine based on model provider.
"""
model_cfg = self.config.model_config
provider = model_cfg.provider.lower()
if provider == "openai":
return ChatOpenAI(
model=model_cfg.model_name,
api_key=model_cfg.get_api_key(),
temperature=model_cfg.temperature,
max_tokens=model_cfg.max_tokens,
base_url=model_cfg.api_base,
)
elif provider == "openrouter":
return ChatOpenRouter(
model=model_cfg.model_name,
api_key=model_cfg.get_api_key(),
temperature=model_cfg.temperature,
max_tokens=model_cfg.max_tokens,
base_url=model_cfg.api_base,
)
else:
raise NotImplementedError(
f"Provider '{provider}' not supported yet."
)
def bind_tools(
self,
tools: Optional[List[BaseTool]] = None,
strict: bool = True
) -> ChatOpenAI:
"""
Optionally bind tools to the initialized model.
Args:
tools: List of tools to bind. Defaults to `self.config.tools` if not provided.
strict: Enforce schema validation for tools.
"""
if not hasattr(self, "llm"):
raise RuntimeError("Model must be initialized before binding tools.")
tools_to_bind = tools or self.config.tools
if not tools_to_bind:
return self.llm # no-op
self.llm = self.llm.bind_tools(tools_to_bind, strict=strict)
return self.llm
# ~~~ GRAPH MANAGEMENT~~~
def compile(self, checkpointer=None, store=None) -> StateGraph:
"""Compile the agent graph for execution.
"""
if self._graph is None:
self._graph = self.build_graph()
self._compiled_graph = self._graph.compile(checkpointer=checkpointer, store=store)
return self._compiled_graph
def get_graph(self) -> StateGraph:
"""Return compiled graph (compile if needed).
"""
if self._compiled_graph is None:
self.compile()
return self._compiled_graph
def visualize(self, output_path: Optional[str] = None):
"""Render the graph as a Mermaid diagram.
"""
if self._compiled_graph is None:
self.compile()
return self._compiled_graph.get_graph().draw_mermaid_png(output_file_path=output_path)
# ~~~ EXECUTION ~~~
def invoke(
self,
input_data: Dict[str, object],
config: Optional[Dict[str, object]] = None
) -> Dict[str, object]:
"""Execute the compiled agent.
"""
if self._compiled_graph is None:
self.compile()
return self._compiled_graph.invoke(input_data, config)
async def ainvoke(
self,
input_data: Dict[str, object],
config: Optional[Dict[str, object]] = None
) -> Dict[str, object]:
"""Execute the agent asynchronously.
"""
if self._compiled_graph is None:
self.compile()
return await self._compiled_graph.ainvoke(input_data, config)
def stream(
self,
input_data: Dict[str, object],
config: Optional[Dict[str, object]] = None
) -> Dict[str, object]:
"""Stream agent execution results.
"""
if self._compiled_graph is None:
self.compile()
return self._compiled_graph.stream(input_data, config)
# ~~~ UTILITIES ~~~
def get_tools(self) -> List[BaseTool]:
"""Return the tools this agent can use.
"""
return list(self.config.tools or [])
def get_capabilities(self) -> List[str]:
"""List of agent capabilities (override in subclasses).
"""
return []
@property
def metadata(self) -> Dict[str, object]:
"""Return agent metadata for discovery and routing.
"""
return {
"name": self.name,
"description": self.description,
"tools": [tool.name for tool in self.get_tools()],
"capabilities": self.get_capabilities(),
}
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name='{self.name}')"
|