Spaces:
Sleeping
Sleeping
| # Define agent state | |
| from dataclasses import dataclass, field | |
| from typing import Annotated, Dict, List, Optional, Sequence | |
| from copilotkit import CopilotKitState # noqa: F401 | |
| from langchain_core.messages import AnyMessage | |
| from langgraph.graph import add_messages | |
| from langgraph.managed import IsLastStep, RemainingSteps | |
| def merge_lists(a: list, b: list) -> list: | |
| """Merge two lists by extending the first with the second""" | |
| return [*a, *b] if isinstance(a, list) and isinstance(b, list) else b | |
| class InputState: | |
| """Defines the input state for the agent, representing a narrower interface to the outside world. | |
| This class is used to define the initial state and structure of incoming data. | |
| """ | |
| messages: Annotated[Sequence[AnyMessage], add_messages] = field(default_factory=list) | |
| """ | |
| Messages tracking the primary execution state of the agent. | |
| Typically accumulates a pattern of: | |
| 1. HumanMessage - user input | |
| 2. AIMessage with .tool_calls - agent picking tool(s) to use to collect information | |
| 3. ToolMessage(s) - the responses (or errors) from the executed tools | |
| 4. AIMessage without .tool_calls - agent responding in unstructured format to the user | |
| 5. HumanMessage - user responds with the next conversational turn | |
| Steps 2-5 may repeat as needed. | |
| The `add_messages` annotation ensures that new messages are merged with existing ones, | |
| updating by ID to maintain an "append-only" state unless a message with the same ID is provided. | |
| """ | |
| class AgentState(InputState): | |
| remaining_steps: RemainingSteps = 25 | |
| is_last_step: IsLastStep = field(default=False) | |
| progress: Optional[str] = None | |
| def items(self): | |
| """Make AgentState behave like a dictionary for CopilotKit compatibility. | |
| This method returns key-value pairs for all attributes in the dataclass. | |
| """ | |
| return self.__dict__.items() | |
| def __getitem__(self, key): | |
| """Support dictionary-like access.""" | |
| return getattr(self, key) | |
| def get(self, key, default=None): | |
| """Provide dictionary-like get method with default support.""" | |
| return getattr(self, key, default) | |
| class SQLAgentState(AgentState): | |
| """Extended state for SQL agent with query tracking.""" | |
| last_query: Optional[str] = None | |
| query_attempts: int = 0 | |
| schema: Optional[Dict[str, List[str]]] = None | |