""" State transition helpers for agent workflows. Provides utility functions for safely updating workflow state, adding agent messages, and managing state transitions. """ from copy import deepcopy from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Union from graph.state.trading_state import ( AgentMessage, FundamentalWorkflowState, TechnicalWorkflowState, UnifiedWorkflowState, ) WorkflowState = Union[ TechnicalWorkflowState, FundamentalWorkflowState, UnifiedWorkflowState ] def add_agent_message( state: WorkflowState, agent_name: str, content: str, metadata: Optional[Dict[str, Any]] = None, ) -> WorkflowState: """ Add an agent message to the workflow state. Args: state: Current workflow state agent_name: Name of the agent sending the message content: Message content metadata: Optional metadata (e.g., confidence scores, data sources) Returns: Updated state with new message """ new_state = deepcopy(state) message = AgentMessage( agent_name=agent_name, content=content, timestamp=datetime.utcnow().isoformat(), metadata=metadata, ) new_state["messages"].append(message) return new_state def get_agent_messages( state: WorkflowState, agent_name: Optional[str] = None, ) -> List[AgentMessage]: """ Get messages from the workflow state. Args: state: Current workflow state agent_name: Optional filter by agent name Returns: List of messages (filtered if agent_name provided) """ messages = state.get("messages", []) if agent_name: return [msg for msg in messages if msg["agent_name"] == agent_name] return messages def update_analysis_result( state: WorkflowState, analysis_type: str, result: Dict[str, Any], ) -> WorkflowState: """ Update a specific analysis result in the state. Args: state: Current workflow state analysis_type: Type of analysis (indicators, patterns, trends, decision, etc.) result: Analysis result dictionary Returns: Updated state """ new_state = deepcopy(state) if analysis_type in new_state: # Update entire analysis section new_state[analysis_type] = result else: # Try to update nested field if isinstance(new_state.get(analysis_type), dict): new_state[analysis_type].update(result) return new_state def get_analysis_result( state: WorkflowState, analysis_type: str, ) -> Optional[Dict[str, Any]]: """ Get a specific analysis result from the state. Args: state: Current workflow state analysis_type: Type of analysis to retrieve Returns: Analysis result or None if not found """ return state.get(analysis_type) def set_workflow_status( state: WorkflowState, status: Literal["pending", "in_progress", "completed", "failed"], current_agent: Optional[str] = None, error: Optional[str] = None, ) -> WorkflowState: """ Update workflow status and current agent. Args: state: Current workflow state status: New workflow status current_agent: Currently active agent (if in_progress) error: Error message (if failed) Returns: Updated state """ new_state = deepcopy(state) new_state["workflow_status"] = status if current_agent is not None: new_state["current_agent"] = current_agent if error is not None: new_state["error"] = error return new_state def get_workflow_status(state: WorkflowState) -> Dict[str, Any]: """ Get workflow status information. Args: state: Current workflow state Returns: Dict with status, current_agent, and error """ return { "status": state.get("workflow_status"), "current_agent": state.get("current_agent"), "error": state.get("error"), } def merge_states( base_state: WorkflowState, update_state: Dict[str, Any], ) -> WorkflowState: """ Merge update dictionary into base state. Performs deep merge for nested dictionaries. Args: base_state: Base workflow state update_state: Update dictionary Returns: Merged state """ new_state = deepcopy(base_state) for key, value in update_state.items(): if ( key in new_state and isinstance(new_state[key], dict) and isinstance(value, dict) ): # Deep merge for nested dicts new_state[key] = _deep_merge_dicts(new_state[key], value) else: # Direct assignment new_state[key] = value return new_state def _deep_merge_dicts(base: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]: """ Recursively merge two dictionaries. Args: base: Base dictionary update: Update dictionary Returns: Merged dictionary """ merged = deepcopy(base) for key, value in update.items(): if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): merged[key] = _deep_merge_dicts(merged[key], value) else: merged[key] = value return merged def extract_latest_value(state: WorkflowState, field_path: str) -> Optional[Any]: """ Extract a value from nested state using dot notation. Example: extract_latest_value(state, "indicators.rsi.value") Args: state: Current workflow state field_path: Dot-separated path to field Returns: Field value or None if not found """ fields = field_path.split(".") current = state for field in fields: if isinstance(current, dict) and field in current: current = current[field] else: return None return current def validate_state_transition( current_agent: str, next_agent: str, workflow_type: Literal["technical", "fundamental", "unified"], ) -> bool: """ Validate that agent transition is allowed in workflow. Args: current_agent: Current agent name next_agent: Next agent name workflow_type: Type of workflow Returns: True if transition is valid """ workflows = { "technical": [ "indicator_agent", "pattern_agent", "trend_agent", "decision_agent", ], "fundamental": [ "valuation_agent", "news_agent", "risk_agent", "advisor_agent", ], "unified": [ "coordinator_agent", "indicator_agent", "pattern_agent", "trend_agent", "decision_agent", "valuation_agent", "news_agent", "risk_agent", "advisor_agent", ], } valid_agents = workflows.get(workflow_type, []) if current_agent not in valid_agents or next_agent not in valid_agents: return False # For technical and fundamental, enforce sequential order if workflow_type in ["technical", "fundamental"]: current_idx = valid_agents.index(current_agent) next_idx = valid_agents.index(next_agent) return next_idx == current_idx + 1 # For unified, coordinator can route to any agent if workflow_type == "unified": return current_agent == "coordinator_agent" or next_agent == "coordinator_agent" return True def get_next_agent( current_agent: Optional[str], workflow_type: Literal["technical", "fundamental", "unified"], ) -> Optional[str]: """ Get the next agent in the workflow sequence. Args: current_agent: Current agent name (None if starting) workflow_type: Type of workflow Returns: Next agent name or None if workflow complete """ workflows = { "technical": [ "indicator_agent", "pattern_agent", "trend_agent", "decision_agent", ], "fundamental": [ "valuation_agent", "news_agent", "risk_agent", "advisor_agent", ], } if workflow_type == "unified": # Unified workflow uses coordinator for routing return "coordinator_agent" if current_agent is None else None valid_agents = workflows.get(workflow_type, []) if current_agent is None: return valid_agents[0] if valid_agents else None if current_agent not in valid_agents: return None current_idx = valid_agents.index(current_agent) if current_idx + 1 < len(valid_agents): return valid_agents[current_idx + 1] return None def is_workflow_complete(state: WorkflowState) -> bool: """ Check if workflow has completed (successfully or with error). Args: state: Current workflow state Returns: True if workflow is complete or failed """ status = state.get("workflow_status") return status in ["completed", "failed"] def format_state_summary(state: WorkflowState) -> str: """ Format workflow state as human-readable summary. Args: state: Current workflow state Returns: Summary string """ lines = [] lines.append(f"Ticker: {state.get('ticker')}") if "timeframe" in state: lines.append(f"Timeframe: {state.get('timeframe')}") lines.append(f"Status: {state.get('workflow_status')}") lines.append(f"Current Agent: {state.get('current_agent') or 'None'}") if state.get("error"): lines.append(f"Error: {state['error']}") messages = state.get("messages", []) lines.append(f"Messages: {len(messages)}") return "\n".join(lines)