Add trading_intelligence/decision_engine.py
Browse files
trading_intelligence/decision_engine.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Decision Engine Module
|
| 3 |
+
======================
|
| 4 |
+
Combines prediction model, risk model, and market regime
|
| 5 |
+
to produce final trading decisions.
|
| 6 |
+
|
| 7 |
+
Output: Buy / Sell / Hold with confidence and risk-adjusted sizing.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from typing import Dict, List, Optional
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from enum import Enum
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Signal(Enum):
|
| 18 |
+
STRONG_BUY = "STRONG_BUY"
|
| 19 |
+
BUY = "BUY"
|
| 20 |
+
HOLD = "HOLD"
|
| 21 |
+
SELL = "SELL"
|
| 22 |
+
STRONG_SELL = "STRONG_SELL"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class TradingDecision:
|
| 27 |
+
"""Complete trading decision with all context."""
|
| 28 |
+
signal: Signal
|
| 29 |
+
confidence: float # 0-1 overall confidence
|
| 30 |
+
direction_prob: float # Probability of upward move
|
| 31 |
+
expected_return: float # Expected return (decimal)
|
| 32 |
+
risk_score: float # 0-1 risk score
|
| 33 |
+
position_size_pct: float # Recommended position size (% of portfolio)
|
| 34 |
+
stop_loss_pct: float # Stop loss as % from entry
|
| 35 |
+
take_profit_pct: float # Take profit as % from entry
|
| 36 |
+
drawdown_risk: float # Probability of significant drawdown
|
| 37 |
+
market_regime: str # Current market regime
|
| 38 |
+
horizon: str # Prediction horizon label
|
| 39 |
+
reasoning: List[str] = field(default_factory=list) # Human-readable reasoning
|
| 40 |
+
alerts: List[Dict] = field(default_factory=list) # Active behavior alerts
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DecisionEngine:
|
| 44 |
+
"""
|
| 45 |
+
Final decision layer that combines all model outputs.
|
| 46 |
+
|
| 47 |
+
Decision logic:
|
| 48 |
+
1. Get prediction from TradingTransformer (direction, return, uncertainty)
|
| 49 |
+
2. Get risk assessment from RiskModel (risk score, sizing, levels)
|
| 50 |
+
3. Check market regime (trending vs mean-reverting vs high-vol)
|
| 51 |
+
4. Apply personalization (adapt to trader profile)
|
| 52 |
+
5. Generate final signal with confidence
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
prediction_model=None,
|
| 58 |
+
risk_model=None,
|
| 59 |
+
personalization_engine=None,
|
| 60 |
+
confidence_threshold: float = 0.6,
|
| 61 |
+
strong_signal_threshold: float = 0.8,
|
| 62 |
+
):
|
| 63 |
+
self.prediction_model = prediction_model
|
| 64 |
+
self.risk_model = risk_model
|
| 65 |
+
self.personalization_engine = personalization_engine
|
| 66 |
+
self.confidence_threshold = confidence_threshold
|
| 67 |
+
self.strong_signal_threshold = strong_signal_threshold
|
| 68 |
+
|
| 69 |
+
# Horizon labels
|
| 70 |
+
self.horizon_labels = ['short_term', 'mid_term', 'long_term']
|
| 71 |
+
|
| 72 |
+
def make_decision(
|
| 73 |
+
self,
|
| 74 |
+
market_features: np.ndarray,
|
| 75 |
+
portfolio_state: Optional[Dict] = None,
|
| 76 |
+
trader_profile: Optional[Dict] = None,
|
| 77 |
+
behavior_alerts: Optional[Dict] = None,
|
| 78 |
+
current_atr: float = 0.01,
|
| 79 |
+
horizon_idx: int = 0,
|
| 80 |
+
) -> TradingDecision:
|
| 81 |
+
"""
|
| 82 |
+
Generate a complete trading decision.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
market_features: (1, num_features, seq_len) normalized features
|
| 86 |
+
portfolio_state: Current portfolio information
|
| 87 |
+
trader_profile: Trader's behavior profile
|
| 88 |
+
behavior_alerts: Current behavior alerts
|
| 89 |
+
current_atr: Current ATR for stop/take-profit calculation
|
| 90 |
+
horizon_idx: Which prediction horizon to use (0=short, 1=mid, 2=long)
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
TradingDecision with full context
|
| 94 |
+
"""
|
| 95 |
+
reasoning = []
|
| 96 |
+
|
| 97 |
+
# 1. Get market prediction
|
| 98 |
+
prediction = self._get_prediction(market_features, horizon_idx)
|
| 99 |
+
direction_prob = prediction['direction_prob']
|
| 100 |
+
expected_return = prediction['expected_return']
|
| 101 |
+
model_confidence = prediction['confidence']
|
| 102 |
+
|
| 103 |
+
reasoning.append(f"Direction probability: {direction_prob:.1%} up")
|
| 104 |
+
reasoning.append(f"Expected return: {expected_return:.2%}")
|
| 105 |
+
reasoning.append(f"Model confidence: {model_confidence:.1%}")
|
| 106 |
+
|
| 107 |
+
# 2. Determine market regime
|
| 108 |
+
regime = self._detect_regime(market_features)
|
| 109 |
+
reasoning.append(f"Market regime: {regime}")
|
| 110 |
+
|
| 111 |
+
# 3. Get risk assessment
|
| 112 |
+
risk_score = 0.5 # Default
|
| 113 |
+
position_size = 0.03 # Default 3% position
|
| 114 |
+
sl_mult = 2.0
|
| 115 |
+
tp_mult = 3.0
|
| 116 |
+
drawdown_risk = 0.1
|
| 117 |
+
|
| 118 |
+
if self.risk_model is not None and portfolio_state is not None:
|
| 119 |
+
risk_output = self._get_risk_assessment(market_features, portfolio_state)
|
| 120 |
+
risk_score = risk_output.get('risk_score', 0.5)
|
| 121 |
+
position_size = risk_output.get('adjusted_position_size', 0.03)
|
| 122 |
+
sl_mult = risk_output.get('stop_loss_atr_mult', 2.0)
|
| 123 |
+
tp_mult = risk_output.get('take_profit_atr_mult', 3.0)
|
| 124 |
+
drawdown_risk = risk_output.get('drawdown_risk', 0.1)
|
| 125 |
+
|
| 126 |
+
reasoning.append(f"Risk score: {risk_score:.2f}")
|
| 127 |
+
|
| 128 |
+
# 4. Apply personalization
|
| 129 |
+
if self.personalization_engine and trader_profile and behavior_alerts:
|
| 130 |
+
personal_params = self.personalization_engine.get_personalized_params(
|
| 131 |
+
trader_profile, behavior_alerts
|
| 132 |
+
)
|
| 133 |
+
# Cap position size
|
| 134 |
+
position_size = min(position_size, personal_params.get('max_position_pct', 0.05))
|
| 135 |
+
|
| 136 |
+
# Adjust confidence threshold
|
| 137 |
+
min_conf = personal_params.get('min_confidence', 0.6)
|
| 138 |
+
|
| 139 |
+
# Use personalized SL/TP if available
|
| 140 |
+
sl_mult = personal_params.get('sl_atr_mult', sl_mult)
|
| 141 |
+
tp_mult = personal_params.get('tp_atr_mult', tp_mult)
|
| 142 |
+
|
| 143 |
+
reasoning.append(f"Personalized min confidence: {min_conf:.1%}")
|
| 144 |
+
else:
|
| 145 |
+
min_conf = self.confidence_threshold
|
| 146 |
+
|
| 147 |
+
# 5. Generate signal
|
| 148 |
+
combined_confidence = model_confidence * (1 - 0.3 * risk_score)
|
| 149 |
+
|
| 150 |
+
# Apply regime adjustments
|
| 151 |
+
if regime == 'high_volatility':
|
| 152 |
+
combined_confidence *= 0.8
|
| 153 |
+
position_size *= 0.7
|
| 154 |
+
reasoning.append("High volatility: reduced confidence and position size")
|
| 155 |
+
elif regime == 'trending':
|
| 156 |
+
combined_confidence *= 1.1 # Slightly boost confidence in trends
|
| 157 |
+
reasoning.append("Trending market: slight confidence boost")
|
| 158 |
+
|
| 159 |
+
combined_confidence = np.clip(combined_confidence, 0, 1)
|
| 160 |
+
|
| 161 |
+
# Determine signal
|
| 162 |
+
if combined_confidence < min_conf:
|
| 163 |
+
signal = Signal.HOLD
|
| 164 |
+
reasoning.append(f"Confidence {combined_confidence:.1%} below threshold {min_conf:.1%} → HOLD")
|
| 165 |
+
elif direction_prob > 0.5:
|
| 166 |
+
if combined_confidence >= self.strong_signal_threshold and expected_return > 0.005:
|
| 167 |
+
signal = Signal.STRONG_BUY
|
| 168 |
+
else:
|
| 169 |
+
signal = Signal.BUY
|
| 170 |
+
reasoning.append(f"Bullish signal: {direction_prob:.1%} up probability")
|
| 171 |
+
else:
|
| 172 |
+
if combined_confidence >= self.strong_signal_threshold and expected_return < -0.005:
|
| 173 |
+
signal = Signal.STRONG_SELL
|
| 174 |
+
else:
|
| 175 |
+
signal = Signal.SELL
|
| 176 |
+
reasoning.append(f"Bearish signal: {1-direction_prob:.1%} down probability")
|
| 177 |
+
|
| 178 |
+
# Check behavior alerts
|
| 179 |
+
alerts = behavior_alerts.get('alerts', []) if behavior_alerts else []
|
| 180 |
+
if alerts:
|
| 181 |
+
for alert in alerts:
|
| 182 |
+
if alert.get('severity') == 'CRITICAL':
|
| 183 |
+
signal = Signal.HOLD
|
| 184 |
+
reasoning.append(f"CRITICAL ALERT: {alert['type']} - Overriding to HOLD")
|
| 185 |
+
|
| 186 |
+
# Compute SL/TP levels
|
| 187 |
+
stop_loss_pct = sl_mult * current_atr
|
| 188 |
+
take_profit_pct = tp_mult * current_atr
|
| 189 |
+
|
| 190 |
+
return TradingDecision(
|
| 191 |
+
signal=signal,
|
| 192 |
+
confidence=float(combined_confidence),
|
| 193 |
+
direction_prob=float(direction_prob),
|
| 194 |
+
expected_return=float(expected_return),
|
| 195 |
+
risk_score=float(risk_score),
|
| 196 |
+
position_size_pct=float(position_size),
|
| 197 |
+
stop_loss_pct=float(stop_loss_pct),
|
| 198 |
+
take_profit_pct=float(take_profit_pct),
|
| 199 |
+
drawdown_risk=float(drawdown_risk),
|
| 200 |
+
market_regime=regime,
|
| 201 |
+
horizon=self.horizon_labels[min(horizon_idx, len(self.horizon_labels)-1)],
|
| 202 |
+
reasoning=reasoning,
|
| 203 |
+
alerts=alerts,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def _get_prediction(self, features: np.ndarray, horizon_idx: int) -> Dict:
|
| 207 |
+
"""Get prediction from model or return defaults."""
|
| 208 |
+
if self.prediction_model is not None:
|
| 209 |
+
x = torch.FloatTensor(features)
|
| 210 |
+
if x.dim() == 2:
|
| 211 |
+
x = x.unsqueeze(0)
|
| 212 |
+
result = self.prediction_model.predict_with_confidence(x)
|
| 213 |
+
return {
|
| 214 |
+
'direction_prob': float(result['direction_probs'][0, horizon_idx]),
|
| 215 |
+
'expected_return': float(result['expected_returns'][0, horizon_idx]),
|
| 216 |
+
'confidence': float(result['confidence'][0, horizon_idx]),
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
# Default values for testing
|
| 220 |
+
return {
|
| 221 |
+
'direction_prob': 0.55,
|
| 222 |
+
'expected_return': 0.002,
|
| 223 |
+
'confidence': 0.65,
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
def _detect_regime(self, features: np.ndarray) -> str:
|
| 227 |
+
"""Simple regime detection from features."""
|
| 228 |
+
# In production, this would use the regime features from FeatureEngine
|
| 229 |
+
return 'normal' # Placeholder
|
| 230 |
+
|
| 231 |
+
def _get_risk_assessment(self, features: np.ndarray, portfolio: Dict) -> Dict:
|
| 232 |
+
"""Get risk assessment from risk model."""
|
| 233 |
+
return {
|
| 234 |
+
'risk_score': 0.5,
|
| 235 |
+
'adjusted_position_size': 0.03,
|
| 236 |
+
'stop_loss_atr_mult': 2.0,
|
| 237 |
+
'take_profit_atr_mult': 3.0,
|
| 238 |
+
'drawdown_risk': 0.1,
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
def make_multi_horizon_decisions(
|
| 242 |
+
self,
|
| 243 |
+
market_features: np.ndarray,
|
| 244 |
+
portfolio_state: Optional[Dict] = None,
|
| 245 |
+
trader_profile: Optional[Dict] = None,
|
| 246 |
+
behavior_alerts: Optional[Dict] = None,
|
| 247 |
+
current_atr: float = 0.01,
|
| 248 |
+
) -> List[TradingDecision]:
|
| 249 |
+
"""Generate decisions for all horizons simultaneously."""
|
| 250 |
+
decisions = []
|
| 251 |
+
for i in range(len(self.horizon_labels)):
|
| 252 |
+
decision = self.make_decision(
|
| 253 |
+
market_features=market_features,
|
| 254 |
+
portfolio_state=portfolio_state,
|
| 255 |
+
trader_profile=trader_profile,
|
| 256 |
+
behavior_alerts=behavior_alerts,
|
| 257 |
+
current_atr=current_atr,
|
| 258 |
+
horizon_idx=i,
|
| 259 |
+
)
|
| 260 |
+
decisions.append(decision)
|
| 261 |
+
return decisions
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def format_decision(decision: TradingDecision) -> str:
|
| 265 |
+
"""Format a trading decision for display."""
|
| 266 |
+
lines = [
|
| 267 |
+
"═" * 60,
|
| 268 |
+
f" TRADING DECISION ({decision.horizon.upper()})",
|
| 269 |
+
"═" * 60,
|
| 270 |
+
f" Signal: {decision.signal.value}",
|
| 271 |
+
f" Confidence: {decision.confidence:.1%}",
|
| 272 |
+
f" Direction: {decision.direction_prob:.1%} probability UP",
|
| 273 |
+
f" Expected Ret: {decision.expected_return:.2%}",
|
| 274 |
+
f" Risk Score: {decision.risk_score:.2f}/1.00",
|
| 275 |
+
f" Position Size: {decision.position_size_pct:.1%} of portfolio",
|
| 276 |
+
f" Stop Loss: {decision.stop_loss_pct:.2%} from entry",
|
| 277 |
+
f" Take Profit: {decision.take_profit_pct:.2%} from entry",
|
| 278 |
+
f" Drawdown Risk: {decision.drawdown_risk:.1%}",
|
| 279 |
+
f" Market Regime: {decision.market_regime}",
|
| 280 |
+
"─" * 60,
|
| 281 |
+
" REASONING:",
|
| 282 |
+
]
|
| 283 |
+
for r in decision.reasoning:
|
| 284 |
+
lines.append(f" • {r}")
|
| 285 |
+
|
| 286 |
+
if decision.alerts:
|
| 287 |
+
lines.append("─" * 60)
|
| 288 |
+
lines.append(" ⚠️ ALERTS:")
|
| 289 |
+
for a in decision.alerts:
|
| 290 |
+
lines.append(f" [{a.get('severity', 'INFO')}] {a.get('type', '')}: {a.get('message', '')}")
|
| 291 |
+
|
| 292 |
+
lines.append("═" * 60)
|
| 293 |
+
return "\n".join(lines)
|