| """ |
| Decision Engine Module |
| ====================== |
| Combines prediction model, risk model, and market regime |
| to produce final trading decisions. |
| |
| Output: Buy / Sell / Hold with confidence and risk-adjusted sizing. |
| """ |
|
|
| import torch |
| import numpy as np |
| from typing import Dict, List, Optional |
| from dataclasses import dataclass, field |
| from enum import Enum |
|
|
|
|
| class Signal(Enum): |
| STRONG_BUY = "STRONG_BUY" |
| BUY = "BUY" |
| HOLD = "HOLD" |
| SELL = "SELL" |
| STRONG_SELL = "STRONG_SELL" |
|
|
|
|
| @dataclass |
| class TradingDecision: |
| """Complete trading decision with all context.""" |
| signal: Signal |
| confidence: float |
| direction_prob: float |
| expected_return: float |
| risk_score: float |
| position_size_pct: float |
| stop_loss_pct: float |
| take_profit_pct: float |
| drawdown_risk: float |
| market_regime: str |
| horizon: str |
| reasoning: List[str] = field(default_factory=list) |
| alerts: List[Dict] = field(default_factory=list) |
|
|
|
|
| class DecisionEngine: |
| """ |
| Final decision layer that combines all model outputs. |
| |
| Decision logic: |
| 1. Get prediction from TradingTransformer (direction, return, uncertainty) |
| 2. Get risk assessment from RiskModel (risk score, sizing, levels) |
| 3. Check market regime (trending vs mean-reverting vs high-vol) |
| 4. Apply personalization (adapt to trader profile) |
| 5. Generate final signal with confidence |
| """ |
| |
| def __init__( |
| self, |
| prediction_model=None, |
| risk_model=None, |
| personalization_engine=None, |
| confidence_threshold: float = 0.6, |
| strong_signal_threshold: float = 0.8, |
| ): |
| self.prediction_model = prediction_model |
| self.risk_model = risk_model |
| self.personalization_engine = personalization_engine |
| self.confidence_threshold = confidence_threshold |
| self.strong_signal_threshold = strong_signal_threshold |
| |
| |
| self.horizon_labels = ['short_term', 'mid_term', 'long_term'] |
| |
| def make_decision( |
| self, |
| market_features: np.ndarray, |
| portfolio_state: Optional[Dict] = None, |
| trader_profile: Optional[Dict] = None, |
| behavior_alerts: Optional[Dict] = None, |
| current_atr: float = 0.01, |
| horizon_idx: int = 0, |
| ) -> TradingDecision: |
| """ |
| Generate a complete trading decision. |
| |
| Args: |
| market_features: (1, num_features, seq_len) normalized features |
| portfolio_state: Current portfolio information |
| trader_profile: Trader's behavior profile |
| behavior_alerts: Current behavior alerts |
| current_atr: Current ATR for stop/take-profit calculation |
| horizon_idx: Which prediction horizon to use (0=short, 1=mid, 2=long) |
| |
| Returns: |
| TradingDecision with full context |
| """ |
| reasoning = [] |
| |
| |
| prediction = self._get_prediction(market_features, horizon_idx) |
| direction_prob = prediction['direction_prob'] |
| expected_return = prediction['expected_return'] |
| model_confidence = prediction['confidence'] |
| |
| reasoning.append(f"Direction probability: {direction_prob:.1%} up") |
| reasoning.append(f"Expected return: {expected_return:.2%}") |
| reasoning.append(f"Model confidence: {model_confidence:.1%}") |
| |
| |
| regime = self._detect_regime(market_features) |
| reasoning.append(f"Market regime: {regime}") |
| |
| |
| risk_score = 0.5 |
| position_size = 0.03 |
| sl_mult = 2.0 |
| tp_mult = 3.0 |
| drawdown_risk = 0.1 |
| |
| if self.risk_model is not None and portfolio_state is not None: |
| risk_output = self._get_risk_assessment(market_features, portfolio_state) |
| risk_score = risk_output.get('risk_score', 0.5) |
| position_size = risk_output.get('adjusted_position_size', 0.03) |
| sl_mult = risk_output.get('stop_loss_atr_mult', 2.0) |
| tp_mult = risk_output.get('take_profit_atr_mult', 3.0) |
| drawdown_risk = risk_output.get('drawdown_risk', 0.1) |
| |
| reasoning.append(f"Risk score: {risk_score:.2f}") |
| |
| |
| if self.personalization_engine and trader_profile and behavior_alerts: |
| personal_params = self.personalization_engine.get_personalized_params( |
| trader_profile, behavior_alerts |
| ) |
| |
| position_size = min(position_size, personal_params.get('max_position_pct', 0.05)) |
| |
| |
| min_conf = personal_params.get('min_confidence', 0.6) |
| |
| |
| sl_mult = personal_params.get('sl_atr_mult', sl_mult) |
| tp_mult = personal_params.get('tp_atr_mult', tp_mult) |
| |
| reasoning.append(f"Personalized min confidence: {min_conf:.1%}") |
| else: |
| min_conf = self.confidence_threshold |
| |
| |
| combined_confidence = model_confidence * (1 - 0.3 * risk_score) |
| |
| |
| if regime == 'high_volatility': |
| combined_confidence *= 0.8 |
| position_size *= 0.7 |
| reasoning.append("High volatility: reduced confidence and position size") |
| elif regime == 'trending': |
| combined_confidence *= 1.1 |
| reasoning.append("Trending market: slight confidence boost") |
| |
| combined_confidence = np.clip(combined_confidence, 0, 1) |
| |
| |
| if combined_confidence < min_conf: |
| signal = Signal.HOLD |
| reasoning.append(f"Confidence {combined_confidence:.1%} below threshold {min_conf:.1%} → HOLD") |
| elif direction_prob > 0.5: |
| if combined_confidence >= self.strong_signal_threshold and expected_return > 0.005: |
| signal = Signal.STRONG_BUY |
| else: |
| signal = Signal.BUY |
| reasoning.append(f"Bullish signal: {direction_prob:.1%} up probability") |
| else: |
| if combined_confidence >= self.strong_signal_threshold and expected_return < -0.005: |
| signal = Signal.STRONG_SELL |
| else: |
| signal = Signal.SELL |
| reasoning.append(f"Bearish signal: {1-direction_prob:.1%} down probability") |
| |
| |
| alerts = behavior_alerts.get('alerts', []) if behavior_alerts else [] |
| if alerts: |
| for alert in alerts: |
| if alert.get('severity') == 'CRITICAL': |
| signal = Signal.HOLD |
| reasoning.append(f"CRITICAL ALERT: {alert['type']} - Overriding to HOLD") |
| |
| |
| stop_loss_pct = sl_mult * current_atr |
| take_profit_pct = tp_mult * current_atr |
| |
| return TradingDecision( |
| signal=signal, |
| confidence=float(combined_confidence), |
| direction_prob=float(direction_prob), |
| expected_return=float(expected_return), |
| risk_score=float(risk_score), |
| position_size_pct=float(position_size), |
| stop_loss_pct=float(stop_loss_pct), |
| take_profit_pct=float(take_profit_pct), |
| drawdown_risk=float(drawdown_risk), |
| market_regime=regime, |
| horizon=self.horizon_labels[min(horizon_idx, len(self.horizon_labels)-1)], |
| reasoning=reasoning, |
| alerts=alerts, |
| ) |
| |
| def _get_prediction(self, features: np.ndarray, horizon_idx: int) -> Dict: |
| """Get prediction from model or return defaults.""" |
| if self.prediction_model is not None: |
| x = torch.FloatTensor(features) |
| if x.dim() == 2: |
| x = x.unsqueeze(0) |
| result = self.prediction_model.predict_with_confidence(x) |
| return { |
| 'direction_prob': float(result['direction_probs'][0, horizon_idx]), |
| 'expected_return': float(result['expected_returns'][0, horizon_idx]), |
| 'confidence': float(result['confidence'][0, horizon_idx]), |
| } |
| |
| |
| return { |
| 'direction_prob': 0.55, |
| 'expected_return': 0.002, |
| 'confidence': 0.65, |
| } |
| |
| def _detect_regime(self, features: np.ndarray) -> str: |
| """Simple regime detection from features.""" |
| |
| return 'normal' |
| |
| def _get_risk_assessment(self, features: np.ndarray, portfolio: Dict) -> Dict: |
| """Get risk assessment from risk model.""" |
| return { |
| 'risk_score': 0.5, |
| 'adjusted_position_size': 0.03, |
| 'stop_loss_atr_mult': 2.0, |
| 'take_profit_atr_mult': 3.0, |
| 'drawdown_risk': 0.1, |
| } |
| |
| def make_multi_horizon_decisions( |
| self, |
| market_features: np.ndarray, |
| portfolio_state: Optional[Dict] = None, |
| trader_profile: Optional[Dict] = None, |
| behavior_alerts: Optional[Dict] = None, |
| current_atr: float = 0.01, |
| ) -> List[TradingDecision]: |
| """Generate decisions for all horizons simultaneously.""" |
| decisions = [] |
| for i in range(len(self.horizon_labels)): |
| decision = self.make_decision( |
| market_features=market_features, |
| portfolio_state=portfolio_state, |
| trader_profile=trader_profile, |
| behavior_alerts=behavior_alerts, |
| current_atr=current_atr, |
| horizon_idx=i, |
| ) |
| decisions.append(decision) |
| return decisions |
|
|
|
|
| def format_decision(decision: TradingDecision) -> str: |
| """Format a trading decision for display.""" |
| lines = [ |
| "═" * 60, |
| f" TRADING DECISION ({decision.horizon.upper()})", |
| "═" * 60, |
| f" Signal: {decision.signal.value}", |
| f" Confidence: {decision.confidence:.1%}", |
| f" Direction: {decision.direction_prob:.1%} probability UP", |
| f" Expected Ret: {decision.expected_return:.2%}", |
| f" Risk Score: {decision.risk_score:.2f}/1.00", |
| f" Position Size: {decision.position_size_pct:.1%} of portfolio", |
| f" Stop Loss: {decision.stop_loss_pct:.2%} from entry", |
| f" Take Profit: {decision.take_profit_pct:.2%} from entry", |
| f" Drawdown Risk: {decision.drawdown_risk:.1%}", |
| f" Market Regime: {decision.market_regime}", |
| "─" * 60, |
| " REASONING:", |
| ] |
| for r in decision.reasoning: |
| lines.append(f" • {r}") |
| |
| if decision.alerts: |
| lines.append("─" * 60) |
| lines.append(" ⚠️ ALERTS:") |
| for a in decision.alerts: |
| lines.append(f" [{a.get('severity', 'INFO')}] {a.get('type', '')}: {a.get('message', '')}") |
| |
| lines.append("═" * 60) |
| return "\n".join(lines) |
|
|