|
|
""" |
|
|
Base class for LangGraph-based agents that serves an interface for building, |
|
|
compiling, and executing custom agent graphs. |
|
|
|
|
|
Alternatively you can also use `create_agent`, which implements a ReAct agent by default. |
|
|
It may be of particular interest, since it enables MiddleWare like `context summarization` |
|
|
and `human in the loop`, `dynamic model selection` out of the box. |
|
|
links: |
|
|
- create_agent: https://docs.langchain.com/oss/python/langchain/agents |
|
|
- middleware: https://docs.langchain.com/oss/python/langchain/middleware |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
from langgraph.graph import StateGraph |
|
|
from langchain_core.tools import BaseTool |
|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain_openrouter import ChatOpenRouter |
|
|
|
|
|
from src.backend.core.configs.agent import AgentConfig |
|
|
|
|
|
|
|
|
class BaseAgent(ABC): |
|
|
"""Abstract base class for all LangGraph-based agents. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: AgentConfig) -> None: |
|
|
"""Initialize the agent with configuration. |
|
|
""" |
|
|
self.config = config |
|
|
self.name = config.name |
|
|
self.description = config.description |
|
|
self._graph: Optional[StateGraph] = None |
|
|
self._compiled_graph = None |
|
|
|
|
|
|
|
|
self.llm = self._init_model() |
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
def build_graph(self) -> StateGraph: |
|
|
"""Build the agent's LangGraph structure. |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
def _init_model(self) -> ChatOpenAI: |
|
|
"""Initialize LLM engine based on model provider. |
|
|
""" |
|
|
model_cfg = self.config.model_config |
|
|
provider = model_cfg.provider.lower() |
|
|
|
|
|
if provider == "openai": |
|
|
return ChatOpenAI( |
|
|
model=model_cfg.model_name, |
|
|
api_key=model_cfg.get_api_key(), |
|
|
temperature=model_cfg.temperature, |
|
|
max_tokens=model_cfg.max_tokens, |
|
|
base_url=model_cfg.api_base, |
|
|
) |
|
|
elif provider == "openrouter": |
|
|
return ChatOpenRouter( |
|
|
model=model_cfg.model_name, |
|
|
api_key=model_cfg.get_api_key(), |
|
|
temperature=model_cfg.temperature, |
|
|
max_tokens=model_cfg.max_tokens, |
|
|
base_url=model_cfg.api_base, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"Provider '{provider}' not supported yet." |
|
|
) |
|
|
|
|
|
|
|
|
def bind_tools( |
|
|
self, |
|
|
tools: Optional[List[BaseTool]] = None, |
|
|
strict: bool = True |
|
|
) -> ChatOpenAI: |
|
|
""" |
|
|
Optionally bind tools to the initialized model. |
|
|
|
|
|
Args: |
|
|
tools: List of tools to bind. Defaults to `self.config.tools` if not provided. |
|
|
strict: Enforce schema validation for tools. |
|
|
""" |
|
|
if not hasattr(self, "llm"): |
|
|
raise RuntimeError("Model must be initialized before binding tools.") |
|
|
|
|
|
tools_to_bind = tools or self.config.tools |
|
|
if not tools_to_bind: |
|
|
return self.llm |
|
|
|
|
|
self.llm = self.llm.bind_tools(tools_to_bind, strict=strict) |
|
|
return self.llm |
|
|
|
|
|
|
|
|
|
|
|
def compile(self, checkpointer=None, store=None) -> StateGraph: |
|
|
"""Compile the agent graph for execution. |
|
|
""" |
|
|
if self._graph is None: |
|
|
self._graph = self.build_graph() |
|
|
|
|
|
self._compiled_graph = self._graph.compile(checkpointer=checkpointer, store=store) |
|
|
return self._compiled_graph |
|
|
|
|
|
|
|
|
def get_graph(self) -> StateGraph: |
|
|
"""Return compiled graph (compile if needed). |
|
|
""" |
|
|
if self._compiled_graph is None: |
|
|
self.compile() |
|
|
return self._compiled_graph |
|
|
|
|
|
|
|
|
def visualize(self, output_path: Optional[str] = None): |
|
|
"""Render the graph as a Mermaid diagram. |
|
|
""" |
|
|
if self._compiled_graph is None: |
|
|
self.compile() |
|
|
return self._compiled_graph.get_graph().draw_mermaid_png(output_file_path=output_path) |
|
|
|
|
|
|
|
|
def invoke( |
|
|
self, |
|
|
input_data: Dict[str, object], |
|
|
config: Optional[Dict[str, object]] = None |
|
|
) -> Dict[str, object]: |
|
|
"""Execute the compiled agent. |
|
|
""" |
|
|
if self._compiled_graph is None: |
|
|
self.compile() |
|
|
return self._compiled_graph.invoke(input_data, config) |
|
|
|
|
|
|
|
|
async def ainvoke( |
|
|
self, |
|
|
input_data: Dict[str, object], |
|
|
config: Optional[Dict[str, object]] = None |
|
|
) -> Dict[str, object]: |
|
|
"""Execute the agent asynchronously. |
|
|
""" |
|
|
if self._compiled_graph is None: |
|
|
self.compile() |
|
|
return await self._compiled_graph.ainvoke(input_data, config) |
|
|
|
|
|
|
|
|
def stream( |
|
|
self, |
|
|
input_data: Dict[str, object], |
|
|
config: Optional[Dict[str, object]] = None |
|
|
) -> Dict[str, object]: |
|
|
"""Stream agent execution results. |
|
|
""" |
|
|
if self._compiled_graph is None: |
|
|
self.compile() |
|
|
return self._compiled_graph.stream(input_data, config) |
|
|
|
|
|
|
|
|
def get_tools(self) -> List[BaseTool]: |
|
|
"""Return the tools this agent can use. |
|
|
""" |
|
|
return list(self.config.tools or []) |
|
|
|
|
|
|
|
|
def get_capabilities(self) -> List[str]: |
|
|
"""List of agent capabilities (override in subclasses). |
|
|
""" |
|
|
return [] |
|
|
|
|
|
|
|
|
@property |
|
|
def metadata(self) -> Dict[str, object]: |
|
|
"""Return agent metadata for discovery and routing. |
|
|
""" |
|
|
return { |
|
|
"name": self.name, |
|
|
"description": self.description, |
|
|
"tools": [tool.name for tool in self.get_tools()], |
|
|
"capabilities": self.get_capabilities(), |
|
|
} |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"{self.__class__.__name__}(name='{self.name}')" |
|
|
|