zongowo111 commited on
Commit
2386f4c
·
verified ·
1 Parent(s): bb132fe

Upload bot predictor module

Browse files
Files changed (1) hide show
  1. bot_predictor.py +314 -0
bot_predictor.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Bot Predictor V8 - For Discord Bot Integration
5
+
6
+ Direct prediction module for trading signals with automatic bias correction
7
+
8
+ Usage:
9
+ from bot_predictor import BotPredictor
10
+
11
+ bot = BotPredictor()
12
+ prediction = bot.predict('BTC')
13
+ print(f"Corrected Price: {prediction['corrected_price']}")
14
+ """
15
+
16
+ import os
17
+ import json
18
+ import numpy as np
19
+ import pandas as pd
20
+ import torch
21
+ from sklearn.preprocessing import MinMaxScaler
22
+
23
+ import ccxt
24
+ import logging
25
+
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+
31
+
32
+ class RegressionLSTM(torch.nn.Module):
33
+ """V8 LSTM Model"""
34
+ def __init__(self, input_size=44, hidden_size=64, num_layers=2, dropout=0.3, bidirectional=True):
35
+ super(RegressionLSTM, self).__init__()
36
+
37
+ self.lstm = torch.nn.LSTM(
38
+ input_size=input_size,
39
+ hidden_size=hidden_size,
40
+ num_layers=num_layers,
41
+ dropout=dropout if num_layers > 1 else 0,
42
+ bidirectional=bidirectional,
43
+ batch_first=True
44
+ )
45
+
46
+ lstm_output_size = hidden_size * (2 if bidirectional else 1)
47
+
48
+ self.regressor = torch.nn.Sequential(
49
+ torch.nn.Linear(lstm_output_size, 64),
50
+ torch.nn.ReLU(),
51
+ torch.nn.Dropout(dropout),
52
+ torch.nn.Linear(64, 32),
53
+ torch.nn.ReLU(),
54
+ torch.nn.Linear(32, 1)
55
+ )
56
+
57
+ def forward(self, x):
58
+ lstm_out, _ = self.lstm(x)
59
+ last_out = lstm_out[:, -1, :]
60
+ price = self.regressor(last_out)
61
+ return price
62
+
63
+
64
+ class BotPredictor:
65
+ """Bot Prediction Engine with Bias Correction"""
66
+
67
+ def __init__(self, model_dir='models/saved', bias_config_path='models/bias_corrections_v8.json'):
68
+ self.model_dir = model_dir
69
+ self.device = device
70
+ self.exchange = ccxt.binance({'enableRateLimit': True})
71
+ self.model_cache = {}
72
+ self.scaler_cache = {}
73
+
74
+ # Load bias corrections
75
+ self.bias_corrections = {}
76
+ if os.path.exists(bias_config_path):
77
+ try:
78
+ with open(bias_config_path, 'r') as f:
79
+ bias_config = json.load(f)
80
+ self.bias_corrections = bias_config.get('corrections', {})
81
+ logger.info(f"Loaded bias corrections for {len(self.bias_corrections)} symbols")
82
+ except Exception as e:
83
+ logger.warning(f"Could not load bias corrections: {e}")
84
+
85
+ def _detect_model_config(self, state_dict):
86
+ """Detect model architecture from weights"""
87
+ try:
88
+ weight_ih = state_dict.get('lstm.weight_ih_l0')
89
+ hidden_size = weight_ih.shape[0] // 4 if weight_ih is not None else 64
90
+ bidirectional = 'lstm.weight_ih_l0_reverse' in state_dict
91
+
92
+ num_layers = 1
93
+ layer = 1
94
+ while f'lstm.weight_ih_l{layer}' in state_dict:
95
+ num_layers += 1
96
+ layer += 1
97
+
98
+ return {
99
+ 'hidden_size': hidden_size,
100
+ 'num_layers': num_layers,
101
+ 'bidirectional': bidirectional,
102
+ 'dropout': 0.3,
103
+ }
104
+ except:
105
+ return {'hidden_size': 64, 'num_layers': 2, 'bidirectional': True, 'dropout': 0.3}
106
+
107
+ def _fetch_data(self, symbol, limit=1000):
108
+ """Fetch latest OHLCV data"""
109
+ try:
110
+ symbol_pair = f"{symbol}/USDT"
111
+ ohlcv = self.exchange.fetch_ohlcv(symbol_pair, '1h', limit=limit)
112
+
113
+ df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
114
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
115
+ return df.sort_values('timestamp').reset_index(drop=True)
116
+ except Exception as e:
117
+ logger.error(f"Error fetching {symbol}: {e}")
118
+ return None
119
+
120
+ def _add_indicators(self, df):
121
+ """Add 44 technical indicators"""
122
+ try:
123
+ df['high-low'] = df['high'] - df['low']
124
+ df['close-open'] = df['close'] - df['open']
125
+ df['returns'] = df['close'].pct_change()
126
+
127
+ for period in [14, 21]:
128
+ delta = df['close'].diff()
129
+ gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
130
+ loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
131
+ rs = gain / loss
132
+ df[f'rsi_{period}'] = 100 - (100 / (1 + rs))
133
+
134
+ ema12 = df['close'].ewm(span=12).mean()
135
+ ema26 = df['close'].ewm(span=26).mean()
136
+ df['macd'] = ema12 - ema26
137
+ df['macd_signal'] = df['macd'].ewm(span=9).mean()
138
+ df['macd_hist'] = df['macd'] - df['macd_signal']
139
+
140
+ sma20 = df['close'].rolling(window=20).mean()
141
+ std20 = df['close'].rolling(window=20).std()
142
+ df['bb_upper'] = sma20 + (std20 * 2)
143
+ df['bb_middle'] = sma20
144
+ df['bb_lower'] = sma20 - (std20 * 2)
145
+
146
+ tr1 = df['high'] - df['low']
147
+ tr2 = abs(df['high'] - df['close'].shift())
148
+ tr3 = abs(df['low'] - df['close'].shift())
149
+ tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
150
+ df['atr'] = tr.rolling(window=14).mean()
151
+
152
+ df['momentum'] = df['close'].diff(10)
153
+ tp = (df['high'] + df['low'] + df['close']) / 3
154
+ df['cci'] = (tp - tp.rolling(window=20).mean()) / (0.015 * tp.rolling(window=20).std())
155
+
156
+ df['sma5'] = df['close'].rolling(window=5).mean()
157
+ df['sma10'] = df['close'].rolling(window=10).mean()
158
+ df['sma20'] = df['close'].rolling(window=20).mean()
159
+ df['sma50'] = df['close'].rolling(window=50).mean()
160
+
161
+ df['volume_sma'] = df['volume'].rolling(window=20).mean()
162
+ df['volume_ratio'] = df['volume'] / df['volume_sma']
163
+
164
+ df = df.ffill().bfill()
165
+ df = df.replace([np.inf, -np.inf], np.nan).ffill().bfill()
166
+
167
+ return df
168
+ except Exception as e:
169
+ logger.error(f"Error adding indicators: {e}")
170
+ return None
171
+
172
+ def _load_model(self, symbol):
173
+ """Load model from cache or disk"""
174
+ if symbol in self.model_cache:
175
+ return self.model_cache[symbol]
176
+
177
+ # Find model file
178
+ possible_names = [f'{symbol}_model_v8.pth', f'{symbol}_model.pth', f'{symbol}.pth']
179
+ model_path = None
180
+
181
+ for name in possible_names:
182
+ path = os.path.join(self.model_dir, name)
183
+ if os.path.exists(path):
184
+ model_path = path
185
+ break
186
+
187
+ if not model_path:
188
+ logger.error(f"Model not found for {symbol}")
189
+ return None
190
+
191
+ try:
192
+ state_dict = torch.load(model_path, map_location=self.device)
193
+ config = self._detect_model_config(state_dict)
194
+
195
+ model = RegressionLSTM(
196
+ input_size=44,
197
+ hidden_size=config['hidden_size'],
198
+ num_layers=config['num_layers'],
199
+ dropout=config['dropout'],
200
+ bidirectional=config['bidirectional']
201
+ )
202
+ model.to(self.device)
203
+ model.load_state_dict(state_dict)
204
+ model.eval()
205
+
206
+ self.model_cache[symbol] = model
207
+ return model
208
+ except Exception as e:
209
+ logger.error(f"Error loading model for {symbol}: {e}")
210
+ return None
211
+
212
+ def predict(self, symbol, apply_correction=True):
213
+ """
214
+ Predict next price for symbol
215
+
216
+ Returns:
217
+ dict with keys:
218
+ - raw_price: 未校正的預測價格
219
+ - correction: 校正值
220
+ - corrected_price: 校正後的預測價格 (推薦用這個)
221
+ - current_price: 當前價格
222
+ - direction: 'UP' 或 'DOWN'
223
+ - confidence: 0-1 信心指數
224
+ """
225
+ try:
226
+ # Fetch data
227
+ df = self._fetch_data(symbol)
228
+ if df is None or len(df) < 100:
229
+ logger.error(f"Insufficient data for {symbol}")
230
+ return None
231
+
232
+ current_price = df['close'].iloc[-1]
233
+
234
+ # Add indicators
235
+ df = self._add_indicators(df)
236
+ if df is None:
237
+ return None
238
+
239
+ # Prepare features
240
+ feature_cols = [col for col in df.columns if col not in ['timestamp', 'close']]
241
+ X = df[feature_cols].values
242
+ X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
243
+
244
+ # Normalize
245
+ scaler_X = MinMaxScaler()
246
+ X_scaled = scaler_X.fit_transform(X)
247
+
248
+ if X_scaled.shape[1] > 44:
249
+ X_scaled = X_scaled[:, :44]
250
+ elif X_scaled.shape[1] < 44:
251
+ padding = np.zeros((X_scaled.shape[0], 44 - X_scaled.shape[1]))
252
+ X_scaled = np.hstack([X_scaled, padding])
253
+
254
+ # Prepare sequence
255
+ lookback = 60
256
+ if len(X_scaled) < lookback + 1:
257
+ logger.error(f"Insufficient sequence data for {symbol}")
258
+ return None
259
+
260
+ X_seq = X_scaled[-lookback:].reshape(1, lookback, 44)
261
+
262
+ # Load model and predict
263
+ model = self._load_model(symbol)
264
+ if model is None:
265
+ return None
266
+
267
+ with torch.no_grad():
268
+ X_tensor = torch.tensor(X_seq, dtype=torch.float32).to(self.device)
269
+ price_scaled = model(X_tensor).cpu().numpy()[0][0]
270
+
271
+ # Inverse transform price
272
+ y_scaler = MinMaxScaler()
273
+ y_scaler.fit(df['close'].values.reshape(-1, 1))
274
+ raw_price = y_scaler.inverse_transform([[price_scaled]])[0][0]
275
+
276
+ # Apply bias correction
277
+ correction = self.bias_corrections.get(symbol, 0)
278
+ corrected_price = raw_price + correction if apply_correction else raw_price
279
+
280
+ # Direction
281
+ direction = 'UP' if corrected_price > current_price else 'DOWN'
282
+ change_pct = abs(corrected_price - current_price) / current_price * 100
283
+ confidence = min(change_pct / 2, 1.0) # Simple confidence metric
284
+
285
+ return {
286
+ 'symbol': symbol,
287
+ 'current_price': float(current_price),
288
+ 'raw_price': float(raw_price),
289
+ 'correction': float(correction),
290
+ 'corrected_price': float(corrected_price),
291
+ 'direction': direction,
292
+ 'change_pct': float(change_pct),
293
+ 'confidence': float(confidence),
294
+ 'model_version': 'v8',
295
+ }
296
+
297
+ except Exception as e:
298
+ logger.error(f"Error predicting {symbol}: {e}")
299
+ return None
300
+
301
+
302
+ if __name__ == '__main__':
303
+ # Test
304
+ bot = BotPredictor()
305
+
306
+ test_symbols = ['BTC', 'ETH', 'SOL']
307
+ for symbol in test_symbols:
308
+ prediction = bot.predict(symbol)
309
+ if prediction:
310
+ print(f"\n{symbol}:")
311
+ print(f" Current: ${prediction['current_price']:.2f}")
312
+ print(f" Predicted: ${prediction['corrected_price']:.2f}")
313
+ print(f" Direction: {prediction['direction']}")
314
+ print(f" Confidence: {prediction['confidence']*100:.1f}%")