Spaces:
Sleeping
Sleeping
| """ | |
| Technical Analysis Workflow using LangGraph. | |
| This workflow orchestrates the technical analysis team: | |
| 1. Fetch market data | |
| 2. Indicator Agent -> calculate technical indicators | |
| 3. Pattern Agent -> identify patterns (parallel with Trend Agent) | |
| 4. Trend Agent -> analyze trends (parallel with Pattern Agent) | |
| 5. Decision Agent -> make trading decision | |
| 6. Generate charts with annotations | |
| """ | |
| import json | |
| import logging | |
| import time | |
| import traceback | |
| from datetime import datetime, timedelta | |
| from typing import Any, Dict, Literal, Optional | |
| import pandas as pd | |
| from langchain_core.runnables import RunnableConfig | |
| from langgraph.graph import END, StateGraph | |
| # Configure logger | |
| logger = logging.getLogger(__name__) | |
| from agents.technical import DecisionAgent, IndicatorAgent, PatternAgent, TrendAgent | |
| from config.default_config import DEFAULT_CONFIG | |
| from data.providers.provider_factory import ProviderFactory | |
| from data.schemas.market_data import check_minimum_data, validate_ohlc | |
| from graph.state.agent_state import add_agent_message, set_workflow_status | |
| from graph.state.trading_state import ( | |
| TechnicalWorkflowState, | |
| create_initial_technical_state, | |
| ) | |
| from utils.charts.annotations import ChartAnnotations | |
| from utils.charts.chart_generator import ChartGenerator | |
| class TechnicalWorkflow: | |
| """ | |
| Technical analysis workflow orchestrator. | |
| This workflow implements User Story 1 (P1 MVP): | |
| - Fetch OHLC data from data providers | |
| - Run technical analysis through 4-agent pipeline | |
| - Generate annotated candlestick charts | |
| - Produce trading recommendation | |
| """ | |
| def __init__(self, config: Optional[Dict[str, Any]] = None): | |
| """ | |
| Initialize technical workflow. | |
| Args: | |
| config: Optional configuration override | |
| """ | |
| self.config = config or DEFAULT_CONFIG | |
| # Initialize agents | |
| self.indicator_agent = IndicatorAgent(config=self.config) | |
| self.pattern_agent = PatternAgent(config=self.config) | |
| self.trend_agent = TrendAgent(config=self.config) | |
| self.decision_agent = DecisionAgent(config=self.config) | |
| # Initialize providers and chart generator | |
| self.provider_factory = ProviderFactory(config=self.config) | |
| self.chart_generator = ChartGenerator() | |
| # Build workflow graph | |
| self.graph = self._build_graph() | |
| def _build_graph(self) -> StateGraph: | |
| """ | |
| Build LangGraph StateGraph for technical workflow. | |
| Returns: | |
| Compiled StateGraph | |
| """ | |
| # Create graph with TechnicalWorkflowState | |
| workflow = StateGraph(TechnicalWorkflowState) | |
| # Add nodes | |
| workflow.add_node("fetch_market_data", self._fetch_market_data_node) | |
| workflow.add_node("indicator_agent", self._indicator_agent_node) | |
| workflow.add_node("pattern_agent", self._pattern_agent_node) | |
| workflow.add_node("trend_agent", self._trend_agent_node) | |
| workflow.add_node("decision_agent", self._decision_agent_node) | |
| workflow.add_node("generate_charts", self._generate_charts_node) | |
| # Define edges (workflow flow) | |
| workflow.set_entry_point("fetch_market_data") | |
| # Sequential workflow to avoid concurrent state updates | |
| workflow.add_edge("fetch_market_data", "indicator_agent") | |
| workflow.add_edge("indicator_agent", "pattern_agent") | |
| workflow.add_edge("pattern_agent", "trend_agent") | |
| workflow.add_edge("trend_agent", "decision_agent") | |
| workflow.add_edge("decision_agent", "generate_charts") | |
| workflow.add_edge("generate_charts", END) | |
| return workflow.compile() | |
| def run( | |
| self, | |
| ticker: str, | |
| timeframe: str = "1d", | |
| start_date: Optional[str] = None, | |
| end_date: Optional[str] = None, | |
| user_query: Optional[str] = None, | |
| ) -> TechnicalWorkflowState: | |
| """ | |
| Run technical analysis workflow for a single timeframe. | |
| Args: | |
| ticker: Stock ticker symbol | |
| timeframe: Analysis timeframe (1m, 5m, 1h, 1d, etc.) | |
| start_date: Start date (YYYY-MM-DD), default: 30 days ago | |
| end_date: End date (YYYY-MM-DD), default: today | |
| user_query: Optional user question/query | |
| Returns: | |
| Final workflow state with analysis results | |
| """ | |
| # Set default dates | |
| if end_date is None: | |
| end_date = datetime.now().strftime("%Y-%m-%d") | |
| if start_date is None: | |
| # Default to 30 days of history (more for daily, less for intraday) | |
| days_back = 90 if timeframe in ["1d", "1w"] else 30 | |
| start_date = (datetime.now() - timedelta(days=days_back)).strftime( | |
| "%Y-%m-%d" | |
| ) | |
| # Create initial state | |
| initial_state = create_initial_technical_state(ticker, timeframe, user_query) | |
| initial_state["market_data"]["start_date"] = start_date | |
| initial_state["market_data"]["end_date"] = end_date | |
| initial_state = set_workflow_status(initial_state, "in_progress") | |
| # Run workflow | |
| try: | |
| final_state = self.graph.invoke(initial_state) | |
| # Mark as completed | |
| final_state = set_workflow_status(final_state, "completed") | |
| return final_state | |
| except Exception as e: | |
| # Log error with full traceback | |
| logger.error( | |
| json.dumps( | |
| { | |
| "workflow": "technical_workflow", | |
| "action": "error", | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| "timestamp": time.time(), | |
| } | |
| ) | |
| ) | |
| # Mark as failed | |
| error_state = set_workflow_status( | |
| initial_state, | |
| "failed", | |
| error=str(e), | |
| ) | |
| return error_state | |
| def _fetch_market_data_node( | |
| self, state: TechnicalWorkflowState | |
| ) -> TechnicalWorkflowState: | |
| """ | |
| Fetch market data from providers. | |
| Args: | |
| state: Current workflow state | |
| Returns: | |
| Updated state with market data | |
| """ | |
| try: | |
| ticker = state["ticker"] | |
| timeframe = state["timeframe"] | |
| start_date = state["market_data"]["start_date"] | |
| end_date = state["market_data"]["end_date"] | |
| # Get provider for OHLC data | |
| provider = self.provider_factory.get_provider("ohlc") | |
| # Fetch data | |
| df = provider.fetch_ohlc( | |
| ticker=ticker, | |
| timeframe=timeframe, | |
| start_date=start_date, | |
| end_date=end_date, | |
| ) | |
| # Validate data | |
| df = validate_ohlc(df) | |
| check_minimum_data(df, min_bars=30, timeframe=timeframe) | |
| # Calculate data quality score | |
| quality_score = self._calculate_data_quality(df) | |
| # Serialize DataFrame for state storage | |
| serialized_df = self._serialize_dataframe(df) | |
| # Update state | |
| new_state = state.copy() | |
| new_state["market_data"]["ohlc_data"] = serialized_df | |
| new_state["market_data"]["data_quality_score"] = quality_score | |
| new_state = add_agent_message( | |
| new_state, | |
| "data_fetcher", | |
| f"Successfully fetched {len(df)} bars of {timeframe} data for {ticker} (quality: {quality_score:.2f})", | |
| metadata={"bars": len(df), "quality": quality_score}, | |
| ) | |
| return new_state | |
| except Exception as e: | |
| # Log error with full traceback | |
| logger.error( | |
| json.dumps( | |
| { | |
| "node": "fetch_market_data", | |
| "action": "error", | |
| "ticker": state.get("ticker"), | |
| "timeframe": state.get("timeframe"), | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| "timestamp": time.time(), | |
| } | |
| ) | |
| ) | |
| error_state = set_workflow_status( | |
| state, | |
| "failed", | |
| error=f"Failed to fetch market data: {str(e)}", | |
| ) | |
| return error_state | |
| def _indicator_agent_node( | |
| self, state: TechnicalWorkflowState | |
| ) -> TechnicalWorkflowState: | |
| """ | |
| Run Indicator Agent. | |
| Args: | |
| state: Current workflow state | |
| Returns: | |
| Updated state with indicator analysis | |
| """ | |
| try: | |
| new_state = set_workflow_status( | |
| state, "in_progress", current_agent="indicator_agent" | |
| ) | |
| result = self.indicator_agent.run(new_state) | |
| return result | |
| except Exception as e: | |
| # Log error with full traceback | |
| logger.error( | |
| json.dumps( | |
| { | |
| "node": "indicator_agent", | |
| "action": "error", | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| "timestamp": time.time(), | |
| } | |
| ) | |
| ) | |
| return set_workflow_status( | |
| state, | |
| "failed", | |
| error=f"Indicator agent failed: {str(e)}", | |
| ) | |
| def _pattern_agent_node( | |
| self, state: TechnicalWorkflowState | |
| ) -> TechnicalWorkflowState: | |
| """ | |
| Run Pattern Agent. | |
| Args: | |
| state: Current workflow state | |
| Returns: | |
| Updated state with pattern analysis | |
| """ | |
| try: | |
| new_state = set_workflow_status( | |
| state, "in_progress", current_agent="pattern_agent" | |
| ) | |
| result = self.pattern_agent.run(new_state) | |
| return result | |
| except Exception as e: | |
| # Log error with full traceback | |
| logger.error( | |
| json.dumps( | |
| { | |
| "node": "pattern_agent", | |
| "action": "error", | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| "timestamp": time.time(), | |
| } | |
| ) | |
| ) | |
| return set_workflow_status( | |
| state, | |
| "failed", | |
| error=f"Pattern agent failed: {str(e)}", | |
| ) | |
| def _trend_agent_node( | |
| self, state: TechnicalWorkflowState | |
| ) -> TechnicalWorkflowState: | |
| """ | |
| Run Trend Agent. | |
| Args: | |
| state: Current workflow state | |
| Returns: | |
| Updated state with trend analysis | |
| """ | |
| try: | |
| new_state = set_workflow_status( | |
| state, "in_progress", current_agent="trend_agent" | |
| ) | |
| result = self.trend_agent.run(new_state) | |
| return result | |
| except Exception as e: | |
| # Log error with full traceback | |
| logger.error( | |
| json.dumps( | |
| { | |
| "node": "trend_agent", | |
| "action": "error", | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| "timestamp": time.time(), | |
| } | |
| ) | |
| ) | |
| return set_workflow_status( | |
| state, | |
| "failed", | |
| error=f"Trend agent failed: {str(e)}", | |
| ) | |
| def _decision_agent_node( | |
| self, state: TechnicalWorkflowState | |
| ) -> TechnicalWorkflowState: | |
| """ | |
| Run Decision Agent. | |
| Args: | |
| state: Current workflow state | |
| Returns: | |
| Updated state with trading decision | |
| """ | |
| try: | |
| new_state = set_workflow_status( | |
| state, "in_progress", current_agent="decision_agent" | |
| ) | |
| result = self.decision_agent.run(new_state) | |
| return result | |
| except Exception as e: | |
| # Log error with full traceback | |
| logger.error( | |
| json.dumps( | |
| { | |
| "node": "decision_agent", | |
| "action": "error", | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| "timestamp": time.time(), | |
| } | |
| ) | |
| ) | |
| return set_workflow_status( | |
| state, | |
| "failed", | |
| error=f"Decision agent failed: {str(e)}", | |
| ) | |
| def _generate_charts_node( | |
| self, state: TechnicalWorkflowState | |
| ) -> TechnicalWorkflowState: | |
| """ | |
| Generate annotated candlestick charts. | |
| Args: | |
| state: Current workflow state | |
| Returns: | |
| Updated state with chart path | |
| """ | |
| try: | |
| # Deserialize DataFrame | |
| df = self._deserialize_dataframe(state["market_data"]["ohlc_data"]) | |
| # Skip chart generation if market data is empty | |
| if df.empty: | |
| logger.warning( | |
| json.dumps( | |
| { | |
| "node": "generate_charts", | |
| "action": "skipped", | |
| "ticker": state["ticker"], | |
| "reason": "market_data is empty", | |
| "timestamp": time.time(), | |
| } | |
| ) | |
| ) | |
| error_state = add_agent_message( | |
| state, | |
| "chart_generator", | |
| "Skipped chart generation: no market data available", | |
| metadata={"skipped": True}, | |
| ) | |
| return error_state | |
| # Prepare indicator overlays | |
| indicators_to_plot = [] | |
| if state.get("indicators", {}).get("rsi", {}).get("series"): | |
| rsi_series = pd.Series(state["indicators"]["rsi"]["series"]) | |
| indicators_to_plot.append( | |
| { | |
| "data": rsi_series, | |
| "panel": 1, # Separate panel | |
| "ylabel": "RSI", | |
| "color": "purple", | |
| } | |
| ) | |
| # Generate chart | |
| fig, chart_path = self.chart_generator.generate_candlestick_chart( | |
| df=df, | |
| ticker=state["ticker"], | |
| timeframe=state["timeframe"], | |
| volume=True, | |
| save=True, | |
| indicators=indicators_to_plot if indicators_to_plot else None, | |
| ) | |
| # Close figure to free memory | |
| self.chart_generator.close_figure(fig) | |
| # Update state | |
| new_state = state.copy() | |
| new_state["chart_path"] = chart_path | |
| new_state = add_agent_message( | |
| new_state, | |
| "chart_generator", | |
| f"Generated chart: {chart_path}", | |
| metadata={"chart_path": chart_path}, | |
| ) | |
| return new_state | |
| except Exception as e: | |
| # Log warning with full traceback (non-fatal error) | |
| logger.warning( | |
| json.dumps( | |
| { | |
| "node": "generate_charts", | |
| "action": "error_non_fatal", | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| "timestamp": time.time(), | |
| } | |
| ) | |
| ) | |
| # Chart generation failure shouldn't fail the whole workflow | |
| error_state = add_agent_message( | |
| state, | |
| "chart_generator", | |
| f"Failed to generate chart: {str(e)}", | |
| metadata={"error": True}, | |
| ) | |
| return error_state | |
| def _calculate_data_quality(self, df: pd.DataFrame) -> float: | |
| """ | |
| Calculate data quality score (0-1). | |
| Args: | |
| df: OHLC DataFrame | |
| Returns: | |
| Quality score | |
| """ | |
| score = 1.0 | |
| # Penalize for missing data | |
| missing_pct = df.isnull().sum().sum() / (len(df) * len(df.columns)) | |
| score -= missing_pct * 0.5 | |
| # Penalize for insufficient data | |
| if len(df) < 50: | |
| score -= 0.2 | |
| # Penalize for data inconsistencies (OHLC violations) | |
| violations = 0 | |
| for i in range(len(df)): | |
| if ( | |
| df["low"].iloc[i] > df["open"].iloc[i] | |
| or df["low"].iloc[i] > df["close"].iloc[i] | |
| ): | |
| violations += 1 | |
| if ( | |
| df["high"].iloc[i] < df["open"].iloc[i] | |
| or df["high"].iloc[i] < df["close"].iloc[i] | |
| ): | |
| violations += 1 | |
| if violations > 0: | |
| score -= (violations / len(df)) * 0.3 | |
| return max(0.0, min(1.0, score)) | |
| def _serialize_dataframe(self, df: pd.DataFrame) -> Dict[str, Any]: | |
| """ | |
| Serialize DataFrame for state storage. | |
| Args: | |
| df: pandas DataFrame | |
| Returns: | |
| Serialized dict | |
| """ | |
| # Reset index to include date as column | |
| df_reset = df.reset_index() | |
| # Convert to dict with orient='list' for efficient storage | |
| return df_reset.to_dict(orient="list") | |
| def _deserialize_dataframe(self, data: Dict[str, Any]) -> pd.DataFrame: | |
| """ | |
| Deserialize DataFrame from state. | |
| Args: | |
| data: Serialized data | |
| Returns: | |
| pandas DataFrame | |
| """ | |
| df = pd.DataFrame(data) | |
| # Convert timestamp column to datetime if present | |
| if "timestamp" in df.columns: | |
| df["timestamp"] = pd.to_datetime(df["timestamp"]) | |
| elif "date" in df.columns: | |
| df["date"] = pd.to_datetime(df["date"]) | |
| df = df.rename(columns={"date": "timestamp"}) | |
| elif "datetime" in df.columns: | |
| df["datetime"] = pd.to_datetime(df["datetime"]) | |
| df = df.rename(columns={"datetime": "timestamp"}) | |
| return df | |