arthi.kasturirangan@informa.com
Initial Push
560d5c2
# Define agent state
from dataclasses import dataclass, field
from typing import Annotated, Dict, List, Optional, Sequence
from copilotkit import CopilotKitState # noqa: F401
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from langgraph.managed import IsLastStep, RemainingSteps
def merge_lists(a: list, b: list) -> list:
"""Merge two lists by extending the first with the second"""
return [*a, *b] if isinstance(a, list) and isinstance(b, list) else b
@dataclass
class InputState:
"""Defines the input state for the agent, representing a narrower interface to the outside world.
This class is used to define the initial state and structure of incoming data.
"""
messages: Annotated[Sequence[AnyMessage], add_messages] = field(default_factory=list)
"""
Messages tracking the primary execution state of the agent.
Typically accumulates a pattern of:
1. HumanMessage - user input
2. AIMessage with .tool_calls - agent picking tool(s) to use to collect information
3. ToolMessage(s) - the responses (or errors) from the executed tools
4. AIMessage without .tool_calls - agent responding in unstructured format to the user
5. HumanMessage - user responds with the next conversational turn
Steps 2-5 may repeat as needed.
The `add_messages` annotation ensures that new messages are merged with existing ones,
updating by ID to maintain an "append-only" state unless a message with the same ID is provided.
"""
@dataclass
class AgentState(InputState):
remaining_steps: RemainingSteps = 25
is_last_step: IsLastStep = field(default=False)
progress: Optional[str] = None
def items(self):
"""Make AgentState behave like a dictionary for CopilotKit compatibility.
This method returns key-value pairs for all attributes in the dataclass.
"""
return self.__dict__.items()
def __getitem__(self, key):
"""Support dictionary-like access."""
return getattr(self, key)
def get(self, key, default=None):
"""Provide dictionary-like get method with default support."""
return getattr(self, key, default)
@dataclass
class SQLAgentState(AgentState):
"""Extended state for SQL agent with query tracking."""
last_query: Optional[str] = None
query_attempts: int = 0
schema: Optional[Dict[str, List[str]]] = None