trading-tools / graph /state /agent_state.py
Deploy Bot
Deploy Trading Analysis Platform to HuggingFace Spaces
a1bf219
"""
State transition helpers for agent workflows.
Provides utility functions for safely updating workflow state,
adding agent messages, and managing state transitions.
"""
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Union
from graph.state.trading_state import (
AgentMessage,
FundamentalWorkflowState,
TechnicalWorkflowState,
UnifiedWorkflowState,
)
WorkflowState = Union[
TechnicalWorkflowState, FundamentalWorkflowState, UnifiedWorkflowState
]
def add_agent_message(
state: WorkflowState,
agent_name: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> WorkflowState:
"""
Add an agent message to the workflow state.
Args:
state: Current workflow state
agent_name: Name of the agent sending the message
content: Message content
metadata: Optional metadata (e.g., confidence scores, data sources)
Returns:
Updated state with new message
"""
new_state = deepcopy(state)
message = AgentMessage(
agent_name=agent_name,
content=content,
timestamp=datetime.utcnow().isoformat(),
metadata=metadata,
)
new_state["messages"].append(message)
return new_state
def get_agent_messages(
state: WorkflowState,
agent_name: Optional[str] = None,
) -> List[AgentMessage]:
"""
Get messages from the workflow state.
Args:
state: Current workflow state
agent_name: Optional filter by agent name
Returns:
List of messages (filtered if agent_name provided)
"""
messages = state.get("messages", [])
if agent_name:
return [msg for msg in messages if msg["agent_name"] == agent_name]
return messages
def update_analysis_result(
state: WorkflowState,
analysis_type: str,
result: Dict[str, Any],
) -> WorkflowState:
"""
Update a specific analysis result in the state.
Args:
state: Current workflow state
analysis_type: Type of analysis (indicators, patterns, trends, decision, etc.)
result: Analysis result dictionary
Returns:
Updated state
"""
new_state = deepcopy(state)
if analysis_type in new_state:
# Update entire analysis section
new_state[analysis_type] = result
else:
# Try to update nested field
if isinstance(new_state.get(analysis_type), dict):
new_state[analysis_type].update(result)
return new_state
def get_analysis_result(
state: WorkflowState,
analysis_type: str,
) -> Optional[Dict[str, Any]]:
"""
Get a specific analysis result from the state.
Args:
state: Current workflow state
analysis_type: Type of analysis to retrieve
Returns:
Analysis result or None if not found
"""
return state.get(analysis_type)
def set_workflow_status(
state: WorkflowState,
status: Literal["pending", "in_progress", "completed", "failed"],
current_agent: Optional[str] = None,
error: Optional[str] = None,
) -> WorkflowState:
"""
Update workflow status and current agent.
Args:
state: Current workflow state
status: New workflow status
current_agent: Currently active agent (if in_progress)
error: Error message (if failed)
Returns:
Updated state
"""
new_state = deepcopy(state)
new_state["workflow_status"] = status
if current_agent is not None:
new_state["current_agent"] = current_agent
if error is not None:
new_state["error"] = error
return new_state
def get_workflow_status(state: WorkflowState) -> Dict[str, Any]:
"""
Get workflow status information.
Args:
state: Current workflow state
Returns:
Dict with status, current_agent, and error
"""
return {
"status": state.get("workflow_status"),
"current_agent": state.get("current_agent"),
"error": state.get("error"),
}
def merge_states(
base_state: WorkflowState,
update_state: Dict[str, Any],
) -> WorkflowState:
"""
Merge update dictionary into base state.
Performs deep merge for nested dictionaries.
Args:
base_state: Base workflow state
update_state: Update dictionary
Returns:
Merged state
"""
new_state = deepcopy(base_state)
for key, value in update_state.items():
if (
key in new_state
and isinstance(new_state[key], dict)
and isinstance(value, dict)
):
# Deep merge for nested dicts
new_state[key] = _deep_merge_dicts(new_state[key], value)
else:
# Direct assignment
new_state[key] = value
return new_state
def _deep_merge_dicts(base: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively merge two dictionaries.
Args:
base: Base dictionary
update: Update dictionary
Returns:
Merged dictionary
"""
merged = deepcopy(base)
for key, value in update.items():
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
merged[key] = _deep_merge_dicts(merged[key], value)
else:
merged[key] = value
return merged
def extract_latest_value(state: WorkflowState, field_path: str) -> Optional[Any]:
"""
Extract a value from nested state using dot notation.
Example: extract_latest_value(state, "indicators.rsi.value")
Args:
state: Current workflow state
field_path: Dot-separated path to field
Returns:
Field value or None if not found
"""
fields = field_path.split(".")
current = state
for field in fields:
if isinstance(current, dict) and field in current:
current = current[field]
else:
return None
return current
def validate_state_transition(
current_agent: str,
next_agent: str,
workflow_type: Literal["technical", "fundamental", "unified"],
) -> bool:
"""
Validate that agent transition is allowed in workflow.
Args:
current_agent: Current agent name
next_agent: Next agent name
workflow_type: Type of workflow
Returns:
True if transition is valid
"""
workflows = {
"technical": [
"indicator_agent",
"pattern_agent",
"trend_agent",
"decision_agent",
],
"fundamental": [
"valuation_agent",
"news_agent",
"risk_agent",
"advisor_agent",
],
"unified": [
"coordinator_agent",
"indicator_agent",
"pattern_agent",
"trend_agent",
"decision_agent",
"valuation_agent",
"news_agent",
"risk_agent",
"advisor_agent",
],
}
valid_agents = workflows.get(workflow_type, [])
if current_agent not in valid_agents or next_agent not in valid_agents:
return False
# For technical and fundamental, enforce sequential order
if workflow_type in ["technical", "fundamental"]:
current_idx = valid_agents.index(current_agent)
next_idx = valid_agents.index(next_agent)
return next_idx == current_idx + 1
# For unified, coordinator can route to any agent
if workflow_type == "unified":
return current_agent == "coordinator_agent" or next_agent == "coordinator_agent"
return True
def get_next_agent(
current_agent: Optional[str],
workflow_type: Literal["technical", "fundamental", "unified"],
) -> Optional[str]:
"""
Get the next agent in the workflow sequence.
Args:
current_agent: Current agent name (None if starting)
workflow_type: Type of workflow
Returns:
Next agent name or None if workflow complete
"""
workflows = {
"technical": [
"indicator_agent",
"pattern_agent",
"trend_agent",
"decision_agent",
],
"fundamental": [
"valuation_agent",
"news_agent",
"risk_agent",
"advisor_agent",
],
}
if workflow_type == "unified":
# Unified workflow uses coordinator for routing
return "coordinator_agent" if current_agent is None else None
valid_agents = workflows.get(workflow_type, [])
if current_agent is None:
return valid_agents[0] if valid_agents else None
if current_agent not in valid_agents:
return None
current_idx = valid_agents.index(current_agent)
if current_idx + 1 < len(valid_agents):
return valid_agents[current_idx + 1]
return None
def is_workflow_complete(state: WorkflowState) -> bool:
"""
Check if workflow has completed (successfully or with error).
Args:
state: Current workflow state
Returns:
True if workflow is complete or failed
"""
status = state.get("workflow_status")
return status in ["completed", "failed"]
def format_state_summary(state: WorkflowState) -> str:
"""
Format workflow state as human-readable summary.
Args:
state: Current workflow state
Returns:
Summary string
"""
lines = []
lines.append(f"Ticker: {state.get('ticker')}")
if "timeframe" in state:
lines.append(f"Timeframe: {state.get('timeframe')}")
lines.append(f"Status: {state.get('workflow_status')}")
lines.append(f"Current Agent: {state.get('current_agent') or 'None'}")
if state.get("error"):
lines.append(f"Error: {state['error']}")
messages = state.get("messages", [])
lines.append(f"Messages: {len(messages)}")
return "\n".join(lines)