GoshawkVortexAI commited on
Commit
39bdaba
·
verified ·
1 Parent(s): f952974

Create ml_filter.py

Browse files
Files changed (1) hide show
  1. ml_filter.py +234 -0
ml_filter.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ml_filter.py — Production inference wrapper for the trained probability filter.
3
+
4
+ Integration point in the pipeline:
5
+
6
+ Rule Engine Output
7
+
8
+
9
+ build_feature_dict() ← feature_builder.py
10
+
11
+
12
+ TradeFilter.predict() ← THIS MODULE
13
+
14
+ ├─► prob < threshold → SKIP (no trade)
15
+
16
+ └─► prob >= threshold → risk_engine.evaluate_risk()
17
+
18
+
19
+ Position sizing → Execution
20
+
21
+ Usage:
22
+ filter = TradeFilter.load()
23
+ result = filter.predict(regime_data, volume_data, scores)
24
+ if result.approved:
25
+ risk = evaluate_risk(..., regime_confidence=result.probability)
26
+ """
27
+
28
+ import json
29
+ import logging
30
+ from dataclasses import dataclass
31
+ from pathlib import Path
32
+ from typing import Dict, Any, Optional
33
+
34
+ import numpy as np
35
+
36
+ from ml_config import MODEL_PATH, THRESHOLD_PATH, DEFAULT_PROB_THRESHOLD, FEATURE_COLUMNS
37
+ from feature_builder import build_feature_dict, feature_dict_to_matrix, validate_features
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ @dataclass
43
+ class FilterResult:
44
+ """Structured output from the probability filter."""
45
+ probability: float # P(win) from the model [0, 1]
46
+ threshold: float # current threshold
47
+ approved: bool # probability >= threshold
48
+ feature_dict: Dict # raw features (for logging/debugging)
49
+ reject_reason: str = "" # why rejected, if applicable
50
+
51
+ def __str__(self) -> str:
52
+ status = "APPROVED" if self.approved else f"REJECTED ({self.reject_reason})"
53
+ return f"FilterResult: p={self.probability:.4f} thresh={self.threshold:.4f} → {status}"
54
+
55
+
56
+ class TradeFilter:
57
+ """
58
+ Singleton-friendly inference wrapper.
59
+ Thread-safe for read operations (predict). Not safe for concurrent reloads.
60
+ """
61
+
62
+ def __init__(self, backend, threshold: float):
63
+ self._backend = backend
64
+ self._threshold = threshold
65
+ self._n_calls = 0
66
+ self._n_approved = 0
67
+
68
+ @classmethod
69
+ def load(cls, model_path: Path = MODEL_PATH, threshold_path: Path = THRESHOLD_PATH) -> "TradeFilter":
70
+ """
71
+ Load model and threshold from disk.
72
+ Falls back to DEFAULT_PROB_THRESHOLD if threshold file missing.
73
+ Returns None if model file doesn't exist (not yet trained).
74
+ """
75
+ import joblib
76
+
77
+ if not model_path.exists():
78
+ logger.warning(
79
+ f"Model file not found at {model_path}. "
80
+ f"Run train.py first. TradeFilter will return None from predict()."
81
+ )
82
+ return None
83
+
84
+ backend = joblib.load(model_path)
85
+ logger.info(f"Loaded model from {model_path}")
86
+
87
+ threshold = DEFAULT_PROB_THRESHOLD
88
+ if threshold_path.exists():
89
+ with open(threshold_path) as f:
90
+ data = json.load(f)
91
+ threshold = float(data.get("threshold", DEFAULT_PROB_THRESHOLD))
92
+ logger.info(f"Loaded threshold={threshold:.4f} from {threshold_path}")
93
+ else:
94
+ logger.warning(f"Threshold file not found. Using default={threshold:.4f}")
95
+
96
+ return cls(backend=backend, threshold=threshold)
97
+
98
+ @classmethod
99
+ def load_or_none(cls) -> Optional["TradeFilter"]:
100
+ """Convenience: returns None if model not yet trained (no crash)."""
101
+ try:
102
+ return cls.load()
103
+ except Exception as e:
104
+ logger.warning(f"Could not load TradeFilter: {e}")
105
+ return None
106
+
107
+ def predict(
108
+ self,
109
+ regime_data: Dict[str, Any],
110
+ volume_data: Dict[str, Any],
111
+ scores: Dict[str, Any],
112
+ ) -> FilterResult:
113
+ """
114
+ Run the full inference pipeline for a single setup.
115
+
116
+ Args:
117
+ regime_data: Output of detect_regime()
118
+ volume_data: Output of analyze_volume()
119
+ scores: Output of score_token()
120
+
121
+ Returns:
122
+ FilterResult with probability and approval decision
123
+ """
124
+ self._n_calls += 1
125
+
126
+ # Build and validate feature vector
127
+ try:
128
+ feat = build_feature_dict(regime_data, volume_data, scores)
129
+ except KeyError as e:
130
+ logger.error(f"Feature construction failed: {e}")
131
+ return FilterResult(
132
+ probability=0.0,
133
+ threshold=self._threshold,
134
+ approved=False,
135
+ feature_dict={},
136
+ reject_reason=f"FEATURE_ERROR: {e}",
137
+ )
138
+
139
+ if not validate_features(feat):
140
+ return FilterResult(
141
+ probability=0.0,
142
+ threshold=self._threshold,
143
+ approved=False,
144
+ feature_dict=feat,
145
+ reject_reason="INVALID_FEATURES (NaN or inf detected)",
146
+ )
147
+
148
+ X = feature_dict_to_matrix(feat)
149
+
150
+ try:
151
+ prob = float(self._backend.predict_win_prob(X)[0])
152
+ except Exception as e:
153
+ logger.error(f"Model inference error: {e}")
154
+ return FilterResult(
155
+ probability=0.0,
156
+ threshold=self._threshold,
157
+ approved=False,
158
+ feature_dict=feat,
159
+ reject_reason=f"INFERENCE_ERROR: {e}",
160
+ )
161
+
162
+ approved = prob >= self._threshold
163
+ if approved:
164
+ self._n_approved += 1
165
+
166
+ reject_reason = "" if approved else f"prob={prob:.4f} < threshold={self._threshold:.4f}"
167
+
168
+ return FilterResult(
169
+ probability=prob,
170
+ threshold=self._threshold,
171
+ approved=approved,
172
+ feature_dict=feat,
173
+ reject_reason=reject_reason,
174
+ )
175
+
176
+ def predict_batch(
177
+ self,
178
+ feature_dicts: list,
179
+ ) -> np.ndarray:
180
+ """
181
+ Batch inference for 100+ symbols simultaneously.
182
+ Returns array of probabilities in the same order as feature_dicts.
183
+ Much faster than calling predict() in a loop.
184
+ """
185
+ valid_rows = []
186
+ valid_indices = []
187
+
188
+ for i, feat in enumerate(feature_dicts):
189
+ if validate_features(feat):
190
+ valid_rows.append([feat[k] for k in FEATURE_COLUMNS])
191
+ valid_indices.append(i)
192
+
193
+ probs = np.zeros(len(feature_dicts), dtype=np.float64)
194
+
195
+ if valid_rows:
196
+ X = np.array(valid_rows, dtype=np.float64)
197
+ batch_probs = self._backend.predict_win_prob(X)
198
+ for j, orig_idx in enumerate(valid_indices):
199
+ probs[orig_idx] = batch_probs[j]
200
+
201
+ return probs
202
+
203
+ def predict_trade_probability(self, feature_dict: Dict[str, float]) -> float:
204
+ """
205
+ Simple scalar interface: feature_dict → float.
206
+ Matches the interface requested in the spec.
207
+ Returns 0.0 on any error.
208
+ """
209
+ if not validate_features(feature_dict):
210
+ return 0.0
211
+ X = feature_dict_to_matrix(feature_dict)
212
+ try:
213
+ return float(self._backend.predict_win_prob(X)[0])
214
+ except Exception:
215
+ return 0.0
216
+
217
+ @property
218
+ def threshold(self) -> float:
219
+ return self._threshold
220
+
221
+ @threshold.setter
222
+ def threshold(self, value: float):
223
+ if not 0.0 < value < 1.0:
224
+ raise ValueError(f"Threshold must be in (0, 1), got {value}")
225
+ self._threshold = value
226
+
227
+ def stats(self) -> dict:
228
+ approval_rate = self._n_approved / self._n_calls if self._n_calls > 0 else 0.0
229
+ return {
230
+ "n_calls": self._n_calls,
231
+ "n_approved": self._n_approved,
232
+ "approval_rate": round(approval_rate, 4),
233
+ "threshold": self._threshold,
234
+ }