Spaces:
Sleeping
Sleeping
| """ | |
| State transition helpers for agent workflows. | |
| Provides utility functions for safely updating workflow state, | |
| adding agent messages, and managing state transitions. | |
| """ | |
| from copy import deepcopy | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Literal, Optional, Union | |
| from graph.state.trading_state import ( | |
| AgentMessage, | |
| FundamentalWorkflowState, | |
| TechnicalWorkflowState, | |
| UnifiedWorkflowState, | |
| ) | |
| WorkflowState = Union[ | |
| TechnicalWorkflowState, FundamentalWorkflowState, UnifiedWorkflowState | |
| ] | |
| def add_agent_message( | |
| state: WorkflowState, | |
| agent_name: str, | |
| content: str, | |
| metadata: Optional[Dict[str, Any]] = None, | |
| ) -> WorkflowState: | |
| """ | |
| Add an agent message to the workflow state. | |
| Args: | |
| state: Current workflow state | |
| agent_name: Name of the agent sending the message | |
| content: Message content | |
| metadata: Optional metadata (e.g., confidence scores, data sources) | |
| Returns: | |
| Updated state with new message | |
| """ | |
| new_state = deepcopy(state) | |
| message = AgentMessage( | |
| agent_name=agent_name, | |
| content=content, | |
| timestamp=datetime.utcnow().isoformat(), | |
| metadata=metadata, | |
| ) | |
| new_state["messages"].append(message) | |
| return new_state | |
| def get_agent_messages( | |
| state: WorkflowState, | |
| agent_name: Optional[str] = None, | |
| ) -> List[AgentMessage]: | |
| """ | |
| Get messages from the workflow state. | |
| Args: | |
| state: Current workflow state | |
| agent_name: Optional filter by agent name | |
| Returns: | |
| List of messages (filtered if agent_name provided) | |
| """ | |
| messages = state.get("messages", []) | |
| if agent_name: | |
| return [msg for msg in messages if msg["agent_name"] == agent_name] | |
| return messages | |
| def update_analysis_result( | |
| state: WorkflowState, | |
| analysis_type: str, | |
| result: Dict[str, Any], | |
| ) -> WorkflowState: | |
| """ | |
| Update a specific analysis result in the state. | |
| Args: | |
| state: Current workflow state | |
| analysis_type: Type of analysis (indicators, patterns, trends, decision, etc.) | |
| result: Analysis result dictionary | |
| Returns: | |
| Updated state | |
| """ | |
| new_state = deepcopy(state) | |
| if analysis_type in new_state: | |
| # Update entire analysis section | |
| new_state[analysis_type] = result | |
| else: | |
| # Try to update nested field | |
| if isinstance(new_state.get(analysis_type), dict): | |
| new_state[analysis_type].update(result) | |
| return new_state | |
| def get_analysis_result( | |
| state: WorkflowState, | |
| analysis_type: str, | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get a specific analysis result from the state. | |
| Args: | |
| state: Current workflow state | |
| analysis_type: Type of analysis to retrieve | |
| Returns: | |
| Analysis result or None if not found | |
| """ | |
| return state.get(analysis_type) | |
| def set_workflow_status( | |
| state: WorkflowState, | |
| status: Literal["pending", "in_progress", "completed", "failed"], | |
| current_agent: Optional[str] = None, | |
| error: Optional[str] = None, | |
| ) -> WorkflowState: | |
| """ | |
| Update workflow status and current agent. | |
| Args: | |
| state: Current workflow state | |
| status: New workflow status | |
| current_agent: Currently active agent (if in_progress) | |
| error: Error message (if failed) | |
| Returns: | |
| Updated state | |
| """ | |
| new_state = deepcopy(state) | |
| new_state["workflow_status"] = status | |
| if current_agent is not None: | |
| new_state["current_agent"] = current_agent | |
| if error is not None: | |
| new_state["error"] = error | |
| return new_state | |
| def get_workflow_status(state: WorkflowState) -> Dict[str, Any]: | |
| """ | |
| Get workflow status information. | |
| Args: | |
| state: Current workflow state | |
| Returns: | |
| Dict with status, current_agent, and error | |
| """ | |
| return { | |
| "status": state.get("workflow_status"), | |
| "current_agent": state.get("current_agent"), | |
| "error": state.get("error"), | |
| } | |
| def merge_states( | |
| base_state: WorkflowState, | |
| update_state: Dict[str, Any], | |
| ) -> WorkflowState: | |
| """ | |
| Merge update dictionary into base state. | |
| Performs deep merge for nested dictionaries. | |
| Args: | |
| base_state: Base workflow state | |
| update_state: Update dictionary | |
| Returns: | |
| Merged state | |
| """ | |
| new_state = deepcopy(base_state) | |
| for key, value in update_state.items(): | |
| if ( | |
| key in new_state | |
| and isinstance(new_state[key], dict) | |
| and isinstance(value, dict) | |
| ): | |
| # Deep merge for nested dicts | |
| new_state[key] = _deep_merge_dicts(new_state[key], value) | |
| else: | |
| # Direct assignment | |
| new_state[key] = value | |
| return new_state | |
| def _deep_merge_dicts(base: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Recursively merge two dictionaries. | |
| Args: | |
| base: Base dictionary | |
| update: Update dictionary | |
| Returns: | |
| Merged dictionary | |
| """ | |
| merged = deepcopy(base) | |
| for key, value in update.items(): | |
| if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): | |
| merged[key] = _deep_merge_dicts(merged[key], value) | |
| else: | |
| merged[key] = value | |
| return merged | |
| def extract_latest_value(state: WorkflowState, field_path: str) -> Optional[Any]: | |
| """ | |
| Extract a value from nested state using dot notation. | |
| Example: extract_latest_value(state, "indicators.rsi.value") | |
| Args: | |
| state: Current workflow state | |
| field_path: Dot-separated path to field | |
| Returns: | |
| Field value or None if not found | |
| """ | |
| fields = field_path.split(".") | |
| current = state | |
| for field in fields: | |
| if isinstance(current, dict) and field in current: | |
| current = current[field] | |
| else: | |
| return None | |
| return current | |
| def validate_state_transition( | |
| current_agent: str, | |
| next_agent: str, | |
| workflow_type: Literal["technical", "fundamental", "unified"], | |
| ) -> bool: | |
| """ | |
| Validate that agent transition is allowed in workflow. | |
| Args: | |
| current_agent: Current agent name | |
| next_agent: Next agent name | |
| workflow_type: Type of workflow | |
| Returns: | |
| True if transition is valid | |
| """ | |
| workflows = { | |
| "technical": [ | |
| "indicator_agent", | |
| "pattern_agent", | |
| "trend_agent", | |
| "decision_agent", | |
| ], | |
| "fundamental": [ | |
| "valuation_agent", | |
| "news_agent", | |
| "risk_agent", | |
| "advisor_agent", | |
| ], | |
| "unified": [ | |
| "coordinator_agent", | |
| "indicator_agent", | |
| "pattern_agent", | |
| "trend_agent", | |
| "decision_agent", | |
| "valuation_agent", | |
| "news_agent", | |
| "risk_agent", | |
| "advisor_agent", | |
| ], | |
| } | |
| valid_agents = workflows.get(workflow_type, []) | |
| if current_agent not in valid_agents or next_agent not in valid_agents: | |
| return False | |
| # For technical and fundamental, enforce sequential order | |
| if workflow_type in ["technical", "fundamental"]: | |
| current_idx = valid_agents.index(current_agent) | |
| next_idx = valid_agents.index(next_agent) | |
| return next_idx == current_idx + 1 | |
| # For unified, coordinator can route to any agent | |
| if workflow_type == "unified": | |
| return current_agent == "coordinator_agent" or next_agent == "coordinator_agent" | |
| return True | |
| def get_next_agent( | |
| current_agent: Optional[str], | |
| workflow_type: Literal["technical", "fundamental", "unified"], | |
| ) -> Optional[str]: | |
| """ | |
| Get the next agent in the workflow sequence. | |
| Args: | |
| current_agent: Current agent name (None if starting) | |
| workflow_type: Type of workflow | |
| Returns: | |
| Next agent name or None if workflow complete | |
| """ | |
| workflows = { | |
| "technical": [ | |
| "indicator_agent", | |
| "pattern_agent", | |
| "trend_agent", | |
| "decision_agent", | |
| ], | |
| "fundamental": [ | |
| "valuation_agent", | |
| "news_agent", | |
| "risk_agent", | |
| "advisor_agent", | |
| ], | |
| } | |
| if workflow_type == "unified": | |
| # Unified workflow uses coordinator for routing | |
| return "coordinator_agent" if current_agent is None else None | |
| valid_agents = workflows.get(workflow_type, []) | |
| if current_agent is None: | |
| return valid_agents[0] if valid_agents else None | |
| if current_agent not in valid_agents: | |
| return None | |
| current_idx = valid_agents.index(current_agent) | |
| if current_idx + 1 < len(valid_agents): | |
| return valid_agents[current_idx + 1] | |
| return None | |
| def is_workflow_complete(state: WorkflowState) -> bool: | |
| """ | |
| Check if workflow has completed (successfully or with error). | |
| Args: | |
| state: Current workflow state | |
| Returns: | |
| True if workflow is complete or failed | |
| """ | |
| status = state.get("workflow_status") | |
| return status in ["completed", "failed"] | |
| def format_state_summary(state: WorkflowState) -> str: | |
| """ | |
| Format workflow state as human-readable summary. | |
| Args: | |
| state: Current workflow state | |
| Returns: | |
| Summary string | |
| """ | |
| lines = [] | |
| lines.append(f"Ticker: {state.get('ticker')}") | |
| if "timeframe" in state: | |
| lines.append(f"Timeframe: {state.get('timeframe')}") | |
| lines.append(f"Status: {state.get('workflow_status')}") | |
| lines.append(f"Current Agent: {state.get('current_agent') or 'None'}") | |
| if state.get("error"): | |
| lines.append(f"Error: {state['error']}") | |
| messages = state.get("messages", []) | |
| lines.append(f"Messages: {len(messages)}") | |
| return "\n".join(lines) | |