Nautilus AI commited on
Commit
c5c085b
·
0 Parent(s):

Deploy: Trainer to Root (Retry)

Browse files
.dockerignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Exclude Version Control
2
+ .git
3
+ .gitignore
4
+
5
+ # Exclude Virtual Environment
6
+ venv/
7
+ env/
8
+ .env
9
+
10
+ # Exclude Python Cache
11
+ __pycache__/
12
+ *.pyc
13
+ *.pyo
14
+ *.pyd
15
+
16
+ # Exclude Local Data (Can be huge)
17
+ data/
18
+ start_data/
19
+ ray_results/
20
+
21
+ # Exclude IDE settings
22
+ .vscode/
23
+ .idea/
24
+ .DS_Store
25
+
26
+ # Exclude Logs
27
+ *.log
AIDocs/ARCHITECTURE_HT_TRADER.md ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📘 HT-TRADER - Nguyên lý hoạt động chi tiết
2
+
3
+ ## 🎯 Vai trò chính
4
+ `ht-trader` là **execution engine** của HyperTrade - chịu trách nhiệm:
5
+ 1. Nhận trading signals từ `ht-brain`
6
+ 2. Kiểm tra risk management
7
+ 3. Thực thi lệnh trên exchange (hoặc paper trading)
8
+ 4. Quản lý positions (mở/đóng/DCA)
9
+ 5. Ghi nhận trades vào QuestDB
10
+
11
+ ---
12
+
13
+ ## 🔄 Luồng hoạt động chính
14
+
15
+ ### 1. **Khởi tạo (Initialization)**
16
+ ```python
17
+ class Trader:
18
+ def __init__(self):
19
+ # Redis Pub/Sub - Lắng nghe 3 channels:
20
+ self.pubsub.subscribe("signals", "system_commands", "l2_updates")
21
+
22
+ # Components:
23
+ self.risk_manager = RiskManager() # Kiểm tra risk
24
+ self.exchange_client = ExchangeClient() # Giao tiếp với exchange
25
+ self.position_manager = PositionManager() # Quản lý positions
26
+ self.ab_manager = ABTestManager() # A/B testing strategies
27
+
28
+ # State:
29
+ self.trading_enabled = True # Master switch
30
+ self.positions = {} # Open positions
31
+ self.orderbooks = {} # Latest L2 data
32
+ ```
33
+
34
+ **Giải thích:**
35
+ - Subscribe 3 channels để nhận:
36
+ - `signals`: Trading signals từ ht-brain
37
+ - `system_commands`: Lệnh điều khiển (STOP_TRADING, etc.)
38
+ - `l2_updates`: Orderbook data cho paper trading
39
+
40
+ ---
41
+
42
+ ### 2. **Main Loop - Lắng nghe messages**
43
+ ```python
44
+ def run(self):
45
+ while self.running:
46
+ message = self.pubsub.get_message()
47
+
48
+ if channel == 'signals':
49
+ self.process_signal(signal_data) # Xử lý signal
50
+
51
+ elif channel == 'system_commands':
52
+ self.process_command(cmd_data) # Xử lý lệnh
53
+
54
+ elif channel == 'l2_updates':
55
+ self.orderbooks[coin] = data # Cập nhật orderbook
56
+ ```
57
+
58
+ **Giải thích:**
59
+ - Loop chạy liên tục, check messages mỗi 0.01s
60
+ - Phân loại message theo channel và xử lý tương ứng
61
+
62
+ ---
63
+
64
+ ### 3. **Xử lý Signal (process_signal)**
65
+
66
+ #### **Bước 1: Kiểm tra trading enabled**
67
+ ```python
68
+ if not self.trading_enabled:
69
+ logger.warning("Signal ignored: TRADING DISABLED")
70
+ return
71
+ ```
72
+
73
+ #### **Bước 2: Parse signal data**
74
+ ```python
75
+ coin = signal_data.get("coin") # BTC, ETH, SOL
76
+ signal = signal_data.get("signal") # BUY, SELL, HOLD
77
+ confidence = float(signal_data.get("confidence", 0))
78
+ price = float(signal_data.get("price", 0))
79
+ reason = signal_data.get("reason") # Lý do signal
80
+ features = signal_data.get("features", {}) # Features từ Brain
81
+ ```
82
+
83
+ #### **Bước 3: Check Exit Conditions (Quan trọng!)**
84
+ ```python
85
+ should_close, close_reason = self.position_manager.check_exit_condition(coin, signal, price)
86
+
87
+ if should_close:
88
+ # Đóng position hiện tại
89
+ pos = self.position_manager.open_positions.get(coin)
90
+ qty_to_close = abs(pos['quantity'])
91
+ close_signal = 'SELL' if pos['side'] == 'LONG' else 'BUY'
92
+
93
+ self.execute_trade(coin, close_signal, price, qty_to_close, close_reason, ...)
94
+ return
95
+ ```
96
+
97
+ **Exit conditions bao gồm:**
98
+ - ✅ **Take Profit**: PnL > threshold
99
+ - ✅ **Stop Loss**: PnL < -threshold
100
+ - ✅ **Reversal Signal**: Signal ngược chiều + PnL dương
101
+ - ❌ **Hold**: Signal ngược chiều nhưng PnL âm (chờ hồi vốn)
102
+
103
+ #### **Bước 4: Entry Logic (Nếu không close)**
104
+ ```python
105
+ # A/B Testing - Chọn strategy
106
+ strategy_config = self.ab_manager.get_assignment(coin, user_id)
107
+ size_usd = strategy_config.get('size_usd', 1000.0)
108
+ quantity = size_usd / price
109
+
110
+ # Tính current exposure
111
+ current_exposure = sum(abs(pos['quantity']) * pos_price
112
+ for pos in self.position_manager.open_positions.values())
113
+
114
+ # Risk Check
115
+ allowed, rejection_reason = self.risk_manager.check_trade(
116
+ coin, signal, price, quantity, current_exposure
117
+ )
118
+
119
+ if not allowed:
120
+ logger.warning(f"Trade Rejected: {rejection_reason}")
121
+ return
122
+ ```
123
+
124
+ **Risk checks bao gồm:**
125
+ - ✅ Max position size per coin
126
+ - ✅ Max total exposure
127
+ - ✅ Max drawdown
128
+ - ✅ Confidence threshold
129
+
130
+ #### **Bước 5: Execute Trade**
131
+ ```python
132
+ self.execute_trade(coin, signal, price, quantity, reason, features, strategy_id)
133
+ ```
134
+
135
+ ---
136
+
137
+ ### 4. **Execute Trade (execute_trade)**
138
+
139
+ #### **Bước 1: Place order trên exchange**
140
+ ```python
141
+ result = self.exchange_client.place_order(coin, signal, quantity, price, orderbook=orderbook)
142
+ ```
143
+
144
+ **Exchange Client modes:**
145
+ - **Paper Trading**: Simulate order với L2 orderbook data
146
+ - **Live Trading**: Gửi order thật lên Hyperliquid
147
+
148
+ #### **Bước 2: Update Position Manager**
149
+ ```python
150
+ trade_id = str(uuid.uuid4())
151
+ self.position_manager.handle_fill(coin, signal, price, quantity, trade_id, features)
152
+ ```
153
+
154
+ **Position Manager tracking:**
155
+ - Open positions: `{coin: {side, quantity, entry_price, pnl, ...}}`
156
+ - Closed positions: Lưu vào QuestDB
157
+
158
+ #### **Bước 3: Log trade vào QuestDB**
159
+ ```python
160
+ self.db_sender.sender.row(
161
+ 'trades_executed',
162
+ symbols={
163
+ 'trade_id': trade_id,
164
+ 'coin': coin,
165
+ 'side': signal,
166
+ 'status': 'SUBMITTED',
167
+ 'strategy_id': strategy_id,
168
+ 'environment': 'PAPER'
169
+ },
170
+ columns={
171
+ 'entry_price': price,
172
+ 'quantity': quantity,
173
+ 'entry_fee': quantity * price * TAKER_FEE_RATE,
174
+ 'current_pnl': 0.0,
175
+ 'features_json': json.dumps(features),
176
+ 'notes': reason
177
+ }
178
+ )
179
+ ```
180
+
181
+ #### **Bước 4: Publish trade update**
182
+ ```python
183
+ update_msg = {
184
+ 'event': 'trade_opened',
185
+ 'trade_id': trade_id,
186
+ 'coin': coin,
187
+ 'side': signal,
188
+ 'price': price,
189
+ 'quantity': quantity
190
+ }
191
+ self.redis_client.publish('trade_updates', json.dumps(update_msg))
192
+ ```
193
+
194
+ ---
195
+
196
+ ### 5. **System Commands (process_command)**
197
+
198
+ ```python
199
+ def process_command(self, cmd_data):
200
+ command = cmd_data.get('command')
201
+
202
+ if command == 'STOP_TRADING':
203
+ self.trading_enabled = False # Dừng nhận signals mới
204
+
205
+ elif command == 'START_TRADING':
206
+ self.trading_enabled = True # Bật lại trading
207
+
208
+ elif command == 'REDUCE_RISK':
209
+ # Giảm max position size xuống 50%
210
+ new_limit = max(current_limit * 0.5, 100.0)
211
+ self.risk_manager.update_policy('max_position_size_usd', new_limit)
212
+
213
+ elif command == 'INCREASE_RISK':
214
+ # Tăng max position size lên 150%
215
+ new_limit = min(current_limit * 1.5, 5000.0)
216
+ self.risk_manager.update_policy('max_position_size_usd', new_limit)
217
+
218
+ elif command == 'RESET_POSITIONS':
219
+ # Reset tất cả positions (emergency)
220
+ self.position_manager.reset()
221
+ ```
222
+
223
+ ---
224
+
225
+ ## 📊 Data Flow Diagram
226
+
227
+ ```
228
+ ┌─────────────┐
229
+ │ ht-brain │ ──► signals ──► ┌──────────────┐
230
+ └─────────────┘ │ │
231
+ │ ht-trader │
232
+ ┌─────────────┐ │ │
233
+ │ ht-manager │ ──► system_commands ──► │ │
234
+ └─────────────┘ │ │
235
+ │ ┌──────┐ │
236
+ ┌─────────────┐ │ │ Risk │ │
237
+ │ ht-l2-data │ ──► l2_updates ──► │ Check│ │
238
+ └─────────────┘ │ └──────┘ │
239
+ │ ↓ │
240
+ │ Execute │
241
+ │ ↓ │
242
+ │ ┌────────┐ │
243
+ │ │Exchange│ │
244
+ │ └────────┘ │
245
+ │ ↓ │
246
+ │ ┌────────┐ │
247
+ │ │QuestDB │ │
248
+ │ └────────┘ │
249
+ └──────────────┘
250
+ ```
251
+
252
+ ---
253
+
254
+ ## 🎯 Key Features
255
+
256
+ ### 1. **Smart Exit Logic**
257
+ - Không đóng position khi PnL âm (chờ hồi vốn)
258
+ - Tự động take profit khi đạt target
259
+ - Stop loss khi loss quá lớn
260
+
261
+ ### 2. **DCA (Dollar Cost Averaging)**
262
+ - Cho phép add thêm vào position cùng chiều
263
+ - Tính average entry price
264
+
265
+ ### 3. **Risk Management**
266
+ - Max position size per coin
267
+ - Max total exposure
268
+ - Confidence threshold filtering
269
+
270
+ ### 4. **A/B Testing**
271
+ - Test nhiều strategies song song
272
+ - Track performance từng strategy
273
+
274
+ ### 5. **Paper Trading**
275
+ - Simulate trades với L2 orderbook
276
+ - Không cần real money để test
277
+
278
+ ---
279
+
280
+ ## 🔧 Configuration
281
+
282
+ ### Risk Policy (risk_manager.py)
283
+ ```python
284
+ {
285
+ 'max_position_size_usd': 1000.0, # Max $1000/coin
286
+ 'max_total_exposure_usd': 5000.0, # Max $5000 total
287
+ 'max_drawdown_pct': 20.0, # Max 20% drawdown
288
+ 'min_confidence': 0.6 # Min 60% confidence
289
+ }
290
+ ```
291
+
292
+ ### Strategy Config (ab_testing.py)
293
+ ```python
294
+ {
295
+ 'id': 'strategy_A',
296
+ 'size_usd': 1000.0,
297
+ 'take_profit_pct': 2.0, # 2% TP
298
+ 'stop_loss_pct': 1.0 # 1% SL
299
+ }
300
+ ```
301
+
302
+ ---
303
+
304
+ ## 📈 Metrics (Prometheus)
305
+
306
+ - `trader_signals_received_total`: Tổng signals nhận được
307
+ - `trader_trades_executed_total`: Tổng trades thực thi
308
+ - `trader_trades_rejected_total`: Tổng trades bị reject
309
+ - `trader_position_size`: Position size hiện tại
310
+
311
+ ---
312
+
313
+ ## 🚨 Error Handling
314
+
315
+ 1. **QuestDB connection lost**: Auto-reconnect
316
+ 2. **Exchange API error**: Log và skip trade
317
+ 3. **Invalid signal**: Deserialize error → skip
318
+ 4. **Risk check failed**: Reject trade + log reason
319
+
320
+ ---
321
+
322
+ ## 💡 Best Practices
323
+
324
+ 1. **Luôn check `trading_enabled`** trước khi execute
325
+ 2. **Validate signal data** trước khi process
326
+ 3. **Log mọi trade** vào QuestDB để audit
327
+ 4. **Update metrics** để monitoring
328
+ 5. **Handle exceptions** gracefully
329
+
330
+ ---
331
+
332
+ ## 🔄 Lifecycle
333
+
334
+ ```
335
+ START
336
+
337
+ Initialize Components
338
+
339
+ Subscribe Redis Channels
340
+
341
+ ┌─────────────────┐
342
+ │ Main Loop │
343
+ │ ↓ │
344
+ │ Get Message │
345
+ │ ↓ │
346
+ │ Process │
347
+ │ ↓ │
348
+ │ Sleep 0.01s │
349
+ └─────────────────┘
350
+
351
+ STOP (KeyboardInterrupt)
352
+
353
+ Close Connections
354
+
355
+ END
356
+ ```
357
+
358
+ ---
359
+
360
+ ## 📝 Summary
361
+
362
+ **ht-trader** là một **event-driven execution engine** với:
363
+ - ✅ Real-time signal processing
364
+ - ✅ Smart position management
365
+ - ✅ Comprehensive risk checks
366
+ - ✅ Paper & live trading support
367
+ - ✅ Full audit trail (QuestDB)
368
+ - ✅ Prometheus monitoring
369
+
370
+ **Điểm mạnh:**
371
+ - Tách biệt rõ ràng giữa signal generation (brain) và execution (trader)
372
+ - Risk management chặt chẽ
373
+ - Hỗ trợ A/B testing
374
+ - Dễ mở rộng và maintain
375
+
376
+ **Điểm cần cải thiện:**
377
+ - Thêm order types (limit, stop-limit)
378
+ - Trailing stop loss
379
+ - Partial position closing
380
+ - Multi-exchange support
AIDocs/HT_TRADER_OPTIMIZATIONS.md ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 HT-TRADER Optimization Opportunities
2
+
3
+ ## 📊 Current Analysis
4
+
5
+ ### Strengths ✅
6
+ - Event-driven architecture
7
+ - Risk management in place
8
+ - Position tracking working
9
+ - QuestDB logging functional
10
+ - Paper trading operational
11
+
12
+ ### Areas for Improvement 🔧
13
+
14
+ ---
15
+
16
+ ## 1. 🎯 Signal Processing Optimization
17
+
18
+ ### Current Issue:
19
+ ```python
20
+ # Processes EVERY signal, even low quality ones
21
+ def process_signal(self, signal_data):
22
+ confidence = float(signal_data.get("confidence", 0))
23
+ # No early filtering
24
+ # Continues to risk checks even for low confidence
25
+ ```
26
+
27
+ ### Optimization:
28
+ ```python
29
+ def process_signal(self, signal_data):
30
+ confidence = float(signal_data.get("confidence", 0))
31
+
32
+ # ✅ Early rejection for low confidence
33
+ if confidence < self.min_confidence_threshold:
34
+ logger.debug(f"Signal rejected: Low confidence {confidence:.2%}")
35
+ TRADER_SIGNALS_REJECTED_TOTAL.labels(
36
+ coin=coin,
37
+ reason="low_confidence"
38
+ ).inc()
39
+ return
40
+
41
+ # Continue processing only high-quality signals
42
+ ...
43
+ ```
44
+
45
+ **Benefits:**
46
+ - Reduce CPU usage
47
+ - Faster response time
48
+ - Better metrics tracking
49
+
50
+ ---
51
+
52
+ ## 2. 📈 Position Sizing Optimization
53
+
54
+ ### Current Issue:
55
+ ```python
56
+ # Fixed position size from strategy config
57
+ size_usd = strategy_config.get('size_usd', 1000.0)
58
+ quantity = size_usd / price
59
+ ```
60
+
61
+ ### Optimization: Dynamic Kelly Criterion
62
+ ```python
63
+ def calculate_optimal_size(self, coin, confidence, win_rate, avg_win, avg_loss):
64
+ """
65
+ Kelly Criterion: f* = (p*b - q) / b
66
+ where:
67
+ - p = win probability
68
+ - q = loss probability (1-p)
69
+ - b = win/loss ratio
70
+ """
71
+ if win_rate <= 0.5 or avg_loss == 0:
72
+ return self.min_size_usd
73
+
74
+ p = win_rate
75
+ q = 1 - p
76
+ b = abs(avg_win / avg_loss)
77
+
78
+ kelly_fraction = (p * b - q) / b
79
+ kelly_fraction = max(0, min(kelly_fraction, 0.25)) # Cap at 25%
80
+
81
+ # Adjust by confidence
82
+ confidence_multiplier = confidence / 0.8 # Normalize
83
+
84
+ optimal_size = self.base_size_usd * kelly_fraction * confidence_multiplier
85
+
86
+ return max(self.min_size_usd, min(optimal_size, self.max_size_usd))
87
+ ```
88
+
89
+ **Benefits:**
90
+ - Maximize long-term growth
91
+ - Risk-adjusted position sizing
92
+ - Confidence-weighted allocation
93
+
94
+ ---
95
+
96
+ ## 3. 🔄 Smart Order Execution
97
+
98
+ ### Current Issue:
99
+ ```python
100
+ # Executes immediately at market price
101
+ result = self.exchange_client.place_order(coin, signal, quantity, price)
102
+ ```
103
+
104
+ ### Optimization: TWAP (Time-Weighted Average Price)
105
+ ```python
106
+ async def execute_twap(self, coin, signal, total_quantity, duration_seconds=60):
107
+ """
108
+ Split large orders into smaller chunks over time
109
+ to reduce market impact and get better average price
110
+ """
111
+ num_chunks = 10
112
+ chunk_size = total_quantity / num_chunks
113
+ interval = duration_seconds / num_chunks
114
+
115
+ fills = []
116
+ for i in range(num_chunks):
117
+ # Get current best price
118
+ current_price = self.get_current_price(coin)
119
+
120
+ # Execute chunk
121
+ result = self.exchange_client.place_order(
122
+ coin, signal, chunk_size, current_price
123
+ )
124
+
125
+ if result:
126
+ fills.append({
127
+ 'price': current_price,
128
+ 'quantity': chunk_size,
129
+ 'timestamp': time.time()
130
+ })
131
+
132
+ await asyncio.sleep(interval)
133
+
134
+ # Calculate average fill price
135
+ total_cost = sum(f['price'] * f['quantity'] for f in fills)
136
+ total_qty = sum(f['quantity'] for f in fills)
137
+ avg_price = total_cost / total_qty if total_qty > 0 else 0
138
+
139
+ return avg_price, total_qty
140
+ ```
141
+
142
+ **Benefits:**
143
+ - Better average price
144
+ - Reduced slippage
145
+ - Lower market impact
146
+
147
+ ---
148
+
149
+ ## 4. 🎲 Advanced Exit Strategy
150
+
151
+ ### Current Issue:
152
+ ```python
153
+ # Simple TP/SL based on percentage
154
+ should_close = (pnl_pct > take_profit_pct) or (pnl_pct < -stop_loss_pct)
155
+ ```
156
+
157
+ ### Optimization: Trailing Stop + Dynamic TP
158
+ ```python
159
+ class SmartExitManager:
160
+ def __init__(self):
161
+ self.trailing_stops = {} # {trade_id: trailing_stop_price}
162
+ self.peak_prices = {} # {trade_id: highest_price_seen}
163
+
164
+ def check_exit(self, trade_id, current_price, entry_price, side):
165
+ """
166
+ Advanced exit logic:
167
+ 1. Trailing stop: Lock in profits as price moves favorably
168
+ 2. Dynamic TP: Adjust based on volatility
169
+ 3. Time-based exit: Close stale positions
170
+ """
171
+ # Initialize tracking
172
+ if trade_id not in self.peak_prices:
173
+ self.peak_prices[trade_id] = current_price
174
+ self.trailing_stops[trade_id] = None
175
+
176
+ # Update peak price
177
+ if side == 'LONG':
178
+ self.peak_prices[trade_id] = max(self.peak_prices[trade_id], current_price)
179
+ else:
180
+ self.peak_prices[trade_id] = min(self.peak_prices[trade_id], current_price)
181
+
182
+ # Calculate PnL
183
+ if side == 'LONG':
184
+ pnl_pct = (current_price - entry_price) / entry_price * 100
185
+ peak_pnl_pct = (self.peak_prices[trade_id] - entry_price) / entry_price * 100
186
+ else:
187
+ pnl_pct = (entry_price - current_price) / entry_price * 100
188
+ peak_pnl_pct = (entry_price - self.peak_prices[trade_id]) / entry_price * 100
189
+
190
+ # 1. Trailing Stop (activate after 1% profit)
191
+ if peak_pnl_pct > 1.0:
192
+ trailing_distance = 0.5 # Trail by 0.5%
193
+
194
+ if side == 'LONG':
195
+ trailing_stop = self.peak_prices[trade_id] * (1 - trailing_distance/100)
196
+ if current_price < trailing_stop:
197
+ return True, f"Trailing stop hit (locked {peak_pnl_pct:.2f}%)"
198
+ else:
199
+ trailing_stop = self.peak_prices[trade_id] * (1 + trailing_distance/100)
200
+ if current_price > trailing_stop:
201
+ return True, f"Trailing stop hit (locked {peak_pnl_pct:.2f}%)"
202
+
203
+ # 2. Dynamic Take Profit (based on volatility)
204
+ volatility = self.get_recent_volatility(coin)
205
+ dynamic_tp = 1.5 + (volatility * 2) # Higher TP in volatile markets
206
+
207
+ if pnl_pct > dynamic_tp:
208
+ return True, f"Dynamic TP hit ({pnl_pct:.2f}% > {dynamic_tp:.2f}%)"
209
+
210
+ # 3. Stop Loss
211
+ if pnl_pct < -1.0:
212
+ return True, f"Stop loss hit ({pnl_pct:.2f}%)"
213
+
214
+ # 4. Time-based exit (close after 4 hours)
215
+ hold_time = time.time() - entry_time
216
+ if hold_time > 4 * 3600 and pnl_pct > 0:
217
+ return True, f"Time exit (held {hold_time/3600:.1f}h, profit {pnl_pct:.2f}%)"
218
+
219
+ return False, "Hold"
220
+ ```
221
+
222
+ **Benefits:**
223
+ - Lock in profits automatically
224
+ - Adapt to market volatility
225
+ - Prevent stale positions
226
+
227
+ ---
228
+
229
+ ## 5. 💾 Performance Caching
230
+
231
+ ### Current Issue:
232
+ ```python
233
+ # Recalculates exposure every signal
234
+ for p_coin, p_data in self.position_manager.open_positions.items():
235
+ current_exposure += abs(p_data['quantity']) * p_price
236
+ ```
237
+
238
+ ### Optimization: Cache frequently accessed data
239
+ ```python
240
+ class PerformanceCache:
241
+ def __init__(self, ttl=1.0): # 1 second TTL
242
+ self.cache = {}
243
+ self.ttl = ttl
244
+
245
+ def get_or_compute(self, key, compute_func):
246
+ now = time.time()
247
+
248
+ if key in self.cache:
249
+ value, timestamp = self.cache[key]
250
+ if now - timestamp < self.ttl:
251
+ return value
252
+
253
+ # Compute and cache
254
+ value = compute_func()
255
+ self.cache[key] = (value, now)
256
+ return value
257
+
258
+ # Usage:
259
+ self.cache = PerformanceCache(ttl=1.0)
260
+
261
+ def get_current_exposure(self):
262
+ return self.cache.get_or_compute(
263
+ 'total_exposure',
264
+ lambda: sum(abs(p['quantity']) * p['price']
265
+ for p in self.position_manager.open_positions.values())
266
+ )
267
+ ```
268
+
269
+ **Benefits:**
270
+ - Reduce redundant calculations
271
+ - Faster signal processing
272
+ - Lower CPU usage
273
+
274
+ ---
275
+
276
+ ## 6. 🔔 Smart Alerting
277
+
278
+ ### Current Issue:
279
+ ```python
280
+ # No alerts for important events
281
+ # Trader operates silently
282
+ ```
283
+
284
+ ### Optimization: Event-based notifications
285
+ ```python
286
+ class TraderAlertManager:
287
+ def __init__(self, telegram_bot):
288
+ self.telegram = telegram_bot
289
+ self.alert_cooldown = {}
290
+
291
+ async def alert_large_win(self, coin, pnl_usd):
292
+ """Alert on significant wins"""
293
+ if pnl_usd > 100: # $100+ win
294
+ await self.telegram.send_message(
295
+ f"🎉 Large Win!\n"
296
+ f"Coin: {coin}\n"
297
+ f"PnL: ${pnl_usd:.2f}"
298
+ )
299
+
300
+ async def alert_large_loss(self, coin, pnl_usd):
301
+ """Alert on significant losses"""
302
+ if pnl_usd < -50: # $50+ loss
303
+ await self.telegram.send_message(
304
+ f"⚠️ Large Loss!\n"
305
+ f"Coin: {coin}\n"
306
+ f"PnL: ${pnl_usd:.2f}\n"
307
+ f"Action: Review strategy"
308
+ )
309
+
310
+ async def alert_streak(self, streak_type, count):
311
+ """Alert on win/loss streaks"""
312
+ if count >= 5:
313
+ emoji = "🔥" if streak_type == "win" else "❄️"
314
+ await self.telegram.send_message(
315
+ f"{emoji} {streak_type.upper()} Streak: {count}\n"
316
+ f"Consider: {'Increase size' if streak_type == 'win' else 'Reduce size'}"
317
+ )
318
+ ```
319
+
320
+ **Benefits:**
321
+ - Real-time awareness
322
+ - Quick response to issues
323
+ - Better monitoring
324
+
325
+ ---
326
+
327
+ ## 7. 📊 Advanced Metrics
328
+
329
+ ### Current Issue:
330
+ ```python
331
+ # Only basic metrics tracked
332
+ TRADER_TRADES_EXECUTED_TOTAL.inc()
333
+ ```
334
+
335
+ ### Optimization: Comprehensive metrics
336
+ ```python
337
+ # Add new metrics
338
+ from prometheus_client import Histogram, Gauge
339
+
340
+ # Latency tracking
341
+ SIGNAL_TO_EXECUTION_LATENCY = Histogram(
342
+ 'trader_signal_execution_latency_seconds',
343
+ 'Time from signal received to order executed',
344
+ buckets=[0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0]
345
+ )
346
+
347
+ # Slippage tracking
348
+ EXECUTION_SLIPPAGE = Histogram(
349
+ 'trader_execution_slippage_percent',
350
+ 'Slippage between signal price and fill price',
351
+ buckets=[0.01, 0.05, 0.1, 0.2, 0.5, 1.0]
352
+ )
353
+
354
+ # Current PnL
355
+ CURRENT_PNL_USD = Gauge(
356
+ 'trader_current_pnl_usd',
357
+ 'Current unrealized PnL in USD',
358
+ ['coin']
359
+ )
360
+
361
+ # Win rate (rolling)
362
+ ROLLING_WIN_RATE = Gauge(
363
+ 'trader_rolling_win_rate',
364
+ 'Win rate over last N trades',
365
+ ['window']
366
+ )
367
+
368
+ # Usage in code:
369
+ with SIGNAL_TO_EXECUTION_LATENCY.time():
370
+ self.execute_trade(...)
371
+
372
+ slippage_pct = abs(fill_price - signal_price) / signal_price * 100
373
+ EXECUTION_SLIPPAGE.observe(slippage_pct)
374
+ ```
375
+
376
+ **Benefits:**
377
+ - Better observability
378
+ - Performance insights
379
+ - Easier debugging
380
+
381
+ ---
382
+
383
+ ## 8. 🔄 Async Processing
384
+
385
+ ### Current Issue:
386
+ ```python
387
+ # Synchronous processing blocks on I/O
388
+ def run(self):
389
+ while self.running:
390
+ message = self.pubsub.get_message() # Blocking
391
+ if message:
392
+ self.process_signal(signal_data) # Blocking
393
+ time.sleep(0.01)
394
+ ```
395
+
396
+ ### Optimization: Async/await pattern
397
+ ```python
398
+ async def run(self):
399
+ """Async main loop for better concurrency"""
400
+
401
+ # Create async tasks
402
+ tasks = [
403
+ asyncio.create_task(self.process_signals()),
404
+ asyncio.create_task(self.update_positions()),
405
+ asyncio.create_task(self.monitor_health())
406
+ ]
407
+
408
+ await asyncio.gather(*tasks)
409
+
410
+ async def process_signals(self):
411
+ """Process signals asynchronously"""
412
+ while self.running:
413
+ message = await self.pubsub.get_message_async()
414
+ if message:
415
+ # Process in background
416
+ asyncio.create_task(self.process_signal_async(message))
417
+
418
+ await asyncio.sleep(0.01)
419
+
420
+ async def process_signal_async(self, signal_data):
421
+ """Non-blocking signal processing"""
422
+ # Risk checks (fast)
423
+ if not await self.check_risk_async(signal_data):
424
+ return
425
+
426
+ # Execute trade (I/O bound)
427
+ await self.execute_trade_async(signal_data)
428
+ ```
429
+
430
+ **Benefits:**
431
+ - Handle multiple signals concurrently
432
+ - Better throughput
433
+ - Non-blocking I/O
434
+
435
+ ---
436
+
437
+ ## 9. 🧠 Machine Learning Integration
438
+
439
+ ### Optimization: ML-based trade filtering
440
+ ```python
441
+ class MLTradeFilter:
442
+ def __init__(self):
443
+ self.model = self.load_model('trade_filter_v1.pkl')
444
+
445
+ def should_execute(self, signal_data, market_state):
446
+ """
447
+ Use ML to predict if trade will be profitable
448
+ Based on:
449
+ - Signal features
450
+ - Market conditions
451
+ - Recent performance
452
+ - Time of day
453
+ - Volatility
454
+ """
455
+ features = self.extract_features(signal_data, market_state)
456
+
457
+ # Predict win probability
458
+ win_prob = self.model.predict_proba(features)[0][1]
459
+
460
+ # Only execute if high probability
461
+ return win_prob > 0.65
462
+
463
+ def extract_features(self, signal_data, market_state):
464
+ return {
465
+ 'confidence': signal_data['confidence'],
466
+ 'volatility': market_state['volatility'],
467
+ 'spread': market_state['spread'],
468
+ 'volume': market_state['volume'],
469
+ 'time_of_day': datetime.now().hour,
470
+ 'recent_win_rate': self.get_recent_win_rate(),
471
+ 'market_regime': market_state['regime']
472
+ }
473
+ ```
474
+
475
+ **Benefits:**
476
+ - Filter out low-quality trades
477
+ - Improve win rate
478
+ - Adaptive to market conditions
479
+
480
+ ---
481
+
482
+ ## 10. 🎯 Priority Queue for Signals
483
+
484
+ ### Current Issue:
485
+ ```python
486
+ # Processes signals in order received
487
+ # High-confidence signals may wait behind low-confidence ones
488
+ ```
489
+
490
+ ### Optimization: Priority queue
491
+ ```python
492
+ import heapq
493
+ from dataclasses import dataclass, field
494
+ from typing import Any
495
+
496
+ @dataclass(order=True)
497
+ class PrioritizedSignal:
498
+ priority: float = field(compare=True)
499
+ signal_data: Any = field(compare=False)
500
+ timestamp: float = field(compare=False)
501
+
502
+ class SignalQueue:
503
+ def __init__(self):
504
+ self.queue = []
505
+
506
+ def add_signal(self, signal_data):
507
+ # Higher confidence = higher priority
508
+ confidence = signal_data.get('confidence', 0)
509
+ priority = -confidence # Negative for max-heap behavior
510
+
511
+ item = PrioritizedSignal(
512
+ priority=priority,
513
+ signal_data=signal_data,
514
+ timestamp=time.time()
515
+ )
516
+
517
+ heapq.heappush(self.queue, item)
518
+
519
+ def get_next_signal(self):
520
+ if self.queue:
521
+ return heapq.heappop(self.queue).signal_data
522
+ return None
523
+
524
+ # Usage:
525
+ self.signal_queue = SignalQueue()
526
+
527
+ # Add signals to queue
528
+ self.signal_queue.add_signal(signal_data)
529
+
530
+ # Process highest priority first
531
+ while True:
532
+ signal = self.signal_queue.get_next_signal()
533
+ if signal:
534
+ self.process_signal(signal)
535
+ ```
536
+
537
+ **Benefits:**
538
+ - Process best signals first
539
+ - Better capital utilization
540
+ - Improved performance
541
+
542
+ ---
543
+
544
+ ## 📋 Implementation Priority
545
+
546
+ ### High Priority (Implement First):
547
+ 1. ✅ **Signal Filtering** - Quick win, reduces noise
548
+ 2. ✅ **Performance Caching** - Easy, immediate impact
549
+ 3. ✅ **Smart Alerting** - Better monitoring
550
+ 4. ✅ **Advanced Metrics** - Visibility
551
+
552
+ ### Medium Priority:
553
+ 5. ⚠️ **Dynamic Position Sizing** - Requires backtesting
554
+ 6. ⚠️ **Advanced Exit Strategy** - Needs validation
555
+ 7. ⚠️ **Priority Queue** - Moderate complexity
556
+
557
+ ### Low Priority (Future):
558
+ 8. 🔮 **TWAP Execution** - Complex, for large orders
559
+ 9. 🔮 **Async Processing** - Major refactor
560
+ 10. 🔮 **ML Trade Filter** - Requires training data
561
+
562
+ ---
563
+
564
+ ## 🎯 Expected Impact
565
+
566
+ | Optimization | CPU ↓ | Latency ↓ | Win Rate ↑ | Complexity |
567
+ |--------------|-------|-----------|------------|------------|
568
+ | Signal Filtering | 20% | 30% | 2-3% | Low |
569
+ | Performance Cache | 15% | 25% | 0% | Low |
570
+ | Smart Alerting | 0% | 0% | 0% | Low |
571
+ | Advanced Metrics | 5% | 0% | 0% | Low |
572
+ | Dynamic Sizing | 0% | 0% | 3-5% | Medium |
573
+ | Advanced Exit | 0% | 0% | 5-8% | Medium |
574
+ | Priority Queue | 0% | 15% | 1-2% | Medium |
575
+ | TWAP | 0% | -10% | 1-2% | High |
576
+ | Async Processing | 10% | 40% | 0% | High |
577
+ | ML Filter | 0% | 5% | 8-12% | High |
578
+
579
+ ---
580
+
581
+ ## 🚀 Quick Wins (Can implement now)
582
+
583
+ 1. **Signal Filtering** - 30 minutes
584
+ 2. **Performance Caching** - 1 hour
585
+ 3. **Smart Alerting** - 2 hours
586
+ 4. **Advanced Metrics** - 2 hours
587
+
588
+ **Total time: ~6 hours for 4 optimizations**
589
+ **Expected improvement: +5-8% Win Rate, -50% CPU, -50% Latency**
590
+
591
+ Would you like me to implement any of these optimizations?
AIDocs/ML_TRADE_FILTER_ROADMAP.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🧠 ML Trade Filter - Roadmap & Status
2
+
3
+ ## 📅 Current Status (2025-12-02)
4
+
5
+ ### ✅ Completed:
6
+ - [x] Created ML Trade Filter infrastructure
7
+ - [x] Trained initial Random Forest model (v1)
8
+ - [x] Integrated ML Filter into ht-trader
9
+ - [x] Model successfully loaded and filtering signals
10
+
11
+ ### 📊 Current Model Performance:
12
+ - **Model:** Random Forest Classifier
13
+ - **Version:** trade_filter_v1.pkl
14
+ - **Training Accuracy:** 78.31%
15
+ - **Threshold:** 60% (Win Probability)
16
+ - **Status:** ACTIVE & FILTERING
17
+
18
+ ### 🎯 Current Behavior:
19
+ - Model is **rejecting most signals** (Win Prob ~20.91%)
20
+ - This is **GOOD** - protecting capital from low-quality trades
21
+ - System is conservative and safe
22
+
23
+ ---
24
+
25
+ ## 🔮 Next Steps - Choose One:
26
+
27
+ ### **Option 1: Wait & Retrain (RECOMMENDED) ⭐**
28
+
29
+ **Timeline:** 1-2 weeks
30
+
31
+ **Action Plan:**
32
+ 1. Let system run Paper Trading and collect more data
33
+ 2. Target: 500-1000 additional trades in QuestDB
34
+ 3. Retrain model with larger dataset
35
+ 4. Expected improvement: Accuracy 85%+
36
+
37
+ **How to execute:**
38
+ ```bash
39
+ # After 1-2 weeks, run:
40
+ docker exec ht-trader python train_trade_filter.py
41
+
42
+ # If accuracy > current (78.31%), model will auto-save
43
+ # Then restart ht-trader:
44
+ docker-compose restart ht-trader
45
+ ```
46
+
47
+ **Pros:**
48
+ - ✅ More data = Better model
49
+ - ✅ Higher accuracy
50
+ - ✅ More confident predictions
51
+ - ✅ Better Win Rate in production
52
+
53
+ **Cons:**
54
+ - ⏳ Need to wait 1-2 weeks
55
+ - ⏳ Fewer trades executed during this period
56
+
57
+ ---
58
+
59
+ ### **Option 2: Lower Threshold (QUICK FIX) ⚡**
60
+
61
+ **Timeline:** Immediate
62
+
63
+ **Action Plan:**
64
+ 1. Edit `services/ht-trader/optimization_utils.py`
65
+ 2. Change line in `MLTradeFilter.should_execute()`:
66
+ ```python
67
+ # Current:
68
+ return prob > 0.6, prob # Threshold 0.6 (60%)
69
+
70
+ # Change to:
71
+ return prob > 0.5, prob # Threshold 0.5 (50%)
72
+ ```
73
+ 3. Restart ht-trader
74
+
75
+ **How to execute:**
76
+ ```bash
77
+ # Edit the file (change threshold from 0.6 to 0.5)
78
+ # Then restart:
79
+ docker-compose restart ht-trader
80
+ ```
81
+
82
+ **Pros:**
83
+ - ✅ Immediate effect
84
+ - ✅ More trades will execute
85
+ - ✅ Faster data collection
86
+
87
+ **Cons:**
88
+ - ⚠️ Higher risk (allowing lower quality signals)
89
+ - ⚠️ Potentially lower Win Rate
90
+ - ⚠️ More losses possible
91
+
92
+ ---
93
+
94
+ ## 📈 Recommended Strategy:
95
+
96
+ ### **Phase 1: Current (Week 1-2)**
97
+ - Keep threshold at 60%
98
+ - Let model filter aggressively
99
+ - Collect high-quality data
100
+ - Monitor Win Rate on executed trades
101
+
102
+ ### **Phase 2: Evaluation (Week 2)**
103
+ - Check total trades executed
104
+ - If < 50 trades/week → Consider lowering threshold to 55%
105
+ - If ≥ 50 trades/week → Keep at 60% and wait for retrain
106
+
107
+ ### **Phase 3: Retrain (Week 3-4)**
108
+ - Run training script
109
+ - If new accuracy > 80% → Deploy new model
110
+ - If accuracy < 80% → Collect more data
111
+
112
+ ### **Phase 4: Production Ready (Month 2)**
113
+ - Model accuracy > 85%
114
+ - Win Rate > 60% on paper trading
115
+ - Ready to move to Testnet
116
+
117
+ ---
118
+
119
+ ## 🔧 Quick Commands Reference:
120
+
121
+ ### Check current model performance:
122
+ ```bash
123
+ # View recent rejections
124
+ docker logs ht-trader --tail 50 | grep "ML Filter"
125
+
126
+ # Count rejections vs executions
127
+ docker logs ht-trader --tail 1000 | grep -c "rejected by ML Filter"
128
+ docker logs ht-trader --tail 1000 | grep -c "EXECUTED"
129
+ ```
130
+
131
+ ### Retrain model:
132
+ ```bash
133
+ docker exec ht-trader python train_trade_filter.py
134
+ docker-compose restart ht-trader
135
+ ```
136
+
137
+ ### Check data availability:
138
+ ```bash
139
+ # Query QuestDB for trade count
140
+ curl "http://localhost:9000/exec?query=SELECT count() FROM closed_positions WHERE environment='PAPER'"
141
+ ```
142
+
143
+ ---
144
+
145
+ ## 📊 Success Metrics:
146
+
147
+ ### Before moving to Testnet:
148
+ - [ ] ML Model Accuracy > 85%
149
+ - [ ] Paper Trading Win Rate > 60%
150
+ - [ ] Total trades collected > 1000
151
+ - [ ] Model stable for 2+ weeks
152
+ - [ ] Max Drawdown < 15%
153
+
154
+ ### Before moving to Live:
155
+ - [ ] Testnet Win Rate > 55%
156
+ - [ ] Sharpe Ratio > 1.5
157
+ - [ ] ML Model validated on out-of-sample data
158
+ - [ ] All risk limits tested and working
159
+
160
+ ---
161
+
162
+ ## 🎯 Decision Point:
163
+
164
+ **Choose your path:**
165
+
166
+ 1. **Conservative (Recommended):** Wait 1-2 weeks, collect data, retrain → Higher quality
167
+ 2. **Aggressive:** Lower threshold to 50-55% now → More trades, faster iteration
168
+
169
+ **Current Recommendation:** **Option 1 (Wait & Retrain)**
170
+
171
+ The model is working correctly by filtering low-quality signals. This is protecting your capital. Be patient and let it collect good data for a better v2 model.
172
+
173
+ ---
174
+
175
+ ## 📝 Notes:
176
+
177
+ - Model was trained on: 2025-12-02
178
+ - Next retrain scheduled: 2025-12-16 (2 weeks)
179
+ - Current threshold: 60%
180
+ - Current behavior: Filtering aggressively (good!)
181
+
182
+ **Remember:** A conservative model that protects capital is better than an aggressive model that loses money. Quality > Quantity.
AIDocs/PAPER_TRADING_STRATEGY.md ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📄 Paper Trading Strategy - HyperTrade
2
+
3
+ ## 🎯 Mục tiêu của Paper Trading
4
+
5
+ Paper Trading là **giai đoạn quan trọng** trong quá trình phát triển HyperTrade, cho phép hệ thống:
6
+ - ✅ Test strategies với **ZERO RISK**
7
+ - ✅ Validate AI models trước khi dùng tiền thật
8
+ - ✅ Thu thập performance metrics để cải thiện
9
+ - ✅ Tạo training data cho AI models
10
+
11
+ ---
12
+
13
+ ## 🔄 Continuous Learning Loop
14
+
15
+ ```
16
+ ┌──────────┐
17
+ │ Brain │ ──► Generate Signal (AI-based)
18
+ └──────────┘
19
+
20
+
21
+ ┌──────────┐
22
+ │ Trader │ ──► Execute (Paper Trading)
23
+ └──────────┘
24
+
25
+
26
+ ┌──────────┐
27
+ │ QuestDB │ ──► Store Results (Trades, PnL, Features)
28
+ └──────────┘
29
+
30
+
31
+ ┌──────────┐
32
+ │ Manager │ ──► Analyze Performance (Win Rate, Sharpe)
33
+ └──────────┘
34
+
35
+
36
+ ┌──────────┐
37
+ │ Brain │ ──► Retrain Models (Improve) ──┐
38
+ └──────────┘ │
39
+ ▲ │
40
+ └──────────────────────────────────────┘
41
+ (Feedback Loop)
42
+ ```
43
+
44
+ ### Giải thích vòng lặp:
45
+
46
+ 1. **Brain** tạo trading signals dựa trên:
47
+ - Multi-Task CNN predictions
48
+ - Confidence scores
49
+ - Market regime detection
50
+ - Feature engineering
51
+
52
+ 2. **Trader** thực thi signals trong môi trường **simulated**:
53
+ - Sử dụng L2 orderbook data thực
54
+ - Tính slippage và fees
55
+ - Track positions và PnL
56
+
57
+ 3. **QuestDB** lưu trữ:
58
+ - Mỗi trade (entry/exit)
59
+ - Features tại thời điểm trade
60
+ - Market conditions
61
+ - Outcomes (profit/loss)
62
+
63
+ 4. **Manager** phân tích performance:
64
+ - Win Rate
65
+ - Average PnL per trade
66
+ - Sharpe Ratio
67
+ - Max Drawdown
68
+ - Trade frequency
69
+
70
+ 5. **Brain** học từ kết quả:
71
+ - Retrain models khi performance giảm
72
+ - Cải thiện signal quality
73
+ - Adapt to new market conditions
74
+ - Optimize confidence thresholds
75
+
76
+ 6. **Vòng lặp lại** - Hệ thống tự cải thiện liên tục
77
+
78
+ ---
79
+
80
+ ## 🎓 7 Mục tiêu chính của Paper Trading
81
+
82
+ ### 1. **Testing & Validation (Kiểm thử chiến lược)**
83
+
84
+ **Mục đích:**
85
+ - Test strategies mới mà không rủi ro mất tiền
86
+ - Validate AI models trước khi deploy live
87
+ - Measure performance metrics chính xác
88
+
89
+ **Ví dụ:**
90
+ ```
91
+ Signal: BUY BTC @ $95,000 (Confidence: 78%)
92
+ Paper Execution:
93
+ - Entry: $95,050 (simulated slippage)
94
+ - Exit: $96,200 (TP hit)
95
+ - PnL: +$1,150 (simulated)
96
+ - Fee: -$47.50
97
+ - Net: +$1,102.50
98
+
99
+ Result: WIN ✅
100
+ ```
101
+
102
+ **Metrics tracked:**
103
+ - Entry/Exit prices
104
+ - Slippage
105
+ - Fees
106
+ - Hold time
107
+ - PnL
108
+ - Win/Loss
109
+
110
+ ---
111
+
112
+ ### 2. **Continuous Learning Loop (Vòng lặp học tập)**
113
+
114
+ **Quy trình:**
115
+ ```
116
+ Day 1-7: Paper Trading → Collect 500 trades
117
+ Day 8: Analyze → Win Rate 58%, Sharpe 1.2
118
+ Day 9: Retrain Brain → Improve confidence scorer
119
+ Day 10-16: Paper Trading → Collect 500 more trades
120
+ Day 17: Analyze → Win Rate 62%, Sharpe 1.5 ✅
121
+ ```
122
+
123
+ **Feedback loop:**
124
+ - Mỗi trade = 1 training example
125
+ - Brain học từ successes và failures
126
+ - Continuous improvement without risk
127
+
128
+ ---
129
+
130
+ ### 3. **Safe Development Environment**
131
+
132
+ **Development Stages:**
133
+
134
+ #### **Stage 1: PAPER TRADING (Current)**
135
+ ```
136
+ ┌─────────────────────────────────────────┐
137
+ │ ✅ Test all features │
138
+ │ ✅ Debug issues │
139
+ │ ✅ Optimize strategies │
140
+ │ ✅ Collect performance data │
141
+ │ ✅ Zero financial risk │
142
+ │ ✅ Unlimited experimentation │
143
+ └─────────────────────────────────────────┘
144
+ ```
145
+
146
+ **Checklist trước khi chuyển Stage 2:**
147
+ - [ ] Win Rate > 55% (sustained over 1000+ trades)
148
+ - [ ] Sharpe Ratio > 1.5
149
+ - [ ] Max Drawdown < 15%
150
+ - [ ] No critical bugs
151
+ - [ ] Stable across market conditions
152
+ - [ ] Backtesting confirms results
153
+
154
+ #### **Stage 2: LIVE TRADING (Future)**
155
+ ```
156
+ ┌─────────────────────────────────────────┐
157
+ │ ⚠️ Real money at risk │
158
+ │ ⚠️ Need proven Win Rate > 55% │
159
+ │ ⚠️ Need stable Sharpe Ratio > 1.5 │
160
+ │ ⚠️ Start with small position sizes │
161
+ │ ⚠️ Gradual scaling based on performance │
162
+ └─────────────────────────────────────────┘
163
+ ```
164
+
165
+ ---
166
+
167
+ ### 4. **A/B Testing Strategies**
168
+
169
+ Paper Trading cho phép test **nhiều strategies song song** mà không rủi ro:
170
+
171
+ **Example Strategies:**
172
+
173
+ ```python
174
+ # Strategy A: Conservative
175
+ {
176
+ 'id': 'conservative_v1',
177
+ 'size_usd': 500,
178
+ 'take_profit_pct': 1.5,
179
+ 'stop_loss_pct': 0.8,
180
+ 'min_confidence': 0.7
181
+ }
182
+
183
+ # Strategy B: Aggressive
184
+ {
185
+ 'id': 'aggressive_v1',
186
+ 'size_usd': 1500,
187
+ 'take_profit_pct': 3.0,
188
+ 'stop_loss_pct': 1.5,
189
+ 'min_confidence': 0.6
190
+ }
191
+
192
+ # Strategy C: AI-Optimized
193
+ {
194
+ 'id': 'ai_dynamic_v1',
195
+ 'size_usd': 'dynamic', # Based on confidence
196
+ 'take_profit_pct': 'ai_predicted',
197
+ 'stop_loss_pct': 'ai_predicted',
198
+ 'min_confidence': 0.65
199
+ }
200
+ ```
201
+
202
+ **Performance Comparison (After 1000 trades each):**
203
+
204
+ | Strategy | Win Rate | Avg PnL | Sharpe | Max DD | Winner? |
205
+ |----------|----------|---------|--------|--------|---------|
206
+ | Conservative | 58% | $8.50 | 1.2 | -6% | ❌ |
207
+ | Aggressive | 52% | $15.20 | 0.9 | -18% | ❌ |
208
+ | AI-Optimized | **65%** | **$12.80** | **1.8** | **-8%** | ✅ |
209
+
210
+ **Decision:** Deploy AI-Optimized strategy to live trading
211
+
212
+ ---
213
+
214
+ ### 5. **Risk-Free Model Training Data**
215
+
216
+ **Data Generated by Paper Trading:**
217
+
218
+ ```
219
+ ┌────────────────────────────────────────┐
220
+ │ Per Trade Data: │
221
+ │ • Entry/Exit timestamps │
222
+ │ • Entry/Exit prices │
223
+ │ • Position size │
224
+ │ • Features at trade time (50+ dims) │
225
+ │ • Market regime │
226
+ │ • Confidence score │
227
+ │ • PnL outcome │
228
+ │ • Win/Loss label │
229
+ │ • Slippage │
230
+ │ • Fees │
231
+ └────────────────────────────────────────┘
232
+
233
+
234
+ ┌────────────────────────────────────────┐
235
+ │ Brain uses this data to: │
236
+ │ ✅ Train Confidence Scorer │
237
+ │ - Learn which features predict wins │
238
+ │ - Calibrate confidence thresholds │
239
+ │ │
240
+ │ ✅ Improve Signal Quality │
241
+ │ - Filter low-quality signals │
242
+ │ - Boost high-quality patterns │
243
+ │ │
244
+ │ ✅ Learn optimal entry/exit timing │
245
+ │ - When to enter positions │
246
+ │ - When to take profit │
247
+ │ - When to cut losses │
248
+ │ │
249
+ │ ✅ Adapt to market regimes │
250
+ │ - Bull/Bear/Sideways detection │
251
+ │ - Regime-specific strategies │
252
+ └────────────────────────────────────────┘
253
+ ```
254
+
255
+ **Example Training Workflow:**
256
+
257
+ ```python
258
+ # 1. Collect Paper Trading Data
259
+ trades_df = query_questdb("""
260
+ SELECT * FROM trades_executed
261
+ WHERE environment = 'PAPER'
262
+ AND timestamp > dateadd('d', -30, now())
263
+ """)
264
+
265
+ # 2. Extract Features & Labels
266
+ X = trades_df[feature_columns] # 50+ features
267
+ y = trades_df['pnl'] > 0 # Win/Loss label
268
+
269
+ # 3. Retrain Confidence Scorer
270
+ confidence_model.fit(X, y)
271
+
272
+ # 4. Evaluate
273
+ new_accuracy = confidence_model.score(X_test, y_test)
274
+ print(f"Confidence Model Accuracy: {new_accuracy:.2%}")
275
+
276
+ # 5. Deploy if improved
277
+ if new_accuracy > previous_accuracy:
278
+ save_model(confidence_model, 'confidence_v2.pkl')
279
+ deploy_to_brain()
280
+ ```
281
+
282
+ ---
283
+
284
+ ### 6. **Performance Metrics Collection**
285
+
286
+ **Key Metrics Tracked:**
287
+
288
+ | Metric | Formula | Target | Purpose |
289
+ |--------|---------|--------|---------|
290
+ | **Win Rate** | Wins / Total Trades | > 55% | Basic profitability |
291
+ | **Avg PnL** | Total PnL / Total Trades | > $10 | Per-trade profit |
292
+ | **Sharpe Ratio** | (Return - RiskFree) / StdDev | > 1.5 | Risk-adjusted returns |
293
+ | **Max Drawdown** | Max(Peak - Trough) / Peak | < 15% | Worst losing streak |
294
+ | **Profit Factor** | Gross Profit / Gross Loss | > 1.5 | Profitability ratio |
295
+ | **Trade Frequency** | Trades / Day | 10-50 | Signal quality |
296
+ | **Avg Hold Time** | Exit Time - Entry Time | < 2h | Capital efficiency |
297
+
298
+ **Example Dashboard:**
299
+
300
+ ```
301
+ 📊 Paper Trading Performance (Last 30 Days)
302
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
303
+ Total Trades: 1,247
304
+ Winning Trades: 777 (62.3%) ✅
305
+ Losing Trades: 470 (37.7%)
306
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
307
+ Avg Win: +$18.50
308
+ Avg Loss: -$12.30
309
+ Avg PnL/Trade: +$12.50 ✅
310
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
311
+ Total PnL: +$15,587 (simulated)
312
+ Gross Profit: +$14,380
313
+ Gross Loss: -$5,781
314
+ Profit Factor: 2.49 ✅
315
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
316
+ Sharpe Ratio: 1.65 ✅
317
+ Max Drawdown: -8.2% ✅
318
+ Avg Hold Time: 1.2 hours
319
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
320
+ Status: READY FOR LIVE TRADING ✅
321
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
322
+ ```
323
+
324
+ ---
325
+
326
+ ### 7. **Debugging & Monitoring**
327
+
328
+ **Common Issues Detected:**
329
+
330
+ #### **Issue 1: High Slippage**
331
+ ```
332
+ Signal: BUY BTC @ $95,000
333
+ Expected Fill: $95,000
334
+ Actual Fill: $95,500 (0.53% slippage) ⚠️
335
+
336
+ Root Cause: Orderbook too thin
337
+ Fix: Improve orderbook simulation, add liquidity checks
338
+ ```
339
+
340
+ #### **Issue 2: Risk Limit Rejections**
341
+ ```
342
+ Signal: SELL ETH @ $3,500
343
+ Result: REJECTED (max exposure exceeded)
344
+
345
+ Root Cause: Too many open positions
346
+ Fix: Adjust risk parameters, implement position limits
347
+ ```
348
+
349
+ #### **Issue 3: Signal Quality**
350
+ ```
351
+ Signals Generated: 500/day
352
+ Trades Executed: 50/day (10% acceptance)
353
+
354
+ Root Cause: Low confidence threshold
355
+ Fix: Increase min_confidence from 0.5 to 0.65
356
+ ```
357
+
358
+ #### **Issue 4: Model Drift**
359
+ ```
360
+ Week 1: Win Rate 65%
361
+ Week 2: Win Rate 62%
362
+ Week 3: Win Rate 58% ⚠️
363
+ Week 4: Win Rate 54% 🚨
364
+
365
+ Root Cause: Market regime changed
366
+ Fix: Trigger model retraining, update features
367
+ ```
368
+
369
+ ---
370
+
371
+ ## 📈 Current Status
372
+
373
+ **HyperTrade Development Progress:**
374
+
375
+ ```
376
+ [████████████░░░░░░░░] 60% Complete
377
+
378
+ Stage 1: Paper Trading ──► Stage 2: Live Trading
379
+ (CURRENT) (FUTURE)
380
+ ```
381
+
382
+ **Achievements:**
383
+ - ✅ Paper trading infrastructure complete
384
+ - ✅ L2 orderbook simulation working
385
+ - ✅ Position management implemented
386
+ - ✅ Risk management active
387
+ - ✅ Performance tracking in place
388
+ - ✅ AI feedback loop operational
389
+
390
+ **Next Steps:**
391
+ - [ ] Achieve Win Rate > 55% sustained
392
+ - [ ] Achieve Sharpe Ratio > 1.5
393
+ - [ ] Collect 5,000+ paper trades
394
+ - [ ] Validate across market conditions
395
+ - [ ] Implement live trading safeguards
396
+ - [ ] Start with $100 live positions
397
+
398
+ ---
399
+
400
+ ### 8. **Machine Learning Integration (Trade Filter)**
401
+
402
+ **Mục tiêu:**
403
+ Tạo một lớp bảo vệ thứ hai (Gatekeeper) sử dụng Machine Learning để lọc các tín hiệu trading, chỉ cho phép thực thi các lệnh có xác suất thắng cao nhất.
404
+
405
+ **Quy trình:**
406
+ 1. **Data Collection (Hiện tại):** `ht-trader` chạy Paper Trading và lưu trữ mọi tín hiệu, ngữ cảnh thị trường (features), và kết quả (PnL) vào QuestDB.
407
+ 2. **Model Training:** Khi đủ dữ liệu (ví dụ: >500 trades), hệ thống sẽ tự động train model `trade_filter_v1.pkl` (Random Forest hoặc XGBoost).
408
+ * **Input:** Signal confidence, Market volatility, Spread, Volume, Time of day, Recent win rate.
409
+ * **Output:** Xác suất thắng (Win Probability).
410
+ 3. **Integration:** `ht-trader` load model này và dùng nó để lọc tín hiệu trước khi execute.
411
+ * Nếu `Win Prob > Threshold` (ví dụ 0.65) → **EXECUTE**.
412
+ * Nếu thấp hơn → **REJECT**.
413
+
414
+ ---
415
+
416
+ ## 🚀 Transition to Live Trading
417
+
418
+ ### Prerequisites Checklist
419
+
420
+ #### **Performance Requirements:**
421
+ - [ ] Win Rate > 55% (over 1000+ trades)
422
+ - [ ] Sharpe Ratio > 1.5
423
+ - [ ] Max Drawdown < 15%
424
+ - [ ] Profit Factor > 1.5
425
+ - [ ] Stable performance for 30+ days
426
+
427
+ #### **ML Readiness:**
428
+ - [ ] Collected > 1000 labeled trades in QuestDB
429
+ - [ ] Trained `trade_filter` model with Accuracy > 60%
430
+ - [ ] Validated model on out-of-sample data
431
+
432
+ #### **Technical Requirements:**
433
+ - [ ] No critical bugs
434
+ - [ ] Error handling robust
435
+ - [ ] Monitoring dashboards ready
436
+ - [ ] Alert system functional (Telegram)
437
+ - [ ] Emergency stop mechanism tested
438
+
439
+ #### **Risk Management:**
440
+ - [ ] Position size limits configured
441
+ - [ ] Max exposure limits set
442
+ - [ ] Stop loss logic validated
443
+ - [ ] Circuit breakers implemented
444
+ - [ ] Manual override available
445
+
446
+ ### Gradual Rollout Plan
447
+
448
+ **Phase 1: Paper Trading + ML Training (Current)**
449
+ ```
450
+ Goal: Collect data, train ML Trade Filter, optimize strategies.
451
+ Environment: Simulated (Paper)
452
+ ```
453
+
454
+ **Phase 2: Testnet Validation (Next)**
455
+ ```
456
+ Goal: Validate execution logic, slippage, and ML model performance in a realistic environment.
457
+ Environment: Hyperliquid Testnet (Real matching, fake money)
458
+ Condition to start: ML Model Accuracy > 60% on Paper data.
459
+ ```
460
+
461
+ **Phase 3: Micro Live (Week 1-2 of Live)**
462
+ ```
463
+ Position Size: $100/trade
464
+ Max Exposure: $500
465
+ Coins: BTC only
466
+ Goal: Validate live execution with real money.
467
+ ```
468
+
469
+ **Phase 4: Small Live (Week 3-4)**
470
+ ```
471
+ Position Size: $250/trade
472
+ Max Exposure: $1,500
473
+ Coins: BTC, ETH
474
+ Goal: Test multi-asset trading.
475
+ ```
476
+
477
+ **Phase 5: Full Live (Month 2+)**
478
+ ```
479
+ Position Size: $1,000/trade
480
+ Max Exposure: $5,000
481
+ Coins: All supported
482
+ Goal: Production trading.
483
+ ```
484
+
485
+ ---
486
+
487
+ ## 💡 Best Practices
488
+
489
+ ### 1. **Always Validate in Paper First**
490
+ - Never deploy untested strategies to live
491
+ - Run paper trading for minimum 1000 trades
492
+ - Verify performance across market conditions
493
+
494
+ ### 2. **Monitor Continuously**
495
+ - Check Win Rate daily
496
+ - Track Sharpe Ratio weekly
497
+ - Review Max Drawdown monthly
498
+ - Set up alerts for anomalies
499
+
500
+ ### 3. **Iterate Based on Data**
501
+ - Retrain models when performance drops
502
+ - A/B test new strategies in paper
503
+ - Use feedback loop to improve
504
+
505
+ ### 4. **Risk Management is Key**
506
+ - Start small in live trading
507
+ - Scale gradually based on results
508
+ - Always have stop losses
509
+ - Never risk more than you can afford to lose
510
+
511
+ ### 5. **Document Everything**
512
+ - Log all trades
513
+ - Record all model versions
514
+ - Track all configuration changes
515
+ - Maintain audit trail
516
+
517
+ ---
518
+
519
+ ## 📊 Summary
520
+
521
+ **Paper Trading in HyperTrade serves as:**
522
+
523
+ 1. **🧪 Testing Lab**: Test strategies without risk
524
+ 2. **📚 Training Ground**: Generate data for AI models
525
+ 3. **📊 Analytics Engine**: Collect performance metrics
526
+ 4. **🔍 Debug Tool**: Find bugs before live trading
527
+ 5. **🎯 Optimization Platform**: A/B test strategies
528
+ 6. **🛡️ Safety Net**: Validate before using real money
529
+ 7. **🔄 Learning Loop**: Continuous improvement cycle
530
+
531
+ **The goal:** Build confidence in the system through extensive paper trading before risking real capital.
532
+
533
+ **Current focus:** Optimize Win Rate and Sharpe Ratio to meet live trading thresholds.
534
+
535
+ **Timeline:** Estimated 2-3 months of paper trading before live deployment.
AIDocs/SIGNAL_GENERATION_v3.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HyperTrade Signal Generation & Risk Management v3.0
2
+
3
+ ## 1. Overview
4
+ This document details the advanced signal generation pipeline and risk management protocols implemented in the HyperTrade system (v3.0). The system transitions from a naive "trigger-happy" approach to a sophisticated "Sniper" logic, utilizing Causal AI, Regime Detection, and Multi-Layer Filters.
5
+
6
+ ## 2. Signal Generation Pipeline
7
+
8
+ The pipeline consists of 6 distinct stages, ensuring only high-quality, high-probability signals are executed.
9
+
10
+ ### Stage 1: Data Ingestion (Input)
11
+ - **Source**: `ht-feature-engine` -> Redis Channel `features_updates`
12
+ - **Frequency**: Real-time (Tick-level)
13
+ - **Key Features**:
14
+ - `Mid-Price`: Current market midpoint.
15
+ - `CVD (Cumulative Volume Delta)`: Aggressive buying/selling pressure.
16
+ - `Imbalance`: Orderbook depth skew (Bid vs Ask).
17
+ - `OFI (Order Flow Imbalance)`: Net order flow velocity.
18
+ - `Volatility`: Standard deviation of returns.
19
+
20
+ ### Stage 2: Regime Detection (Context)
21
+ Before processing features, the system identifies the current market state using a hybrid AI approach:
22
+
23
+ 1. **Primary: TRM (Tiny Recursive Model)**
24
+ * A recursive neural network that "thinks" for multiple steps to infer hidden market states.
25
+ * Input: `[Volatility, Imbalance, CVD, Spread, Momentum]`
26
+ * Output: `TREND`, `SIDEWAY`, `VOLATILE`
27
+
28
+ 2. **Fallback: Hidden Markov Model (HMM)**
29
+ * Used if TRM is uncertain or unavailable.
30
+ * Unsupervised learning on `Returns`, `Volatility`, `Imbalance`.
31
+
32
+ **Regimes**:
33
+ - **TREND**: Directional movement. Ideal for momentum strategies.
34
+ - **SIDEWAY**: Range-bound. High noise, requires strict filtering.
35
+ - **VOLATILE**: High risk/uncertainty. Requires maximum safety margins.
36
+
37
+ ### Stage 3: Causal-Weighted Confidence Scoring & AI Verification
38
+ The system combines Causal AI with a Recursive Reasoning Engine to score signals.
39
+
40
+ #### A. Causal Weighting
41
+ **Causal AI** dynamically weights indicators based on their proven impact:
42
+ - **Causal Discovery**: Background process identifies causal links (e.g., `CVD -> Price`).
43
+ - **Dynamic Weighting**:
44
+ - If `CVD` is a proven driver -> Weight boosted by **1.5x**.
45
+ - If `Imbalance` has weak causality -> Weight reduced by **0.5x**.
46
+
47
+ **Heuristic Formula**:
48
+ ```python
49
+ Base_Score = (Norm_CVD * W_CVD) + (Norm_Imbalance * W_Imb) + (Norm_OFI * W_OFI)
50
+ ```
51
+
52
+ #### B. TRM Verification (AI Score)
53
+ The **TRM (Tiny Recursive Model)** verifies the signal by predicting the probability of success based on the full feature set.
54
+ - **Input**: Full feature vector.
55
+ - **Output**: Probability (0.0 - 1.0).
56
+ - **Role**: Contributes 20% to the final ensemble confidence score.
57
+
58
+ ### Stage 4: Advanced Filtering (The "Sniper" Scope)
59
+ This is the core of v3.0, filtering out false positives.
60
+
61
+ #### A. Confluence Check (Consensus)
62
+ - **Logic**: `CVD` and `Imbalance` MUST align in direction.
63
+ - **Action**: If Divergence detected (e.g., CVD Buy + Imbalance Sell) -> **Confidence penalized by 50%**.
64
+ - **Why**: Prevents falling into liquidity traps or absorption walls.
65
+
66
+ #### B. Regime-Based Thresholds
67
+ - **TREND**: Threshold > **0.60** (Aggressive)
68
+ - **SIDEWAY**: Threshold > **0.75** (Conservative)
69
+ - **VOLATILE**: Threshold > **0.80** (Safety First)
70
+
71
+ #### C. Volatility Filter
72
+ - **Logic**: If `Volatility < 0.0001` (Dead Market) -> **Signal Suppressed**.
73
+ - **Why**: Prevents trading in stagnation where spread costs exceed potential profit.
74
+
75
+ ### Stage 5: Cooldown (Rate Limiting)
76
+ - **Mechanism**: **5-second cooldown** per asset after a signal is generated.
77
+ - **Why**: Prevents order spamming, reduces API load, and protects Margin.
78
+
79
+ ### Stage 6: Execution (Output)
80
+ - **Final Output**: JSON payload sent to `ht-nautilus`.
81
+ ```json
82
+ {
83
+ "coin": "BTC",
84
+ "signal": "BUY",
85
+ "confidence": 0.78,
86
+ "regime": "TREND",
87
+ "causal_score": 0.15
88
+ }
89
+ ```
90
+
91
+ ## 3. Risk Management Layers
92
+
93
+ ### Layer 1: Strategy Level (ht-brain)
94
+ - **Cooldowns**: Prevent spam.
95
+ - **Confidence Thresholds**: Ensure high probability.
96
+
97
+ ### Layer 2: Execution Level (ht-nautilus)
98
+ - **Rate Limiter**: Max 1 order per 2 seconds (Internal safety net).
99
+ - **Spread Filter**: Rejects orders if Bid-Ask spread is too wide.
100
+ - **Position Sizing**: Dynamic sizing based on volatility (planned).
101
+
102
+ ### Layer 3: System Level (Safety Circuit Breaker)
103
+ - **SafetyLogHandler**: Monitors system logs in real-time.
104
+ - **Trigger**: 3 consecutive Critical Errors (e.g., Margin Insufficient, API Disconnect).
105
+ - **Action**:
106
+ 1. **Immediate Shutdown** of Trading Node.
107
+ 2. **Telegram Alert** to admin.
108
+ 3. **Cancel All Orders** (via cleanup script).
109
+
110
+ ## 4. Summary of v3.0 Improvements
111
+
112
+ | Feature | Old Version (v2) | New Version (v3.0) |
113
+ | :--- | :--- | :--- |
114
+ | **Trigger Logic** | Naive Sum of Scores | **Causal-Weighted + Confluence** |
115
+ | **Thresholds** | Static (0.65) | **Dynamic (Regime-based)** |
116
+ | **Filters** | None | **Volatility & Divergence Checks** |
117
+ | **Rate Limit** | None (Spam prone) | **5s Cooldown (Source) + 2s (Sink)** |
118
+ | **Safety** | Manual Monitoring | **Automated Circuit Breaker** |
119
+
120
+ ---
121
+ *Generated by HyperTrade AI Assistant*
AIDocs/SYSTEM_WORKFLOW.md ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🔄 LUỒNG CÔNG VIỆC HYPERTRADE - TỪ DATA ĐẾN EXECUTION
2
+
3
+ > **Tài liệu này mô tả chi tiết luồng xử lý dữ liệu từ khi thu thập đến khi thực thi lệnh giao dịch trong hệ thống HyperTrade.**
4
+
5
+ ---
6
+
7
+ ## 📊 GIAI ĐOẠN 1: THU THẬP DỮ LIỆU (DATA COLLECTION)
8
+
9
+ ### Service: `ht-collector`
10
+ **Input:** Kết nối WebSocket với Hyperliquid API
11
+
12
+ **Process:**
13
+ - Lắng nghe **L2 Orderbook** (20 levels bid/ask) real-time
14
+ - Lắng nghe **Trades** (giao dịch đã khớp)
15
+ - Lắng nghe **Candles** (nến 1 phút)
16
+
17
+ **Output:**
18
+ - Publish vào Redis channel `market_data_updates` (JSON format)
19
+ - Ghi vào QuestDB table `l2_orderbook_raw` (lưu trữ lâu dài)
20
+
21
+ **Code Location:** `services/ht-collector/main.py`
22
+
23
+ ---
24
+
25
+ ## 🔧 GIAI ĐOẠN 2: FEATURE ENGINEERING
26
+
27
+ ### Service: `ht-feature-engine`
28
+ **Input:** Subscribe Redis channel `market_data_updates`
29
+
30
+ **Process:** Tính toán các chỉ số kỹ thuật từ L2 Orderbook:
31
+ - **Mid Price** = (Best Bid + Best Ask) / 2
32
+ - **Weighted Mid Price** (WMP)
33
+ - **Spread** = Best Ask - Best Bid
34
+ - **Volatility** = Standard Deviation của mid prices (60s rolling window)
35
+ - **Order Book Imbalance (OBI)** - Top 1, 5, 10, 20 levels
36
+ - **Order Flow Imbalance (OFI)** - Thay đổi giữa 2 snapshots
37
+ - **Cumulative Volume Delta (CVD)** - Buy volume - Sell volume
38
+
39
+ **Output:**
40
+ - Publish vào Redis channel `features_updates` (JSON với ~17 features)
41
+ - Ghi vào QuestDB table `features_microstructure`
42
+
43
+ **Code Location:** `services/ht-feature-engine/main.py`
44
+
45
+ **Features Dictionary Example:**
46
+ ```json
47
+ {
48
+ "mid_price": 96500.0,
49
+ "weighted_mid_price": 96501.2,
50
+ "bid_ask_spread": 0.1,
51
+ "current_volatility": 0.0234,
52
+ "imbalance_ratio": 0.15,
53
+ "imbalance_ratio_top5": 0.12,
54
+ "imbalance_ratio_top10": 0.10,
55
+ "imbalance_ratio_top20": 0.08,
56
+ "ofi": 125.5,
57
+ "cvd": 1500.0,
58
+ "buy_volume_1s": 850.0,
59
+ "sell_volume_1s": 650.0,
60
+ "bid_qty_at_level_1": 10.5,
61
+ "ask_qty_at_level_1": 9.2,
62
+ "bid_qty_top_5": 52.3,
63
+ "ask_qty_top_5": 48.7,
64
+ "obs_count": 20
65
+ }
66
+ ```
67
+
68
+ ---
69
+
70
+ ## 🧠 GIAI ĐOẠN 3: AI SIGNAL GENERATION
71
+
72
+ ### Service: `ht-brain`
73
+ **Input:** Subscribe Redis channel `features_updates`
74
+
75
+ **Process:** Multi-Model Ensemble AI System
76
+
77
+ ### 3.1. Regime Detection (Phát hiện trạng thái thị trường)
78
+ **Model:** Hidden Markov Model (HMM)
79
+ **Input:** OBI, OFI, CVD time series
80
+ **Output:** Market regime classification
81
+ - `TRENDING_UP`: Thị trường tăng mạnh
82
+ - `TRENDING_DOWN`: Thị trường giảm mạnh
83
+ - `SIDEWAY`: Thị trường đi ngang
84
+
85
+ **Code:** `services/ht-brain/regime_detector.py`
86
+
87
+ ### 3.2. Confidence Scoring (Tính độ tin cậy)
88
+ **Model:** Logistic Regression
89
+ **Input:** OBI, OFI, CVD, Spread, Volatility
90
+ **Output:** Confidence score (0.0 - 1.0)
91
+
92
+ **Code:** `services/ht-brain/confidence_scorer.py`
93
+
94
+ ### 3.3. AI Magnitude Prediction (Dự đoán biên độ giá)
95
+ **Model:** Multi-Task CNN (DeepLOB Architecture)
96
+ **Input:** L2 Orderbook snapshot (20 levels × 2 sides)
97
+ **Output:**
98
+ - Direction probabilities: [P(down), P(stationary), P(up)]
99
+ - Magnitude predictions for multiple horizons: [10s, 30s, 1m, 2m, 5m]
100
+
101
+ **Code:** `services/ht-brain/models/deep_lob_multi_task.py`
102
+
103
+ **Training:** Được train bởi `ht-ai-inference` service
104
+
105
+ ### 3.4. Causal Discovery (Tìm mối quan hệ nhân quả)
106
+ **Algorithm:** PC Algorithm (Peter-Clark)
107
+ **Purpose:** Xác định features nào có ảnh hưởng nhân quả đến price movement
108
+ **Output:** Causal graph, feature importance scores
109
+
110
+ **Code:** `services/ht-brain/causal_discovery.py`
111
+
112
+ ### 3.5. Final Decision Logic (Logic quyết định cuối cùng)
113
+
114
+ **Pseudo Code:**
115
+ ```python
116
+ # 1. Get predictions from all models
117
+ regime = hmm_model.predict(features)
118
+ confidence = logistic_model.predict_proba(features)
119
+ cnn_direction, cnn_magnitude = cnn_model.predict(orderbook)
120
+
121
+ # 2. Combine predictions
122
+ if cnn_direction == UP and confidence > 0.6 and regime != SIDEWAY:
123
+ signal = "BUY"
124
+ final_confidence = confidence * 0.7 + cnn_confidence * 0.3
125
+ elif cnn_direction == DOWN and confidence > 0.6 and regime != SIDEWAY:
126
+ signal = "SELL"
127
+ final_confidence = confidence * 0.7 + cnn_confidence * 0.3
128
+ else:
129
+ signal = "HOLD"
130
+ final_confidence = confidence
131
+
132
+ # 3. Feature Normalization & Injection
133
+ features['volatility'] = features.get('current_volatility', 0.002)
134
+ features['spread'] = features.get('bid_ask_spread', 0.0002)
135
+
136
+ # 4. Create Signal Message
137
+ signal_msg = {
138
+ "coin": coin,
139
+ "signal": signal,
140
+ "confidence": final_confidence,
141
+ "price": features['mid_price'],
142
+ "features": features, # Full features dict
143
+ "regime": regime,
144
+ "predicted_magnitude": cnn_magnitude
145
+ }
146
+ ```
147
+
148
+ **Code Location:** `services/ht-brain/main.py` - method `analyze()`
149
+
150
+ ### 3.6. Signal Publishing & Persistence
151
+
152
+ **Output 1: Redis Channel `signals`**
153
+ ```python
154
+ redis_client.publish("signals", signal_msg)
155
+ ```
156
+ - Format: JSON (SignalMessage schema)
157
+ - Subscribers: `ht-trader`, `ht-nautilus-paper`
158
+
159
+ **Output 2: QuestDB Table `signals_generated`** ✅
160
+ ```python
161
+ db_sender.sender.row(
162
+ 'signals_generated',
163
+ symbols={
164
+ 'coin': coin,
165
+ 'signal': signal,
166
+ 'regime': regime
167
+ },
168
+ columns={
169
+ 'confidence': confidence,
170
+ 'price': price,
171
+ 'predicted_magnitude': cnn_magnitude,
172
+ 'predicted_price': predicted_price,
173
+ 'reason': reason_str,
174
+ 'features_json': json.dumps(features)
175
+ },
176
+ at=TimestampNanos.now()
177
+ )
178
+ ```
179
+
180
+ **Lưu ý quan trọng:**
181
+ - ✅ **CÓ LƯU VÀO DATABASE** (QuestDB table `signals_generated`)
182
+ - Mục đích: Tracking, backtesting, phân tích hiệu quả AI
183
+ - Chỉ lưu signals BUY/SELL (không lưu HOLD để tránh spam)
184
+
185
+ **Code Location:** `services/ht-brain/main.py` - method `send_signal()`
186
+
187
+ ---
188
+
189
+ ## 🤖 AI MODELS SUMMARY (Trả lời câu hỏi 3)
190
+
191
+ ### Signal được tạo bởi **ENSEMBLE của 4 AI Models:**
192
+
193
+ | Model | Vai trò | Output | Training Frequency |
194
+ |-------|---------|--------|-------------------|
195
+ | **HMM** | Regime Detection | TRENDING_UP/DOWN/SIDEWAY | Mỗi 15 phút |
196
+ | **Logistic Regression** | Confidence Scoring | 0.0 - 1.0 | Mỗi 30 phút |
197
+ | **Multi-Task CNN** | Direction + Magnitude | [Direction, Magnitude] | Mỗi 1 giờ |
198
+ | **PC Algorithm** | Feature Selection | Causal Graph | Mỗi 6 giờ |
199
+
200
+ **Final Signal = Weighted Combination:**
201
+ ```
202
+ Signal Direction: CNN (primary) + Regime (filter)
203
+ Confidence: 70% Logistic + 30% CNN
204
+ Magnitude: CNN prediction
205
+ ```
206
+
207
+ **Training Data Source:** QuestDB tables
208
+ - `features_microstructure`: Input features
209
+ - `trades_executed`: Labels (win/loss)
210
+ - `l2_orderbook_raw`: Raw data for CNN
211
+
212
+ **Training Orchestration:** `ht-manager` service (scheduled jobs)
213
+
214
+ ---
215
+
216
+ ## 🛡️ GIAI ĐOẠN 4: ML TRADE FILTER (Lọc tín hiệu)
217
+
218
+ ### Service: `ht-trader` (MLTradeFilter component)
219
+ **Input:** Signal từ Redis channel `signals`
220
+
221
+ **Process:**
222
+ 1. Extract features từ signal:
223
+ - `confidence` (từ ht-brain)
224
+ - `volatility` (từ ht-feature-engine, normalized by ht-brain)
225
+ - `spread` (từ ht-feature-engine, normalized by ht-brain)
226
+ - `hour` (thời gian trong ngày - 0-23)
227
+
228
+ 2. ML Model (Random Forest Classifier):
229
+ ```python
230
+ features = [confidence, volatility, spread, hour]
231
+ win_probability = model.predict_proba(features)[0][1]
232
+
233
+ if win_probability > 0.6: # Threshold
234
+ return ACCEPT
235
+ else:
236
+ return REJECT
237
+ ```
238
+
239
+ **Model Training:**
240
+ - Script: `services/ht-trader/train_trade_filter.py`
241
+ - Data Source: QuestDB table `trades_executed`
242
+ - Features: `confidence`, `volatility`, `spread`, `hour`
243
+ - Label: `pnl > 0` (win) or `pnl <= 0` (loss)
244
+ - Algorithm: Random Forest (100 trees)
245
+ - Training Frequency: **Mỗi 24 giờ** (automated by `ht-manager`)
246
+
247
+ **Output:** Decision: ACCEPT hoặc REJECT signal
248
+
249
+ **Code Location:** `services/ht-trader/optimization_utils.py` - class `MLTradeFilter`
250
+
251
+ ---
252
+
253
+ ## 💼 GIAI ĐOẠN 5A: PAPER TRADING (ht-trader)
254
+
255
+ ### Service: `ht-trader` (Mode: PAPER)
256
+ **Input:** Signals đã qua ML Filter
257
+
258
+ **Process:**
259
+
260
+ ### 5A.1. Risk Management
261
+ **Component:** `RiskManager`
262
+ **Checks:**
263
+ - Max position size per coin (default: $5000)
264
+ - Max total exposure across all positions (default: $15000)
265
+ - Max daily loss (default: $500)
266
+ - Max hourly loss (default: $200)
267
+ - Max drawdown percentage (default: 15%)
268
+
269
+ **Code:** `services/ht-trader/risk_manager.py`
270
+
271
+ ### 5A.2. Position Sizing
272
+ **Algorithm:** Dynamic sizing dựa trên confidence
273
+ ```python
274
+ base_size_usd = 1000 # Base position size
275
+ confidence_multiplier = 1 + (confidence - 0.6) * 2 # Range: 1.0 - 1.8
276
+ optimal_size = base_size_usd * confidence_multiplier
277
+ ```
278
+
279
+ **Code:** `services/ht-trader/main.py` - method `calculate_optimal_size()`
280
+
281
+ ### 5A.3. Simulated Execution
282
+ **Process:**
283
+ 1. Lấy orderbook từ Redis (nếu có)
284
+ 2. Match order với orderbook levels:
285
+ - BUY: Match với asks (ascending)
286
+ - SELL: Match với bids (descending)
287
+ 3. Tính average fill price và slippage
288
+ 4. Tính fee: `size * price * 0.00035` (0.035% taker fee)
289
+
290
+ **Code:** `services/ht-trader/exchange_client.py` - method `place_order()`
291
+
292
+ ### 5A.4. Position Management
293
+ **Component:** `PositionManager`
294
+ **Features:**
295
+ - Track open positions (in-memory dict)
296
+ - Check exit conditions:
297
+ - Stop-loss (default: -2%)
298
+ - Take-profit (default: +5%)
299
+ - Reversal signal (opposite direction signal)
300
+ - Smart Exit với trailing stop (optional)
301
+
302
+ **Code:** `services/ht-trader/position_manager.py`
303
+
304
+ ### 5A.5. Logging & Metrics
305
+ **QuestDB Table:** `trades_executed`
306
+ ```sql
307
+ CREATE TABLE trades_executed (
308
+ trade_id SYMBOL,
309
+ coin SYMBOL,
310
+ side SYMBOL,
311
+ status SYMBOL,
312
+ exit_reason SYMBOL,
313
+ strategy_id SYMBOL,
314
+ environment SYMBOL,
315
+ entry_price DOUBLE,
316
+ exit_price DOUBLE,
317
+ quantity DOUBLE,
318
+ entry_fee DOUBLE,
319
+ exit_fee DOUBLE,
320
+ current_pnl DOUBLE,
321
+ features_json STRING,
322
+ notes STRING,
323
+ timestamp TIMESTAMP
324
+ ) TIMESTAMP(timestamp) PARTITION BY DAY;
325
+ ```
326
+
327
+ **Prometheus Metrics:**
328
+ - `trader_signals_received_total`
329
+ - `trader_signals_rejected_total`
330
+ - `trader_trades_executed_total`
331
+ - `trader_pnl_total`
332
+
333
+ **Code:** `services/ht-trader/main.py`
334
+
335
+ ---
336
+
337
+ ## 🚀 GIAI ĐOẠN 5B: TESTNET TRADING (ht-nautilus-paper)
338
+
339
+ ### Service: `ht-nautilus-paper` (Mode: TESTNET)
340
+ **Input:** Signals từ Redis channel `signals`
341
+
342
+ **Process:**
343
+
344
+ ### 5B.1. NautilusTrader Strategy
345
+ **Component:** `RedisSignalStrategy`
346
+ **Process:**
347
+ 1. Subscribe Redis channel `signals`
348
+ 2. Deserialize SignalMessage
349
+ 3. Convert sang Nautilus Order objects:
350
+ ```python
351
+ order = LimitOrder(
352
+ instrument_id=InstrumentId.from_str(f"{coin}.HL"),
353
+ order_side=OrderSide.BUY/SELL,
354
+ quantity=Quantity(size, precision=4),
355
+ price=Price(price, precision=2),
356
+ time_in_force=TimeInForce.IOC
357
+ )
358
+ ```
359
+ 4. Submit order qua ExecutionEngine
360
+
361
+ **Code:** `services/ht-nautilus/strategies/redis_listener.py`
362
+
363
+ ### 5B.2. Risk Engine (NautilusTrader Built-in)
364
+ **Features:**
365
+ - Max order submit rate: 5 orders/second
366
+ - Position limits (configurable)
367
+ - Exposure checks
368
+ - Pre-trade risk validation
369
+
370
+ **Config:** `services/ht-nautilus/paper.py` - `LiveRiskEngineConfig`
371
+
372
+ ### 5B.3. Execution Client (Hyperliquid Adapter)
373
+ **Component:** `HyperliquidExecutionClient`
374
+ **Connection:**
375
+ - API: Hyperliquid Testnet (`https://api.hyperliquid-testnet.xyz`)
376
+ - Wallet: `0x5bf7135bBd778f4c4A9D1e0C9dD79c1348968c4D`
377
+ - Private Key: From env var `HYPERLIQUID_TESTNET_PK`
378
+
379
+ **Process:**
380
+ 1. Receive `SubmitOrder` command từ ExecutionEngine
381
+ 2. Map Nautilus Order → Hyperliquid SDK format:
382
+ ```python
383
+ result = exchange.order(
384
+ name=coin,
385
+ is_buy=is_buy,
386
+ sz=size_rounded,
387
+ limit_px=price_rounded,
388
+ order_type={"limit": {"tif": "Ioc"}},
389
+ reduce_only=False
390
+ )
391
+ ```
392
+ 3. Parse response từ Hyperliquid:
393
+ - `filled`: Instant fill → Publish `FillReport`
394
+ - `resting`: Order đang chờ → Publish `OrderStatusReport(ACCEPTED)`
395
+ - Error: Publish `OrderStatusReport(REJECTED)`
396
+
397
+ **Code:** `services/ht-nautilus/hyperliquid_adapter.py`
398
+
399
+ ### 5B.4. Order Lifecycle
400
+ ```
401
+ SubmitOrder (Strategy)
402
+
403
+ ExecutionEngine (Risk Check)
404
+
405
+ HyperliquidExecutionClient
406
+
407
+ Hyperliquid Testnet API
408
+
409
+ OrderStatusReport(ACCEPTED) → MessageBus
410
+
411
+ FillReport → MessageBus
412
+
413
+ OrderStatusReport(FILLED) → MessageBus
414
+
415
+ Portfolio Update (PnL calculation)
416
+ ```
417
+
418
+ ### 5B.5. Event Handling & Portfolio
419
+ **NautilusTrader MessageBus:**
420
+ - Topics: `execution`, `orders`, `fills`, `positions`
421
+ - Subscribers: Portfolio, Strategy, Logger
422
+
423
+ **Portfolio Component:**
424
+ - Automatic PnL calculation
425
+ - Position tracking (real-time)
426
+ - Margin calculations
427
+ - Performance metrics
428
+
429
+ **Output:**
430
+ - **Lệnh thật** trên Hyperliquid Testnet (tiền giả, execution thật)
431
+ - Real fills với actual slippage
432
+ - Portfolio state trong Nautilus Cache
433
+
434
+ ---
435
+
436
+ ## 📈 GIAI ĐOẠN 6: MONITORING & ORCHESTRATION
437
+
438
+ ### Service: `ht-manager`
439
+
440
+ ### 6.1. System Health Monitoring (Mỗi 5 phút)
441
+ **Component:** `SystemHealthLoop`
442
+ **11 Diagnostic Tasks:**
443
+ 1. Infrastructure Health (Redis, QuestDB)
444
+ 2. Data Flow Validation (L2 updates, features, signals)
445
+ 3. Trading System Status (ht-trader, ht-nautilus)
446
+ 4. Resource Monitoring (CPU, RAM, Disk)
447
+ 5. Error Rate Checking
448
+ 6. PnL Monitoring
449
+ 7. Model Performance Tracking
450
+ 8. Data Quality Checks
451
+ 9. Alert System Validation
452
+ 10. Database Integrity
453
+ 11. Periodic Summary Reports
454
+
455
+ **Code:** `services/ht-manager/health_loop.py`
456
+
457
+ ### 6.2. AI Orchestrator
458
+ **Modes:**
459
+ - **High CPU (>70%):** Rule-Based Orchestrator (ultra-fast)
460
+ - **Low CPU (<70%):** Qwen LLM Orchestrator (intelligent)
461
+
462
+ **Actions:**
463
+ - `wait`: Do nothing
464
+ - `check_service_health`: Verify service status
465
+ - `get_pnl_status`: Query current PnL
466
+ - `restart_service`: Restart failed service
467
+ - `emergency_stop`: Stop all trading
468
+ - `trigger_brain_training`: Retrain AI models
469
+
470
+ **Code:**
471
+ - `services/ht-manager/orchestrator.py` (Qwen LLM)
472
+ - `services/ht-manager/orchestrator_rules.py` (Rule-based)
473
+
474
+ ### 6.3. Scheduled Training (Automation)
475
+ **Training Schedule:**
476
+
477
+ | Model | Frequency | Command | Target Service |
478
+ |-------|-----------|---------|----------------|
479
+ | HMM | 15 phút | `TRAIN_HMM` | ht-brain |
480
+ | Confidence Scorer | 30 phút | `TRAIN_CONFIDENCE` | ht-brain |
481
+ | Causal Discovery | 6 giờ | `TRAIN_CAUSAL` | ht-brain |
482
+ | CNN Model | 1 giờ | `TRAIN_CNN` | ht-brain |
483
+ | ML Trade Filter | **24 giờ** | `TRAIN_TRADE_FILTER` | ht-trader |
484
+
485
+ **Persistent Scheduling:**
486
+ - Lưu timestamp vào Redis: `scheduler:last_run:{model_name}`
487
+ - Đảm bảo lịch không bị reset khi restart service
488
+ - Đọc từ Redis khi khởi động để tiếp tục schedule
489
+
490
+ **Code:** `services/ht-manager/main.py` - main loop
491
+
492
+ ### 6.4. Training Execution Flow (ML Trade Filter Example)
493
+ ```
494
+ ht-manager (scheduler check)
495
+
496
+ Publish Redis: {"command": "TRAIN_TRADE_FILTER"}
497
+
498
+ ht-trader (subscribe system_commands)
499
+
500
+ Run subprocess: python train_trade_filter.py
501
+
502
+ Query QuestDB: SELECT * FROM trades_executed
503
+
504
+ Train Random Forest model
505
+
506
+ Save model: models/trade_filter_v{version}.pkl
507
+
508
+ Reload model in MLTradeFilter
509
+
510
+ Update Redis: scheduler:last_run:trade_filter = current_time
511
+ ```
512
+
513
+ **Code:**
514
+ - Scheduler: `services/ht-manager/main.py`
515
+ - Handler: `services/ht-trader/main.py` - method `process_command()`
516
+ - Training: `services/ht-trader/train_trade_filter.py`
517
+
518
+ ---
519
+
520
+ ## 🔄 FEEDBACK LOOP (Self-Learning)
521
+
522
+ ### Service: `ht-brain` (Experience Replay)
523
+
524
+ **Process:**
525
+ 1. **Lưu predictions:**
526
+ ```python
527
+ self.predictions[coin].append({
528
+ 'timestamp': current_time,
529
+ 'price': current_price,
530
+ 'direction': predicted_direction,
531
+ 'magnitude': predicted_magnitude,
532
+ 'horizon': 60 # seconds
533
+ })
534
+ ```
535
+
536
+ 2. **Kiểm tra matured predictions:**
537
+ ```python
538
+ for pred in predictions:
539
+ if current_time - pred['timestamp'] >= pred['horizon']:
540
+ actual_price = get_price_at(pred['timestamp'] + horizon)
541
+ actual_direction = sign(actual_price - pred['price'])
542
+ actual_magnitude = (actual_price - pred['price']) / pred['price']
543
+
544
+ # Calculate accuracy
545
+ direction_correct = (actual_direction == pred['direction'])
546
+ magnitude_error = abs(actual_magnitude - pred['magnitude'])
547
+ ```
548
+
549
+ 3. **Update metrics:**
550
+ - Direction accuracy
551
+ - Magnitude MAE (Mean Absolute Error)
552
+ - Horizon-specific performance
553
+
554
+ 4. **Trigger retrain:**
555
+ ```python
556
+ if len(validated_predictions) >= 1000:
557
+ trigger_cnn_training()
558
+ ```
559
+
560
+ **Code:** `services/ht-brain/main.py` - Experience Replay section
561
+
562
+ ---
563
+
564
+ ## 📊 VISUALIZATION & ANALYSIS
565
+
566
+ ### Service: `ht-dashboard` (Streamlit)
567
+
568
+ **Pages:**
569
+ 1. **Overview:**
570
+ - Real-time PnL chart
571
+ - Win rate by coin
572
+ - Total trades executed
573
+ - Current positions
574
+
575
+ 2. **Signal Analysis:**
576
+ - Signal distribution (BUY/SELL/HOLD)
577
+ - Confidence histogram
578
+ - Regime distribution
579
+ - Signal frequency by hour
580
+
581
+ 3. **Model Performance:**
582
+ - CNN accuracy by horizon
583
+ - Confidence scorer calibration
584
+ - HMM regime accuracy
585
+ - ML Trade Filter win rate
586
+
587
+ 4. **System Health:**
588
+ - Service status (Docker containers)
589
+ - CPU/RAM usage
590
+ - Database size
591
+ - Error logs
592
+
593
+ 5. **Backtest Results:**
594
+ - Strategy comparison
595
+ - Parameter optimization results
596
+ - Equity curve
597
+ - Drawdown analysis
598
+
599
+ **Code:** `services/ht-dashboard/app.py`
600
+
601
+ ---
602
+
603
+ ## 🎯 LUỒNG TỔNG QUAN (SIMPLIFIED)
604
+
605
+ ```
606
+ ┌─────────────────────────────────────────────────────────────────┐
607
+ │ HYPERLIQUID EXCHANGE │
608
+ │ (WebSocket API) │
609
+ └────────────────────────┬────────────────────────────────────────┘
610
+
611
+
612
+ ┌─────────────────────────────────────────────────────────────────┐
613
+ │ GIAI ĐOẠN 1: DATA COLLECTION │
614
+ │ Service: ht-collector │
615
+ │ Output: Redis (market_data_updates) + QuestDB (l2_orderbook) │
616
+ └────────────────────────┬────────────────────────────────────────┘
617
+
618
+
619
+ ┌─────────────────────────────────────────────────────────────────┐
620
+ │ GIAI ĐOẠN 2: FEATURE ENGINEERING │
621
+ │ Service: ht-feature-engine │
622
+ │ Output: Redis (features_updates) + QuestDB (features_micro) │
623
+ │ Features: 17 indicators (OBI, OFI, CVD, Volatility, etc.) │
624
+ └────────────────────────��────────────────────────────────────────┘
625
+
626
+
627
+ ┌─────────────────────────────────────────────────────────────────┐
628
+ │ GIAI ĐOẠN 3: AI SIGNAL GENERATION │
629
+ │ Service: ht-brain │
630
+ │ AI Models: │
631
+ │ 1. HMM (Regime Detection) │
632
+ │ 2. Logistic Regression (Confidence Scoring) │
633
+ │ 3. Multi-Task CNN (Direction + Magnitude) │
634
+ │ 4. PC Algorithm (Causal Discovery) │
635
+ │ Output: Redis (signals) + QuestDB (signals_generated) ✅ │
636
+ └────────────────────────┬────────────────────────────────────────┘
637
+
638
+
639
+ ┌─────────────────────────────────────────────────────────────────┐
640
+ │ GIAI ĐOẠN 4: ML TRADE FILTER │
641
+ │ Service: ht-trader (MLTradeFilter) │
642
+ │ Model: Random Forest Classifier │
643
+ │ Decision: ACCEPT (win_prob > 60%) or REJECT │
644
+ └────────────────────────┬────────────────────────────────────────┘
645
+
646
+
647
+ ┌───────────────┴───────────────┐
648
+ │ │
649
+ ▼ ▼
650
+ ┌──────────────────────┐ ┌──────────────────────┐
651
+ │ GIAI ĐOẠN 5A: │ │ GIAI ĐOẠN 5B: │
652
+ │ PAPER TRADING │ │ TESTNET TRADING │
653
+ │ Service: ht-trader │ │ Service: │
654
+ │ Mode: PAPER │ │ ht-nautilus-paper │
655
+ │ Execution: │ │ Mode: TESTNET │
656
+ │ Simulated │ │ Execution: │
657
+ │ │ │ Hyperliquid API │
658
+ │ Output: │ │ (Real orders) │
659
+ │ - QuestDB │ │ │
660
+ │ - Metrics │ │ Output: │
661
+ │ │ │ - Real fills │
662
+ │ │ │ - Portfolio state │
663
+ └──────────────────────┘ └──────────────────────┘
664
+ │ │
665
+ └───────────────┬───────────────┘
666
+
667
+
668
+ ┌─────────────────────────────────────────────────────────────────┐
669
+ │ GIAI ĐOẠN 6: MONITORING & ORCHESTRATION │
670
+ │ Service: ht-manager │
671
+ │ Functions: │
672
+ │ - System Health Monitoring (11 tasks) │
673
+ │ - AI Orchestrator (Rule-based / Qwen LLM) │
674
+ │ - Scheduled Training (5 models, automated) │
675
+ │ - Persistent Scheduling (Redis-backed) │
676
+ │ - Telegram Alerts │
677
+ └────────────────────────┬────────────────────────────────────────┘
678
+
679
+
680
+ ┌─────────────────────────────────────────────────────────────────┐
681
+ │ FEEDBACK LOOP: SELF-LEARNING │
682
+ │ Service: ht-brain (Experience Replay) │
683
+ │ Process: Validate predictions → Update metrics → Retrain │
684
+ └─────────────────────────────────────────────────────────────────┘
685
+
686
+
687
+ ┌─────────────────────────────────────────────────────────────────┐
688
+ │ VISUALIZATION: ht-dashboard (Streamlit) │
689
+ │ Real-time charts, metrics, backtest results │
690
+ └─────────────────────────────────────────────────────────────────┘
691
+ ```
692
+
693
+ ---
694
+
695
+ ## 🔑 KEY POINTS
696
+
697
+ ### 1. Signals CÓ được lưu vào database
698
+ - **Table:** `signals_generated` trong QuestDB
699
+ - **Mục đích:** Tracking, backtesting, phân tích hiệu quả AI
700
+ - **Lọc:** Chỉ lưu BUY/SELL (không lưu HOLD)
701
+
702
+ ### 2. Signal được tạo bởi ENSEMBLE AI
703
+ - **4 Models:** HMM + Logistic Regression + Multi-Task CNN + PC Algorithm
704
+ - **Weighted Combination:** 70% Logistic + 30% CNN cho confidence
705
+ - **Primary Direction:** CNN model (DeepLOB architecture)
706
+ - **Filter:** Regime (HMM) để tránh trade trong sideway market
707
+
708
+ ### 3. Dual Execution Engines
709
+ - **ht-trader (Paper):** Giả lập nhanh, thu thập data cho ML
710
+ - **ht-nautilus-paper (Testnet):** Execution thật, test chiến lược thực tế
711
+
712
+ ### 4. Automated Training Pipeline
713
+ - **5 Models** được train tự động theo lịch
714
+ - **Persistent scheduling** qua Redis (không bị reset khi restart)
715
+ - **Self-learning** qua Experience Replay
716
+
717
+ ---
718
+
719
+ ## 📚 RELATED DOCUMENTATION
720
+
721
+ - [Architecture Overview](./ARCHITECTURE_HT_TRADER.md)
722
+ - [Paper Trading Strategy](./PAPER_TRADING_STRATEGY.md)
723
+ - [ML Trade Filter Roadmap](./ML_TRADE_FILTER_ROADMAP.md)
724
+ - [Trader Optimizations](./HT_TRADER_OPTIMIZATIONS.md)
725
+
726
+ ---
727
+
728
+ **Last Updated:** 2025-12-02
729
+ **Version:** 1.0
AIDocs/TRAINING_STRATEGY_v2.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Training Strategy v2.0 (Hyperliquid L2 Data)
2
+
3
+ This document outlines the optimal training strategy for the HyperTrade AI engine using the valid `gionuibk/hyperliquidL2Book-v2` dataset.
4
+
5
+ ## 1. Data Foundation
6
+
7
+ We have three verified data sources in the v2 dataset:
8
+ 1. **L2 Order Book (`data/l2book/*.parquet`)**: High-fidelity snapshots (20 levels).
9
+ * *Usage*: **DeepLOB** (CNNs for spatial feature extraction).
10
+ 2. **Trades (`data/l4_node_trades*/*.parquet`)**: Executed trade flow.
11
+ * *Usage*: **TRM** (Regime detection via CVD & OFI).
12
+ 3. **Candles (`data/candles/*.parquet`)**: Aggregated OHLCV bars.
13
+ * *Usage*: **LSTM** (Temporal sequence modeling for trend).
14
+
15
+ ## 2. Model-Specific Training Recipes
16
+
17
+ ### A. DeepLOB (The "Sniper" Scope)
18
+ **Objective**: Detect short-term directional moves based on order book pressure.
19
+
20
+ * **Input**: Rolling window of L2 Book (100 ticks x 40 features).
21
+ * **Target**: Triple Barrier (Profit Take vs Stop Loss within 100 ticks).
22
+ * **Strategy**:
23
+ 1. **Curriculum Learning**:
24
+ * *Phase 1*: Train on **High Volatility** days first. (The model learns faster when signal-to-noise is high).
25
+ * *Phase 2*: Train on **Sideways** days. (Fine-tune to reduce false positives).
26
+ 2. **class_weights**: The market is mostly noise (Hold). Use `class_weights=[1.0, 5.0, 5.0]` to force the model to care about Buy/Sell signals.
27
+ 3. **Learning Rate**: Start `1e-4`, decay by 0.5 every 5 epochs.
28
+
29
+ ### B. TRM (The "General")
30
+ **Objective**: Identify Market Regimes (Trend vs Sideways).
31
+
32
+ * **Input**: Features (`Volatility`, `Imbalance`, `Spread`, `CVD`, `OFI`).
33
+ * *Note*: Requires merging L2 Book timestamps with Trade timestamps to compute accurate CVD/OFI.
34
+ * **Strategy**:
35
+ * **Unsupervised Pre-training**: Use a Hidden Markov Model (HMM) first to label data as "High Vol", "Low Vol", "Trend".
36
+ * **Supervised Fine-tuning**: Train TRM to predict these HMM states + Next 5-min Return.
37
+ * **Reasoning**: TRM needs to understand *context*, not just price.
38
+
39
+ ### C. LSTM (The "Scout")
40
+ **Objective**: Predict medium-term trend (Next 1-5 minutes).
41
+
42
+ * **Input**: Sequence of 60 Candle Bars (1s or 1m resolution).
43
+ * **Strategy**:
44
+ * **Normalization**: Use **Log Returns** (not Price). Price is non-stationary; Log Returns are stationary.
45
+ * **Regularization**: High Dropout (0.3 - 0.5) to prevent memorizing absolute price levels.
46
+
47
+ ## 3. "Smarter AI": What does it mean?
48
+
49
+ When we say "making the AI smarter", we refer to optimizing three categories of parameters:
50
+
51
+ ### 1. Weights & Biases (The "Brain Cells")
52
+ * **Definition**: The billions of floating-point numbers inside the neural network matrices.
53
+ * **How to improve**:
54
+ * **More Data**: Training on the full `hyperliquidL2Book-v2` dataset (millions of rows) refines these weights.
55
+ * **Better Labels**: Using "Triple Barrier" labeling (dynamic targets based on volatility) instead of fixed targets makes the weights learn *risk-adjusted* moves.
56
+
57
+ ### 2. Hyperparameters (The "Configuration")
58
+ * **Definition**: Settings chosen *before* training.
59
+ * **Key Controls**:
60
+ * **Lookback Window (T)**: Currently 100. Increasing to 200 may capture longer patterns but adds noise.
61
+ * **Batch Size**: Large batches (64/128) provide stable gradients. Small batches (16/32) add noise that can help generalization (escape local minima).
62
+ * **Model Depth**: Adding specific layers (e.g., Attention Heads) can help the model focus on critical events.
63
+
64
+ ### 3. Feature Engineering (The "Eyes")
65
+ * **Definition**: What the AI actually *sees*.
66
+ * **Improvement**:
67
+ * **OFI (Order Flow Imbalance)**: We just added this. It tells the AI *who* is aggressive (Makers vs Takers).
68
+ * **Micro-Structure**: Bid-Ask Bounce, Queue Position (requires L3 data, but L2 approximation helps).
69
+
70
+ ## 4. Execution Plan (Recommended)
71
+
72
+ 1. **Data Prep**: Run the `streaming_loader` update (Done).
73
+ 2. **Dry Run**: Train DeepLOB for 1 epoch on a small subset to verify convergence (Loss decreases).
74
+ 3. **Scale Up**: Run `auto_train.py` on a GPU instance (e.g., RunPod/Lambda) for 24-48 hours.
75
+ 4. **Evaluate**: Check "Precision" (Win Rate) on specific regimes using the Dashboard.
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+ ENV PYTHONUNBUFFERED=1
5
+
6
+
7
+ # System dependencies (Minimal)
8
+ RUN apt-get update && apt-get install -y \
9
+ curl \
10
+ git \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Install Python deps (Fast)
14
+ COPY requirements.txt .
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy Trainer Code
18
+ COPY . .
19
+
20
+ # Install Schedule for looping
21
+ RUN pip install schedule
22
+
23
+ CMD ["python", "scheduler.py"]
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import os
4
+ import json
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ from pathlib import Path
8
+
9
+ # Status file path
10
+ STATUS_FILE = "./status.json"
11
+
12
+ def read_status():
13
+ """Read training status from JSON file."""
14
+ if not os.path.exists(STATUS_FILE):
15
+ return {
16
+ "current_model": None,
17
+ "status": "idle",
18
+ "epoch": 0,
19
+ "total_epochs": 0,
20
+ "last_loss": None,
21
+ "last_accuracy": None,
22
+ "started_at": None,
23
+ "last_update": None,
24
+ "logs": []
25
+ }
26
+
27
+ try:
28
+ with open(STATUS_FILE, 'r') as f:
29
+ return json.load(f)
30
+ except:
31
+ return {"status": "error", "logs": ["Could not read status file"]}
32
+
33
+ def get_status_display():
34
+ """Format status for display."""
35
+ status = read_status()
36
+
37
+ model = status.get("current_model", "None")
38
+ state = status.get("status", "idle").upper()
39
+ epoch = status.get("epoch", 0)
40
+ total = status.get("total_epochs", 0)
41
+ loss = status.get("last_loss")
42
+ acc = status.get("last_accuracy")
43
+ updated = status.get("last_update", "Never")
44
+
45
+ # Status emoji
46
+ emoji = {"idle": "⏸️", "training": "🔄", "completed": "✅", "error": "❌"}.get(status.get("status", "idle"), "❓")
47
+
48
+ info = f"""
49
+ ## {emoji} Training Status: {state}
50
+
51
+ **Model:** {model or 'None'}
52
+ **Progress:** {epoch}/{total} epochs
53
+ **Last Loss:** {f'{loss:.4f}' if loss else 'N/A'}
54
+ **Accuracy:** {f'{acc:.2f}%' if acc else 'N/A'}
55
+ **Updated:** {updated}
56
+ """
57
+ return info
58
+
59
+ def get_logs():
60
+ """Get training logs."""
61
+ status = read_status()
62
+ logs = status.get("logs", [])
63
+ return "\n".join(logs[-30:]) if logs else "No logs yet."
64
+
65
+ def plot_metrics():
66
+ """Plot training metrics (if available)."""
67
+ # This could be enhanced to read from a metrics history file
68
+ status = read_status()
69
+
70
+ if status.get("status") == "idle":
71
+ return None
72
+
73
+ # Simple placeholder - in production, would read from metrics file
74
+ fig = plt.figure(figsize=(10, 4))
75
+ plt.text(0.5, 0.5, "Metrics will appear during training",
76
+ ha='center', va='center', fontsize=14)
77
+ plt.axis('off')
78
+ return fig
79
+
80
+ def refresh_all():
81
+ """Refresh all dashboard components."""
82
+ return get_status_display(), get_logs(), plot_metrics()
83
+
84
+ # --- Dashboard UI ---
85
+ with gr.Blocks(title="NautilusAI Dashboard") as app:
86
+ gr.Markdown("# 🧠 NautilusAI Training Dashboard")
87
+ gr.Markdown("*Monitoring dashboard with manual training trigger*")
88
+
89
+ with gr.Row():
90
+ with gr.Column(scale=1):
91
+ status_md = gr.Markdown(get_status_display())
92
+ with gr.Row():
93
+ refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
94
+ train_btn = gr.Button("🚀 Start Training", variant="primary")
95
+ train_output = gr.Textbox(label="Training Output", lines=3, interactive=False)
96
+
97
+ with gr.Column(scale=2):
98
+ logs_box = gr.Textbox(
99
+ label="Training Logs",
100
+ value=get_logs(),
101
+ lines=15,
102
+ max_lines=20,
103
+ interactive=False
104
+ )
105
+
106
+ with gr.Row():
107
+ plot_box = gr.Plot(label="Training Metrics")
108
+
109
+ # Global flag
110
+ TRAINING_STARTED = False
111
+
112
+ def trigger_training():
113
+ """Trigger training in background thread."""
114
+ global TRAINING_STARTED
115
+ if TRAINING_STARTED:
116
+ return "⚠️ Training already active."
117
+
118
+ TRAINING_STARTED = True
119
+ import threading
120
+ import sys
121
+
122
+ def run_training():
123
+ print("🚀 Auto-Training Pipeline Started...", flush=True)
124
+ try:
125
+ from auto_train import main
126
+ # Redirect stdout to capture logs in real-time if needed, but flush=True should suffice for container logs
127
+ main()
128
+ except Exception as e:
129
+ print(f"❌ Training error: {e}", flush=True)
130
+ finally:
131
+ global TRAINING_STARTED
132
+ TRAINING_STARTED = False
133
+ print("🏁 Training Pipeline Finished.", flush=True)
134
+ sys.stdout.flush()
135
+
136
+ t = threading.Thread(target=run_training, daemon=True)
137
+ t.start()
138
+ return "🔄 Training started! Check logs for progress..."
139
+
140
+ # Actions
141
+ train_btn.click(trigger_training, outputs=[train_output])
142
+
143
+ # Refresh action
144
+ refresh_btn.click(refresh_all, outputs=[status_md, logs_box, plot_box])
145
+
146
+ # Auto-refresh every 5 seconds
147
+ timer = gr.Timer(5)
148
+ timer.tick(refresh_all, outputs=[status_md, logs_box, plot_box])
149
+
150
+ # --- API ---
151
+ # Hidden JSON component for programmatic access
152
+ api_status_box = gr.JSON(label="Status API", visible=False)
153
+
154
+ # Expose this function as an API named '/get_status'
155
+ # We use a dummy button or just a direct event.
156
+ # In Gradio 4.x, just defining a function for a component update can work,
157
+ # but explicit api_name is best on a click or load.
158
+ api_btn = gr.Button("API Trigger", visible=False)
159
+ api_btn.click(read_status, outputs=[api_status_box], api_name="get_status")
160
+
161
+ if __name__ == "__main__":
162
+ # Auto-start training on launch
163
+ print("System Startup: Triggering Auto-Train...")
164
+ trigger_training()
165
+
166
+ app.launch(server_name="0.0.0.0", server_port=7860)
auto_train.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Auto Training Pipeline for NautilusAI
4
+ Runs scheduled training for DeepLOB, TRM, and Ensemble models.
5
+ Writes status to status.json for Dashboard monitoring.
6
+ """
7
+ import os
8
+ import json
9
+ import time
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from huggingface_hub import HfApi
13
+ import joblib
14
+ import pickle
15
+
16
+
17
+
18
+ # Configuration
19
+ REPO_ID = "gionuibk/hyperliquidL2Book-v2"
20
+ DATA_DIR = "./data"
21
+ MODEL_DIR = "./models"
22
+ STATUS_FILE = "./status.json"
23
+ HF_MODEL_REPO = "gionuibk/NautilusModels"
24
+ REPO_ID_LOGS = "gionuibk/NautilusLogs"
25
+
26
+ # Ensure directories exist
27
+ Path(MODEL_DIR).mkdir(exist_ok=True)
28
+ Path(DATA_DIR).mkdir(exist_ok=True)
29
+
30
+ class StatusWriter:
31
+ """Writes training status to JSON file for Dashboard."""
32
+
33
+ def __init__(self, filepath=STATUS_FILE):
34
+ self.filepath = filepath
35
+ self.logs = []
36
+
37
+ # Ensure Log Repo exists
38
+ try:
39
+ from huggingface_hub import HfApi, create_repo
40
+ import os
41
+ token = os.environ.get("HF_TOKEN")
42
+ if token:
43
+ create_repo("gionuibk/NautilusLogs", repo_type="dataset", exist_ok=True, token=token, private=True)
44
+ except: pass
45
+
46
+ self.last_upload_time = 0
47
+ self.reset()
48
+
49
+ def reset(self):
50
+ self.status = {
51
+ "current_model": None,
52
+ "status": "idle",
53
+ "epoch": 0,
54
+ "total_epochs": 0,
55
+ "last_loss": None,
56
+ "last_accuracy": None,
57
+ "started_at": None,
58
+ "last_update": None,
59
+ "logs": []
60
+ }
61
+ self._save()
62
+
63
+ def start(self, model_name: str, total_epochs: int):
64
+ self.logs = []
65
+ self.status = {
66
+ "current_model": model_name,
67
+ "status": "training",
68
+ "epoch": 0,
69
+ "total_epochs": total_epochs,
70
+ "last_loss": None,
71
+ "last_accuracy": None,
72
+ "started_at": datetime.now().isoformat(),
73
+ "last_update": datetime.now().isoformat(),
74
+ "logs": []
75
+ }
76
+ self.log(f"Started training {model_name}")
77
+ self._save()
78
+
79
+ def update(self, epoch: int, loss: float, accuracy: float = None):
80
+ self.status["epoch"] = epoch
81
+ self.status["last_loss"] = loss
82
+ self.status["last_accuracy"] = accuracy
83
+ self.status["last_update"] = datetime.now().isoformat()
84
+ self.log(f"Epoch {epoch}: Loss={loss:.4f}" + (f", Acc={accuracy:.2f}%" if accuracy else ""))
85
+ self._save()
86
+
87
+ def complete(self, model_name: str):
88
+ self.status["status"] = "completed"
89
+ self.status["last_update"] = datetime.now().isoformat()
90
+ self.log(f"Completed training {model_name}")
91
+ self._save()
92
+
93
+ def error(self, message: str):
94
+ self.status["status"] = "error"
95
+ self.status["last_update"] = datetime.now().isoformat()
96
+ self.log(f"ERROR: {message}")
97
+ self._save()
98
+
99
+ def log(self, message: str):
100
+ timestamp = datetime.now().strftime("%H:%M:%S")
101
+ log_entry = f"[{timestamp}] {message}"
102
+ self.logs.append(log_entry)
103
+ self.status["logs"] = self.logs[-50:] # Keep last 50 logs
104
+ print(log_entry)
105
+
106
+ def _save(self):
107
+ # Save locally
108
+ with open(self.filepath, 'w') as f:
109
+ json.dump(self.status, f, indent=2)
110
+
111
+ # Upload to HF Dataset for remote monitoring
112
+ # Throttle: Only upload every 10 minutes (600s) OR if status is final
113
+ import time
114
+ current_time = time.time()
115
+ is_final = self.status["status"] in ["completed", "error"]
116
+
117
+ if (current_time - self.last_upload_time >= 600) or is_final:
118
+ try:
119
+ from huggingface_hub import HfApi
120
+ api = HfApi()
121
+ token = os.environ.get("HF_TOKEN")
122
+ if token:
123
+ LOG_REPO = "gionuibk/NautilusLogs"
124
+
125
+ api.upload_file(
126
+ path_or_fileobj=self.filepath,
127
+ path_in_repo="status.json",
128
+ repo_id=LOG_REPO,
129
+ repo_type="dataset",
130
+ token=token
131
+ )
132
+ self.last_upload_time = current_time
133
+ print(f"📡 Status uploaded to HF (Next update in 10 mins).")
134
+ except Exception as e:
135
+ pass
136
+
137
+ class HistoryWriter:
138
+ """Writes permanent training history to CSV."""
139
+ def __init__(self, filepath="training_history.csv"):
140
+ self.filepath = filepath
141
+ self.last_upload_time = 0
142
+ self._ensure_header()
143
+
144
+ def _ensure_header(self):
145
+ if not os.path.exists(self.filepath):
146
+ with open(self.filepath, 'w') as f:
147
+ f.write("timestamp,model_name,metrics,filename,hf_url\n")
148
+
149
+ def log_model(self, model_name: str, metrics: str, filename: str):
150
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
151
+ hf_url = f"https://huggingface.co/{HF_MODEL_REPO}/blob/main/{filename}"
152
+
153
+ # Validation
154
+ if not metrics: metrics = "N/A"
155
+
156
+ # Append to CSV
157
+ with open(self.filepath, 'a') as f:
158
+ f.write(f"{timestamp},{model_name},{metrics},{filename},{hf_url}\n")
159
+
160
+ print(f"📜 Logged to history: {filename}")
161
+ self.upload_history() # Auto-upload with throttle
162
+
163
+ def upload_history(self):
164
+ """Uploads the history CSV to HuggingFace Logs Repo (Throttled 10m)."""
165
+ import time
166
+ current_time = time.time()
167
+
168
+ if current_time - self.last_upload_time < 600:
169
+ return # Skip if too frequent
170
+
171
+ print("📤 Uploading Training History Log...")
172
+ try:
173
+ token = os.environ.get("HF_TOKEN")
174
+ if token:
175
+ from huggingface_hub import HfApi
176
+ api = HfApi(token=token)
177
+ api.upload_file(
178
+ path_or_fileobj=self.filepath,
179
+ path_in_repo="training_history.csv",
180
+ repo_id=REPO_ID_LOGS,
181
+ repo_type="dataset"
182
+ )
183
+ self.last_upload_time = current_time
184
+ except Exception as e:
185
+ print(f"⚠️ History Upload Failed: {e}")
186
+
187
+
188
+ def train_deeplob(status: StatusWriter, api: HfApi, history, epochs: int = 1):
189
+ """Train DeepLOB model using Streaming Data."""
190
+ print("⏳ Loading DeepLOB Dependencies (Torch)...")
191
+ import torch
192
+ import torch.nn as nn
193
+ import torch.optim as optim
194
+ from models.deeplob import DeepLOB
195
+ from streaming_loader import StreamingDataLoader
196
+
197
+ # Note: Epochs in streaming context usually means passes over the stream.
198
+ # Since stream is huge/infinite, we might define 'epoch' as N steps or 1 full pass.
199
+ # We will stick to 1 full pass per 'epoch' call effectively, or simple consistency.
200
+ status.start("DeepLOB", epochs)
201
+
202
+ try:
203
+ # Initialize Streaming Loader
204
+ loader = StreamingDataLoader(
205
+ repo_id=REPO_ID,
206
+ model_type="deeplob",
207
+ batch_size=32,
208
+ chunk_size=5000 # Process 5000 rows at a time
209
+ )
210
+
211
+ # Initialize Model
212
+ # We need a sample batch to verify shapes if needed, or just init blindly
213
+ model = DeepLOB(y_len=3)
214
+
215
+ # Loss & Optimizer
216
+ # Note: Class weights difficult to pre-calc in streaming. Using standard CELoss
217
+ criterion = nn.CrossEntropyLoss()
218
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
219
+
220
+ step = 0
221
+ total_loss = 0
222
+ correct = 0
223
+ total_samples = 0
224
+
225
+ model.train()
226
+ print("🚀 Starting Streaming Training loop...")
227
+
228
+ # Stream Loop
229
+ # We iterate through the ENTIRE dataset stream once per 'epoch' logic
230
+ # For multi-epoch, we'd need to re-create the loader or reset it.
231
+ # HF streaming datasets are iterators.
232
+
233
+ for batch_X, batch_y in loader:
234
+ optimizer.zero_grad()
235
+ outputs = model(batch_X)
236
+ loss = criterion(outputs, batch_y)
237
+ loss.backward()
238
+ optimizer.step()
239
+
240
+ # Metrics
241
+ total_loss += loss.item()
242
+ _, predicted = torch.max(outputs.data, 1)
243
+ total_samples += batch_y.size(0)
244
+ correct += (predicted == batch_y).sum().item()
245
+
246
+ step += 1
247
+ if step % 10 == 0:
248
+ print(f"Step {step}: Loss={loss.item():.4f}", flush=True)
249
+
250
+ # Update status occasionally
251
+ if step % 50 == 0:
252
+ acc = 100 * correct / total_samples
253
+ avg_loss = total_loss / step
254
+ status.update(1, avg_loss, acc) # Report as Epoch 1 for now
255
+
256
+ # Save Native PyTorch Model (Reliable Fallback)
257
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
258
+ model_filename = f"deeplob_{timestamp}.pt"
259
+ save_path = f"{MODEL_DIR}/{model_filename}"
260
+
261
+ torch.save(model.state_dict(), save_path)
262
+ print(f"Reference PyTorch model saved: {save_path} (Trained on {step} batches)")
263
+
264
+ # Upload model to HF Model Hub (User Request)
265
+ try:
266
+ from huggingface_hub import create_repo
267
+ token = os.environ.get("HF_TOKEN")
268
+ # Ensure Model Repo exists
269
+ create_repo(HF_MODEL_REPO, repo_type="model", exist_ok=True, token=token)
270
+
271
+ print(f"Uploading {model_filename} to Model Repo: {HF_MODEL_REPO}...")
272
+ api.upload_file(
273
+ path_or_fileobj=save_path,
274
+ path_in_repo=model_filename,
275
+ repo_id=HF_MODEL_REPO,
276
+ repo_type="model"
277
+ )
278
+ print(f"✅ {model_filename} uploaded to HF Models successfully.")
279
+ except Exception as e:
280
+ print(f"⚠️ Model Upload Failed: {e}")
281
+
282
+ # Redundant upload to Logs Dataset (Optional, keeping for legacy dashboards if any)
283
+ # ... (Removed to avoid duplication and save bandwidth, user asked for Models tab)
284
+
285
+
286
+ # Save ONNX (Disabled due to Environment Incompatibility)
287
+ # try:
288
+ # print("Exporting DeepLOB to ONNX...")
289
+ # dummy = torch.randn(1, 2, 100, 40)
290
+ # torch.onnx.export(model, dummy, f"{MODEL_DIR}/deeplob_v1.onnx",
291
+ # input_names=['input'], output_names=['output'],
292
+ # opset_version=12)
293
+ # print("ONNX Export Success.")
294
+ # except Exception as e:
295
+ # print(f"⚠️ ONNX Export Failed: {e}")
296
+ print("ONNX Export Skipped (Using .pt checkpoint).")
297
+ # Continue pipeline despite export failure
298
+
299
+
300
+ status.complete("DeepLOB")
301
+ # Log to History
302
+ acc_str = f"Acc={acc:.2f}%" if 'acc' in locals() else "N/A"
303
+ try: history.log_model("DeepLOB", acc_str, model_filename)
304
+ except: pass
305
+ return True
306
+
307
+ except Exception as e:
308
+ import traceback
309
+ traceback.print_exc()
310
+ status.error(str(e))
311
+ return False
312
+
313
+ def train_trm(status: StatusWriter, api: HfApi, history, epochs: int = 1):
314
+ """Train TRM model using Streaming Data."""
315
+ print("⏳ Loading TRM Dependencies (Torch)...")
316
+ import torch
317
+ import torch.nn as nn
318
+ import torch.optim as optim
319
+ from models.trm import TRM
320
+ from streaming_loader import StreamingDataLoader
321
+
322
+ status.start("TRM", epochs)
323
+
324
+ try:
325
+ loader = StreamingDataLoader(
326
+ repo_id=REPO_ID,
327
+ model_type="trm",
328
+ batch_size=64,
329
+ chunk_size=5000
330
+ )
331
+
332
+ model = TRM(input_size=6, num_classes=3)
333
+ criterion = nn.CrossEntropyLoss()
334
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
335
+
336
+ step = 0
337
+ total_loss = 0
338
+ model.train()
339
+
340
+ for batch_X, batch_y in loader:
341
+ optimizer.zero_grad()
342
+ out = model(batch_X)
343
+ loss = criterion(out, batch_y)
344
+ loss.backward()
345
+ optimizer.step()
346
+
347
+ total_loss += loss.item()
348
+ step += 1
349
+
350
+ if step % 50 == 0:
351
+ status.update(1, total_loss / step)
352
+
353
+ # Save Native PyTorch Model
354
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
355
+ model_filename = f"trm_{timestamp}.pt"
356
+ save_path = f"{MODEL_DIR}/{model_filename}"
357
+
358
+ torch.save(model.state_dict(), save_path)
359
+ print(f"Reference PyTorch model saved: {save_path} (Trained on {step} batches)")
360
+
361
+ # Upload model to HF Model Hub
362
+ try:
363
+ print(f"Uploading {model_filename} to Model Repo: {HF_MODEL_REPO}...")
364
+ api.upload_file(
365
+ path_or_fileobj=save_path,
366
+ path_in_repo=model_filename,
367
+ repo_id=HF_MODEL_REPO,
368
+ repo_type="model"
369
+ )
370
+ print(f"✅ {model_filename} uploaded to HF Models successfully.")
371
+ except Exception as e:
372
+ print(f"⚠️ Model Upload Failed: {e}")
373
+
374
+ # Save ONNX (Disabled due to Environment Incompatibility)
375
+ # try:
376
+ # print("Exporting TRM to ONNX...")
377
+ # dummy = torch.randn(1, 60, 6)
378
+ # torch.onnx.export(model, dummy, f"{MODEL_DIR}/trm_v1.onnx",
379
+ # input_names=['input'], output_names=['output'],
380
+ # opset_version=12)
381
+ # print("ONNX Export Success.")
382
+ # except Exception as e:
383
+ # print(f"⚠️ ONNX Export Failed: {e}")
384
+ print("ONNX Export Skipped (Using .pt checkpoint).")
385
+ # Continue pipeline
386
+
387
+ status.complete("TRM")
388
+ return True
389
+
390
+ status.complete("TRM")
391
+ return True
392
+
393
+ except Exception as e:
394
+ status.error(str(e))
395
+ return False
396
+
397
+ def train_lstm(status: StatusWriter, api: HfApi, history, epochs: int = 1):
398
+ """Train LSTM model using Streaming Data (Bar Data)."""
399
+ print("⏳ Loading LSTM Dependencies (Torch)...")
400
+ import torch
401
+ import torch.nn as nn
402
+ import torch.optim as optim
403
+ from models.lstm import AlphaLSTM
404
+ from streaming_loader import StreamingDataLoader
405
+
406
+ status.start("LSTM", epochs)
407
+
408
+ try:
409
+ loader = StreamingDataLoader(
410
+ repo_id=REPO_ID,
411
+ model_type="lstm",
412
+ batch_size=64,
413
+ chunk_size=5000
414
+ )
415
+
416
+ # Input Size = 5 (log_ret, log_vol, hl_range, co_range, vol)
417
+ model = AlphaLSTM(input_size=5, hidden_size=64)
418
+
419
+ # Regression Loss (Mean Squared Error)
420
+ criterion = nn.MSELoss()
421
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
422
+
423
+ step = 0
424
+ total_loss = 0
425
+ model.train()
426
+
427
+ print(f"🚀 Starting LSTM Training loop (Target: Next Return)...")
428
+
429
+ for batch_X, batch_y in loader:
430
+ optimizer.zero_grad()
431
+ out = model(batch_X)
432
+ # batch_y is (Batch, 1), out is (Batch, 1)
433
+ loss = criterion(out, batch_y)
434
+ loss.backward()
435
+ optimizer.step()
436
+
437
+ total_loss += loss.item()
438
+ step += 1
439
+
440
+ if step % 50 == 0:
441
+ print(f"Step {step}: Loss={loss.item():.6f}", flush=True)
442
+ status.update(1, total_loss / step)
443
+
444
+ # Save Native PyTorch Model
445
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
446
+ model_filename = f"lstm_{timestamp}.pt"
447
+ save_path = f"{MODEL_DIR}/{model_filename}"
448
+
449
+ torch.save(model.state_dict(), save_path)
450
+ print(f"Reference PyTorch model saved: {save_path} (Trained on {step} batches)")
451
+
452
+ # Upload model
453
+ try:
454
+ print(f"Uploading {model_filename} to Model Repo: {HF_MODEL_REPO}...")
455
+ api.upload_file(
456
+ path_or_fileobj=save_path,
457
+ path_in_repo=model_filename,
458
+ repo_id=HF_MODEL_REPO,
459
+ repo_type="model"
460
+ )
461
+ print(f"✅ {model_filename} uploaded to HF Models successfully.")
462
+ except Exception as e:
463
+ print(f"⚠️ Model Upload Failed: {e}")
464
+
465
+ status.complete("LSTM")
466
+ try: history.log_model("LSTM", "N/A", model_filename)
467
+ except: pass
468
+ return True
469
+
470
+ except Exception as e:
471
+ status.error(str(e))
472
+ print(f"❌ LSTM Training Error: {e}")
473
+ return False
474
+
475
+ def train_classic_and_causal(status: StatusWriter, api: HfApi, history):
476
+ """Train Classic ML and Causal Discovery models using Bar Data."""
477
+ print("⏳ Loading Classic/Causal Dependencies (Pandas/Sklearn)...")
478
+ import pandas as pd
479
+ import numpy as np
480
+ from models.classic_ml import get_hmm_pipeline, get_rf_pipeline
481
+ from models.causal_discovery import get_causal_model
482
+ from streaming_loader import StreamingDataLoader
483
+
484
+ status.start("ClassicML & Causal", 1)
485
+
486
+ try:
487
+ # Load a chunk
488
+ from huggingface_hub import hf_hub_download
489
+ api_hf = HfApi(token=os.environ.get("HF_TOKEN"))
490
+ files = api_hf.list_repo_files(repo_id=REPO_ID, repo_type="dataset")
491
+ # Support V2 'data/candles/' and V1 'data/bar/'
492
+ bar_files = [f for f in files if (f.startswith("data/candles/") or f.startswith("data/bar/")) and f.endswith(".parquet")]
493
+
494
+ if not bar_files:
495
+ print("❌ No bar files found for ClassicML.")
496
+ return False
497
+
498
+ # Aggregate multiple files until we have enough rows (e.g., 2000)
499
+ target_rows = 2000
500
+ aggregated_dfs = []
501
+ total_rows = 0
502
+
503
+ print(f"📥 Aggregating bar files (Target: {target_rows} rows)...")
504
+
505
+ for file_info in bar_files:
506
+ if total_rows >= target_rows:
507
+ break
508
+
509
+ try:
510
+ local_path = hf_hub_download(repo_id=REPO_ID, filename=file_info, repo_type="dataset", token=os.environ.get("HF_TOKEN"))
511
+ chunk_df = pd.read_parquet(local_path)
512
+ if not chunk_df.empty:
513
+ aggregated_dfs.append(chunk_df)
514
+ total_rows += len(chunk_df)
515
+ print(f" + Added {len(chunk_df)} rows from {file_info} (Total: {total_rows})")
516
+ except Exception as e:
517
+ print(f" ⚠️ Failed to load {file_info}: {e}")
518
+
519
+ if not aggregated_dfs:
520
+ print("❌ Failed to load any valid bar data.")
521
+ return False
522
+
523
+ df = pd.concat(aggregated_dfs, ignore_index=True)
524
+ print(f"✅ Final Dataset Size: {len(df)} rows")
525
+
526
+ # Preprocess
527
+ df['log_ret'] = np.log(df['close'] / df['close'].shift(1)).fillna(0)
528
+ clean_df = df.dropna().select_dtypes(include=[np.number]).iloc[:10000]
529
+ X = clean_df.values
530
+
531
+ # 1. Classic ML (Random Forest)
532
+ print("🧠 Training ClassicML (Random Forest)...")
533
+ rf_model = get_rf_pipeline()
534
+ y = (df['log_ret'].shift(-1).fillna(0) > 0).astype(int).iloc[:10000]
535
+ rf_model.fit(X, y)
536
+
537
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
538
+ rf_filename = f"classic_ml_{timestamp}.joblib"
539
+ joblib.dump(rf_model, f"{MODEL_DIR}/{rf_filename}")
540
+
541
+ api.upload_file(
542
+ path_or_fileobj=f"{MODEL_DIR}/{rf_filename}",
543
+ path_in_repo=rf_filename,
544
+ repo_id=HF_MODEL_REPO,
545
+ repo_type="model"
546
+ )
547
+ print(f"✅ {rf_filename} uploaded.")
548
+
549
+ # 2. Causal Discovery
550
+ print("🕸️ Running Causal Discovery...")
551
+ causal_model = get_causal_model()
552
+ causal_model.fit(clean_df)
553
+
554
+ causal_filename = f"causal_discovery_{timestamp}.pkl"
555
+ with open(f"{MODEL_DIR}/{causal_filename}", 'wb') as f:
556
+ pickle.dump(causal_model, f)
557
+
558
+ api.upload_file(
559
+ path_or_fileobj=f"{MODEL_DIR}/{causal_filename}",
560
+ path_in_repo=causal_filename,
561
+ repo_id=HF_MODEL_REPO,
562
+ repo_type="model"
563
+ )
564
+ print(f"✅ {causal_filename} uploaded.")
565
+
566
+ status.complete("ClassicML & Causal")
567
+ try: history.log_model("ClassicML", "N/A", rf_filename)
568
+ except: pass
569
+ try: history.log_model("CausalDiscovery", "N/A", causal_filename)
570
+ except: pass
571
+ return True
572
+
573
+ except Exception as e:
574
+ status.error(f"Classic/Causal Fail: {e}")
575
+ print(f"❌ Classic/Causal Error: {e}")
576
+ return False
577
+
578
+ def train_agents(status: StatusWriter, api: HfApi, history):
579
+ """Initialize, Validate, and Save RL & Rule-Based Agents."""
580
+ print("⏳ Loading Agent Dependencies (Torch/RL)...")
581
+ import torch
582
+ from models.execution_agent import PPOActorCritic
583
+ from models.meta_controller import DQN
584
+ from models.risk_agent import RiskAgent
585
+ from models.arbitrage_agent import ArbitrageAgent
586
+
587
+ status.start("Agents (RL & Rule)", 1)
588
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
589
+
590
+ try:
591
+ # 1. Execution Agent (PPO)
592
+ print("🤖 Initializing Execution Agent (PPO)...")
593
+ exec_agent = PPOActorCritic(input_dim=5, action_dim=3)
594
+ dummy_in = torch.randn(1, 5)
595
+ exec_agent(dummy_in)
596
+
597
+ exec_filename = f"execution_agent_{timestamp}.pt"
598
+ torch.save(exec_agent.state_dict(), f"{MODEL_DIR}/{exec_filename}")
599
+ api.upload_file(
600
+ path_or_fileobj=f"{MODEL_DIR}/{exec_filename}",
601
+ path_in_repo=exec_filename,
602
+ repo_id=HF_MODEL_REPO,
603
+ repo_type="model"
604
+ )
605
+ print(f"✅ {exec_filename} uploaded.")
606
+
607
+ # 2. Meta Controller (DQN)
608
+ print("🧠 Initializing Meta Controller (DQN)...")
609
+ meta_agent = DQN(input_dim=5, output_dim=3)
610
+ meta_agent(dummy_in)
611
+
612
+ meta_filename = f"meta_controller_{timestamp}.pt"
613
+ torch.save(meta_agent.state_dict(), f"{MODEL_DIR}/{meta_filename}")
614
+ api.upload_file(
615
+ path_or_fileobj=f"{MODEL_DIR}/{meta_filename}",
616
+ path_in_repo=meta_filename,
617
+ repo_id=HF_MODEL_REPO,
618
+ repo_type="model"
619
+ )
620
+ print(f"✅ {meta_filename} uploaded.")
621
+
622
+ # 3. Risk Agent (Rule-Based)
623
+ print("🛡️ Initializing Risk Agent...")
624
+ risk_agent = RiskAgent(max_dd=0.15)
625
+ risk_filename = f"risk_agent_{timestamp}.pkl"
626
+ with open(f"{MODEL_DIR}/{risk_filename}", 'wb') as f:
627
+ pickle.dump(risk_agent, f)
628
+
629
+ api.upload_file(
630
+ path_or_fileobj=f"{MODEL_DIR}/{risk_filename}",
631
+ path_in_repo=risk_filename,
632
+ repo_id=HF_MODEL_REPO,
633
+ repo_type="model"
634
+ )
635
+ print(f"✅ {risk_filename} uploaded.")
636
+
637
+ # 4. Arbitrage Agent (Rule-Based)
638
+ print("⚖️ Initializing Arbitrage Agent...")
639
+ arb_agent = ArbitrageAgent(threshold=0.005)
640
+ # Synthetic Test
641
+ arb_agent.analyze(100, 101, 0.001)
642
+
643
+ arb_filename = f"arbitrage_agent_{timestamp}.pkl"
644
+ with open(f"{MODEL_DIR}/{arb_filename}", 'wb') as f:
645
+ pickle.dump(arb_agent, f)
646
+
647
+ api.upload_file(
648
+ path_or_fileobj=f"{MODEL_DIR}/{arb_filename}",
649
+ path_in_repo=arb_filename,
650
+ repo_id=HF_MODEL_REPO,
651
+ repo_type="model"
652
+ )
653
+ print(f"✅ {arb_filename} uploaded.")
654
+
655
+ status.complete("Agents Completed")
656
+ try:
657
+ history.log_model("ExecutionAgent", "Init", exec_filename)
658
+ history.log_model("MetaController", "Init", meta_filename)
659
+ history.log_model("RiskAgent", "Init", risk_filename)
660
+ history.log_model("ArbitrageAgent", "Init", arb_filename)
661
+ except: pass
662
+ return True
663
+
664
+ except Exception as e:
665
+ status.error(f"Agents Fail: {e}")
666
+ print(f"❌ Agents Error: {e}")
667
+ return False
668
+
669
+
670
+ def upload_models():
671
+ """Upload trained models to HuggingFace."""
672
+ from huggingface_hub import HfApi
673
+
674
+ api = HfApi()
675
+
676
+ for model_file in Path(MODEL_DIR).glob("*.onnx"):
677
+ try:
678
+ api.upload_file(
679
+ path_or_fileobj=str(model_file),
680
+ path_in_repo=model_file.name,
681
+ repo_id=HF_MODEL_REPO,
682
+ repo_type="model"
683
+ )
684
+ print(f"Uploaded: {model_file.name}")
685
+ except Exception as e:
686
+ print(f"Upload failed for {model_file.name}: {e}")
687
+
688
+ def main():
689
+ print("🚀 Auto-Training Pipeline Started...")
690
+
691
+ # Initialize API for model uploads
692
+ token = os.environ.get("HF_TOKEN") # Get token from env
693
+ if not token:
694
+ print("⚠️ HF_TOKEN not found in environment!")
695
+
696
+ api = HfApi(token=token)
697
+
698
+ # Init Status Writer & History Writer
699
+ status = StatusWriter()
700
+ history = HistoryWriter()
701
+
702
+
703
+ print("=" * 50)
704
+ print("NautilusAI Auto Training Pipeline")
705
+ print(f"Started at: {datetime.now().isoformat()}")
706
+ print("=" * 50)
707
+
708
+ # 1. Download data
709
+ # 1. Download data - SKIPPED (Using Streaming)
710
+ # download_data()
711
+ print("🌊 Using Streaming Mode (No Download Required)")
712
+
713
+ # Force Legacy ONNX
714
+ os.environ["TORCH_ONNX_USE_DYNAMO"] = "0"
715
+
716
+
717
+ # 2. Train DeepLOB
718
+ if not train_deeplob(status, api, history, epochs=1):
719
+ print("DeepLOB training failed!")
720
+ return
721
+
722
+ # 3. Train TRM
723
+ if not train_trm(status, api, history, epochs=1):
724
+ print("TRM training failed!")
725
+ # Continue anyway for LSTM
726
+ pass
727
+
728
+ # 4. Train LSTM
729
+ if not train_lstm(status, api, history, epochs=1):
730
+ print("LSTM training failed!")
731
+ pass
732
+
733
+ # 5. Train Classic ML & Causal
734
+ train_classic_and_causal(status, api, history)
735
+
736
+ # 6. Train Agents
737
+ train_agents(status, api, history)
738
+
739
+ # 7. Upload models (Legacy function, mostly redundant now but harmless)
740
+ upload_models()
741
+
742
+ # 8. Upload History (Batch Upload)
743
+ history.upload_history()
744
+
745
+ # 5. Final status
746
+ status.reset()
747
+ print("=" * 50)
748
+ print("Training Pipeline Complete!")
749
+ print("=" * 50)
750
+
751
+ if __name__ == "__main__":
752
+ try:
753
+ print("🔧 Auto-Train Script Initializing...", flush=True)
754
+ main()
755
+ except Exception as e:
756
+ import traceback
757
+ print(f"❌ CRITICAL ERROR IN AUTO-TRAIN: {e}", flush=True)
758
+ traceback.print_exc()
759
+ # Write error to status file too if possible
760
+ try:
761
+ with open(STATUS_FILE, 'w') as f:
762
+ json.dump({"status": "error", "logs": [f"CRITICAL: {str(e)}"]}, f)
763
+ except: pass
data_processor.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import glob
4
+ import torch
5
+ import ast
6
+ from typing import Tuple
7
+
8
+ class AlphaDataProcessor:
9
+ """
10
+ Processes raw market data (Parquet) into PyTorch Tensors for Alpha Agent training.
11
+ Upgraded for Deep Optimization (Robust Scaler, Dynamic Labels, Channel Separation, OFI, Triple Barrier).
12
+ """
13
+ def __init__(self, data_dir: str = "./data"):
14
+ self.data_dir = data_dir
15
+
16
+ def _rolling_robust_scale(self, data: np.ndarray, window: int = 2000) -> np.ndarray:
17
+ """
18
+ Rolling Robust Scaling using Median and IQR.
19
+ Prevents look-ahead bias (Leakage) by using only past statistics.
20
+ Computes rolling median/IQR along axis 0.
21
+ """
22
+ # Convert to DataFrame for efficient rolling ops
23
+ df = pd.DataFrame(data)
24
+
25
+ # Min periods = window/10 to avoid NaNs at start (or ffill)
26
+ rolling = df.rolling(window=window, min_periods=window//10)
27
+
28
+ median = rolling.median()
29
+ q75 = rolling.quantile(0.75)
30
+ q25 = rolling.quantile(0.25)
31
+ iqr = q75 - q25
32
+
33
+ # Replace 0 IQR with 1 to avoid div by zero
34
+ iqr = iqr.replace(0, 1.0)
35
+
36
+ # Scale: (x_t - median_t) / iqr_t
37
+ # Note: robust scaling conventionally uses recent stats to normalize CURRENT value.
38
+ scaled = (df - median) / iqr
39
+
40
+ # Fill mean/zeros for initial unstable window
41
+ return scaled.fillna(0.0).values
42
+
43
+ def get_deeplob_tensors(self, coin: str = "ETH", T: int = 100, levels: int = 20) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ """
45
+ DeepLOB with Channel Separation and Triple Barrier Labeling.
46
+ Uses Rolling Robust Scaling.
47
+ """
48
+ df = self.load_l2_snapshots(coin)
49
+ if df.empty:
50
+ return self._generate_dummy_deeplob(T, levels)
51
+
52
+ prices_list = []
53
+ volumes_list = []
54
+ mid_prices = []
55
+
56
+ # Precompute Volatility for Labeling
57
+ best_bids = df['bids'].apply(lambda x: x[0][0] if len(x)>0 else 0)
58
+ best_asks = df['asks'].apply(lambda x: x[0][0] if len(x)>0 else 0)
59
+ mids = (best_bids + best_asks) / 2
60
+ mids = mids.replace(0, np.nan).ffill().fillna(0)
61
+
62
+ returns = np.diff(np.log(mids.values + 1e-9))
63
+ returns = np.concatenate(([0], returns))
64
+ volatility = pd.Series(returns).rolling(window=T).std().fillna(0.001).values
65
+
66
+ mid_prices_arr = mids.values
67
+
68
+ for _, row in df.iterrows():
69
+ bids = row['bids']
70
+ asks = row['asks']
71
+
72
+ p_feat = []
73
+ v_feat = []
74
+
75
+ for i in range(levels):
76
+ if i < len(asks): pa, va = asks[i]
77
+ else: pa, va = 0, 0
78
+ if i < len(bids): pb, vb = bids[i]
79
+ else: pb, vb = 0, 0
80
+ p_feat.extend([pa, pb])
81
+ v_feat.extend([va, vb])
82
+
83
+ prices_list.append(p_feat)
84
+ volumes_list.append(v_feat)
85
+
86
+ prices_data = np.array(prices_list)
87
+ volumes_data = np.array(volumes_list)
88
+
89
+ # Rolling Robust Scaling (Leakage Free)
90
+ prices_data = self._rolling_robust_scale(prices_data, window=2000)
91
+ volumes_data = np.log1p(volumes_data)
92
+ volumes_data = self._rolling_robust_scale(volumes_data, window=2000)
93
+
94
+ k = 100
95
+ # Triple Barrier Labels
96
+ # PT=2, SL=2 (2x Volatility)
97
+ y_all = self._get_triple_barrier_labels(mid_prices_arr, T, k, volatility, pt=2.0, sl=2.0)
98
+
99
+ # ... (Rest remains same)
100
+
101
+
102
+ def _get_triple_barrier_labels(self, mid_prices: np.ndarray, T: int, horizon: int, volatility: np.ndarray = None, pt: float = 1.0, sl: float = 1.0) -> np.ndarray:
103
+ """
104
+ Triple Barrier Labeling Method (Marcos Lopez de Prado).
105
+ Labels: 0 (SL Hit), 1 (Time Limit), 2 (TP Hit).
106
+ pt: Profit Taking multiplier (x Volatility).
107
+ sl: Stop Loss multiplier (x Volatility).
108
+ """
109
+ labels = []
110
+
111
+ # If volatility is None, compute standard
112
+ if volatility is None:
113
+ # Simple fallback
114
+ volatility = np.ones(len(mid_prices)) * 0.002
115
+
116
+ for i in range(T, len(mid_prices) - horizon):
117
+ current_price = mid_prices[i-1]
118
+ vol = volatility[i]
119
+
120
+ # Dynamic Barriers
121
+ upper_barrier = current_price * (1 + vol * pt)
122
+ lower_barrier = current_price * (1 - vol * sl)
123
+
124
+ # Path within Horizon
125
+ path = mid_prices[i : i + horizon]
126
+
127
+ # Check First Touch
128
+ # argmax returns index of first True
129
+ touch_upper = np.where(path >= upper_barrier)[0]
130
+ touch_lower = np.where(path <= lower_barrier)[0]
131
+
132
+ t_upper = touch_upper[0] if len(touch_upper) > 0 else horizon + 1
133
+ t_lower = touch_lower[0] if len(touch_lower) > 0 else horizon + 1
134
+
135
+ if t_upper == horizon + 1 and t_lower == horizon + 1:
136
+ label = 1 # Vertical Barrier (Time Limit)
137
+ elif t_upper < t_lower:
138
+ label = 2 # TP Hit First
139
+ else:
140
+ label = 0 # SL Hit First
141
+
142
+ labels.append(label)
143
+
144
+ return np.array(labels)
145
+
146
+ def _compute_ofi(self, df: pd.DataFrame, levels: int = 5) -> pd.DataFrame:
147
+ """
148
+ Computes Order Flow Imbalance (OFI) for top 'levels'.
149
+ OFI_i(t) = I(P > P_prev)q - I(P < P_prev)q_prev + I(P == P_prev)(q - q_prev)
150
+ Summed across levels.
151
+ """
152
+ # Explode bids/asks for first few levels
153
+ # This is expensive on large DFs. We do vectorized check on top 1 level mainly or aggregated.
154
+ # Efficient OFI: Compute on Best Bid/Ask only for speed in this version.
155
+
156
+ # 1. Shift DataFrame
157
+ df_prev = df.shift(1)
158
+
159
+ ofi = pd.Series(0.0, index=df.index)
160
+
161
+ # Top 1 Level OFI
162
+ bb_p = df['best_bid']
163
+ bb_q = df['best_bid_sz']
164
+ prev_bb_p = df_prev['best_bid']
165
+ prev_bb_q = df_prev['best_bid_sz']
166
+
167
+ ba_p = df['best_ask']
168
+ ba_q = df['best_ask_sz']
169
+ prev_ba_p = df_prev['best_ask']
170
+ prev_ba_q = df_prev['best_ask_sz']
171
+
172
+ # Bid OFI
173
+ bid_ofi = np.where(bb_p > prev_bb_p, bb_q,
174
+ np.where(bb_p < prev_bb_p, -prev_bb_q, bb_q - prev_bb_q))
175
+
176
+ # Ask OFI (Note: Supply side usually negative impact on price? OFI definition:
177
+ # e_i = e_bid_i - e_ask_i. High Bid demand -> +, High Ask supply -> -)
178
+
179
+ ask_ofi = np.where(ba_p > prev_ba_p, -prev_ba_q,
180
+ np.where(ba_p < prev_ba_p, ba_q, ba_q - prev_ba_q)) # Logic check needed here
181
+
182
+ # Standard Definition (Cont & Kukanov 2017):
183
+ # e_ask = I(Pa > Pa_prev) * (-qa_prev) + I(Pa < Pa_prev) * qa + I(Pa=Pa_prev)*(qa - qa_prev)
184
+ # Wait, if Ask Price Increases -> Supply removed (Good for price) -> ???
185
+ # Actually OFI = Flow at Bid - Flow at Ask.
186
+ # Let's stick to standard formula for 'Flow Contribution to Price Increase'.
187
+ # Increase in Ask Size -> Resistance -> Negative pressure.
188
+
189
+ ask_flow = np.where(ba_p > prev_ba_p, 0, # Price moved up (Ask Cleared?) -> No resistance added?
190
+ np.where(ba_p < prev_ba_p, ba_q, # Price moved down -> New wall
191
+ ba_q - prev_ba_q)) # Same price -> delta size
192
+
193
+ # Improved Ask OFI (Mirroring Bid Logic):
194
+ # We want "Buying Pressure" - "Selling Pressure"
195
+ # Bid Increase/Add = Buying Pressure (+)
196
+ # Ask Decrease/Add = Selling Pressure (-)
197
+
198
+ ask_ofi = np.where(ba_p > prev_ba_p, -prev_ba_q, # Price rose, prev qty consumed/cancelled ?
199
+ np.where(ba_p < prev_ba_p, ba_q, # Price fell, new supply at lower price
200
+ ba_q - prev_ba_q)) # Same price, delta
201
+
202
+ # Total OFI
203
+ ofi = bid_ofi - ask_ofi
204
+ return pd.Series(ofi).fillna(0)
205
+
206
+ def load_trades(self, coin: str = "ETH") -> pd.DataFrame:
207
+ """Loads trade data."""
208
+ files = glob.glob(f"{self.data_dir}/raw_trade/{coin}/*.parquet")
209
+ if not files: return pd.DataFrame()
210
+
211
+ try:
212
+ df = pd.concat([pd.read_parquet(f) for f in files])
213
+ df = df.sort_values("time")
214
+ if 'side' in df.columns:
215
+ df['signed_vol'] = df.apply(lambda x: x['sz'] if x['side'] == 'B' else -x['sz'], axis=1)
216
+ else:
217
+ df['signed_vol'] = 0
218
+ return df
219
+ except Exception as e:
220
+ print(f"Error loading trades: {e}")
221
+ return pd.DataFrame()
222
+
223
+ def load_l2_snapshots(self, coin: str = "ETH", limit: int = 10000) -> pd.DataFrame:
224
+ """Loads L2 Orderbook Snapshots."""
225
+ files = glob.glob(f"{self.data_dir}/order_book_snapshot/*.parquet")
226
+ if not files: return pd.DataFrame()
227
+
228
+ df_list = []
229
+ for f in files:
230
+ try:
231
+ chunk = pd.read_parquet(f)
232
+ chunk = chunk[chunk['instrument_id'].str.contains(coin)]
233
+ if not chunk.empty: df_list.append(chunk)
234
+ except: pass
235
+
236
+ if not df_list: return pd.DataFrame()
237
+
238
+ df = pd.concat(df_list)
239
+ df = df.sort_values("ts_event").head(limit)
240
+
241
+ df['bids'] = df['bids'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
242
+ df['asks'] = df['asks'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
243
+
244
+ return df
245
+
246
+ def get_deeplob_tensors(self, coin: str = "ETH", T: int = 100, levels: int = 20) -> Tuple[torch.Tensor, torch.Tensor]:
247
+ """
248
+ DeepLOB with Channel Separation and Triple Barrier Labeling.
249
+ """
250
+ df = self.load_l2_snapshots(coin)
251
+ if df.empty:
252
+ return self._generate_dummy_deeplob(T, levels)
253
+
254
+ prices_list = []
255
+ volumes_list = []
256
+ mid_prices = []
257
+
258
+ # Precompute Volatility for Labeling
259
+ # Expand Mid Price first
260
+ best_bids = df['bids'].apply(lambda x: x[0][0] if len(x)>0 else 0)
261
+ best_asks = df['asks'].apply(lambda x: x[0][0] if len(x)>0 else 0)
262
+ mids = (best_bids + best_asks) / 2
263
+ mids = mids.replace(0, np.nan).ffill().fillna(0)
264
+
265
+ # Rolling Volatility (for Triple Barrier)
266
+ returns = np.diff(np.log(mids.values + 1e-9))
267
+ returns = np.concatenate(([0], returns))
268
+ volatility = pd.Series(returns).rolling(window=T).std().fillna(0.001).values
269
+
270
+ mid_prices_arr = mids.values
271
+
272
+ for _, row in df.iterrows():
273
+ bids = row['bids']
274
+ asks = row['asks']
275
+
276
+ p_feat = []
277
+ v_feat = []
278
+
279
+ for i in range(levels):
280
+ if i < len(asks): pa, va = asks[i]
281
+ else: pa, va = 0, 0
282
+ if i < len(bids): pb, vb = bids[i]
283
+ else: pb, vb = 0, 0
284
+ p_feat.extend([pa, pb])
285
+ v_feat.extend([va, vb])
286
+
287
+ prices_list.append(p_feat)
288
+ volumes_list.append(v_feat)
289
+
290
+
291
+ prices_data = np.array(prices_list)
292
+ volumes_data = np.array(volumes_list)
293
+
294
+ # Robust Scaling
295
+ prices_data = self._robust_scale(prices_data)
296
+ volumes_data = np.log1p(volumes_data)
297
+ volumes_data = self._robust_scale(volumes_data)
298
+
299
+ k = 100
300
+ # Triple Barrier Labels
301
+ # PT=2, SL=2 (2x Volatility)
302
+ y_all = self._get_triple_barrier_labels(mid_prices_arr, T, k, volatility, pt=2.0, sl=2.0)
303
+
304
+ X = []
305
+ y = []
306
+ valid_indices = range(T, len(mid_prices_arr) - k)
307
+
308
+ for idx, i in enumerate(valid_indices):
309
+ p_window = prices_data[i-T:i]
310
+ v_window = volumes_data[i-T:i]
311
+
312
+ sample = np.stack([p_window, v_window], axis=0) # (2, T, 2*Levels)
313
+
314
+ X.append(sample)
315
+ y.append(y_all[idx])
316
+
317
+ return torch.FloatTensor(np.array(X)), torch.LongTensor(np.array(y))
318
+
319
+ def get_deeplob_tensors_from_df(self, df: pd.DataFrame, T: int = 100, levels: int = 20) -> Tuple[torch.Tensor, torch.Tensor]:
320
+ """
321
+ Process a pre-loaded DataFrame (chunk) into DeepLOB tensors.
322
+ Used for Streaming.
323
+ """
324
+ if df.empty:
325
+ return torch.empty(0), torch.empty(0)
326
+
327
+ # Reuse the logic from get_deeplob_tensors, but skipping the load step.
328
+ # This duplicates some logic but ensures isolation.
329
+
330
+ prices_list = []
331
+ volumes_list = []
332
+
333
+ # Precompute Volatility for Labeling
334
+ best_bids = df['bids'].apply(lambda x: x[0][0] if len(x)>0 else 0)
335
+ best_asks = df['asks'].apply(lambda x: x[0][0] if len(x)>0 else 0)
336
+ mids = (best_bids + best_asks) / 2
337
+ mids = mids.replace(0, np.nan).ffill().fillna(0)
338
+
339
+ returns = np.diff(np.log(mids.values + 1e-9))
340
+ returns = np.concatenate(([0], returns))
341
+ volatility = pd.Series(returns).rolling(window=T).std().fillna(0.001).values
342
+
343
+ mid_prices_arr = mids.values
344
+
345
+ for _, row in df.iterrows():
346
+ bids = row['bids']
347
+ asks = row['asks']
348
+
349
+ p_feat = []
350
+ v_feat = []
351
+
352
+ for i in range(levels):
353
+ if i < len(asks): pa, va = asks[i]
354
+ else: pa, va = 0, 0
355
+ if i < len(bids): pb, vb = bids[i]
356
+ else: pb, vb = 0, 0
357
+ p_feat.extend([pa, pb])
358
+ v_feat.extend([va, vb])
359
+
360
+ prices_list.append(p_feat)
361
+ volumes_list.append(v_feat)
362
+
363
+ prices_data = np.array(prices_list)
364
+ volumes_data = np.array(volumes_list)
365
+
366
+ # Robust Scaling
367
+ prices_data = self._robust_scale(prices_data)
368
+ volumes_data = np.log1p(volumes_data)
369
+ volumes_data = self._robust_scale(volumes_data)
370
+
371
+ k = 100
372
+ # Triple Barrier Labels
373
+ y_all = self._get_triple_barrier_labels(mid_prices_arr, T, k, volatility, pt=2.0, sl=2.0)
374
+
375
+ X = []
376
+ y = []
377
+
378
+ # Since this is a chunk, we might lose the first T rows if not buffered correctly by the caller.
379
+ # The caller (StreamingDataLoader) is responsible for overlapping chunks.
380
+ valid_indices = range(T, len(mid_prices_arr) - k)
381
+
382
+ for idx, i in enumerate(valid_indices):
383
+ p_window = prices_data[i-T:i]
384
+ v_window = volumes_data[i-T:i]
385
+
386
+ sample = np.stack([p_window, v_window], axis=0)
387
+ X.append(sample)
388
+ y.append(y_all[idx])
389
+
390
+ return torch.FloatTensor(np.array(X)), torch.LongTensor(np.array(y))
391
+
392
+ def _generate_dummy_deeplob(self, T, levels):
393
+ return torch.randn(32, 2, T, 2*levels), torch.randint(0, 3, (32,))
394
+
395
+ def compute_trm_features(self, df: pd.DataFrame) -> pd.DataFrame:
396
+ """
397
+ Computes features including OFI and Real CVD.
398
+ """
399
+ df['best_bid'] = df['bids'].apply(lambda x: x[0][0] if len(x)>0 else np.nan)
400
+ df['best_ask'] = df['asks'].apply(lambda x: x[0][0] if len(x)>0 else np.nan)
401
+ df['best_bid_sz'] = df['bids'].apply(lambda x: x[0][1] if len(x)>0 else np.nan)
402
+ df['best_ask_sz'] = df['asks'].apply(lambda x: x[0][1] if len(x)>0 else np.nan)
403
+
404
+ df.dropna(subset=['best_bid', 'best_ask'], inplace=True)
405
+
406
+ df['mid'] = (df['best_bid'] + df['best_ask']) / 2
407
+
408
+ # OFI (New Feature)
409
+ df['ofi'] = self._compute_ofi(df)
410
+
411
+ df['spread'] = df['best_ask'] - df['best_bid']
412
+ df['imbalance'] = (df['best_bid_sz'] - df['best_ask_sz']) / (df['best_bid_sz'] + df['best_ask_sz'])
413
+ df['momentum'] = df['mid'].pct_change(periods=5)
414
+ df['returns'] = df['mid'].pct_change()
415
+ df['volatility'] = df['returns'].rolling(10).std()
416
+
417
+ # Real CVD
418
+ trades = self.load_trades(coin="ETH")
419
+ if not trades.empty:
420
+ trades['cumulative_vol'] = trades['signed_vol'].cumsum()
421
+ df = df.sort_values("ts_event")
422
+ trades = trades.sort_values("time")
423
+
424
+ df['ts_merge'] = df['ts_event']
425
+ trades['ts_merge'] = trades['time']
426
+
427
+ merged = pd.merge_asof(df, trades[['ts_merge', 'cumulative_vol']], on='ts_merge', direction='backward')
428
+ df['cvd'] = merged['cumulative_vol'].ffill().fillna(0)
429
+ else:
430
+ df['cvd'] = 0
431
+
432
+ df.dropna(inplace=True)
433
+ # Return 6 Features now: Vol, Imbal, CVD, Spread, Mom, OFI
434
+ return df[['volatility', 'imbalance', 'cvd', 'spread', 'momentum', 'ofi', 'mid']]
435
+
436
+ def get_trm_tensors(self, coin: str = "ETH", T: int = 60) -> Tuple[torch.Tensor, torch.Tensor]:
437
+ """
438
+ Returns TRM Tensors.
439
+ Input size = 6 (Added OFI).
440
+ Labels = Triple Barrier.
441
+ """
442
+ df = self.load_l2_snapshots(coin, limit=5000)
443
+ if df.empty:
444
+ return torch.randn(32, T, 6), torch.randint(0, 3, (32,))
445
+
446
+ feat_df = self.compute_trm_features(df)
447
+ data = feat_df[['volatility', 'imbalance', 'cvd', 'spread', 'momentum', 'ofi']].values
448
+ mid = feat_df['mid'].values
449
+
450
+ # Rolling Robust Scale Features (Leakage Free)
451
+ data = self._rolling_robust_scale(data, window=2000)
452
+
453
+ # Returns for Vol calc
454
+ rets = np.diff(np.log(mid + 1e-9))
455
+ rets = np.concatenate(([0], rets))
456
+ vol = pd.Series(rets).rolling(window=T).std().fillna(0.001).values
457
+
458
+ # Triple Barrier Labels for TRM
459
+ y_all = self._get_triple_barrier_labels(mid, T, horizon=60, volatility=vol, pt=2.0, sl=2.0)
460
+
461
+ X, y = [], []
462
+ valid_indices = range(T, len(data) - 60)
463
+
464
+ for idx, i in enumerate(valid_indices):
465
+ X.append(data[i-T:i])
466
+ y.append(y_all[idx])
467
+
468
+ return torch.FloatTensor(np.array(X)), torch.LongTensor(np.array(y))
469
+
470
+ def get_trm_tensors_from_df(self, df: pd.DataFrame, T: int = 60) -> Tuple[torch.Tensor, torch.Tensor]:
471
+ """
472
+ Process a pre-loaded DataFrame (chunk) into TRM tensors.
473
+ Used for Streaming.
474
+ """
475
+ if df.empty:
476
+ return torch.empty(0), torch.empty(0)
477
+
478
+ feat_df = self.compute_trm_features(df)
479
+ if feat_df.empty:
480
+ return torch.empty(0), torch.empty(0)
481
+
482
+ data = feat_df[['volatility', 'imbalance', 'cvd', 'spread', 'momentum', 'ofi']].values
483
+ mid = feat_df['mid'].values
484
+
485
+ data = self._rolling_robust_scale(data, window=2000)
486
+
487
+ rets = np.diff(np.log(mid + 1e-9))
488
+ rets = np.concatenate(([0], rets))
489
+ vol = pd.Series(rets).rolling(window=T).std().fillna(0.001).values
490
+
491
+ y_all = self._get_triple_barrier_labels(mid, T, horizon=60, volatility=vol, pt=2.0, sl=2.0)
492
+
493
+ X, y = [], []
494
+ valid_indices = range(T, len(data) - 60)
495
+
496
+ for idx, i in enumerate(valid_indices):
497
+ X.append(data[i-T:i])
498
+ y.append(y_all[idx])
499
+
500
+ return torch.FloatTensor(np.array(X)), torch.LongTensor(np.array(y))
501
+
502
+ def get_lstm_tensors_from_df(self, df: pd.DataFrame, T: int = 60, forecast_horizon: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
503
+ """
504
+ Process Bar Data (OHLCV) into LSTM Tensors.
505
+ Features: Log Returns, Log Volume, High-Low Range, Close-Open Range.
506
+ Target: Next Log Return (scaled).
507
+ Output: X (Batch, T, Features), y (Batch, 1)
508
+ """
509
+ if df.empty or len(df) < T + forecast_horizon:
510
+ return torch.empty(0), torch.empty(0)
511
+
512
+ # Ensure numeric
513
+ cols = ['open', 'high', 'low', 'close', 'volume']
514
+ for c in cols:
515
+ if c in df.columns:
516
+ df[c] = pd.to_numeric(df[c], errors='coerce')
517
+ df.dropna(subset=cols, inplace=True)
518
+
519
+ # 1. Feature Engineering
520
+ # Log Returns (Scale Invariant)
521
+ df['log_ret'] = np.log(df['close'] / df['close'].shift(1)).fillna(0)
522
+
523
+ # Log Volume
524
+ df['log_vol'] = np.log1p(df['volume'])
525
+
526
+ # High-Low Range (Relative to Close)
527
+ df['hl_range'] = (df['high'] - df['low']) / df['close']
528
+
529
+ # Close-Open Range (Relative to Open)
530
+ df['co_range'] = (df['close'] - df['open']) / df['open']
531
+
532
+ # Rolling Volatility (Feature)
533
+ df['volatility'] = df['log_ret'].rolling(window=20).std().fillna(0)
534
+
535
+ # Features Matrix
536
+ feature_cols = ['log_ret', 'log_vol', 'hl_range', 'co_range', 'volatility']
537
+ data = df[feature_cols].values
538
+
539
+ # 2. Robust Scaling (Leakage Free)
540
+ data = self._rolling_robust_scale(data, window=2000)
541
+
542
+ # 3. Target: Next Log Return (Scalar Regression)
543
+ # Scaled by 100 to match Tanh output range [-1, 1] for typical volatility
544
+ # e.g. 1% move = 0.01 * 100 = 1.0
545
+ target = df['log_ret'].shift(-forecast_horizon).fillna(0).values * 100
546
+
547
+ X, y = [], []
548
+ valid_indices = range(T, len(data) - forecast_horizon)
549
+
550
+ for i in valid_indices:
551
+ window = data[i-T:i] # (T, Features)
552
+ label = target[i] # (1,)
553
+
554
+ X.append(window)
555
+ y.append(label)
556
+
557
+ return torch.FloatTensor(np.array(X)), torch.FloatTensor(np.array(y)).unsqueeze(1)
558
+
559
+ def _robust_scale(self, data):
560
+ # Helper for legacy robust scale (non-rolling) if needed,
561
+ # or alias to rolling with large window for batch
562
+ # For now, implementing simple robust scale
563
+ median = np.median(data, axis=0)
564
+ q75 = np.percentile(data, 75, axis=0)
565
+ q25 = np.percentile(data, 25, axis=0)
566
+ iqr = q75 - q25
567
+ iqr[iqr == 0] = 1.0
568
+ return (data - median) / iqr
debug_causal.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import pickle
4
+ import os
5
+ import sys
6
+
7
+ # Add path to find models
8
+ sys.path.append(os.getcwd())
9
+
10
+ from models.causal_discovery import get_causal_model
11
+
12
+ # Mock Data
13
+ df = pd.DataFrame(np.random.randn(100, 5), columns=['open', 'high', 'low', 'close', 'volume'])
14
+ clean_df = df
15
+
16
+ print("Init Causal...")
17
+ model = get_causal_model()
18
+
19
+ print("Fit Causal...")
20
+ model.fit(clean_df)
21
+
22
+ print("Pickle Causal...")
23
+ with open("causal_debug.pkl", "wb") as f:
24
+ pickle.dump(model, f)
25
+
26
+ print("✅ Success")
debug_logic.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import types
4
+ from unittest.mock import MagicMock
5
+
6
+ # 1. Mock PyTorch Class BEFORE importing data_processor
7
+ class MockTorch:
8
+ def __init__(self):
9
+ self.float32 = "float32"
10
+ self.long = "long"
11
+
12
+ # Mock Tensor class for type hints
13
+ class MockTensor:
14
+ pass
15
+ self.Tensor = MockTensor
16
+
17
+ def FloatTensor(self, x): return x # Return numpy array directly
18
+ def LongTensor(self, x): return x
19
+ def tensor(self, x): return x
20
+ def randn(self, *args):
21
+ import numpy as np
22
+ return np.random.randn(*args)
23
+ def randint(self, low, high, size):
24
+ import numpy as np
25
+ return np.random.randint(low, high, size)
26
+ def stack(self, tensors, axis=0):
27
+ import numpy as np
28
+ return np.stack(tensors, axis=axis)
29
+
30
+ # 2. Inject into sys.modules
31
+ mock_torch = MockTorch()
32
+ sys.modules['torch'] = mock_torch
33
+
34
+ # 3. Now safe to import
35
+ import numpy as np
36
+ import pandas as pd
37
+ import os
38
+
39
+ # Add parent dir
40
+ sys.path.append(os.getcwd())
41
+
42
+ from data_processor import AlphaDataProcessor
43
+
44
+ def test_logic():
45
+ print("✅ Successfully imported data_processor with Mock Torch")
46
+
47
+ # 1. Simulate get_deeplob_tensors return
48
+ N, T, Levels = 10, 100, 20
49
+ # DataProcessor returns:
50
+ # X = stack([p_window, v_window], axis=0) -> (2, T, 2*Levels) for each sample
51
+ # Then np.array(X) -> (N, 2, T, 2*Levels)
52
+
53
+ # Mock simulating what data_processor actually builds
54
+ # (N samples, 2 channels, T timesteps, 2*Levels features)
55
+ mock_X = np.random.randn(N, 2, T, 2*Levels)
56
+
57
+ print(f"Mock DataProcessor Output Shape: {mock_X.shape}")
58
+
59
+ # 2. Verify DeepLOB Input Requirement
60
+ # DeepLOB Conv2d(2, 16, ...) expects (N, C, H, W) = (N, 2, T, Features)
61
+ # Our features = 2*Levels = 40
62
+
63
+ # If the shape is 4D: (N, 2, 100, 40) -> IT IS CORRECT
64
+ # If we unsqueeze(1) -> (N, 1, 2, 100, 40) -> 5D -> INCORRECT
65
+
66
+ if mock_X.ndim == 4 and mock_X.shape[1] == 2:
67
+ print("✅ Data Shape matches Conv2d Expectation (N, 2, T, F)")
68
+ print(" -> (N, Channels=2, Height=100, Width=40)")
69
+ print(" -> NO unsqueeze(1) needed!")
70
+ else:
71
+ print(f"❌ Data Shape Mismatch: {mock_X.shape}")
72
+
73
+ # 3. Verify TRM logic
74
+ # TRM needs (N, T, F)
75
+ # F = 6 (Vol, Imb, CVD, Spr, Mom, OFI)
76
+ N_trm, T_trm, F_trm = 10, 60, 6
77
+ mock_X_trm = np.random.randn(N_trm, T_trm, F_trm)
78
+
79
+ print(f"\nMock TRM Output Shape: {mock_X_trm.shape}")
80
+ if mock_X_trm.ndim == 3 and mock_X_trm.shape[2] == 6:
81
+ print("✅ TRM Shape matches Transformer Expectation (N, T, F)")
82
+ else:
83
+ print(f"❌ TRM Shape Mismatch: {mock_X_trm.shape}")
84
+
85
+ if __name__ == "__main__":
86
+ test_logic()
envs/nautilus_env.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium import spaces
3
+ import numpy as np
4
+ # from nautilus_trader.backtest.engine import BacktestEngine, BacktestEngineConfig
5
+ # from nautilus_trader.model.data import BarType
6
+ # from nautilus_trader.config import InstrumentId
7
+
8
+ # NOTE: NautilusTrader imports commented out - this is a stub environment
9
+ # Uncomment when implementing full RL environment
10
+
11
+ class NautilusExecutionEnv(gym.Env):
12
+ """
13
+ OpenAI/Gymnasium Environment for NautilusTrader.
14
+ Wraps the BacktestEngine to provide a step-by-step Interface for RL Agents.
15
+ """
16
+ def __init__(self, config: dict):
17
+ super().__init__()
18
+ self.config = config
19
+ self.instrument_id = config.get("instrument_id", "ETH-USDC-PERP")
20
+
21
+ # Define Observation Space (Features)
22
+ # Example: [RSI, Imbalance, Spread, PositionSize, PnL]
23
+ self.observation_space = spaces.Box(
24
+ low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32
25
+ )
26
+
27
+ # Define Action Space
28
+ # Example: Discrete(3) -> 0: Hold, 1: Buy, 2: Sell
29
+ # Or Continuous Box for Order Size / Price Offset
30
+ self.action_space = spaces.Discrete(3)
31
+
32
+ self.engine = None
33
+ self._setup_engine()
34
+
35
+ def _setup_engine(self):
36
+ """
37
+ Initializes the Nautilus Backtest Engine
38
+ """
39
+ # TODO: Load Data Catalog here
40
+ # engine_config = BacktestEngineConfig(strategies=[...])
41
+ # self.engine = BacktestEngine(config=engine_config)
42
+ pass
43
+
44
+ def reset(self, seed=None, options=None):
45
+ """
46
+ Resets the environment (restarts the backtest).
47
+ """
48
+ super().reset(seed=seed)
49
+ self._setup_engine()
50
+
51
+ # Get initial state
52
+ initial_obs = np.zeros(self.observation_space.shape, dtype=np.float32)
53
+ info = {}
54
+ return initial_obs, info
55
+
56
+ def step(self, action):
57
+ """
58
+ Advances the engine by one step (Bar/Tick).
59
+ Execute action -> engine.step() -> Get Reward -> Get Next State
60
+ """
61
+ # 1. Translate Action to Order
62
+ # if action == 1: self.strategy.buy()
63
+
64
+ # 2. Step Engine
65
+ # self.engine.run_next_step() ???
66
+ # Note: Nautilus is event-driven, not strictly step-based.
67
+ # We need to run the engine until the next 'decision point' (e.g. next bar).
68
+
69
+ # 3. Calculate Reward (PnL Change)
70
+ reward = 0.0
71
+
72
+ # 4. Get New State
73
+ obs = np.zeros(self.observation_space.shape, dtype=np.float32)
74
+ terminated = False
75
+ truncated = False
76
+ info = {}
77
+
78
+ return obs, reward, terminated, truncated, info
models/arbitrage_agent.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ArbitrageAgent:
2
+ """
3
+ Arbitrage Agent (Rule-Based).
4
+ Monitors Perp vs Spot prices.
5
+ """
6
+ def __init__(self, threshold=0.005):
7
+ self.threshold = threshold
8
+
9
+ def analyze(self, spot_price, perp_price, funding_rate):
10
+ """
11
+ Returns Action:
12
+ 0: Do Nothing
13
+ 1: Long Spot / Short Perp (Basis > Thresh)
14
+ 2: Short Spot / Long Perp (Basis < -Thresh)
15
+ """
16
+ basis = (perp_price - spot_price) / spot_price
17
+
18
+ # Funding Arbitrage
19
+ # If funding positive -> Shorts pay Longs. We want to be Short Perp.
20
+ # So we Long Spot, Short Perp.
21
+
22
+ if basis > self.threshold:
23
+ print(f"Arb Opportunity: Basis {basis:.4f} > {self.threshold}. Action: Long Spot / Short Perp")
24
+ return 1 # Cash and Carry
25
+
26
+ if basis < -self.threshold:
27
+ print(f"Arb Opportunity: Basis {basis:.4f} < -{self.threshold}. Action: Short Spot / Long Perp")
28
+ return 2 # Reverse Carry?
29
+
30
+ return 0
models/arbitrage_agent_20251210_155924.pkl ADDED
Binary file (84 Bytes). View file
 
models/causal_discovery.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ # Wrappers for Tigramite (PCMCI)
4
+ # Note: Tigramite must be installed in environment
5
+ try:
6
+ from tigramite import data_processing as pp
7
+ from tigramite.pcmci import PCMCI
8
+ from tigramite.independence_tests.parcorr import ParCorr
9
+ TIGRAMITE_AVAILABLE = True
10
+ except ImportError:
11
+ TIGRAMITE_AVAILABLE = False
12
+ print("Warning: Tigramite not found. Using Placeholder.")
13
+
14
+ class CausalDiscovery:
15
+ """
16
+ Causal Discovery using Tigramite (PCMCI).
17
+ Identifies causal links in time-series data using Partial Correlation (ParCorr).
18
+ Focuses on finding parents of key variables (e.g., Returns).
19
+ """
20
+ def __init__(self, alpha=0.05, max_lag=5):
21
+ self.alpha = alpha
22
+ self.max_lag = max_lag
23
+ self.results = None
24
+ self.graph = None
25
+
26
+ def fit(self, df: pd.DataFrame):
27
+ """
28
+ Fit PCMCI on the dataframe.
29
+ df: Pandas DataFrame (Time Series).
30
+ """
31
+ if not TIGRAMITE_AVAILABLE:
32
+ return self
33
+
34
+ # 1. Prepare Data
35
+ # Tigramite requires (T, N) numpy array
36
+ data = df.values
37
+ var_names = df.columns.tolist()
38
+
39
+ dataframe = pp.DataFrame(data,
40
+ var_names=var_names,
41
+ missing_flag=999)
42
+
43
+ # 2. Init PCMCI with ParCorr (Linear Partial Correlation)
44
+ # For non-linear, use GPDC or CMIknn (slower)
45
+ parcorr = ParCorr(significance='analytic')
46
+ pcmci = PCMCI(dataframe=dataframe, cond_ind_test=parcorr, verbosity=0)
47
+
48
+ # 3. Run PCMCI
49
+ # PC phase then MCI phase
50
+ self.results = pcmci.run_pcmci(tau_max=self.max_lag, pc_alpha=self.alpha)
51
+
52
+ # 4. Extract Graph (p_matrix < alpha)
53
+ # q_matrix handles FDR control, often better
54
+ q_matrix = self.results['q_matrix']
55
+ self.graph = q_matrix < self.alpha
56
+
57
+ return self
58
+
59
+ def get_feature_weights(self):
60
+ """
61
+ Calculate feature importance based on Causal Strength (Val Matrix)
62
+ or Degree in the Causal Graph.
63
+
64
+ Returns: normalized weights for each feature.
65
+ """
66
+ if not TIGRAMITE_AVAILABLE or self.results is None:
67
+ return np.ones(5) # Fallback
68
+
69
+ # We want to know which features cause 'Volatility' or 'Returns' (if present)
70
+ # Or simply generalized centrality.
71
+
72
+ val_matrix = np.abs(self.results['val_matrix']) # (N, N, Lags+1)
73
+ # Sum absolute causal strength across all lags for each link
74
+ # Shape: (N_features, N_features) - Strength of i -> j
75
+ strength_matrix = np.sum(val_matrix, axis=2)
76
+
77
+ # Total Outgoing Causal Strength (How much 'i' influences others)
78
+ out_strength = np.sum(strength_matrix, axis=1) # Sum over j
79
+
80
+ # Total Incoming Causal Strength (How much 'i' is influenced)
81
+ in_strength = np.sum(strength_matrix, axis=0)
82
+
83
+ # Hybrid Score: Drivers are important
84
+ score = out_strength + in_strength
85
+
86
+ # Normalize
87
+ if score.sum() == 0: return np.ones(len(score))
88
+
89
+ weights = score / score.max()
90
+ return np.maximum(weights, 0.2) # Min weight
91
+
92
+ def get_causal_model():
93
+ return CausalDiscovery(alpha=0.05, max_lag=3)
models/classic_ml.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hmmlearn.hmm import GaussianHMM
2
+ from sklearn.linear_model import LogisticRegression
3
+ from sklearn.ensemble import RandomForestClassifier
4
+ from sklearn.pipeline import Pipeline
5
+ from sklearn.preprocessing import StandardScaler
6
+ from sklearn.impute import SimpleImputer
7
+ import numpy as np
8
+
9
+ # 1. HMM (Regime Detection)
10
+ def get_hmm_pipeline(n_components=3):
11
+ return Pipeline([
12
+ ('imputer', SimpleImputer(strategy='mean')),
13
+ ('scaler', StandardScaler()),
14
+ ('hmm', GaussianHMM(n_components=n_components, covariance_type="full", n_iter=100))
15
+ ])
16
+
17
+ # 2. Logistic Regression (Confidence Scorer)
18
+ def get_logistic_pipeline():
19
+ return Pipeline([
20
+ ('imputer', SimpleImputer(strategy='mean')),
21
+ ('scaler', StandardScaler()),
22
+ ('clf', LogisticRegression(random_state=42, solver='liblinear'))
23
+ ])
24
+
25
+ # 3. Random Forest (Trade Filter)
26
+ def get_rf_pipeline():
27
+ return Pipeline([
28
+ ('imputer', SimpleImputer(strategy='mean')),
29
+ ('scaler', StandardScaler()),
30
+ ('clf', RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42, class_weight='balanced'))
31
+ ])
32
+
33
+ # 4. Causal Discovery (Real PC Algorithm)
34
+ from models.causal_discovery import get_causal_model as get_pc_model
35
+
36
+ def get_causal_model():
37
+ return get_pc_model()
38
+
models/deeplob.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class InceptionModule(nn.Module):
5
+ def __init__(self, in_channels, out_channels):
6
+ super(InceptionModule, self).__init__()
7
+
8
+ # Parallel Convolutions
9
+ # Branch 1: 1x1 Conv
10
+ self.branch1 = nn.Sequential(
11
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
12
+ nn.LeakyReLU(negative_slope=0.01),
13
+ nn.BatchNorm2d(out_channels)
14
+ )
15
+
16
+ # Branch 2: 1x3 Conv
17
+ self.branch2 = nn.Sequential(
18
+ nn.Conv2d(in_channels, out_channels, kernel_size=(1,1)),
19
+ nn.LeakyReLU(negative_slope=0.01),
20
+ nn.Conv2d(out_channels, out_channels, kernel_size=(1,3), padding=(0,1)),
21
+ nn.LeakyReLU(negative_slope=0.01),
22
+ nn.BatchNorm2d(out_channels)
23
+ )
24
+
25
+ # Branch 3: 1x5 Conv
26
+ self.branch3 = nn.Sequential(
27
+ nn.Conv2d(in_channels, out_channels, kernel_size=(1,1)),
28
+ nn.LeakyReLU(negative_slope=0.01),
29
+ nn.Conv2d(out_channels, out_channels, kernel_size=(1,5), padding=(0,2)),
30
+ nn.LeakyReLU(negative_slope=0.01),
31
+ nn.BatchNorm2d(out_channels)
32
+ )
33
+
34
+ # Branch 4: MaxPool + 1x1
35
+ self.branch4 = nn.Sequential(
36
+ nn.MaxPool2d(kernel_size=(1,3), stride=1, padding=(0,1)),
37
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
38
+ nn.LeakyReLU(negative_slope=0.01),
39
+ nn.BatchNorm2d(out_channels)
40
+ )
41
+
42
+ def forward(self, x):
43
+ b1 = self.branch1(x)
44
+ b2 = self.branch2(x)
45
+ b3 = self.branch3(x)
46
+ b4 = self.branch4(x)
47
+ return torch.cat([b1, b2, b3, b4], dim=1)
48
+
49
+ class SEBlock(nn.Module):
50
+ """
51
+ Squeeze-and-Excitation Block (Channel Attention).
52
+ Recalibrates feature maps adaptively.
53
+ Hu et al. (2018).
54
+ """
55
+ def __init__(self, channels, reduction=16):
56
+ super(SEBlock, self).__init__()
57
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # Squeeze
58
+ self.fc = nn.Sequential(
59
+ nn.Linear(channels, channels // reduction, bias=False),
60
+ nn.ReLU(inplace=True),
61
+ nn.Linear(channels // reduction, channels, bias=False),
62
+ nn.Sigmoid()
63
+ )
64
+
65
+ def forward(self, x):
66
+ b, c = x.shape[0], x.shape[1]
67
+ # b, c, _, _ = x.size() # This tuple unpacking can fail in ONNX tracing
68
+ y = self.avg_pool(x).view(b, c) # Mean per channel
69
+ y = self.fc(y).view(b, c, 1, 1) # Attention weights
70
+ return x * y.expand_as(x) # Scale features
71
+
72
+ class DeepLOB(nn.Module):
73
+ """
74
+ DeepLOB with Inception + SE-Block Attention (Academic Standard).
75
+ """
76
+ def __init__(self, y_len=3):
77
+ super().__init__()
78
+ self.y_len = y_len
79
+
80
+ # Initial Blocks
81
+ self.block1 = nn.Sequential(
82
+ nn.Conv2d(2, 16, kernel_size=(1,2), stride=(1,2)), # Input channels=2 (P, V)
83
+ nn.LeakyReLU(negative_slope=0.01),
84
+ nn.BatchNorm2d(16),
85
+ InceptionModule(16, 8), # Out: 32
86
+ SEBlock(32) # Attention
87
+ )
88
+
89
+ self.block2 = nn.Sequential(
90
+ nn.Conv2d(32, 16, kernel_size=(1,2), stride=(1,2)),
91
+ nn.LeakyReLU(negative_slope=0.01),
92
+ nn.BatchNorm2d(16),
93
+ InceptionModule(16, 8), # Out: 32
94
+ SEBlock(32) # Attention
95
+ )
96
+
97
+ self.block3 = nn.Sequential(
98
+ nn.Conv2d(32, 16, kernel_size=(1,10)),
99
+ nn.LeakyReLU(negative_slope=0.01),
100
+ nn.BatchNorm2d(16),
101
+ InceptionModule(16, 8), # Out: 32
102
+ SEBlock(32) # Attention
103
+ )
104
+
105
+ # LSTM
106
+ self.lstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=1, batch_first=True)
107
+ self.fc = nn.Linear(64, y_len)
108
+
109
+ def forward(self, x):
110
+ # x: (N, 2, 100, 40)
111
+
112
+ x = self.block1(x)
113
+ x = self.block2(x)
114
+ x = self.block3(x)
115
+
116
+ # Reshape for LSTM: (N, T, Features)
117
+ x = x.permute(0, 2, 1, 3)
118
+ x = x.reshape(x.shape[0], x.shape[1], -1)
119
+
120
+ if x.dim() == 2:
121
+ x = x.unsqueeze(0)
122
+
123
+ # Explicit Init for ONNX
124
+ h0 = torch.zeros(1, x.size(0), 64).to(x.device)
125
+ c0 = torch.zeros(1, x.size(0), 64).to(x.device)
126
+
127
+ out, _ = self.lstm(x, (h0, c0))
128
+ out = out[:, -1, :]
129
+ out = self.fc(out)
130
+
131
+ return out
models/execution_agent.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class PPOActorCritic(nn.Module):
6
+ """
7
+ Execution Agent using PPO (Proximal Policy Optimization).
8
+ Input: [Signal (1), L2_Imbalance (1), Spread (1), Position_Net (1), Volatility (1)] -> 5 Dim
9
+ Action Space:
10
+ - Type: Limit (0) vs Market (1) -> Categorical(2)
11
+ - Price Offset: Continuous (Gaussian)
12
+ - Size: Continuous (Gaussian, 0-1 ratio)
13
+ """
14
+ def __init__(self, input_dim=5, action_dim=3):
15
+ super(PPOActorCritic, self).__init__()
16
+ self.input_dim = input_dim
17
+
18
+ # Shared Feature Extractor
19
+ self.common = nn.Linear(input_dim, 64)
20
+
21
+ # Actor Heads
22
+ # 1. Order Type (Discrete)
23
+ self.actor_type = nn.Sequential(
24
+ nn.Linear(64, 32),
25
+ nn.Linear(32, 2),
26
+ nn.Softmax(dim=-1)
27
+ )
28
+ # 2. Price Offset (Continuous) - Mu, Sigma
29
+ self.actor_offset_mu = nn.Sequential(nn.Linear(64, 32), nn.Linear(32, 1), nn.Tanh())
30
+ self.actor_offset_sigma = nn.Parameter(torch.zeros(1))
31
+
32
+ # 3. Size (Continuous) - Mu, Sigma
33
+ self.actor_size_mu = nn.Sequential(nn.Linear(64, 32), nn.Linear(32, 1), nn.Sigmoid())
34
+
35
+ # Critic Head (Value Function)
36
+ self.critic = nn.Sequential(
37
+ nn.Linear(64, 32),
38
+ nn.Linear(32, 1)
39
+ )
40
+
41
+ def forward(self, x):
42
+ x = F.relu(self.common(x))
43
+
44
+ probs_type = self.actor_type(x)
45
+ mu_offset = self.actor_offset_mu(x)
46
+ mu_size = self.actor_size_mu(x)
47
+ value = self.critic(x)
48
+
49
+ return probs_type, mu_offset, mu_size, value
models/execution_agent_20251210_155924.pt ADDED
Binary file (43 kB). View file
 
models/lstm.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class AlphaLSTM(nn.Module):
5
+ """
6
+ Simple LSTM model for predicting price returns.
7
+ Input: (Batch, Seq_Len, Features)
8
+ Output: (Batch, 1) -> Predicted Next Return
9
+ """
10
+ def __init__(self, input_size=4, hidden_size=64, num_layers=2, dropout=0.2):
11
+ super(AlphaLSTM, self).__init__()
12
+ self.hidden_size = hidden_size
13
+ self.num_layers = num_layers
14
+
15
+ self.lstm = nn.LSTM(
16
+ input_size=input_size,
17
+ hidden_size=hidden_size,
18
+ num_layers=num_layers,
19
+ batch_first=True,
20
+ dropout=dropout
21
+ )
22
+
23
+ # Fully Connected Layer
24
+ self.fc = nn.Linear(hidden_size, 1)
25
+ self.activation = nn.Tanh() # Output range [-1, 1] suited for "Signal"
26
+
27
+ def forward(self, x):
28
+ # Initialize hidden state with zeros
29
+ h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
30
+ c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
31
+
32
+ # Forward propagate LSTM
33
+ out, _ = self.lstm(x, (h0, c0))
34
+
35
+ # Decode the hidden state of the last time step
36
+ out = self.fc(out[:, -1, :])
37
+ out = self.activation(out)
38
+ return out
models/meta_controller.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class DQN(nn.Module):
6
+ """
7
+ Meta-Controller using DQN (Deep Q-Network).
8
+ Input: [Volatility(1), Market_Regime(3 - OneHot), Global_PnL_Trend(1)] -> 5 Dim
9
+ Output: Q-Values for Actions (3)
10
+ 0: FollowTrend Agent
11
+ 1: MeanReversion Agent
12
+ 2: Defensive Mode (Cash)
13
+ """
14
+ def __init__(self, input_dim=5, output_dim=3):
15
+ super(DQN, self).__init__()
16
+ self.net = nn.Sequential(
17
+ nn.Linear(input_dim, 64),
18
+ nn.LeakyReLU(),
19
+ nn.Linear(64, 64),
20
+ nn.LeakyReLU(),
21
+ nn.Linear(64, output_dim)
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.net(x)
models/meta_controller_20251210_155924.pt ADDED
Binary file (22.4 kB). View file
 
models/risk_agent.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class RiskAgent:
2
+ """
3
+ Risk Agent (Guardian).
4
+ Hard-coded rules to prevent catastrophe.
5
+ """
6
+ def __init__(self, max_dd=0.15, max_hourly_loss=200):
7
+ self.max_dd = max_dd
8
+ self.max_hourly_loss = max_hourly_loss
9
+ self.current_dd = 0.0
10
+ self.hourly_loss = 0.0
11
+
12
+ def check_health(self, equity, initial_equity, recent_pnl):
13
+ """
14
+ Returns boolean: True (Healthy), False (Stop Trading).
15
+ """
16
+ # Update DD
17
+ peak_equity = max(equity, initial_equity)
18
+ self.current_dd = (peak_equity - equity) / peak_equity
19
+
20
+ # Check Rules
21
+ if self.current_dd > self.max_dd:
22
+ print(f"RISK TRIGGER: Max Drawdown {self.current_dd:.2%} > {self.max_dd:.2%}")
23
+ return False
24
+
25
+ if recent_pnl < -self.hourly_loss:
26
+ print(f"RISK TRIGGER: Hourly Loss {recent_pnl} > {self.max_hourly_loss}")
27
+ return False
28
+
29
+ return True
models/risk_agent_20251210_155924.pkl ADDED
Binary file (137 Bytes). View file
 
models/trm.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class TinyRecursiveCore(nn.Module):
6
+ """
7
+ The Shared Core Network for TRM.
8
+ Refines the latent state z_t given the context c.
9
+ z_{t+1} = f(z_t, c)
10
+ """
11
+ def __init__(self, hidden_dim, context_dim):
12
+ super(TinyRecursiveCore, self).__init__()
13
+ # Input is concatenation of current state and context
14
+ self.net = nn.Sequential(
15
+ nn.Linear(hidden_dim + context_dim, hidden_dim),
16
+ nn.LeakyReLU(negative_slope=0.01),
17
+ nn.LayerNorm(hidden_dim),
18
+ nn.Linear(hidden_dim, hidden_dim),
19
+ nn.LeakyReLU(negative_slope=0.01),
20
+ nn.LayerNorm(hidden_dim)
21
+ )
22
+
23
+ def forward(self, z, c):
24
+ combined = torch.cat([z, c], dim=1)
25
+ # Residual connection ideally, but raw update usually fine for short loops.
26
+ delta_z = self.net(combined)
27
+ return z + delta_z # Residual State Update
28
+
29
+ class TRM(nn.Module):
30
+ """
31
+ Samsung SAIL "Tiny Recursive Model" (TRM).
32
+ Replaces massive parameter count with a recursive reasoning loop.
33
+ Architecture:
34
+ 1. Encoder: Maps Input (Batch, Seq, Feat) -> Context Vector 'c'.
35
+ 2. Initialization: init state z_0 = c (or learned).
36
+ 3. Recursion: Apply TinyCore N times: z_{k+1} = Core(z_k, c).
37
+ 4. Decoder: Map z_N -> Output Classes.
38
+ """
39
+ def __init__(self, input_size=5, hidden_dim=64, num_classes=3, recur_steps=5):
40
+ super(TRM, self).__init__()
41
+ self.recur_steps = recur_steps
42
+ self.hidden_dim = hidden_dim
43
+
44
+ # Encoder: Flatten sequence? Or use simple LSTM/Linear to get Context?
45
+ # To keep it "Tiny", let's use a simple Linear projection of the Flattened window.
46
+ # Assuming T=60, Feat=5 -> 300 dim flat.
47
+ # We can make it dynamic by using AdaptiveAvgPool or assuming input is (Batch, Seq, Feat)
48
+
49
+ # Simple Encoder: (Batch, Seq, Feat) -> (Batch, Hidden)
50
+ # Using LSTM to summarize temporal context first is robust.
51
+ self.encoder = nn.LSTM(input_size, hidden_dim, batch_first=True)
52
+
53
+ # The Tiny Shared Core
54
+ self.core = TinyRecursiveCore(hidden_dim, hidden_dim) # z_dim=hidden, c_dim=hidden
55
+
56
+ # Decoder (Head)
57
+ self.decoder = nn.Linear(hidden_dim, num_classes)
58
+ self.softmax = nn.Softmax(dim=1)
59
+
60
+ def forward(self, x):
61
+ # x: (Batch, Seq, Features)
62
+
63
+ # 1. Encode Context
64
+ # LSTM output: (Batch, Seq, Hidden), (h_n, c_n)
65
+
66
+ # Explicit Init for ONNX
67
+ h0 = torch.zeros(1, x.size(0), self.hidden_dim).to(x.device)
68
+ c0 = torch.zeros(1, x.size(0), self.hidden_dim).to(x.device)
69
+
70
+ _, (h_n, _) = self.encoder(x, (h0, c0))
71
+ context = h_n[-1] # Shape (Batch, Hidden)
72
+
73
+ # 2. Reasoning Loop
74
+ z = context.clone() # Initialize state with context
75
+
76
+ for _ in range(self.recur_steps):
77
+ z = self.core(z, context)
78
+
79
+ # 3. Readout
80
+ out = self.decoder(z)
81
+ return self.softmax(out)
nautilus_trader_source ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit a1adc1b7b44aa620b1191e31547c9a5ac3b82ba2
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch
3
+ gymnasium
4
+ pandas
5
+ numpy
6
+ msgpack
7
+ pyarrow
8
+ scikit-learn
9
+ tensorboard
10
+ matplotlib
11
+ hmmlearn
12
+ scipy
13
+ networkx
14
+ tigramite
15
+ huggingface_hub
16
+ onnx
17
+ onnxscript
18
+ ray[rllib]
19
+ datasets
run_dev_loop.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # NautilusTrainer Dev Loop Orchestrator
4
+ # Usage: ./run_dev_loop.sh
5
+
6
+ echo "🔄 Starting NautilusTrainer Dev Loop..."
7
+ echo "----------------------------------------"
8
+
9
+ # 1. Set Environment
10
+ cd "$(dirname "$0")"
11
+
12
+ # Activate venv from parent directory if exists
13
+ if [ -d "../.venv" ]; then
14
+ source ../.venv/bin/activate
15
+ fi
16
+
17
+ export PYTHONPATH=$PYTHONPATH:$(pwd)
18
+
19
+ # 2. Run Tests
20
+ python3 tests/test_pipeline.py
21
+
22
+ EXIT_CODE=$?
23
+
24
+ echo "----------------------------------------"
25
+ if [ $EXIT_CODE -eq 0 ]; then
26
+ echo "✅ SUCCESS: Pipeline verified."
27
+ else
28
+ echo "❌ FAILURE: Pipeline failed."
29
+ fi
30
+
31
+ exit $EXIT_CODE
run_prod_docker.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # NautilusTrainer Production Simulation
4
+ # Builds the Docker container and runs it exactly as HF Space would.
5
+
6
+ APP_NAME="nautilus-trainer-prod"
7
+
8
+ echo "🐳 Building Docker Image..."
9
+ docker build -t $APP_NAME ./NautilusTrainer
10
+
11
+ if [ $? -ne 0 ]; then
12
+ echo "❌ Docker Build Failed."
13
+ exit 1
14
+ fi
15
+
16
+ echo "✅ Build Success."
17
+
18
+ # Check for HF_TOKEN
19
+ if [ -z "$HF_TOKEN" ]; then
20
+ echo "⚠️ WARNING: HF_TOKEN is not set in your environment."
21
+ echo " Training will likely fail when trying to stream data."
22
+ echo " Usage: HF_TOKEN=hf_... ./run_prod_docker.sh"
23
+ read -p " Do you want to continue anyway? (y/n) " -n 1 -r
24
+ echo
25
+ if [[ ! $REPLY =~ ^[Yy]$ ]]; then
26
+ exit 1
27
+ fi
28
+ fi
29
+
30
+ echo "🚀 Running Container on port 7860..."
31
+ echo " Access Dashboard at: http://localhost:7860"
32
+ echo " Press Ctrl+C to stop."
33
+
34
+ docker run --rm -p 7860:7860 \
35
+ -e HF_TOKEN=$HF_TOKEN \
36
+ $APP_NAME
37
+
scheduler.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import schedule
3
+ from datetime import datetime
4
+ import auto_train
5
+
6
+ def run_pipeline():
7
+ print(f"\n⏰ Scheduler: Starting Training Pipeline at {datetime.now().isoformat()}...")
8
+ try:
9
+ auto_train.main()
10
+ except Exception as e:
11
+ print(f"❌ Scheduler Error: {e}")
12
+ print(f"💤 Scheduler: Sleeping for 1 hour...")
13
+
14
+ if __name__ == "__main__":
15
+ print("🚀 NautilusAI Training Scheduler Started")
16
+ print("📅 Schedule: Run every 60 minutes")
17
+
18
+ # Run immediately on startup
19
+ run_pipeline()
20
+
21
+ # Schedule subsequent runs
22
+ schedule.every(60).minutes.do(run_pipeline)
23
+
24
+ while True:
25
+ schedule.run_pending()
26
+ time.sleep(60)
streaming_loader.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import ast
5
+ import os
6
+ import json
7
+ from typing import Iterator, Tuple
8
+ from datasets import load_dataset
9
+ from data_processor import AlphaDataProcessor
10
+ import gc
11
+
12
+ class StreamingDataLoader:
13
+ """
14
+ Streams training data directly from HuggingFace Datasets without downloading.
15
+ Buffers chunks to enable rolling window operations.
16
+ """
17
+ def __init__(self,
18
+ repo_id: str = "gionuibk/hyperliquid-data",
19
+ model_type: str = "deeplob",
20
+ batch_size: int = 32,
21
+ chunk_size: int = 500, # Reduced to ensure frequent yields
22
+ buffer_size: int = 200): # Reduced buffer
23
+ """
24
+ Args:
25
+ repo_id: HF Dataset ID
26
+ model_type: 'deeplob' or 'trm'
27
+ batch_size: Training batch size
28
+ chunk_size: Rows per processing chunk
29
+ buffer_size: Overlap size to maintain rolling stats continuity
30
+ """
31
+ self.repo_id = repo_id
32
+ self.model_type = model_type
33
+ self.batch_size = batch_size
34
+ self.chunk_size = chunk_size
35
+ self.buffer_size = buffer_size
36
+
37
+ self.processor = AlphaDataProcessor()
38
+
39
+ def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
40
+ """
41
+ Yields batches of (X, y) tensors from the stream.
42
+ """
43
+ print(f"📡 Connecting to HF Dataset Stream: {self.repo_id}")
44
+ token = os.environ.get("HF_TOKEN")
45
+
46
+ try:
47
+ # MANUAL LOADING Mode (Bypassing datasets library due to Arrow/Parquet errors)
48
+ from huggingface_hub import HfApi, hf_hub_download
49
+ api = HfApi(token=token)
50
+
51
+ # 1. List files
52
+ # 1. List files
53
+ print("🔍 Listing files...")
54
+ files = api.list_repo_files(repo_id=self.repo_id, repo_type="dataset")
55
+
56
+ if self.model_type == "lstm":
57
+ # Use Bar Data for LSTM (Support both v1 'data/bar/' and v2 'data/candles/')
58
+ target_files = [
59
+ f for f in files
60
+ if (f.startswith("data/bar/") or f.startswith("data/candles/"))
61
+ and f.endswith(".parquet")
62
+ ]
63
+ print(f"📂 Found {len(target_files)} Bar/Candle files for LSTM.")
64
+ else:
65
+ # Use L2 Snapshots for DeepLOB/TRM (Support both v1 'order_book_snapshot' and v2 'l2book')
66
+ target_files = [
67
+ f for f in files
68
+ if ("order_book_snapshot" in f or "l2book" in f)
69
+ and f.endswith(".parquet")
70
+ ]
71
+ print(f"📂 Found {len(target_files)} Snapshot/L2Book files for {self.model_type}.")
72
+
73
+ # Buffer for rolling operations
74
+ buffer_df = pd.DataFrame()
75
+ chunk_rows = []
76
+
77
+ total_loaded_rows = 0
78
+
79
+ for file_path in target_files:
80
+ try:
81
+ print(f"⬇️ Downloading {file_path}...")
82
+ # Download to temp dir to avoid cache filling
83
+ temp_dir = "./temp_data"
84
+ os.makedirs(temp_dir, exist_ok=True)
85
+
86
+ local_path = hf_hub_download(
87
+ repo_id=self.repo_id,
88
+ filename=file_path,
89
+ repo_type="dataset",
90
+ token=token,
91
+ local_dir=temp_dir,
92
+ local_dir_use_symlinks=False,
93
+ force_download=True # Ensure we have a fresh copy to delete later
94
+ )
95
+
96
+ print(f"📖 Reading {file_path}...")
97
+ # Read parquet directly using pandas (robust)
98
+ try:
99
+ df = pd.read_parquet(local_path)
100
+ except BaseException as e: # Catch EVERYTHING including OSError
101
+ print(f"⚠️ Parquet Read Failed for {file_path}: {e}")
102
+ continue
103
+
104
+ rows_in_file = len(df)
105
+ print(f"✅ Loaded {rows_in_file} rows from {file_path}")
106
+ total_loaded_rows += rows_in_file
107
+
108
+ total_loaded_rows += rows_in_file
109
+
110
+ # Iterate rows in the dataframe
111
+ for i, row in df.iterrows():
112
+ # Parse L2 columns (Support both nested lists and flat columns)
113
+ if 'bids' in row and isinstance(row['bids'], str):
114
+ try: row['bids'] = ast.literal_eval(row['bids'])
115
+ except: pass
116
+ if 'asks' in row and isinstance(row['asks'], str):
117
+ try: row['asks'] = ast.literal_eval(row['asks'])
118
+ except: pass
119
+
120
+ # Handle Flat Format (bid_px_1, bid_sz_1, ...)
121
+ if 'bids' not in row and 'bid_px_1' in row:
122
+ bids = []
123
+ asks = []
124
+ for level in range(1, 21): # Support up to 20 levels
125
+ if f'bid_px_{level}' in row:
126
+ bids.append([row[f'bid_px_{level}'], row[f'bid_sz_{level}']])
127
+ if f'ask_px_{level}' in row:
128
+ asks.append([row[f'ask_px_{level}'], row[f'ask_sz_{level}']])
129
+ row['bids'] = bids
130
+ row['asks'] = asks
131
+
132
+ # Pandas iterrows returns (index, Series), we want the Series/dict
133
+ # Append as dict for processing
134
+ chunk_rows.append(row.to_dict())
135
+
136
+ if len(chunk_rows) >= self.chunk_size:
137
+ # Process and yield chunk
138
+ yield from self._process_chunk(chunk_rows, buffer_df)
139
+
140
+ # Update Buffer from new chunk
141
+ new_df = pd.DataFrame(chunk_rows)
142
+ buffer_df = new_df.tail(self.buffer_size)
143
+ chunk_rows = []
144
+ gc.collect()
145
+
146
+ except Exception as e:
147
+ print(f"⚠️ Failed to process file {file_path}: {e}")
148
+
149
+ finally:
150
+ # CRITICAL: Clean up file immediately to save disk space
151
+ if 'local_path' in locals() and os.path.exists(local_path):
152
+ try:
153
+ # Verify it's a file before removing (safety)
154
+ if os.path.isfile(local_path):
155
+ os.remove(local_path)
156
+ except: pass
157
+
158
+ # Process remaining rows after all files
159
+ if len(chunk_rows) > 0:
160
+ print(f"🧹 Processing final residual chunk ({len(chunk_rows)} rows)...")
161
+ yield from self._process_chunk(chunk_rows, buffer_df)
162
+
163
+ except Exception as e:
164
+ print(f"⚠️ Manual Loading Error: {e}")
165
+ import traceback
166
+ traceback.print_exc()
167
+
168
+ def _process_chunk(self, chunk_rows, buffer_df):
169
+ # Helper to process a chunk and yield batches
170
+ new_df = pd.DataFrame(chunk_rows)
171
+
172
+ # Merge with buffer (previous context)
173
+ if not buffer_df.empty:
174
+ combined_df = pd.concat([buffer_df, new_df])
175
+ else:
176
+ combined_df = new_df
177
+
178
+ # Process
179
+ if self.model_type == "deeplob":
180
+ X, y = self.processor.get_deeplob_tensors_from_df(combined_df)
181
+ elif self.model_type == "trm":
182
+ X, y = self.processor.get_trm_tensors_from_df(combined_df)
183
+ elif self.model_type == "lstm":
184
+ X, y = self.processor.get_lstm_tensors_from_df(combined_df)
185
+ else:
186
+ raise ValueError(f"Unknown model type: {self.model_type}")
187
+
188
+ # Yield batches
189
+ if len(X) > 0:
190
+ dataset_size = len(X)
191
+ indices = torch.randperm(dataset_size)
192
+ X = X[indices]
193
+ y = y[indices]
194
+
195
+ for k in range(0, dataset_size, self.batch_size):
196
+ batch_X = X[k:k+self.batch_size]
197
+ batch_y = y[k:k+self.batch_size]
198
+ if len(batch_X) == self.batch_size:
199
+ yield batch_X, batch_y
200
+
201
+ def get_sample_batch(self) -> Tuple[torch.Tensor, torch.Tensor]:
202
+ for batch_X, batch_y in self:
203
+ return batch_X, batch_y
204
+ raise RuntimeError("Stream empty or failed")
tests/mock_data.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import time
4
+ from typing import List, Dict
5
+
6
+ class MockDataGenerator:
7
+ """Generates synthetic L2 Orderbook and Trade data for testing."""
8
+
9
+ @staticmethod
10
+ def generate_l2_snapshot(num_rows: int = 100, levels: int = 20) -> pd.DataFrame:
11
+ """
12
+ Generates a DataFrame mimicking the L2 Snapshot structure.
13
+ Columns: ts_event, instrument_id, bids, asks
14
+ """
15
+ base_price = 2000.0
16
+ data = []
17
+
18
+ start_time = time.time() * 1000
19
+
20
+ for i in range(num_rows):
21
+ ts = start_time + i * 1000 # 1 sec intervals
22
+
23
+ # Random Walk Price
24
+ noise = np.random.normal(0, 1)
25
+ mid_price = base_price + noise
26
+ base_price = mid_price
27
+
28
+ # Generate Levels
29
+ bids = []
30
+ asks = []
31
+
32
+ for l in range(levels):
33
+ spread = (l + 1) * 0.5
34
+ bid_p = mid_price - spread
35
+ ask_p = mid_price + spread
36
+
37
+ bid_sz = abs(np.random.normal(10, 5)) + 1
38
+ ask_sz = abs(np.random.normal(10, 5)) + 1
39
+
40
+ bids.append([bid_p, bid_sz])
41
+ asks.append([ask_p, ask_sz])
42
+
43
+ data.append({
44
+ "ts_event": ts,
45
+ "instrument_id": "ETH-USD",
46
+ "bids": bids, # List of lists format
47
+ "asks": asks
48
+ })
49
+
50
+ return pd.DataFrame(data)
51
+
52
+ @staticmethod
53
+ def generate_trades(num_rows: int = 100) -> pd.DataFrame:
54
+ """
55
+ Generates synthetic trade data.
56
+ Columns: time, coin, px, sz, side
57
+ """
58
+ base_price = 2000.0
59
+ data = []
60
+ start_time = time.time() * 1000
61
+
62
+ for i in range(num_rows):
63
+ ts = start_time + i * 500
64
+ px = base_price + np.random.normal(0, 1)
65
+ sz = abs(np.random.normal(1, 0.5))
66
+ side = 'B' if np.random.random() > 0.5 else 'A'
67
+
68
+ data.append({
69
+ "time": ts,
70
+ "coin": "ETH",
71
+ "px": px,
72
+ "sz": sz,
73
+ "side": side
74
+ })
75
+
76
+ return pd.DataFrame(data)
tests/test_loader_v2.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import unittest
4
+ import torch
5
+ # Add parent dir to path to import NautilusTrainer modules
6
+ sys.path.append("..")
7
+ from streaming_loader import StreamingDataLoader
8
+
9
+ class TestLoaderV2(unittest.TestCase):
10
+ def test_load_l2_snapshot(self):
11
+ print("\nTesting L2 Snapshot Load (v2)...")
12
+ loader = StreamingDataLoader(
13
+ repo_id="gionuibk/hyperliquidL2Book-v2",
14
+ model_type="deeplob",
15
+ batch_size=4,
16
+ chunk_size=1000,
17
+ buffer_size=1000
18
+ )
19
+
20
+ try:
21
+ batch_X, batch_y = loader.get_sample_batch()
22
+ print(f"✅ Success! X shape: {batch_X.shape}, y shape: {batch_y.shape}")
23
+ self.assertIsInstance(batch_X, torch.Tensor)
24
+ self.assertIsInstance(batch_y, torch.Tensor)
25
+ except Exception as e:
26
+ self.fail(f"Failed to load L2 snapshot: {e}")
27
+
28
+ def test_load_candles(self):
29
+ print("\nTesting Candle Load (v2)...")
30
+ loader = StreamingDataLoader(
31
+ repo_id="gionuibk/hyperliquidL2Book-v2",
32
+ model_type="lstm",
33
+ batch_size=4,
34
+ chunk_size=1000,
35
+ buffer_size=1000
36
+ )
37
+
38
+ try:
39
+ batch_X, batch_y = loader.get_sample_batch()
40
+ print(f"✅ Success! X shape: {batch_X.shape}, y shape: {batch_y.shape}")
41
+ self.assertIsInstance(batch_X, torch.Tensor)
42
+ self.assertIsInstance(batch_y, torch.Tensor)
43
+ except Exception as e:
44
+ self.fail(f"Failed to load Candles: {e}")
45
+
46
+ if __name__ == '__main__':
47
+ unittest.main()
tests/test_pipeline.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+
7
+ # Add parent directory to path to import modules
8
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
+
10
+ from data_processor import AlphaDataProcessor
11
+ from models.deeplob import DeepLOB
12
+ from tests.mock_data import MockDataGenerator
13
+
14
+ def test_deeplob_pipeline():
15
+ print("🧪 Starting DeepLOB Pipeline Test...")
16
+
17
+ # 1. Generate Mock Data
18
+ print(" Generatng mock data...", end="")
19
+ df = MockDataGenerator.generate_l2_snapshot(num_rows=2000, levels=20)
20
+ print("Done.")
21
+
22
+ # 2. Process Data
23
+ print(" Processing tensors...", end="")
24
+ processor = AlphaDataProcessor()
25
+ # T=100 is standard
26
+ X, y = processor.get_deeplob_tensors_from_df(df, T=100, levels=20)
27
+
28
+ print(f"Done. Shape: X={X.shape}, y={y.shape}")
29
+
30
+ if len(X) == 0:
31
+ print("❌ Error: No tensors generated from mock data.")
32
+ sys.exit(1)
33
+
34
+ # 3. Model Init
35
+ print(" Initializing DeepLOB...", end="")
36
+ model = DeepLOB(y_len=3)
37
+ criterion = nn.CrossEntropyLoss()
38
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
39
+ print("Done.")
40
+
41
+ # 4. Training Step (Forward + Backward)
42
+ print(" Running training step...", end="")
43
+ try:
44
+ model.train()
45
+
46
+ # Take a small batch
47
+ batch_size = 8
48
+ if len(X) < 8: batch_size = len(X)
49
+
50
+ batch_X = X[:batch_size]
51
+ batch_y = y[:batch_size]
52
+
53
+ optimizer.zero_grad()
54
+ outputs = model(batch_X)
55
+
56
+ loss = criterion(outputs, batch_y)
57
+ loss.backward()
58
+ optimizer.step()
59
+
60
+ print(f"Done. Loss: {loss.item():.4f}")
61
+ print("✅ DeepLOB Pipeline Test Passed!")
62
+
63
+ except Exception as e:
64
+ print(f"\n❌ Error during training step: {e}")
65
+ import traceback
66
+ traceback.print_exc()
67
+ from models.trm import TRM
68
+
69
+ def test_trm_pipeline():
70
+ print("\n🧪 Starting TRM Pipeline Test...")
71
+
72
+ # 1. Generate Mock Data
73
+ print(" Generatng mock data...", end="")
74
+ df = MockDataGenerator.generate_l2_snapshot(num_rows=2000, levels=20)
75
+ print("Done.")
76
+
77
+ # 2. Process Data
78
+ print(" Processing transformers...", end="")
79
+ processor = AlphaDataProcessor()
80
+ # TRM uses 6 features
81
+ X, y = processor.get_trm_tensors_from_df(df, T=60)
82
+
83
+ print(f"Done. Shape: X={X.shape}, y={y.shape}")
84
+
85
+ if len(X) == 0:
86
+ print("❌ Error: No tensors generated for TRM.")
87
+ return
88
+
89
+ # 3. Model Init
90
+ print(" Initializing TRM...", end="")
91
+ # Check TRM signature from file if needed, assuming (input_size=6, num_classes=3) based on auto_train.py
92
+ model = TRM(input_size=6, num_classes=3)
93
+ criterion = nn.CrossEntropyLoss()
94
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
95
+ print("Done.")
96
+
97
+ # 4. Training Step
98
+ print(" Running training step...", end="")
99
+ try:
100
+ model.train()
101
+ batch_size = 8
102
+ if len(X) < 8: batch_size = len(X)
103
+
104
+ batch_X = X[:batch_size]
105
+ batch_y = y[:batch_size]
106
+
107
+ optimizer.zero_grad()
108
+ outputs = model(batch_X)
109
+
110
+ loss = criterion(outputs, batch_y)
111
+ loss.backward()
112
+ optimizer.step()
113
+
114
+ print(f"Done. Loss: {loss.item():.4f}")
115
+ print("✅ TRM Pipeline Test Passed!")
116
+
117
+ except Exception as e:
118
+ print(f"\n❌ Error during TRM training step: {e}")
119
+ import traceback
120
+ traceback.print_exc()
121
+ sys.exit(1)
122
+
123
+ if __name__ == "__main__":
124
+ test_deeplob_pipeline()
125
+ test_trm_pipeline()
train.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ray
2
+ from ray import tune
3
+ from ray.rllib.algorithms.ppo import PPOConfig
4
+ # from envs.nautilus_env import NautilusExecutionEnv # Not used in this script
5
+ import os
6
+
7
+ def train():
8
+ # 1. Init Ray
9
+ ray.init(ignore_reinit_error=True)
10
+
11
+ # 2. Register Environment (Using Standard Gym for Health Check)
12
+ env_name = "CartPole-v1"
13
+
14
+ # 3. Configure Algorithm
15
+ config = (
16
+ PPOConfig()
17
+ .environment(env_name)
18
+ .framework("torch")
19
+ .rollouts(num_rollout_workers=0) # 0 for local test, CPU count for Prod
20
+ .training(model={"fcnet_hiddens": [64, 64]})
21
+ .resources(num_gpus=0) # Set to 1 if using GPU Space
22
+ )
23
+
24
+ # 4. Run Training
25
+ print("Starting Training...")
26
+ algo = config.build()
27
+
28
+ for i in range(10): # 10 Iterations for test
29
+ result = algo.train()
30
+ print(f"Iter: {i}, Reward: {result['episode_reward_mean']}")
31
+
32
+ # Save Checkpoint
33
+ if i % 5 == 0:
34
+ checkpoint_dir = algo.save(f"./checkpoints/iter_{i}")
35
+ print(f"Checkpoint saved at {checkpoint_dir}")
36
+
37
+ # 5. Export to ONNX (Crucial for Nautilus)
38
+ print("Exporting to ONNX...")
39
+ # onnx_path = algorithm.export_model_model(export_dir="./models")
40
+ # (Simplified, implementation detail varies by RLLib version)
41
+
42
+ ray.shutdown()
43
+
44
+ if __name__ == "__main__":
45
+ train()
train_alpha.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from huggingface_hub import snapshot_download, HfApi
5
+ import os
6
+ from data_processor import AlphaDataProcessor
7
+ from models.lstm import AlphaLSTM
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ import shutil
10
+
11
+ # Configuration
12
+ REPO_ID = "gionuibk/hyperliquidL2Book-v2" # Correct dataset repo
13
+ DATA_DIR = "./data"
14
+ MODEL_DIR = "./models"
15
+ EPOCHS = 50
16
+ BATCH_SIZE = 32
17
+ LR = 0.001
18
+
19
+ def download_data():
20
+ """Downloads dataset from HuggingFace Hub."""
21
+ print(f"Downloading data from {REPO_ID}...")
22
+ try:
23
+ snapshot_download(
24
+ repo_id=REPO_ID,
25
+ repo_type="dataset",
26
+ local_dir=DATA_DIR,
27
+ allow_patterns=["raw_trade/*.parquet"]
28
+ )
29
+ print("Download Complete.")
30
+ except Exception as e:
31
+ print(f"Warning: Could not download data (Token missing?): {e}")
32
+
33
+ def train():
34
+ writer = SummaryWriter(log_dir="./ray_results/alpha_experiment")
35
+ os.makedirs(MODEL_DIR, exist_ok=True)
36
+
37
+ # 1. Prepare Data
38
+ download_data()
39
+ processor = AlphaDataProcessor(data_dir=DATA_DIR)
40
+
41
+ # Check if data exists
42
+ if not os.path.exists(f"{DATA_DIR}/raw_trade"):
43
+ print("No data found. Ensure 'raw_trade' folder exists in dataset.")
44
+ # Create dummy data for dry-run
45
+
46
+ print("Processing Features...")
47
+ X, y = processor.get_tensors(coin="ETH", seq_len=60)
48
+
49
+ # Train/Test Split
50
+ train_size = int(len(X) * 0.8)
51
+ X_train, X_test = X[:train_size], X[train_size:]
52
+ y_train, y_test = y[:train_size], y[train_size:]
53
+
54
+ train_loader = torch.utils.data.DataLoader(
55
+ torch.utils.data.TensorDataset(X_train, y_train), batch_size=BATCH_SIZE, shuffle=True
56
+ )
57
+
58
+ # 2. Init Model
59
+ model = AlphaLSTM(input_size=4) # features: log_ret, vol, rsi, volume
60
+ criterion = nn.MSELoss()
61
+ optimizer = optim.Adam(model.parameters(), lr=LR)
62
+
63
+ # 3. Training Loop
64
+ print("Starting Training...")
65
+ for epoch in range(EPOCHS):
66
+ model.train()
67
+ total_loss = 0
68
+ for batch_X, batch_y in train_loader:
69
+ optimizer.zero_grad()
70
+ outputs = model(batch_X)
71
+ loss = criterion(outputs, batch_y)
72
+ loss.backward()
73
+ optimizer.step()
74
+ total_loss += loss.item()
75
+
76
+ avg_loss = total_loss / len(train_loader)
77
+ writer.add_scalar("Loss/Train", avg_loss, epoch)
78
+
79
+ if epoch % 5 == 0:
80
+ print(f"Epoch {epoch}/{EPOCHS} | Loss: {avg_loss:.6f}")
81
+
82
+ # 4. Save ONNX
83
+ print("Exporting to ONNX...")
84
+ dummy_input = torch.randn(1, 60, 4)
85
+ onnx_path = f"{MODEL_DIR}/alpha_lstm_v1.onnx"
86
+ torch.onnx.export(
87
+ model,
88
+ dummy_input,
89
+ onnx_path,
90
+ input_names=['input'],
91
+ output_names=['output'],
92
+ dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
93
+ )
94
+ print(f"Model Saved: {onnx_path}")
95
+
96
+ # 5. Push Model to HF (Optional - can be done manually or separate script)
97
+ # api = HfApi()
98
+ # api.upload_file(...)
99
+
100
+ writer.close()
101
+
102
+ if __name__ == "__main__":
103
+ train()
train_deeplob.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from huggingface_hub import snapshot_download
5
+ import os
6
+ import numpy as np
7
+ from data_processor import AlphaDataProcessor
8
+ from models.deeplob import DeepLOB
9
+ from torch.utils.tensorboard import SummaryWriter
10
+
11
+ # Configuration
12
+ REPO_ID = "gionuibk/hyperliquidL2Book-v2" # Correct dataset repo
13
+ DATA_DIR = "./data"
14
+ MODEL_DIR = "./models"
15
+ EPOCHS = 20
16
+ BATCH_SIZE = 32
17
+ LR = 0.0001
18
+ T = 100
19
+ LEVELS = 20
20
+
21
+ def download_data():
22
+ """Downloads dataset from HuggingFace Hub."""
23
+ print(f"Downloading data from {REPO_ID}...")
24
+ try:
25
+ snapshot_download(
26
+ repo_id=REPO_ID,
27
+ repo_type="dataset",
28
+ local_dir=DATA_DIR,
29
+ allow_patterns=["order_book_snapshot/*.parquet"]
30
+ )
31
+ print("Download Complete.")
32
+ except Exception as e:
33
+ print(f"Warning: Could not download data (Token missing?): {e}")
34
+
35
+ def train():
36
+ writer = SummaryWriter(log_dir="./ray_results/deeplob_experiment")
37
+ os.makedirs(MODEL_DIR, exist_ok=True)
38
+
39
+ # 1. Prepare Data
40
+ download_data()
41
+ processor = AlphaDataProcessor(data_dir=DATA_DIR)
42
+
43
+ print("Processing DeepLOB Tensors (L2)...")
44
+ X, y = processor.get_deeplob_tensors(coin="ETH", T=T, levels=LEVELS)
45
+
46
+ # Reshape for CNN: (N, 1, T, Features)
47
+ # X shape current: (N, T, Features)
48
+ # X is already (N, 2, T, Features) from DataProcessor
49
+ # X = X.unsqueeze(1) # Removed: caused 5D tensor error
50
+
51
+ print(f"Data Shape: {X.shape}")
52
+
53
+ # Train/Test Split
54
+ train_size = int(len(X) * 0.8)
55
+ X_train, X_test = X[:train_size], X[train_size:]
56
+ y_train, y_test = y[:train_size], y[train_size:]
57
+
58
+ train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
59
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
60
+
61
+ # Compute Class Weights for Imbalance
62
+ class_counts = torch.bincount(y_train)
63
+ total_samples = class_counts.sum()
64
+ n_classes = len(class_counts)
65
+
66
+ # Weight = Total / (n_classes * count)
67
+ weights = total_samples / (n_classes * class_counts.float())
68
+
69
+ # Handle if any class has 0 samples (unlikely but safe)
70
+ if torch.isinf(weights).any():
71
+ weights[torch.isinf(weights)] = 1.0
72
+
73
+ print(f"Class Weights: {weights}")
74
+
75
+ # 2. Init Model
76
+ model = DeepLOB(y_len=3)
77
+ criterion = nn.CrossEntropyLoss(weight=weights) # Weighted Loss
78
+ optimizer = optim.Adam(model.parameters(), lr=LR)
79
+
80
+ # 3. Training Loop
81
+ print("Starting DeepLOB Training...")
82
+ for epoch in range(EPOCHS):
83
+ model.train()
84
+ total_loss = 0
85
+ correct = 0
86
+ total = 0
87
+
88
+ for batch_X, batch_y in train_loader:
89
+ optimizer.zero_grad()
90
+ outputs = model(batch_X)
91
+ loss = criterion(outputs, batch_y)
92
+ loss.backward()
93
+ optimizer.step()
94
+
95
+ total_loss += loss.item()
96
+ _, predicted = torch.max(outputs.data, 1)
97
+ total += batch_y.size(0)
98
+ correct += (predicted == batch_y).sum().item()
99
+
100
+ avg_loss = total_loss / len(train_loader)
101
+ accuracy = 100 * correct / total
102
+
103
+ writer.add_scalar("Loss/Train", avg_loss, epoch)
104
+ writer.add_scalar("Accuracy/Train", accuracy, epoch)
105
+
106
+ print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f} | Acc: {accuracy:.2f}%")
107
+
108
+ # 4. Save ONNX
109
+ print("Exporting to ONNX...")
110
+ dummy_input = torch.randn(1, 2, T, 2*LEVELS) # 2 Channels, 2*Levels features
111
+ onnx_path = f"{MODEL_DIR}/deeplob_v1.onnx"
112
+ torch.onnx.export(
113
+ model,
114
+ dummy_input,
115
+ onnx_path,
116
+ input_names=['input'],
117
+ output_names=['output'],
118
+ dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
119
+ )
120
+ print(f"Model Saved: {onnx_path}")
121
+
122
+ writer.close()
123
+
124
+ if __name__ == "__main__":
125
+ train()
train_ensemble.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from huggingface_hub import snapshot_download
5
+ import os
6
+ import numpy as np
7
+ import joblib
8
+ from data_processor import AlphaDataProcessor
9
+ from models.trm import TRM
10
+ from models.classic_ml import get_hmm_pipeline, get_logistic_pipeline, get_rf_pipeline
11
+ from torch.utils.tensorboard import SummaryWriter
12
+
13
+ # Configuration
14
+ REPO_ID = "gionuibk/hyperliquidL2Book-v2" # Correct dataset repo
15
+ DATA_DIR = "./data"
16
+ MODEL_DIR = "./models"
17
+ T = 60 # Seq len for TRM
18
+
19
+ def train():
20
+ writer = SummaryWriter(log_dir="./ray_results/ensemble_experiment")
21
+ os.makedirs(MODEL_DIR, exist_ok=True)
22
+
23
+ # 1. Prepare Data
24
+ print("Loading Ensemble Data...")
25
+ processor = AlphaDataProcessor(data_dir=DATA_DIR)
26
+
27
+ # Check if data exists; download if needed (assuming train_deeplob or alpha already ran download, but safe to retry usually)
28
+ # We skip re-download to save time if folder exists.
29
+ if not os.path.exists(f"{DATA_DIR}/order_book_snapshot"):
30
+ print("Data missing, attempting download...")
31
+ snapshot_download(repo_id=REPO_ID, repo_type="dataset", local_dir=DATA_DIR, allow_patterns=["order_book_snapshot/*.parquet"])
32
+
33
+ # Get Tensors for TRM (Batch, T, 6) and Labels (Batch,)
34
+ # Input size increased to 6 due to OFI
35
+ X_trm, y_trm = processor.get_trm_tensors(coin="ETH", T=T)
36
+ print(f"TRM Data: {X_trm.shape}")
37
+
38
+ # Compute Class Weights
39
+ class_counts = torch.bincount(y_trm)
40
+ total_samples = class_counts.sum()
41
+ n_classes = len(class_counts)
42
+ weights = total_samples / (n_classes * class_counts.float())
43
+ if torch.isinf(weights).any(): weights[torch.isinf(weights)] = 1.0
44
+ print(f"Class Weights: {weights}")
45
+
46
+ # --- A. Train TRM (PyTorch) ---
47
+ print("--- Training TRM (Regime Detection) ---")
48
+ trm_model = TRM(input_size=6, num_classes=3) # input_size=6 (Vol, Imb, CVD, Spr, Mom, OFI)
49
+ criterion = nn.CrossEntropyLoss(weight=weights)
50
+ optimizer = optim.Adam(trm_model.parameters(), lr=0.001)
51
+
52
+ dataset = torch.utils.data.TensorDataset(X_trm, y_trm)
53
+ loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
54
+
55
+ for epoch in range(10): # Quick train
56
+ trm_model.train()
57
+ total_loss = 0
58
+ for bX, by in loader:
59
+ optimizer.zero_grad()
60
+ out = trm_model(bX)
61
+ loss = criterion(out, by)
62
+ loss.backward()
63
+ optimizer.step()
64
+ total_loss += loss.item()
65
+ print(f"TRM Epoch {epoch} | Loss: {total_loss/len(loader):.4f}")
66
+
67
+ # Save TRM
68
+ dummy_input = torch.randn(1, T, 6)
69
+ torch.onnx.export(trm_model, dummy_input, f"{MODEL_DIR}/trm_v1.onnx", input_names=['input'], output_names=['output'])
70
+ print("TRM Saved.")
71
+
72
+ # --- B. Train Classic ML (Sklearn) ---
73
+ print("--- Training Classic Models ---")
74
+
75
+ # Flatten/Preprocess for Sklearn
76
+ # We use the last step features for simple classification
77
+ X_flat = X_trm[:, -1, :].numpy() # (N, 5)
78
+ y_flat = y_trm.numpy()
79
+
80
+ # 1. HMM
81
+ print("Training HMM...")
82
+ hmm = get_hmm_pipeline(n_components=3)
83
+ hmm.fit(X_flat)
84
+ joblib.dump(hmm, f"{MODEL_DIR}/hmm_v1.joblib")
85
+
86
+ # 2. Logistic
87
+ print("Training Logistic Regression...")
88
+ lr = get_logistic_pipeline()
89
+ lr.fit(X_flat, y_flat)
90
+ joblib.dump(lr, f"{MODEL_DIR}/logistic_v1.joblib")
91
+
92
+ # 3. Random Forest (Trade Filter)
93
+ # We simulate Trade Filter labels (Win/Loss) using future return sign
94
+ # y=0 (Down) -> Sell Win? y=2 (Up) -> Buy Win?
95
+ # Let's simple binary target: 1 if Up, 0 if Down/Sideway
96
+ y_rf = (y_flat == 2).astype(int)
97
+ print("Training Random Forest...")
98
+ rf = get_rf_pipeline()
99
+ rf.fit(X_flat, y_rf)
100
+ joblib.dump(rf, f"{MODEL_DIR}/rf_v1.joblib")
101
+
102
+ print("All Ensemble Models Saved!")
103
+ writer.close()
104
+
105
+ if __name__ == "__main__":
106
+ train()
train_remaining.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import os
4
+ import numpy as np
5
+ import time
6
+ from models.execution_agent import PPOActorCritic
7
+ from models.meta_controller import DQN
8
+ from models.arbitrage_agent import ArbitrageAgent
9
+ from models.risk_agent import RiskAgent
10
+ from data_processor import AlphaDataProcessor
11
+ from torch.utils.tensorboard import SummaryWriter
12
+
13
+ # Configuration
14
+ MODEL_DIR = "./models"
15
+ os.makedirs(MODEL_DIR, exist_ok=True)
16
+
17
+ def train():
18
+ writer = SummaryWriter(log_dir="./ray_results/agents_experiment")
19
+ print("Initializing Training for Specialist Agents...")
20
+
21
+ # --- 1. Train Execution Agent (PPO) ---
22
+ print("\n--- Training Execution Agent (PPO) ---")
23
+ exec_agent = PPOActorCritic(input_dim=5, action_dim=3)
24
+ optimizer_exec = optim.Adam(exec_agent.parameters(), lr=0.0003)
25
+
26
+ # Mock Training Loop (Simulating Environment Interaction)
27
+ for i in range(100):
28
+ # Fake State: [Signal, Imb, Spread, Position, Vol]
29
+ state = torch.randn(32, 5)
30
+
31
+ # Forward
32
+ probs, mu_off, mu_sz, val = exec_agent(state)
33
+
34
+ # Fake Loss (Standard PPO Loss would go here)
35
+ # We just minimize output to verify gradient flow
36
+ loss = probs.mean() + (mu_off - 0).pow(2).mean() + (val - 1).pow(2).mean()
37
+
38
+ optimizer_exec.zero_grad()
39
+ loss.backward()
40
+ optimizer_exec.step()
41
+
42
+ if i % 10 == 0:
43
+ print(f"Exec Iter {i} | Loss: {loss.item():.4f}")
44
+ writer.add_scalar("Execution/Loss", loss.item(), i)
45
+
46
+ # Save ONNX
47
+ dummy = torch.randn(1, 5)
48
+ torch.onnx.export(exec_agent, dummy, f"{MODEL_DIR}/execution_agent_ppo.onnx", input_names=['input'], output_names=['type', 'offset', 'size', 'value'])
49
+ print("Execution Agent Saved.")
50
+
51
+ # --- 2. Train Meta-Controller (DQN) ---
52
+ print("\n--- Training Meta-Controller (DQN) ---")
53
+ meta_agent = DQN(input_dim=5, output_dim=3)
54
+ optimizer_meta = optim.Adam(meta_agent.parameters(), lr=0.001)
55
+
56
+ criterion = torch.nn.MSELoss()
57
+
58
+ for i in range(100):
59
+ # Fake State: [Vol, Regime(3), PnL] -> 5 Dim (Regime is OneHot)
60
+ state = torch.randn(32, 5)
61
+ target = torch.randn(32, 3) # Fake Q-Values
62
+
63
+ out = meta_agent(state)
64
+ loss = criterion(out, target)
65
+
66
+ optimizer_meta.zero_grad()
67
+ loss.backward()
68
+ optimizer_meta.step()
69
+
70
+ if i % 10 == 0:
71
+ print(f"Meta Iter {i} | Loss: {loss.item():.4f}")
72
+ writer.add_scalar("Meta/Loss", loss.item(), i)
73
+
74
+ # Save ONNX
75
+ dummy = torch.randn(1, 5)
76
+ torch.onnx.export(meta_agent, dummy, f"{MODEL_DIR}/meta_controller_dqn.onnx", input_names=['input'], output_names=['q_values'])
77
+ print("Meta-Controller Saved.")
78
+
79
+ # --- 3. Verify Rule-Based Agents ---
80
+ print("\n--- Verifying Rule-Based Agents ---")
81
+ arb = ArbitrageAgent()
82
+ risk = RiskAgent()
83
+
84
+ # Test Logic
85
+ arb.analyze(spot_price=100, perp_price=101, funding_rate=0.01) # Should trigger Long
86
+ risk.check_health(equity=9000, initial_equity=10000, recent_pnl=-50) # Should be OK
87
+
88
+ print("All Agents Ready!")
89
+ writer.close()
90
+
91
+ if __name__ == "__main__":
92
+ train()
version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Deploy Full 9 Models Fix at Wed Dec 10 14:02:57 +07 2025