TrackMate-AI / core /agents /SkeletonGraphAgent.py
Abhishek
Adding the application.
2a902a6
"""
Module Name: SkeletonGraphAgent
Description: This module contains the Langgraph's Agent class that provides a flexible agent framework
using langgraph where all configuration is dynamically loaded from user input. This includes
system prompts, rules, input variables, LLM configuration, output structure, and tools.
Author: Abhishek Singh
Last Modified: 2025-05-29
"""
import logging, re, ast, json, os, aiohttp
from tqdm import tqdm
from typing import Any, Dict, List, Optional, Union, TypedDict, Annotated, Sequence
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage
from pydantic import BaseModel, Field, create_model
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.callbacks import get_openai_callback
from core.llms.base_llm import get_llm
from langchain_mcp_adapters.client import MultiServerMCPClient
from urllib.parse import urlparse
import asyncio
# Configure logging
logger = logging.getLogger(__name__)
class AgentState(TypedDict):
"""
Represents the state of the agent.
"""
messages: Annotated[Sequence[BaseMessage], add_messages] # List of messages exchanged with the agent
input_variables: Dict[str, Any] # Input variables provided to the agent
final_response: Optional[Dict[str, Any]] # Final structured response from the agent
class SkeletonGraphAgent:
"""
A flexible agent built with langgraph that takes all configuration from user input including:
- System prompt
- Rules to be applied in the system prompt
- Input variables
- LLM configuration
- Output structure
- Tools to be used
This agent serves as a foundation for creating custom agents without modifying code.
It dynamically loads tools, creates structured output formats, and handles various input types.
Args:
metadata (Dict[str, Any]): Configuration dictionary containing:
- model (Dict): LLM model configuration with name and parameters
- temperature (float): Temperature for LLM (0-1)
- tokens (int): Max tokens for LLM response
- system_prompt (str): Base system prompt for the agent
- input_variables (List[Dict]): List of input variables with names and default values
- outputType (List[Dict]): Structure of the expected output
- rules (List[Dict]): Rules to be applied in the system prompt
- tools (List[str]): List of tool names to be loaded and used by the agent
"""
def __init__(self, metadata: Dict[str, Any]):
"""
Initializes the SkeletonGraphAgent with the provided metadata.
Args:
metadata (Dict[str, Any]): Configuration dictionary for the agent.
"""
# Initialize MCP-related attributes
self.client = None
self.mcp_tools = []
# Extract and set the LLM configuration parameters such as model name, temperature, max tokens, system prompt, etc.
self._configure_llm_parameters(metadata)
# Parse the output structure for structured responses
self._parse_structured_output(metadata)
# Create a pydantic model of the output structure
self._create_pydantic_model()
# Configure the Agents tools
self._configure_agents_tools(metadata)
# Configuring the llm(s), one llm for generation and one for responding in output format
self._configure_llm()
def _configure_llm_parameters(self, metadata: Dict[str, Any]):
"""
Configures the LLM parameters from the provided metadata.
Args:
metadata (Dict[str, Any]): Configuration dictionary containing LLM parameters.
Returns:
str: The name of the configured LLM.
"""
# LLM Configuration
self.model_name = metadata.get("model", {}).get("input", "gpt-4o")
self.temperature = metadata.get("temperature", 0)
self.max_tokens = metadata.get("tokens", 1000)
self.system_prompt = metadata.get("system_prompt", "You are a helpful AI assistant.")
self.rules = self._parse_literal(metadata.get("rules", "[]"), [])
self.input_variables = metadata.get("input_variables", [{"name": "input", "input": ""}])
self.api_key = metadata.get("api_key", None)
# If rules are provided, append them to the system prompt
if self.rules:
for rule in self.rules:
self.system_prompt += f"\n{rule['rulename']}: {rule['ruleDescription']}"
def _parse_structured_output(self, metadata: Dict[str, Any]):
"""
Parse the outputType metadata into a dictionary of field definitions.
This defines the structure of the agent's output.
Args:
metadata (Dict[str, Any]): The metadata containing output structure
default (Any): The default value to return if parsing fails
Returns:
Dict[str, Any]: Dictionary of output fields with their types and descriptions
"""
try:
# Parse the outputType from metadata
self.output_type = self._parse_literal(metadata.get("outputType", "[]"), [])
# Initialize output_fields as an empty dictionary
self.output_fields = {}
# Populate output_fields with the parsed outputType
for field in self.output_type:
self.output_fields[field["outputKey"]] = {
"type": field["outputKeyType"],
"description": field["outputDescription"]
}
except (ValueError, TypeError) as e:
logger.warning(f"Failed to parse output structure: {str(e)}")
def _create_pydantic_model(self):
"""
Dynamically create a Pydantic class based on user-provided fields.
This model defines the structure of the agent's output.
"""
# Check if output_fields is empty
if not self.output_fields:
logger.warning("No output fields defined. Using default model with a single 'output' field.")
try:
self.pydantic_model = None
if self.output_fields:
field_definitions = {
field_name: (field_info['type'], Field(
description=field_info['description']))
for field_name, field_info in self.output_fields.items()
}
self.pydantic_model = create_model(
'OutputFormat',
__doc__="Dynamically generated Pydantic model for agent output.",
**field_definitions
)
logger.debug(f"Created Pydantic model with fields: {list(self.output_fields.keys())}")
except Exception as e:
logger.error(f"Failed to create Pydantic model: {str(e)}")
def _configure_agents_tools(self, metadata: Dict[str, Any]):
"""
Configures the agent's tools and output structure based on the provided metadata.
Args:
metadata (Dict[str, Any]): Configuration dictionary containing tools and output structure.
"""
# Get tools from metadata
tools_config = self._parse_literal(metadata.get("tools", "[]"), [])
# Initialize tools list
self.tools = []
# Handle both tool names and tool instances
for tool in tools_config:
if not tool:
continue
try:
if callable(tool):
# If tool is a function/callable (e.g. @tool decorated function), use it directly
self.tools.append(tool)
logger.info(f"Successfully loaded tool function: {tool.__name__}")
elif isinstance(tool, str):
# If tool is a string, try to import from core.tools
module_path = f"core.tools.{tool}"
module = __import__(module_path, fromlist=[tool])
tool_class = getattr(module, tool)
# Check if it's already a tool instance
if callable(tool_class):
tool_instance = tool_class
else:
tool_instance = tool_class()
self.tools.append(tool_instance)
logger.info(f"Successfully loaded tool: {tool}")
else:
# If it's already a tool instance
self.tools.append(tool)
logger.info(f"Successfully loaded tool instance: {tool.__class__.__name__}")
except Exception as e:
logger.error(f"Failed to load tool {tool}: {str(e)}")
def _load_mcp_config(self, config_path: str = "core/config/mcp_config.json") -> Dict[str, Any]:
"""
Load MCP configuration from a JSON file.
Args:
config_path (str): Path to the MCP configuration file
Returns:
Dict[str, Any]: MCP configuration dictionary
"""
try:
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Loaded MCP configuration from {config_path}")
return config
else:
logger.warning(f"MCP config file not found at {config_path}. Using empty config.")
return {}
except Exception as e:
logger.error(f"Failed to load MCP config from {config_path}: {str(e)}")
return {}
async def _check_mcp_server_health(self, url: str, timeout: int = 5) -> bool:
"""
Check if an MCP server is up and running.
Args:
url (str): The URL of the MCP server
timeout (int): Timeout in seconds for the health check
Returns:
bool: True if server is up, False otherwise
"""
try:
parsed_url = urlparse(url)
host = parsed_url.hostname
port = parsed_url.port
if not host or not port:
logger.warning(f"Invalid URL format: {url}")
return False
# Try to connect to the server
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
try:
async with session.get(f"http://{host}:{port}") as response:
# If we get any response, consider the server up
logger.debug(f"MCP server at {url} is responding (status: {response.status})")
return True
except aiohttp.ClientError:
# Try a simple TCP connection if HTTP fails
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(host, port),
timeout=timeout
)
writer.close()
await writer.wait_closed()
logger.debug(f"MCP server at {url} is reachable via TCP")
return True
except Exception:
logger.debug(f"MCP server at {url} is not reachable")
return False
except Exception as e:
logger.debug(f"Health check failed for {url}: {str(e)}")
return False
async def _configure_mcp_client(self, metadata: Dict[str, Any]):
"""
Configures the MCP client for the agent based on the provided metadata.
Handles both stdio and HTTP-based transports (sse, streamable_http).
"""
try:
# Load MCP configuration from file
mcp_config = self._load_mcp_config()
if not mcp_config:
logger.info("No MCP configuration found. Skipping MCP client setup.")
return
# Create the MCP client with available servers
self.client = MultiServerMCPClient(mcp_config)
# Start the client connection
await self.client.__aenter__()
# Get the tools from the client
self.mcp_tools = self.client.get_tools()
self.tools.extend(self.mcp_tools)
logger.info(f"MCP client configured successfully with {len(mcp_config)} servers and {len(self.mcp_tools)} tools.")
except Exception as e:
logger.error(f"Failed to configure MCP client: {str(e)}")
self.client = None
self.mcp_tools = []
@classmethod
async def create(cls, metadata: Dict[str, Any]):
"""
Async factory method to create and configure a SkeletonGraphAgent instance.
Args:
metadata (Dict[str, Any]): Configuration dictionary.
Returns:
SkeletonGraphAgent: Configured instance.
"""
self = cls(metadata) # Call __init__ with metadata
await self._configure_mcp_client(metadata) # Configure async parts
if self.mcp_tools:
self.main_llm = self.main_llm.bind_tools(self.tools) # Bind MCP tools to the main LLM
# Building the state graph for the agent, we build the state graph after the MCP client is configured
self._build_state_graph()
return self
def _configure_llm(self):
"""
Configures the LLM for the agent based on the provided metadata.
Args:
metadata (Dict[str, Any]): Configuration dictionary containing LLM parameters.
"""
try:
# Initialize the LLM with the specified model and parameters
self.main_llm = get_llm(
model_name=self.model_name,
provider="openai", # Default provider, can be changed if needed
api_key=self.api_key,
temperature=self.temperature,
max_tokens=self.max_tokens
)
# If tools are configured, bind them to the main LLM
if self.tools:
self.main_llm = self.main_llm.bind_tools(self.tools)
logger.info(f"LLM configured with model: {self.model_name}, temperature: {self.temperature}, max tokens: {self.max_tokens}")
# If a structured output is required, configure the LLM for structured output
if self.pydantic_model:
# If a second LLM is needed for structured output, configure it similarly
self.llm_for_response = get_llm(
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_tokens
)
self.llm_with_structured_output = self.llm_for_response.with_structured_output(self.pydantic_model)
except Exception as e:
logger.error(f"Failed to configure LLM: {str(e)}")
def _parse_literal(self, value: str, default: Any) -> Any:
"""
Parse a string value into a Python object.
Handles various string formats including lists, dictionaries, and type references.
Args:
value (str): The string value to parse
default (Any): The default value to return if parsing fails
Returns:
Any: The parsed Python object or the default value
"""
try:
# Handle type references in the string
cleaned_value = re.sub(r"<class '(\w+)'>", r"\1", str(value))
# Handle type references without quotes
cleaned_value = re.sub(r'"type":\s*(\w+)', lambda m: f'"type": "{m.group(1)}"', cleaned_value)
return ast.literal_eval(cleaned_value)
except (ValueError, SyntaxError) as e:
logger.debug(f"Failed to parse literal value: {value}. Error: {str(e)}")
# Handle comma-separated values
if isinstance(value, str):
if ',' in value:
return [item.strip() for item in value.split(',')]
elif ' ' in value:
return value.split()
return default
def _build_state_graph(self):
"""
Builds the state graph for the agent using langgraph.
This defines the flow of messages and tool usage in the agent's operation.
"""
try:
# Initialize the state graph
self.graph = StateGraph(AgentState)
# Define the main agent node that processes input and generates a response
self.graph.add_node("agent_node", self._agent_node)
# Set entry point of the graph
self.graph.set_entry_point("agent_node")
# If output is required in a structured format, add a respond node
if self.pydantic_model:
# Add a node for responding in structured format
self.graph.add_node("respond", self._respond)
# Connect the respond node to the END
self.graph.add_edge("respond", END)
# Add a node if tools are configured
if self.tools:
# Add a node for tools
self.graph.add_node("tools", ToolNode(self.tools))
# Connect the agent node to the tools node
self.graph.add_edge("tools", "agent_node")
# Adding the should_continue node to determine if the agent should continue processing
if self.pydantic_model and (self.tools or self.mcp_tools):
self.graph.add_conditional_edges(
"agent_node",
self._should_continue,
{
"continue": "tools", # Continue processing
"respond": "respond", # Respond in structured format
# "end": END # End the conversation
}
)
elif self.pydantic_model and not self.tools:
# If structured output is required, go to respond node
self.graph.add_edge(
"agent_node",
"respond"
)
elif not self.pydantic_model and (self.tools or self.mcp_tools):
self.graph.add_conditional_edges(
"agent_node",
self._should_continue,
{
"continue": "tools", # Continue processing
"end": END # End the conversation
}
)
else:
# If no structured output or tools, end the conversation
self.graph.add_edge("agent_node", END)
self.workflow = self.graph.compile()
logger.info("State graph built successfully with tools and initial system message.")
except Exception as e:
logger.error(f"Failed to build state graph: {str(e)}")
def _agent_node(self, state: AgentState) -> AgentState:
"""
The main agent node that processes the input and generates a response.
"""
# Get the current messages from the state
messages = state.get('messages', [])
# Add system message only if it's the first message
if not messages or not any(isinstance(msg, SystemMessage) for msg in messages):
messages = [SystemMessage(content=self.system_prompt)] + messages
# Add input variables to the messages
input_variables = state.get("input_variables", {})
if input_variables:
input_message = HumanMessage(content=json.dumps(input_variables))
messages.append(input_message)
else:
input_message = HumanMessage(content="No input variables provided.")
messages.append(input_message)
response = self.main_llm.invoke(messages)
# Return complete state
return {
"messages": messages + [response],
"input_variables": state.get("input_variables", {}), # Preserve input variables
"final_response": state.get("final_response") # Preserve any existing final response
}
def _respond(self, state: AgentState) -> AgentState:
"""
The Respond node, will be called if response is required in a Structured Format.
Args:
state (AgentState): The current state of the agent.
Returns:
AgentState: The updated state after processing the input.
"""
# Get the current messages from the state
messages = state.get("messages", [])
response = self.llm_with_structured_output.invoke(messages)
# Create an AIMessage with the structured response
ai_message = AIMessage(content=str(response))
# Preserve existing messages and append new message
return {
"final_response": response,
"messages": state.get("messages", []) + [ai_message],
"input_variables": state.get("input_variables", {}) # Preserve input variables
}
def _should_continue(self, state: AgentState) -> str:
"""
Determines whether the agent should continue processing based on the state.
"""
if not state.get("messages"):
return "end"
last_message = state["messages"][-1]
# Check if the last message is a ToolMessage or AIMessage
if not last_message.tool_calls:
if self.pydantic_model:
return "respond"
return "end"
return "continue"
def _show_graph(self):
"""
Displays the state graph of the agent.
This is useful for debugging and understanding the flow of the agent.
"""
from IPython.display import Image, display
display(Image(self.workflow.get_graph().draw_mermaid_png()))
async def _execute(self, input: str, metadata: Dict[str, Any] = {None}) -> Dict[str, Any]:
"""
Execute the agent with the provided inputs.
Args:
input (str): The primary input text
metadata (Dict[str, Any]): Additional metadata including:
- chat_history: Optional chat history for context
- input_variables: Values for the input variables defined during initialization
Returns:
Dict[str, Any]: Dictionary containing the results and execution metadata
"""
# Get the chat history from metadata, if provided
chat_history = metadata.get("chat_history", [])
# Convert history to BaseMessage objects if they're not already
processed_history = self._parse_history(chat_history)
# Add the main input to the messages
messages = processed_history + [HumanMessage(content=input)]
# Parse input variables from metadata
input_variables = self._parse_input_variables(metadata.get("input_variables", []))
try:
# Check if we're processing batch data
if "data" in input_variables:
result = await self._process_batch_data(input_variables, messages)
else:
result = await self._process_single_input(input_variables, messages)
except Exception as e:
logger.error(f"Error processing input: {str(e)}")
result = {
"success": False,
"message": f"Error processing input: {str(e)}",
"raw_response": None,
}
return result
def _parse_history(self, chat_history: List):
"""
Parses the chat history to convert it into a list of BaseMessage objects.
Args:
chat_history (List): The chat history to parse
Returns:
List[BaseMessage]: List of BaseMessage objects representing the chat history
"""
parsed_history = []
for msg in chat_history:
if isinstance(msg, BaseMessage):
parsed_history.append(msg)
elif isinstance(msg, tuple) and len(msg) == 2:
role, content = msg
if role.lower() == "user" or role.lower() == "human":
parsed_history.append(HumanMessage(content=content))
elif role.lower() == "assistant" or role.lower() == "ai":
parsed_history.append(AIMessage(content=content))
elif role.lower() == "system":
parsed_history.append(SystemMessage(content=content))
else:
parsed_history.append(HumanMessage(content=f"{role}: {content}"))
else:
# Default to human message
parsed_history.append(HumanMessage(content=str(msg)))
return parsed_history
def _parse_input_variables(self, input_variables: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Parses the input variables from the provided list into a dictionary.
Args:
input_variables (List[Dict[str, Any]]): List of input variable definitions
Returns:
Dict[str, Any]: Dictionary of input variables with their names and values
"""
parsed_variables = {}
for var in input_variables:
if isinstance(var, dict) and "name" in var:
name = var["name"]
value = var.get("input", "")
parsed_variables[name] = value
else:
logger.warning(f"Invalid input variable format: {var}")
return parsed_variables
def _process_structured_output(self, output: Dict[str, Any]) -> Union[Dict[str, Any], str]:
"""
Process the structured output from the agent.
Args:
output (Dict[str, Any]): The structured output from the agent
Returns:
Union[Dict[str, Any], str]: Processed structured output
"""
try:
# If a Pydantic model is defined, validate and return the structured output
if self.pydantic_model:
return {key: getattr(output['final_response'], key) for key in self.output_fields.keys()}
else:
# If no structured output is defined, return the raw output
return output
except Exception as e:
logger.error(f"Error processing structured output: {str(e)}")
return str(output)
async def _process_batch_data(self, execution_inputs: Dict[str, Any], messages) -> Dict[str, Any]:
"""
Process a batch of data items.
Args:
execution_inputs (Dict[str, Any]): The execution inputs including data array
Returns:
Dict[str, Any]: Results of batch processing
"""
with get_openai_callback() as cb:
response = []
try:
# Create a copy of inputs without the data field
data_inputs = execution_inputs.copy()
data = data_inputs.pop("data")
# Parse data if it's a string
if isinstance(data, str):
data = self._parse_literal(data, [])
total_docs = len(data)
logger.info(f"Processing batch of {total_docs} documents")
# Process each data item with a progress bar
with tqdm(total=total_docs, desc="Processing documents") as pbar:
for doc in data:
# Add the current data item to the inputs
data_inputs["data"] = doc
# Initialize the initial state with messages and input variables
initial_state = {
"messages": messages,
"input_variables": data_inputs,
"final_response": None
}
# Invoke the agent
result = await self.workflow.ainvoke(initial_state)
# Process the structured output if a Pydantic model is defined
if self.pydantic_model:
output_response = self._process_structured_output(result)
else:
# If no structured output is defined, use the raw result
output_response = result
response.append(output_response)
pbar.update(1)
# Create the final result with metadata
result = {
"success": True,
"message": json.dumps(response),
"metainfo": self._get_callback_metadata(cb)
}
except Exception as e:
logger.error(f"Error processing batch data: {str(e)}")
result = {
"success": False,
"message": f"Error processing batch data: {str(e)}",
"metainfo": self._get_callback_metadata(cb)
}
return result
async def _process_single_input(self, execution_inputs: Dict[str, Any], messages) -> Dict[str, Any]:
"""
Process a single input.
Args:
execution_inputs (Dict[str, Any]): The execution inputs
Returns:
Dict[str, Any]: Result of processing
"""
with get_openai_callback() as cb:
try:
# Initialize the initial state with messages and input variables
initial_state = {
"messages": messages,
"input_variables": execution_inputs,
"final_response": None
}
# Invoke the agent
response = await self.workflow.ainvoke(initial_state)
# Process the result based on whether fields were provided
dict_data = self._process_structured_output(response)
# Create the final result with metadata
result = {
"success": True,
"message": str(dict_data),
"response": response,
"metainfo": self._get_callback_metadata(cb)
}
except Exception as e:
logger.error(f"Error processing input: {str(e)}")
result = {
"success": False,
"message": f"Error processing input: {str(e)}",
"metainfo": self._get_callback_metadata(cb)
}
return result
def _get_callback_metadata(self, cb) -> Dict[str, Any]:
"""
Get metadata from the OpenAI callback.
Args:
cb: The OpenAI callback object
Returns:
Dict[str, Any]: Metadata about the API call
"""
return {
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"total_tokens": cb.total_tokens,
"total_cost": cb.total_cost
}
async def close(self):
"""
Clean up resources, particularly the MCP client connection.
"""
if self.client:
try:
await self.client.__aexit__(None, None, None)
logger.info("MCP client connection closed successfully.")
except Exception as e:
logger.error(f"Error closing MCP client: {str(e)}")
finally:
self.client = None
self.mcp_tools = []
def __del__(self):
"""
Destructor to ensure cleanup when the object is garbage collected.
"""
if self.client:
import asyncio
try:
# Try to close the client if there's an active event loop
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self.close())
else:
asyncio.run(self.close())
except Exception:
# If we can't close properly, at least log it
logger.warning("Could not properly close MCP client in destructor")