avinashhm commited on
Commit
eaf3794
·
verified ·
1 Parent(s): dc5d45a

Add trading_intelligence/decision_engine.py

Browse files
trading_intelligence/decision_engine.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Decision Engine Module
3
+ ======================
4
+ Combines prediction model, risk model, and market regime
5
+ to produce final trading decisions.
6
+
7
+ Output: Buy / Sell / Hold with confidence and risk-adjusted sizing.
8
+ """
9
+
10
+ import torch
11
+ import numpy as np
12
+ from typing import Dict, List, Optional
13
+ from dataclasses import dataclass, field
14
+ from enum import Enum
15
+
16
+
17
+ class Signal(Enum):
18
+ STRONG_BUY = "STRONG_BUY"
19
+ BUY = "BUY"
20
+ HOLD = "HOLD"
21
+ SELL = "SELL"
22
+ STRONG_SELL = "STRONG_SELL"
23
+
24
+
25
+ @dataclass
26
+ class TradingDecision:
27
+ """Complete trading decision with all context."""
28
+ signal: Signal
29
+ confidence: float # 0-1 overall confidence
30
+ direction_prob: float # Probability of upward move
31
+ expected_return: float # Expected return (decimal)
32
+ risk_score: float # 0-1 risk score
33
+ position_size_pct: float # Recommended position size (% of portfolio)
34
+ stop_loss_pct: float # Stop loss as % from entry
35
+ take_profit_pct: float # Take profit as % from entry
36
+ drawdown_risk: float # Probability of significant drawdown
37
+ market_regime: str # Current market regime
38
+ horizon: str # Prediction horizon label
39
+ reasoning: List[str] = field(default_factory=list) # Human-readable reasoning
40
+ alerts: List[Dict] = field(default_factory=list) # Active behavior alerts
41
+
42
+
43
+ class DecisionEngine:
44
+ """
45
+ Final decision layer that combines all model outputs.
46
+
47
+ Decision logic:
48
+ 1. Get prediction from TradingTransformer (direction, return, uncertainty)
49
+ 2. Get risk assessment from RiskModel (risk score, sizing, levels)
50
+ 3. Check market regime (trending vs mean-reverting vs high-vol)
51
+ 4. Apply personalization (adapt to trader profile)
52
+ 5. Generate final signal with confidence
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ prediction_model=None,
58
+ risk_model=None,
59
+ personalization_engine=None,
60
+ confidence_threshold: float = 0.6,
61
+ strong_signal_threshold: float = 0.8,
62
+ ):
63
+ self.prediction_model = prediction_model
64
+ self.risk_model = risk_model
65
+ self.personalization_engine = personalization_engine
66
+ self.confidence_threshold = confidence_threshold
67
+ self.strong_signal_threshold = strong_signal_threshold
68
+
69
+ # Horizon labels
70
+ self.horizon_labels = ['short_term', 'mid_term', 'long_term']
71
+
72
+ def make_decision(
73
+ self,
74
+ market_features: np.ndarray,
75
+ portfolio_state: Optional[Dict] = None,
76
+ trader_profile: Optional[Dict] = None,
77
+ behavior_alerts: Optional[Dict] = None,
78
+ current_atr: float = 0.01,
79
+ horizon_idx: int = 0,
80
+ ) -> TradingDecision:
81
+ """
82
+ Generate a complete trading decision.
83
+
84
+ Args:
85
+ market_features: (1, num_features, seq_len) normalized features
86
+ portfolio_state: Current portfolio information
87
+ trader_profile: Trader's behavior profile
88
+ behavior_alerts: Current behavior alerts
89
+ current_atr: Current ATR for stop/take-profit calculation
90
+ horizon_idx: Which prediction horizon to use (0=short, 1=mid, 2=long)
91
+
92
+ Returns:
93
+ TradingDecision with full context
94
+ """
95
+ reasoning = []
96
+
97
+ # 1. Get market prediction
98
+ prediction = self._get_prediction(market_features, horizon_idx)
99
+ direction_prob = prediction['direction_prob']
100
+ expected_return = prediction['expected_return']
101
+ model_confidence = prediction['confidence']
102
+
103
+ reasoning.append(f"Direction probability: {direction_prob:.1%} up")
104
+ reasoning.append(f"Expected return: {expected_return:.2%}")
105
+ reasoning.append(f"Model confidence: {model_confidence:.1%}")
106
+
107
+ # 2. Determine market regime
108
+ regime = self._detect_regime(market_features)
109
+ reasoning.append(f"Market regime: {regime}")
110
+
111
+ # 3. Get risk assessment
112
+ risk_score = 0.5 # Default
113
+ position_size = 0.03 # Default 3% position
114
+ sl_mult = 2.0
115
+ tp_mult = 3.0
116
+ drawdown_risk = 0.1
117
+
118
+ if self.risk_model is not None and portfolio_state is not None:
119
+ risk_output = self._get_risk_assessment(market_features, portfolio_state)
120
+ risk_score = risk_output.get('risk_score', 0.5)
121
+ position_size = risk_output.get('adjusted_position_size', 0.03)
122
+ sl_mult = risk_output.get('stop_loss_atr_mult', 2.0)
123
+ tp_mult = risk_output.get('take_profit_atr_mult', 3.0)
124
+ drawdown_risk = risk_output.get('drawdown_risk', 0.1)
125
+
126
+ reasoning.append(f"Risk score: {risk_score:.2f}")
127
+
128
+ # 4. Apply personalization
129
+ if self.personalization_engine and trader_profile and behavior_alerts:
130
+ personal_params = self.personalization_engine.get_personalized_params(
131
+ trader_profile, behavior_alerts
132
+ )
133
+ # Cap position size
134
+ position_size = min(position_size, personal_params.get('max_position_pct', 0.05))
135
+
136
+ # Adjust confidence threshold
137
+ min_conf = personal_params.get('min_confidence', 0.6)
138
+
139
+ # Use personalized SL/TP if available
140
+ sl_mult = personal_params.get('sl_atr_mult', sl_mult)
141
+ tp_mult = personal_params.get('tp_atr_mult', tp_mult)
142
+
143
+ reasoning.append(f"Personalized min confidence: {min_conf:.1%}")
144
+ else:
145
+ min_conf = self.confidence_threshold
146
+
147
+ # 5. Generate signal
148
+ combined_confidence = model_confidence * (1 - 0.3 * risk_score)
149
+
150
+ # Apply regime adjustments
151
+ if regime == 'high_volatility':
152
+ combined_confidence *= 0.8
153
+ position_size *= 0.7
154
+ reasoning.append("High volatility: reduced confidence and position size")
155
+ elif regime == 'trending':
156
+ combined_confidence *= 1.1 # Slightly boost confidence in trends
157
+ reasoning.append("Trending market: slight confidence boost")
158
+
159
+ combined_confidence = np.clip(combined_confidence, 0, 1)
160
+
161
+ # Determine signal
162
+ if combined_confidence < min_conf:
163
+ signal = Signal.HOLD
164
+ reasoning.append(f"Confidence {combined_confidence:.1%} below threshold {min_conf:.1%} → HOLD")
165
+ elif direction_prob > 0.5:
166
+ if combined_confidence >= self.strong_signal_threshold and expected_return > 0.005:
167
+ signal = Signal.STRONG_BUY
168
+ else:
169
+ signal = Signal.BUY
170
+ reasoning.append(f"Bullish signal: {direction_prob:.1%} up probability")
171
+ else:
172
+ if combined_confidence >= self.strong_signal_threshold and expected_return < -0.005:
173
+ signal = Signal.STRONG_SELL
174
+ else:
175
+ signal = Signal.SELL
176
+ reasoning.append(f"Bearish signal: {1-direction_prob:.1%} down probability")
177
+
178
+ # Check behavior alerts
179
+ alerts = behavior_alerts.get('alerts', []) if behavior_alerts else []
180
+ if alerts:
181
+ for alert in alerts:
182
+ if alert.get('severity') == 'CRITICAL':
183
+ signal = Signal.HOLD
184
+ reasoning.append(f"CRITICAL ALERT: {alert['type']} - Overriding to HOLD")
185
+
186
+ # Compute SL/TP levels
187
+ stop_loss_pct = sl_mult * current_atr
188
+ take_profit_pct = tp_mult * current_atr
189
+
190
+ return TradingDecision(
191
+ signal=signal,
192
+ confidence=float(combined_confidence),
193
+ direction_prob=float(direction_prob),
194
+ expected_return=float(expected_return),
195
+ risk_score=float(risk_score),
196
+ position_size_pct=float(position_size),
197
+ stop_loss_pct=float(stop_loss_pct),
198
+ take_profit_pct=float(take_profit_pct),
199
+ drawdown_risk=float(drawdown_risk),
200
+ market_regime=regime,
201
+ horizon=self.horizon_labels[min(horizon_idx, len(self.horizon_labels)-1)],
202
+ reasoning=reasoning,
203
+ alerts=alerts,
204
+ )
205
+
206
+ def _get_prediction(self, features: np.ndarray, horizon_idx: int) -> Dict:
207
+ """Get prediction from model or return defaults."""
208
+ if self.prediction_model is not None:
209
+ x = torch.FloatTensor(features)
210
+ if x.dim() == 2:
211
+ x = x.unsqueeze(0)
212
+ result = self.prediction_model.predict_with_confidence(x)
213
+ return {
214
+ 'direction_prob': float(result['direction_probs'][0, horizon_idx]),
215
+ 'expected_return': float(result['expected_returns'][0, horizon_idx]),
216
+ 'confidence': float(result['confidence'][0, horizon_idx]),
217
+ }
218
+
219
+ # Default values for testing
220
+ return {
221
+ 'direction_prob': 0.55,
222
+ 'expected_return': 0.002,
223
+ 'confidence': 0.65,
224
+ }
225
+
226
+ def _detect_regime(self, features: np.ndarray) -> str:
227
+ """Simple regime detection from features."""
228
+ # In production, this would use the regime features from FeatureEngine
229
+ return 'normal' # Placeholder
230
+
231
+ def _get_risk_assessment(self, features: np.ndarray, portfolio: Dict) -> Dict:
232
+ """Get risk assessment from risk model."""
233
+ return {
234
+ 'risk_score': 0.5,
235
+ 'adjusted_position_size': 0.03,
236
+ 'stop_loss_atr_mult': 2.0,
237
+ 'take_profit_atr_mult': 3.0,
238
+ 'drawdown_risk': 0.1,
239
+ }
240
+
241
+ def make_multi_horizon_decisions(
242
+ self,
243
+ market_features: np.ndarray,
244
+ portfolio_state: Optional[Dict] = None,
245
+ trader_profile: Optional[Dict] = None,
246
+ behavior_alerts: Optional[Dict] = None,
247
+ current_atr: float = 0.01,
248
+ ) -> List[TradingDecision]:
249
+ """Generate decisions for all horizons simultaneously."""
250
+ decisions = []
251
+ for i in range(len(self.horizon_labels)):
252
+ decision = self.make_decision(
253
+ market_features=market_features,
254
+ portfolio_state=portfolio_state,
255
+ trader_profile=trader_profile,
256
+ behavior_alerts=behavior_alerts,
257
+ current_atr=current_atr,
258
+ horizon_idx=i,
259
+ )
260
+ decisions.append(decision)
261
+ return decisions
262
+
263
+
264
+ def format_decision(decision: TradingDecision) -> str:
265
+ """Format a trading decision for display."""
266
+ lines = [
267
+ "═" * 60,
268
+ f" TRADING DECISION ({decision.horizon.upper()})",
269
+ "═" * 60,
270
+ f" Signal: {decision.signal.value}",
271
+ f" Confidence: {decision.confidence:.1%}",
272
+ f" Direction: {decision.direction_prob:.1%} probability UP",
273
+ f" Expected Ret: {decision.expected_return:.2%}",
274
+ f" Risk Score: {decision.risk_score:.2f}/1.00",
275
+ f" Position Size: {decision.position_size_pct:.1%} of portfolio",
276
+ f" Stop Loss: {decision.stop_loss_pct:.2%} from entry",
277
+ f" Take Profit: {decision.take_profit_pct:.2%} from entry",
278
+ f" Drawdown Risk: {decision.drawdown_risk:.1%}",
279
+ f" Market Regime: {decision.market_regime}",
280
+ "─" * 60,
281
+ " REASONING:",
282
+ ]
283
+ for r in decision.reasoning:
284
+ lines.append(f" • {r}")
285
+
286
+ if decision.alerts:
287
+ lines.append("─" * 60)
288
+ lines.append(" ⚠️ ALERTS:")
289
+ for a in decision.alerts:
290
+ lines.append(f" [{a.get('severity', 'INFO')}] {a.get('type', '')}: {a.get('message', '')}")
291
+
292
+ lines.append("═" * 60)
293
+ return "\n".join(lines)