HR-Assistant / src /backend /core /base_agent.py
owenkaplinsky's picture
update from github stable code (#3)
3370983 verified
"""
Base class for LangGraph-based agents that serves an interface for building,
compiling, and executing custom agent graphs.
Alternatively you can also use `create_agent`, which implements a ReAct agent by default.
It may be of particular interest, since it enables MiddleWare like `context summarization`
and `human in the loop`, `dynamic model selection` out of the box.
links:
- create_agent: https://docs.langchain.com/oss/python/langchain/agents
- middleware: https://docs.langchain.com/oss/python/langchain/middleware
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from langgraph.graph import StateGraph
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langchain_openrouter import ChatOpenRouter
from src.backend.core.configs.agent import AgentConfig
class BaseAgent(ABC):
"""Abstract base class for all LangGraph-based agents.
"""
def __init__(self, config: AgentConfig) -> None:
"""Initialize the agent with configuration.
"""
self.config = config
self.name = config.name
self.description = config.description
self._graph: Optional[StateGraph] = None
self._compiled_graph = None
# Initialize model (tools are bound optionally via bind_tools)
self.llm = self._init_model()
# ~~~ ABSTRACT METHODS ~~~
@abstractmethod
def build_graph(self) -> StateGraph:
"""Build the agent's LangGraph structure.
"""
pass
# ~~~ MODEL INITIALIZATION ~~~
def _init_model(self) -> ChatOpenAI:
"""Initialize LLM engine based on model provider.
"""
model_cfg = self.config.model_config
provider = model_cfg.provider.lower()
if provider == "openai":
return ChatOpenAI(
model=model_cfg.model_name,
api_key=model_cfg.get_api_key(),
temperature=model_cfg.temperature,
max_tokens=model_cfg.max_tokens,
base_url=model_cfg.api_base,
)
elif provider == "openrouter":
return ChatOpenRouter(
model=model_cfg.model_name,
api_key=model_cfg.get_api_key(),
temperature=model_cfg.temperature,
max_tokens=model_cfg.max_tokens,
base_url=model_cfg.api_base,
)
else:
raise NotImplementedError(
f"Provider '{provider}' not supported yet."
)
def bind_tools(
self,
tools: Optional[List[BaseTool]] = None,
strict: bool = True
) -> ChatOpenAI:
"""
Optionally bind tools to the initialized model.
Args:
tools: List of tools to bind. Defaults to `self.config.tools` if not provided.
strict: Enforce schema validation for tools.
"""
if not hasattr(self, "llm"):
raise RuntimeError("Model must be initialized before binding tools.")
tools_to_bind = tools or self.config.tools
if not tools_to_bind:
return self.llm # no-op
self.llm = self.llm.bind_tools(tools_to_bind, strict=strict)
return self.llm
# ~~~ GRAPH MANAGEMENT~~~
def compile(self, checkpointer=None, store=None) -> StateGraph:
"""Compile the agent graph for execution.
"""
if self._graph is None:
self._graph = self.build_graph()
self._compiled_graph = self._graph.compile(checkpointer=checkpointer, store=store)
return self._compiled_graph
def get_graph(self) -> StateGraph:
"""Return compiled graph (compile if needed).
"""
if self._compiled_graph is None:
self.compile()
return self._compiled_graph
def visualize(self, output_path: Optional[str] = None):
"""Render the graph as a Mermaid diagram.
"""
if self._compiled_graph is None:
self.compile()
return self._compiled_graph.get_graph().draw_mermaid_png(output_file_path=output_path)
# ~~~ EXECUTION ~~~
def invoke(
self,
input_data: Dict[str, object],
config: Optional[Dict[str, object]] = None
) -> Dict[str, object]:
"""Execute the compiled agent.
"""
if self._compiled_graph is None:
self.compile()
return self._compiled_graph.invoke(input_data, config)
async def ainvoke(
self,
input_data: Dict[str, object],
config: Optional[Dict[str, object]] = None
) -> Dict[str, object]:
"""Execute the agent asynchronously.
"""
if self._compiled_graph is None:
self.compile()
return await self._compiled_graph.ainvoke(input_data, config)
def stream(
self,
input_data: Dict[str, object],
config: Optional[Dict[str, object]] = None
) -> Dict[str, object]:
"""Stream agent execution results.
"""
if self._compiled_graph is None:
self.compile()
return self._compiled_graph.stream(input_data, config)
# ~~~ UTILITIES ~~~
def get_tools(self) -> List[BaseTool]:
"""Return the tools this agent can use.
"""
return list(self.config.tools or [])
def get_capabilities(self) -> List[str]:
"""List of agent capabilities (override in subclasses).
"""
return []
@property
def metadata(self) -> Dict[str, object]:
"""Return agent metadata for discovery and routing.
"""
return {
"name": self.name,
"description": self.description,
"tools": [tool.name for tool in self.get_tools()],
"capabilities": self.get_capabilities(),
}
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name='{self.name}')"