Spaces:
Sleeping
Sleeping
| """ | |
| Module Name: SkeletonGraphAgent | |
| Description: This module contains the Langgraph's Agent class that provides a flexible agent framework | |
| using langgraph where all configuration is dynamically loaded from user input. This includes | |
| system prompts, rules, input variables, LLM configuration, output structure, and tools. | |
| Author: Abhishek Singh | |
| Last Modified: 2025-05-29 | |
| """ | |
| import logging, re, ast, json, os, aiohttp | |
| from tqdm import tqdm | |
| from typing import Any, Dict, List, Optional, Union, TypedDict, Annotated, Sequence | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage | |
| from pydantic import BaseModel, Field, create_model | |
| from langgraph.graph.message import add_messages | |
| from langgraph.graph import StateGraph, START, END, MessagesState | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_community.callbacks import get_openai_callback | |
| from core.llms.base_llm import get_llm | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from urllib.parse import urlparse | |
| import asyncio | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| class AgentState(TypedDict): | |
| """ | |
| Represents the state of the agent. | |
| """ | |
| messages: Annotated[Sequence[BaseMessage], add_messages] # List of messages exchanged with the agent | |
| input_variables: Dict[str, Any] # Input variables provided to the agent | |
| final_response: Optional[Dict[str, Any]] # Final structured response from the agent | |
| class SkeletonGraphAgent: | |
| """ | |
| A flexible agent built with langgraph that takes all configuration from user input including: | |
| - System prompt | |
| - Rules to be applied in the system prompt | |
| - Input variables | |
| - LLM configuration | |
| - Output structure | |
| - Tools to be used | |
| This agent serves as a foundation for creating custom agents without modifying code. | |
| It dynamically loads tools, creates structured output formats, and handles various input types. | |
| Args: | |
| metadata (Dict[str, Any]): Configuration dictionary containing: | |
| - model (Dict): LLM model configuration with name and parameters | |
| - temperature (float): Temperature for LLM (0-1) | |
| - tokens (int): Max tokens for LLM response | |
| - system_prompt (str): Base system prompt for the agent | |
| - input_variables (List[Dict]): List of input variables with names and default values | |
| - outputType (List[Dict]): Structure of the expected output | |
| - rules (List[Dict]): Rules to be applied in the system prompt | |
| - tools (List[str]): List of tool names to be loaded and used by the agent | |
| """ | |
| def __init__(self, metadata: Dict[str, Any]): | |
| """ | |
| Initializes the SkeletonGraphAgent with the provided metadata. | |
| Args: | |
| metadata (Dict[str, Any]): Configuration dictionary for the agent. | |
| """ | |
| # Initialize MCP-related attributes | |
| self.client = None | |
| self.mcp_tools = [] | |
| # Extract and set the LLM configuration parameters such as model name, temperature, max tokens, system prompt, etc. | |
| self._configure_llm_parameters(metadata) | |
| # Parse the output structure for structured responses | |
| self._parse_structured_output(metadata) | |
| # Create a pydantic model of the output structure | |
| self._create_pydantic_model() | |
| # Configure the Agents tools | |
| self._configure_agents_tools(metadata) | |
| # Configuring the llm(s), one llm for generation and one for responding in output format | |
| self._configure_llm() | |
| def _configure_llm_parameters(self, metadata: Dict[str, Any]): | |
| """ | |
| Configures the LLM parameters from the provided metadata. | |
| Args: | |
| metadata (Dict[str, Any]): Configuration dictionary containing LLM parameters. | |
| Returns: | |
| str: The name of the configured LLM. | |
| """ | |
| # LLM Configuration | |
| self.model_name = metadata.get("model", {}).get("input", "gpt-4o") | |
| self.temperature = metadata.get("temperature", 0) | |
| self.max_tokens = metadata.get("tokens", 1000) | |
| self.system_prompt = metadata.get("system_prompt", "You are a helpful AI assistant.") | |
| self.rules = self._parse_literal(metadata.get("rules", "[]"), []) | |
| self.input_variables = metadata.get("input_variables", [{"name": "input", "input": ""}]) | |
| self.api_key = metadata.get("api_key", None) | |
| # If rules are provided, append them to the system prompt | |
| if self.rules: | |
| for rule in self.rules: | |
| self.system_prompt += f"\n{rule['rulename']}: {rule['ruleDescription']}" | |
| def _parse_structured_output(self, metadata: Dict[str, Any]): | |
| """ | |
| Parse the outputType metadata into a dictionary of field definitions. | |
| This defines the structure of the agent's output. | |
| Args: | |
| metadata (Dict[str, Any]): The metadata containing output structure | |
| default (Any): The default value to return if parsing fails | |
| Returns: | |
| Dict[str, Any]: Dictionary of output fields with their types and descriptions | |
| """ | |
| try: | |
| # Parse the outputType from metadata | |
| self.output_type = self._parse_literal(metadata.get("outputType", "[]"), []) | |
| # Initialize output_fields as an empty dictionary | |
| self.output_fields = {} | |
| # Populate output_fields with the parsed outputType | |
| for field in self.output_type: | |
| self.output_fields[field["outputKey"]] = { | |
| "type": field["outputKeyType"], | |
| "description": field["outputDescription"] | |
| } | |
| except (ValueError, TypeError) as e: | |
| logger.warning(f"Failed to parse output structure: {str(e)}") | |
| def _create_pydantic_model(self): | |
| """ | |
| Dynamically create a Pydantic class based on user-provided fields. | |
| This model defines the structure of the agent's output. | |
| """ | |
| # Check if output_fields is empty | |
| if not self.output_fields: | |
| logger.warning("No output fields defined. Using default model with a single 'output' field.") | |
| try: | |
| self.pydantic_model = None | |
| if self.output_fields: | |
| field_definitions = { | |
| field_name: (field_info['type'], Field( | |
| description=field_info['description'])) | |
| for field_name, field_info in self.output_fields.items() | |
| } | |
| self.pydantic_model = create_model( | |
| 'OutputFormat', | |
| __doc__="Dynamically generated Pydantic model for agent output.", | |
| **field_definitions | |
| ) | |
| logger.debug(f"Created Pydantic model with fields: {list(self.output_fields.keys())}") | |
| except Exception as e: | |
| logger.error(f"Failed to create Pydantic model: {str(e)}") | |
| def _configure_agents_tools(self, metadata: Dict[str, Any]): | |
| """ | |
| Configures the agent's tools and output structure based on the provided metadata. | |
| Args: | |
| metadata (Dict[str, Any]): Configuration dictionary containing tools and output structure. | |
| """ | |
| # Get tools from metadata | |
| tools_config = self._parse_literal(metadata.get("tools", "[]"), []) | |
| # Initialize tools list | |
| self.tools = [] | |
| # Handle both tool names and tool instances | |
| for tool in tools_config: | |
| if not tool: | |
| continue | |
| try: | |
| if callable(tool): | |
| # If tool is a function/callable (e.g. @tool decorated function), use it directly | |
| self.tools.append(tool) | |
| logger.info(f"Successfully loaded tool function: {tool.__name__}") | |
| elif isinstance(tool, str): | |
| # If tool is a string, try to import from core.tools | |
| module_path = f"core.tools.{tool}" | |
| module = __import__(module_path, fromlist=[tool]) | |
| tool_class = getattr(module, tool) | |
| # Check if it's already a tool instance | |
| if callable(tool_class): | |
| tool_instance = tool_class | |
| else: | |
| tool_instance = tool_class() | |
| self.tools.append(tool_instance) | |
| logger.info(f"Successfully loaded tool: {tool}") | |
| else: | |
| # If it's already a tool instance | |
| self.tools.append(tool) | |
| logger.info(f"Successfully loaded tool instance: {tool.__class__.__name__}") | |
| except Exception as e: | |
| logger.error(f"Failed to load tool {tool}: {str(e)}") | |
| def _load_mcp_config(self, config_path: str = "core/config/mcp_config.json") -> Dict[str, Any]: | |
| """ | |
| Load MCP configuration from a JSON file. | |
| Args: | |
| config_path (str): Path to the MCP configuration file | |
| Returns: | |
| Dict[str, Any]: MCP configuration dictionary | |
| """ | |
| try: | |
| if os.path.exists(config_path): | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| logger.info(f"Loaded MCP configuration from {config_path}") | |
| return config | |
| else: | |
| logger.warning(f"MCP config file not found at {config_path}. Using empty config.") | |
| return {} | |
| except Exception as e: | |
| logger.error(f"Failed to load MCP config from {config_path}: {str(e)}") | |
| return {} | |
| async def _check_mcp_server_health(self, url: str, timeout: int = 5) -> bool: | |
| """ | |
| Check if an MCP server is up and running. | |
| Args: | |
| url (str): The URL of the MCP server | |
| timeout (int): Timeout in seconds for the health check | |
| Returns: | |
| bool: True if server is up, False otherwise | |
| """ | |
| try: | |
| parsed_url = urlparse(url) | |
| host = parsed_url.hostname | |
| port = parsed_url.port | |
| if not host or not port: | |
| logger.warning(f"Invalid URL format: {url}") | |
| return False | |
| # Try to connect to the server | |
| async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: | |
| try: | |
| async with session.get(f"http://{host}:{port}") as response: | |
| # If we get any response, consider the server up | |
| logger.debug(f"MCP server at {url} is responding (status: {response.status})") | |
| return True | |
| except aiohttp.ClientError: | |
| # Try a simple TCP connection if HTTP fails | |
| try: | |
| reader, writer = await asyncio.wait_for( | |
| asyncio.open_connection(host, port), | |
| timeout=timeout | |
| ) | |
| writer.close() | |
| await writer.wait_closed() | |
| logger.debug(f"MCP server at {url} is reachable via TCP") | |
| return True | |
| except Exception: | |
| logger.debug(f"MCP server at {url} is not reachable") | |
| return False | |
| except Exception as e: | |
| logger.debug(f"Health check failed for {url}: {str(e)}") | |
| return False | |
| async def _configure_mcp_client(self, metadata: Dict[str, Any]): | |
| """ | |
| Configures the MCP client for the agent based on the provided metadata. | |
| Handles both stdio and HTTP-based transports (sse, streamable_http). | |
| """ | |
| try: | |
| # Load MCP configuration from file | |
| mcp_config = self._load_mcp_config() | |
| if not mcp_config: | |
| logger.info("No MCP configuration found. Skipping MCP client setup.") | |
| return | |
| # Create the MCP client with available servers | |
| self.client = MultiServerMCPClient(mcp_config) | |
| # Start the client connection | |
| await self.client.__aenter__() | |
| # Get the tools from the client | |
| self.mcp_tools = self.client.get_tools() | |
| self.tools.extend(self.mcp_tools) | |
| logger.info(f"MCP client configured successfully with {len(mcp_config)} servers and {len(self.mcp_tools)} tools.") | |
| except Exception as e: | |
| logger.error(f"Failed to configure MCP client: {str(e)}") | |
| self.client = None | |
| self.mcp_tools = [] | |
| async def create(cls, metadata: Dict[str, Any]): | |
| """ | |
| Async factory method to create and configure a SkeletonGraphAgent instance. | |
| Args: | |
| metadata (Dict[str, Any]): Configuration dictionary. | |
| Returns: | |
| SkeletonGraphAgent: Configured instance. | |
| """ | |
| self = cls(metadata) # Call __init__ with metadata | |
| await self._configure_mcp_client(metadata) # Configure async parts | |
| if self.mcp_tools: | |
| self.main_llm = self.main_llm.bind_tools(self.tools) # Bind MCP tools to the main LLM | |
| # Building the state graph for the agent, we build the state graph after the MCP client is configured | |
| self._build_state_graph() | |
| return self | |
| def _configure_llm(self): | |
| """ | |
| Configures the LLM for the agent based on the provided metadata. | |
| Args: | |
| metadata (Dict[str, Any]): Configuration dictionary containing LLM parameters. | |
| """ | |
| try: | |
| # Initialize the LLM with the specified model and parameters | |
| self.main_llm = get_llm( | |
| model_name=self.model_name, | |
| provider="openai", # Default provider, can be changed if needed | |
| api_key=self.api_key, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens | |
| ) | |
| # If tools are configured, bind them to the main LLM | |
| if self.tools: | |
| self.main_llm = self.main_llm.bind_tools(self.tools) | |
| logger.info(f"LLM configured with model: {self.model_name}, temperature: {self.temperature}, max tokens: {self.max_tokens}") | |
| # If a structured output is required, configure the LLM for structured output | |
| if self.pydantic_model: | |
| # If a second LLM is needed for structured output, configure it similarly | |
| self.llm_for_response = get_llm( | |
| model_name=self.model_name, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens | |
| ) | |
| self.llm_with_structured_output = self.llm_for_response.with_structured_output(self.pydantic_model) | |
| except Exception as e: | |
| logger.error(f"Failed to configure LLM: {str(e)}") | |
| def _parse_literal(self, value: str, default: Any) -> Any: | |
| """ | |
| Parse a string value into a Python object. | |
| Handles various string formats including lists, dictionaries, and type references. | |
| Args: | |
| value (str): The string value to parse | |
| default (Any): The default value to return if parsing fails | |
| Returns: | |
| Any: The parsed Python object or the default value | |
| """ | |
| try: | |
| # Handle type references in the string | |
| cleaned_value = re.sub(r"<class '(\w+)'>", r"\1", str(value)) | |
| # Handle type references without quotes | |
| cleaned_value = re.sub(r'"type":\s*(\w+)', lambda m: f'"type": "{m.group(1)}"', cleaned_value) | |
| return ast.literal_eval(cleaned_value) | |
| except (ValueError, SyntaxError) as e: | |
| logger.debug(f"Failed to parse literal value: {value}. Error: {str(e)}") | |
| # Handle comma-separated values | |
| if isinstance(value, str): | |
| if ',' in value: | |
| return [item.strip() for item in value.split(',')] | |
| elif ' ' in value: | |
| return value.split() | |
| return default | |
| def _build_state_graph(self): | |
| """ | |
| Builds the state graph for the agent using langgraph. | |
| This defines the flow of messages and tool usage in the agent's operation. | |
| """ | |
| try: | |
| # Initialize the state graph | |
| self.graph = StateGraph(AgentState) | |
| # Define the main agent node that processes input and generates a response | |
| self.graph.add_node("agent_node", self._agent_node) | |
| # Set entry point of the graph | |
| self.graph.set_entry_point("agent_node") | |
| # If output is required in a structured format, add a respond node | |
| if self.pydantic_model: | |
| # Add a node for responding in structured format | |
| self.graph.add_node("respond", self._respond) | |
| # Connect the respond node to the END | |
| self.graph.add_edge("respond", END) | |
| # Add a node if tools are configured | |
| if self.tools: | |
| # Add a node for tools | |
| self.graph.add_node("tools", ToolNode(self.tools)) | |
| # Connect the agent node to the tools node | |
| self.graph.add_edge("tools", "agent_node") | |
| # Adding the should_continue node to determine if the agent should continue processing | |
| if self.pydantic_model and (self.tools or self.mcp_tools): | |
| self.graph.add_conditional_edges( | |
| "agent_node", | |
| self._should_continue, | |
| { | |
| "continue": "tools", # Continue processing | |
| "respond": "respond", # Respond in structured format | |
| # "end": END # End the conversation | |
| } | |
| ) | |
| elif self.pydantic_model and not self.tools: | |
| # If structured output is required, go to respond node | |
| self.graph.add_edge( | |
| "agent_node", | |
| "respond" | |
| ) | |
| elif not self.pydantic_model and (self.tools or self.mcp_tools): | |
| self.graph.add_conditional_edges( | |
| "agent_node", | |
| self._should_continue, | |
| { | |
| "continue": "tools", # Continue processing | |
| "end": END # End the conversation | |
| } | |
| ) | |
| else: | |
| # If no structured output or tools, end the conversation | |
| self.graph.add_edge("agent_node", END) | |
| self.workflow = self.graph.compile() | |
| logger.info("State graph built successfully with tools and initial system message.") | |
| except Exception as e: | |
| logger.error(f"Failed to build state graph: {str(e)}") | |
| def _agent_node(self, state: AgentState) -> AgentState: | |
| """ | |
| The main agent node that processes the input and generates a response. | |
| """ | |
| # Get the current messages from the state | |
| messages = state.get('messages', []) | |
| # Add system message only if it's the first message | |
| if not messages or not any(isinstance(msg, SystemMessage) for msg in messages): | |
| messages = [SystemMessage(content=self.system_prompt)] + messages | |
| # Add input variables to the messages | |
| input_variables = state.get("input_variables", {}) | |
| if input_variables: | |
| input_message = HumanMessage(content=json.dumps(input_variables)) | |
| messages.append(input_message) | |
| else: | |
| input_message = HumanMessage(content="No input variables provided.") | |
| messages.append(input_message) | |
| response = self.main_llm.invoke(messages) | |
| # Return complete state | |
| return { | |
| "messages": messages + [response], | |
| "input_variables": state.get("input_variables", {}), # Preserve input variables | |
| "final_response": state.get("final_response") # Preserve any existing final response | |
| } | |
| def _respond(self, state: AgentState) -> AgentState: | |
| """ | |
| The Respond node, will be called if response is required in a Structured Format. | |
| Args: | |
| state (AgentState): The current state of the agent. | |
| Returns: | |
| AgentState: The updated state after processing the input. | |
| """ | |
| # Get the current messages from the state | |
| messages = state.get("messages", []) | |
| response = self.llm_with_structured_output.invoke(messages) | |
| # Create an AIMessage with the structured response | |
| ai_message = AIMessage(content=str(response)) | |
| # Preserve existing messages and append new message | |
| return { | |
| "final_response": response, | |
| "messages": state.get("messages", []) + [ai_message], | |
| "input_variables": state.get("input_variables", {}) # Preserve input variables | |
| } | |
| def _should_continue(self, state: AgentState) -> str: | |
| """ | |
| Determines whether the agent should continue processing based on the state. | |
| """ | |
| if not state.get("messages"): | |
| return "end" | |
| last_message = state["messages"][-1] | |
| # Check if the last message is a ToolMessage or AIMessage | |
| if not last_message.tool_calls: | |
| if self.pydantic_model: | |
| return "respond" | |
| return "end" | |
| return "continue" | |
| def _show_graph(self): | |
| """ | |
| Displays the state graph of the agent. | |
| This is useful for debugging and understanding the flow of the agent. | |
| """ | |
| from IPython.display import Image, display | |
| display(Image(self.workflow.get_graph().draw_mermaid_png())) | |
| async def _execute(self, input: str, metadata: Dict[str, Any] = {None}) -> Dict[str, Any]: | |
| """ | |
| Execute the agent with the provided inputs. | |
| Args: | |
| input (str): The primary input text | |
| metadata (Dict[str, Any]): Additional metadata including: | |
| - chat_history: Optional chat history for context | |
| - input_variables: Values for the input variables defined during initialization | |
| Returns: | |
| Dict[str, Any]: Dictionary containing the results and execution metadata | |
| """ | |
| # Get the chat history from metadata, if provided | |
| chat_history = metadata.get("chat_history", []) | |
| # Convert history to BaseMessage objects if they're not already | |
| processed_history = self._parse_history(chat_history) | |
| # Add the main input to the messages | |
| messages = processed_history + [HumanMessage(content=input)] | |
| # Parse input variables from metadata | |
| input_variables = self._parse_input_variables(metadata.get("input_variables", [])) | |
| try: | |
| # Check if we're processing batch data | |
| if "data" in input_variables: | |
| result = await self._process_batch_data(input_variables, messages) | |
| else: | |
| result = await self._process_single_input(input_variables, messages) | |
| except Exception as e: | |
| logger.error(f"Error processing input: {str(e)}") | |
| result = { | |
| "success": False, | |
| "message": f"Error processing input: {str(e)}", | |
| "raw_response": None, | |
| } | |
| return result | |
| def _parse_history(self, chat_history: List): | |
| """ | |
| Parses the chat history to convert it into a list of BaseMessage objects. | |
| Args: | |
| chat_history (List): The chat history to parse | |
| Returns: | |
| List[BaseMessage]: List of BaseMessage objects representing the chat history | |
| """ | |
| parsed_history = [] | |
| for msg in chat_history: | |
| if isinstance(msg, BaseMessage): | |
| parsed_history.append(msg) | |
| elif isinstance(msg, tuple) and len(msg) == 2: | |
| role, content = msg | |
| if role.lower() == "user" or role.lower() == "human": | |
| parsed_history.append(HumanMessage(content=content)) | |
| elif role.lower() == "assistant" or role.lower() == "ai": | |
| parsed_history.append(AIMessage(content=content)) | |
| elif role.lower() == "system": | |
| parsed_history.append(SystemMessage(content=content)) | |
| else: | |
| parsed_history.append(HumanMessage(content=f"{role}: {content}")) | |
| else: | |
| # Default to human message | |
| parsed_history.append(HumanMessage(content=str(msg))) | |
| return parsed_history | |
| def _parse_input_variables(self, input_variables: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """ | |
| Parses the input variables from the provided list into a dictionary. | |
| Args: | |
| input_variables (List[Dict[str, Any]]): List of input variable definitions | |
| Returns: | |
| Dict[str, Any]: Dictionary of input variables with their names and values | |
| """ | |
| parsed_variables = {} | |
| for var in input_variables: | |
| if isinstance(var, dict) and "name" in var: | |
| name = var["name"] | |
| value = var.get("input", "") | |
| parsed_variables[name] = value | |
| else: | |
| logger.warning(f"Invalid input variable format: {var}") | |
| return parsed_variables | |
| def _process_structured_output(self, output: Dict[str, Any]) -> Union[Dict[str, Any], str]: | |
| """ | |
| Process the structured output from the agent. | |
| Args: | |
| output (Dict[str, Any]): The structured output from the agent | |
| Returns: | |
| Union[Dict[str, Any], str]: Processed structured output | |
| """ | |
| try: | |
| # If a Pydantic model is defined, validate and return the structured output | |
| if self.pydantic_model: | |
| return {key: getattr(output['final_response'], key) for key in self.output_fields.keys()} | |
| else: | |
| # If no structured output is defined, return the raw output | |
| return output | |
| except Exception as e: | |
| logger.error(f"Error processing structured output: {str(e)}") | |
| return str(output) | |
| async def _process_batch_data(self, execution_inputs: Dict[str, Any], messages) -> Dict[str, Any]: | |
| """ | |
| Process a batch of data items. | |
| Args: | |
| execution_inputs (Dict[str, Any]): The execution inputs including data array | |
| Returns: | |
| Dict[str, Any]: Results of batch processing | |
| """ | |
| with get_openai_callback() as cb: | |
| response = [] | |
| try: | |
| # Create a copy of inputs without the data field | |
| data_inputs = execution_inputs.copy() | |
| data = data_inputs.pop("data") | |
| # Parse data if it's a string | |
| if isinstance(data, str): | |
| data = self._parse_literal(data, []) | |
| total_docs = len(data) | |
| logger.info(f"Processing batch of {total_docs} documents") | |
| # Process each data item with a progress bar | |
| with tqdm(total=total_docs, desc="Processing documents") as pbar: | |
| for doc in data: | |
| # Add the current data item to the inputs | |
| data_inputs["data"] = doc | |
| # Initialize the initial state with messages and input variables | |
| initial_state = { | |
| "messages": messages, | |
| "input_variables": data_inputs, | |
| "final_response": None | |
| } | |
| # Invoke the agent | |
| result = await self.workflow.ainvoke(initial_state) | |
| # Process the structured output if a Pydantic model is defined | |
| if self.pydantic_model: | |
| output_response = self._process_structured_output(result) | |
| else: | |
| # If no structured output is defined, use the raw result | |
| output_response = result | |
| response.append(output_response) | |
| pbar.update(1) | |
| # Create the final result with metadata | |
| result = { | |
| "success": True, | |
| "message": json.dumps(response), | |
| "metainfo": self._get_callback_metadata(cb) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing batch data: {str(e)}") | |
| result = { | |
| "success": False, | |
| "message": f"Error processing batch data: {str(e)}", | |
| "metainfo": self._get_callback_metadata(cb) | |
| } | |
| return result | |
| async def _process_single_input(self, execution_inputs: Dict[str, Any], messages) -> Dict[str, Any]: | |
| """ | |
| Process a single input. | |
| Args: | |
| execution_inputs (Dict[str, Any]): The execution inputs | |
| Returns: | |
| Dict[str, Any]: Result of processing | |
| """ | |
| with get_openai_callback() as cb: | |
| try: | |
| # Initialize the initial state with messages and input variables | |
| initial_state = { | |
| "messages": messages, | |
| "input_variables": execution_inputs, | |
| "final_response": None | |
| } | |
| # Invoke the agent | |
| response = await self.workflow.ainvoke(initial_state) | |
| # Process the result based on whether fields were provided | |
| dict_data = self._process_structured_output(response) | |
| # Create the final result with metadata | |
| result = { | |
| "success": True, | |
| "message": str(dict_data), | |
| "response": response, | |
| "metainfo": self._get_callback_metadata(cb) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing input: {str(e)}") | |
| result = { | |
| "success": False, | |
| "message": f"Error processing input: {str(e)}", | |
| "metainfo": self._get_callback_metadata(cb) | |
| } | |
| return result | |
| def _get_callback_metadata(self, cb) -> Dict[str, Any]: | |
| """ | |
| Get metadata from the OpenAI callback. | |
| Args: | |
| cb: The OpenAI callback object | |
| Returns: | |
| Dict[str, Any]: Metadata about the API call | |
| """ | |
| return { | |
| "prompt_tokens": cb.prompt_tokens, | |
| "completion_tokens": cb.completion_tokens, | |
| "total_tokens": cb.total_tokens, | |
| "total_cost": cb.total_cost | |
| } | |
| async def close(self): | |
| """ | |
| Clean up resources, particularly the MCP client connection. | |
| """ | |
| if self.client: | |
| try: | |
| await self.client.__aexit__(None, None, None) | |
| logger.info("MCP client connection closed successfully.") | |
| except Exception as e: | |
| logger.error(f"Error closing MCP client: {str(e)}") | |
| finally: | |
| self.client = None | |
| self.mcp_tools = [] | |
| def __del__(self): | |
| """ | |
| Destructor to ensure cleanup when the object is garbage collected. | |
| """ | |
| if self.client: | |
| import asyncio | |
| try: | |
| # Try to close the client if there's an active event loop | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| loop.create_task(self.close()) | |
| else: | |
| asyncio.run(self.close()) | |
| except Exception: | |
| # If we can't close properly, at least log it | |
| logger.warning("Could not properly close MCP client in destructor") | |