Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,1165 +1,596 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
#
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
-
import gradio as gr
|
| 11 |
import requests
|
| 12 |
-
import
|
| 13 |
-
from datetime import datetime, timedelta
|
| 14 |
-
import warnings
|
| 15 |
-
warnings.filterwarnings('ignore')
|
| 16 |
-
|
| 17 |
-
# Machine Learning
|
| 18 |
-
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, AdaBoostRegressor
|
| 19 |
-
from sklearn.linear_model import Ridge, Lasso, ElasticNet
|
| 20 |
-
from sklearn.svm import SVR
|
| 21 |
-
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
|
| 22 |
-
from sklearn.model_selection import train_test_split
|
| 23 |
-
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 24 |
|
| 25 |
# Visualization
|
| 26 |
-
import
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
#
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
})
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
try:
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
'bar': bar,
|
| 59 |
-
'limit': str(limit)
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
response = self.session.get(endpoint, params=params, timeout=10)
|
| 63 |
-
|
| 64 |
-
if response.status_code == 200:
|
| 65 |
-
data = response.json()
|
| 66 |
-
|
| 67 |
-
if data['code'] == '0':
|
| 68 |
-
candles = data['data']
|
| 69 |
-
|
| 70 |
-
df = pd.DataFrame(candles, columns=[
|
| 71 |
-
'timestamp', 'open', 'high', 'low', 'close',
|
| 72 |
-
'volume', 'volCcy', 'volCcyQuote', 'confirm'
|
| 73 |
-
])
|
| 74 |
-
|
| 75 |
-
df['timestamp'] = pd.to_datetime(df['timestamp'].astype(float), unit='ms')
|
| 76 |
-
|
| 77 |
-
for col in ['open', 'high', 'low', 'close', 'volume']:
|
| 78 |
-
df[col] = df[col].astype(float)
|
| 79 |
-
|
| 80 |
-
df = df.sort_values('timestamp').reset_index(drop=True)
|
| 81 |
-
|
| 82 |
-
return df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
|
| 83 |
-
else:
|
| 84 |
-
print(f"API Error: {data['msg']}")
|
| 85 |
-
return None
|
| 86 |
else:
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
params = {'instId': instId}
|
| 99 |
-
|
| 100 |
-
response = self.session.get(endpoint, params=params, timeout=10)
|
| 101 |
-
|
| 102 |
-
if response.status_code == 200:
|
| 103 |
-
data = response.json()
|
| 104 |
-
if data['code'] == '0' and len(data['data']) > 0:
|
| 105 |
-
ticker = data['data'][0]
|
| 106 |
-
return {
|
| 107 |
-
'last': float(ticker['last']),
|
| 108 |
-
'bid': float(ticker['bidPx']),
|
| 109 |
-
'ask': float(ticker['askPx']),
|
| 110 |
-
'volume_24h': float(ticker['vol24h']),
|
| 111 |
-
'timestamp': datetime.now()
|
| 112 |
-
}
|
| 113 |
-
return None
|
| 114 |
-
|
| 115 |
-
except Exception as e:
|
| 116 |
-
print(f"Error fetching ticker: {str(e)}")
|
| 117 |
-
return None
|
| 118 |
-
# app.py - PARÇA 2/5
|
| 119 |
-
# ========================================================================
|
| 120 |
-
# Feature Engineering Module
|
| 121 |
-
# ========================================================================
|
| 122 |
-
|
| 123 |
-
class FeatureEngineer:
|
| 124 |
-
"""Advanced feature engineering for crypto price prediction"""
|
| 125 |
-
|
| 126 |
-
@staticmethod
|
| 127 |
-
def add_technical_indicators(df):
|
| 128 |
-
"""Add comprehensive technical indicators"""
|
| 129 |
-
df = df.copy()
|
| 130 |
-
|
| 131 |
-
# Basic features
|
| 132 |
-
df['returns'] = df['close'].pct_change()
|
| 133 |
-
df['log_returns'] = np.log(df['close'] / df['close'].shift(1))
|
| 134 |
-
df['price_range'] = df['high'] - df['low']
|
| 135 |
-
df['price_change'] = df['close'] - df['open']
|
| 136 |
-
df['body'] = abs(df['close'] - df['open'])
|
| 137 |
-
df['upper_shadow'] = df['high'] - df[['open', 'close']].max(axis=1)
|
| 138 |
-
df['lower_shadow'] = df[['open', 'close']].min(axis=1) - df['low']
|
| 139 |
-
|
| 140 |
-
# Moving Averages
|
| 141 |
-
for window in [5, 10, 20, 50, 100]:
|
| 142 |
-
df[f'sma_{window}'] = df['close'].rolling(window=window).mean()
|
| 143 |
-
df[f'ema_{window}'] = df['close'].ewm(span=window, adjust=False).mean()
|
| 144 |
-
df[f'price_to_sma_{window}'] = df['close'] / df[f'sma_{window}']
|
| 145 |
-
|
| 146 |
-
# MACD
|
| 147 |
-
exp1 = df['close'].ewm(span=12, adjust=False).mean()
|
| 148 |
-
exp2 = df['close'].ewm(span=26, adjust=False).mean()
|
| 149 |
-
df['macd'] = exp1 - exp2
|
| 150 |
-
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
|
| 151 |
-
df['macd_diff'] = df['macd'] - df['macd_signal']
|
| 152 |
-
|
| 153 |
-
# RSI
|
| 154 |
-
for period in [14, 28]:
|
| 155 |
-
delta = df['close'].diff()
|
| 156 |
-
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
| 157 |
-
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
| 158 |
-
rs = gain / loss
|
| 159 |
-
df[f'rsi_{period}'] = 100 - (100 / (1 + rs))
|
| 160 |
-
|
| 161 |
-
# Bollinger Bands
|
| 162 |
-
for window in [20, 50]:
|
| 163 |
-
rolling_mean = df['close'].rolling(window=window).mean()
|
| 164 |
-
rolling_std = df['close'].rolling(window=window).std()
|
| 165 |
-
df[f'bb_upper_{window}'] = rolling_mean + (rolling_std * 2)
|
| 166 |
-
df[f'bb_lower_{window}'] = rolling_mean - (rolling_std * 2)
|
| 167 |
-
df[f'bb_width_{window}'] = df[f'bb_upper_{window}'] - df[f'bb_lower_{window}']
|
| 168 |
-
df[f'bb_position_{window}'] = (df['close'] - df[f'bb_lower_{window}']) / df[f'bb_width_{window}']
|
| 169 |
-
|
| 170 |
-
# ATR
|
| 171 |
-
high_low = df['high'] - df['low']
|
| 172 |
-
high_close = np.abs(df['high'] - df['close'].shift())
|
| 173 |
-
low_close = np.abs(df['low'] - df['close'].shift())
|
| 174 |
-
ranges = pd.concat([high_low, high_close, low_close], axis=1)
|
| 175 |
-
true_range = np.max(ranges, axis=1)
|
| 176 |
-
df['atr_14'] = true_range.rolling(14).mean()
|
| 177 |
-
|
| 178 |
-
# Stochastic Oscillator
|
| 179 |
-
low_14 = df['low'].rolling(window=14).min()
|
| 180 |
-
high_14 = df['high'].rolling(window=14).max()
|
| 181 |
-
df['stoch_k'] = 100 * ((df['close'] - low_14) / (high_14 - low_14))
|
| 182 |
-
df['stoch_d'] = df['stoch_k'].rolling(window=3).mean()
|
| 183 |
-
|
| 184 |
-
# Volume features
|
| 185 |
-
df['volume_sma_20'] = df['volume'].rolling(window=20).mean()
|
| 186 |
-
df['volume_ratio'] = df['volume'] / df['volume_sma_20']
|
| 187 |
-
df['volume_price_trend'] = df['volume'] * df['returns']
|
| 188 |
-
|
| 189 |
-
# OBV
|
| 190 |
-
df['obv'] = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
|
| 191 |
-
|
| 192 |
-
# Momentum
|
| 193 |
-
for period in [5, 10, 20]:
|
| 194 |
-
df[f'momentum_{period}'] = df['close'].diff(period)
|
| 195 |
-
df[f'roc_{period}'] = df['close'].pct_change(period)
|
| 196 |
-
|
| 197 |
-
# Volatility
|
| 198 |
-
for window in [5, 10, 20, 30]:
|
| 199 |
-
df[f'volatility_{window}'] = df['returns'].rolling(window=window).std()
|
| 200 |
-
|
| 201 |
-
# Statistical features
|
| 202 |
-
for window in [10, 20]:
|
| 203 |
-
df[f'skew_{window}'] = df['returns'].rolling(window=window).skew()
|
| 204 |
-
df[f'kurt_{window}'] = df['returns'].rolling(window=window).kurt()
|
| 205 |
-
|
| 206 |
-
return df
|
| 207 |
-
|
| 208 |
-
@staticmethod
|
| 209 |
-
def add_lag_features(df, n_lags=5):
|
| 210 |
-
"""Add lagged features"""
|
| 211 |
-
df = df.copy()
|
| 212 |
-
|
| 213 |
-
for lag in range(1, n_lags + 1):
|
| 214 |
-
df[f'close_lag_{lag}'] = df['close'].shift(lag)
|
| 215 |
-
df[f'volume_lag_{lag}'] = df['volume'].shift(lag)
|
| 216 |
-
df[f'returns_lag_{lag}'] = df['returns'].shift(lag)
|
| 217 |
-
|
| 218 |
-
return df
|
| 219 |
-
|
| 220 |
-
@staticmethod
|
| 221 |
-
def add_time_features(df):
|
| 222 |
-
"""Add time-based features"""
|
| 223 |
-
df = df.copy()
|
| 224 |
-
|
| 225 |
-
df['hour'] = df['timestamp'].dt.hour
|
| 226 |
-
df['day_of_week'] = df['timestamp'].dt.dayofweek
|
| 227 |
-
df['day_of_month'] = df['timestamp'].dt.day
|
| 228 |
-
df['month'] = df['timestamp'].dt.month
|
| 229 |
-
|
| 230 |
-
# Cyclical encoding
|
| 231 |
-
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 232 |
-
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 233 |
-
df['day_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
|
| 234 |
-
df['day_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
|
| 235 |
-
|
| 236 |
-
return df
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
# app.py - PARÇA 3/5
|
| 240 |
-
# ========================================================================
|
| 241 |
-
# Ensemble Model
|
| 242 |
-
# ========================================================================
|
| 243 |
-
|
| 244 |
-
class EnsemblePredictor:
|
| 245 |
-
"""Advanced Ensemble Model for BTC/USDT prediction"""
|
| 246 |
-
|
| 247 |
-
def __init__(self):
|
| 248 |
-
self.models = {}
|
| 249 |
-
self.weights = {}
|
| 250 |
-
self.scalers = {}
|
| 251 |
-
self.feature_columns = None
|
| 252 |
-
self.is_trained = False
|
| 253 |
-
|
| 254 |
-
def initialize_models(self):
|
| 255 |
-
"""Initialize all models"""
|
| 256 |
-
|
| 257 |
-
self.models['random_forest'] = RandomForestRegressor(
|
| 258 |
-
n_estimators=200,
|
| 259 |
-
max_depth=15,
|
| 260 |
-
min_samples_split=5,
|
| 261 |
-
random_state=42,
|
| 262 |
-
n_jobs=-1
|
| 263 |
-
)
|
| 264 |
-
|
| 265 |
-
self.models['gradient_boosting'] = GradientBoostingRegressor(
|
| 266 |
-
n_estimators=200,
|
| 267 |
-
learning_rate=0.05,
|
| 268 |
-
max_depth=5,
|
| 269 |
-
random_state=42
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
self.models['adaboost'] = AdaBoostRegressor(
|
| 273 |
-
n_estimators=100,
|
| 274 |
-
learning_rate=0.1,
|
| 275 |
-
random_state=42
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
self.models['ridge'] = Ridge(alpha=1.0)
|
| 279 |
-
self.models['lasso'] = Lasso(alpha=0.1, max_iter=2000)
|
| 280 |
-
self.models['elastic_net'] = ElasticNet(alpha=0.1, l1_ratio=0.5, max_iter=2000)
|
| 281 |
-
|
| 282 |
-
for model_name in self.models.keys():
|
| 283 |
-
self.weights[model_name] = 1.0 / len(self.models)
|
| 284 |
-
|
| 285 |
-
def prepare_data(self, df, target_col='close'):
|
| 286 |
-
"""Prepare data for training"""
|
| 287 |
-
|
| 288 |
-
exclude_cols = ['timestamp', target_col]
|
| 289 |
-
feature_cols = [col for col in df.columns if col not in exclude_cols]
|
| 290 |
-
|
| 291 |
-
df = df.replace([np.inf, -np.inf], np.nan)
|
| 292 |
-
df = df.fillna(method='ffill').fillna(method='bfill').fillna(0)
|
| 293 |
-
|
| 294 |
-
X = df[feature_cols].values
|
| 295 |
-
y = df[target_col].values
|
| 296 |
-
|
| 297 |
-
self.feature_columns = feature_cols
|
| 298 |
-
|
| 299 |
-
return X, y
|
| 300 |
-
|
| 301 |
-
def train(self, X_train, y_train, X_val, y_val):
|
| 302 |
-
"""Train ensemble model"""
|
| 303 |
-
|
| 304 |
-
self.initialize_models()
|
| 305 |
-
|
| 306 |
-
self.scalers['standard'] = StandardScaler()
|
| 307 |
-
self.scalers['robust'] = RobustScaler()
|
| 308 |
-
|
| 309 |
-
X_train_standard = self.scalers['standard'].fit_transform(X_train)
|
| 310 |
-
X_val_standard = self.scalers['standard'].transform(X_val)
|
| 311 |
-
|
| 312 |
-
X_train_robust = self.scalers['robust'].fit_transform(X_train)
|
| 313 |
-
X_val_robust = self.scalers['robust'].transform(X_val)
|
| 314 |
-
|
| 315 |
-
predictions_val = {}
|
| 316 |
-
|
| 317 |
-
print("Training Random Forest...")
|
| 318 |
-
self.models['random_forest'].fit(X_train_standard, y_train)
|
| 319 |
-
predictions_val['random_forest'] = self.models['random_forest'].predict(X_val_standard)
|
| 320 |
-
|
| 321 |
-
print("Training Gradient Boosting...")
|
| 322 |
-
self.models['gradient_boosting'].fit(X_train_standard, y_train)
|
| 323 |
-
predictions_val['gradient_boosting'] = self.models['gradient_boosting'].predict(X_val_standard)
|
| 324 |
-
|
| 325 |
-
print("Training AdaBoost...")
|
| 326 |
-
self.models['adaboost'].fit(X_train_standard, y_train)
|
| 327 |
-
predictions_val['adaboost'] = self.models['adaboost'].predict(X_val_standard)
|
| 328 |
-
|
| 329 |
-
print("Training Ridge...")
|
| 330 |
-
self.models['ridge'].fit(X_train_robust, y_train)
|
| 331 |
-
predictions_val['ridge'] = self.models['ridge'].predict(X_val_robust)
|
| 332 |
-
|
| 333 |
-
print("Training Lasso...")
|
| 334 |
-
self.models['lasso'].fit(X_train_robust, y_train)
|
| 335 |
-
predictions_val['lasso'] = self.models['lasso'].predict(X_val_robust)
|
| 336 |
-
|
| 337 |
-
print("Training Elastic Net...")
|
| 338 |
-
self.models['elastic_net'].fit(X_train_robust, y_train)
|
| 339 |
-
predictions_val['elastic_net'] = self.models['elastic_net'].predict(X_val_robust)
|
| 340 |
-
|
| 341 |
-
self.optimize_weights(predictions_val, y_val)
|
| 342 |
-
self.is_trained = True
|
| 343 |
-
|
| 344 |
-
return predictions_val
|
| 345 |
-
|
| 346 |
-
def optimize_weights(self, predictions_val, y_val):
|
| 347 |
-
"""Optimize ensemble weights"""
|
| 348 |
-
|
| 349 |
-
performances = {}
|
| 350 |
-
for model_name, preds in predictions_val.items():
|
| 351 |
-
mse = mean_squared_error(y_val, preds)
|
| 352 |
-
performances[model_name] = 1.0 / (mse + 1e-10)
|
| 353 |
-
|
| 354 |
-
total_performance = sum(performances.values())
|
| 355 |
-
for model_name in performances:
|
| 356 |
-
self.weights[model_name] = performances[model_name] / total_performance
|
| 357 |
-
|
| 358 |
-
print("\n=== Optimized Weights ===")
|
| 359 |
-
for model_name, weight in self.weights.items():
|
| 360 |
-
print(f"{model_name}: {weight:.4f}")
|
| 361 |
-
|
| 362 |
-
def predict(self, X):
|
| 363 |
-
"""Make ensemble predictions"""
|
| 364 |
-
|
| 365 |
-
if not self.is_trained:
|
| 366 |
-
raise ValueError("Model must be trained first")
|
| 367 |
-
|
| 368 |
-
X_standard = self.scalers['standard'].transform(X)
|
| 369 |
-
X_robust = self.scalers['robust'].transform(X)
|
| 370 |
-
|
| 371 |
-
predictions = {}
|
| 372 |
-
predictions['random_forest'] = self.models['random_forest'].predict(X_standard)
|
| 373 |
-
predictions['gradient_boosting'] = self.models['gradient_boosting'].predict(X_standard)
|
| 374 |
-
predictions['adaboost'] = self.models['adaboost'].predict(X_standard)
|
| 375 |
-
predictions['ridge'] = self.models['ridge'].predict(X_robust)
|
| 376 |
-
predictions['lasso'] = self.models['lasso'].predict(X_robust)
|
| 377 |
-
predictions['elastic_net'] = self.models['elastic_net'].predict(X_robust)
|
| 378 |
-
|
| 379 |
-
ensemble_pred = np.zeros(len(X))
|
| 380 |
-
for model_name, preds in predictions.items():
|
| 381 |
-
ensemble_pred += self.weights[model_name] * preds
|
| 382 |
-
|
| 383 |
-
return ensemble_pred, predictions
|
| 384 |
-
|
| 385 |
-
def evaluate(self, X_test, y_test):
|
| 386 |
-
"""Evaluate model"""
|
| 387 |
-
|
| 388 |
-
ensemble_pred, individual_preds = self.predict(X_test)
|
| 389 |
-
|
| 390 |
-
mse = mean_squared_error(y_test, ensemble_pred)
|
| 391 |
-
mae = mean_absolute_error(y_test, ensemble_pred)
|
| 392 |
-
rmse = np.sqrt(mse)
|
| 393 |
-
r2 = r2_score(y_test, ensemble_pred)
|
| 394 |
-
mape = np.mean(np.abs((y_test - ensemble_pred) / y_test)) * 100
|
| 395 |
-
|
| 396 |
-
metrics = {
|
| 397 |
-
'ensemble': {
|
| 398 |
-
'MSE': mse,
|
| 399 |
-
'RMSE': rmse,
|
| 400 |
-
'MAE': mae,
|
| 401 |
-
'R2': r2,
|
| 402 |
-
'MAPE': mape
|
| 403 |
-
}
|
| 404 |
-
}
|
| 405 |
-
|
| 406 |
-
for model_name, preds in individual_preds.items():
|
| 407 |
-
mse_ind = mean_squared_error(y_test, preds)
|
| 408 |
-
rmse_ind = np.sqrt(mse_ind)
|
| 409 |
-
mae_ind = mean_absolute_error(y_test, preds)
|
| 410 |
-
r2_ind = r2_score(y_test, preds)
|
| 411 |
-
|
| 412 |
-
metrics[model_name] = {
|
| 413 |
-
'MSE': mse_ind,
|
| 414 |
-
'RMSE': rmse_ind,
|
| 415 |
-
'MAE': mae_ind,
|
| 416 |
-
'R2': r2_ind
|
| 417 |
-
}
|
| 418 |
-
|
| 419 |
-
return metrics, ensemble_pred
|
| 420 |
-
|
| 421 |
-
# app.py - PARÇA 4/5
|
| 422 |
-
# ========================================================================
|
| 423 |
-
# Visualization ve Main Pipeline
|
| 424 |
-
# ========================================================================
|
| 425 |
-
|
| 426 |
-
class Visualizer:
|
| 427 |
-
"""Visualization utilities"""
|
| 428 |
-
|
| 429 |
-
@staticmethod
|
| 430 |
-
def plot_predictions(y_true, y_pred, timestamps=None, title="BTC/USDT Predictions"):
|
| 431 |
-
"""Plot actual vs predicted"""
|
| 432 |
-
|
| 433 |
-
fig = go.Figure()
|
| 434 |
-
|
| 435 |
-
if timestamps is None:
|
| 436 |
-
timestamps = list(range(len(y_true)))
|
| 437 |
-
|
| 438 |
-
fig.add_trace(go.Scatter(
|
| 439 |
-
x=timestamps,
|
| 440 |
-
y=y_true,
|
| 441 |
-
mode='lines',
|
| 442 |
-
name='Actual',
|
| 443 |
-
line=dict(color='cyan', width=2)
|
| 444 |
-
))
|
| 445 |
-
|
| 446 |
-
fig.add_trace(go.Scatter(
|
| 447 |
-
x=timestamps,
|
| 448 |
-
y=y_pred,
|
| 449 |
-
mode='lines',
|
| 450 |
-
name='Predicted',
|
| 451 |
-
line=dict(color='magenta', width=2, dash='dash')
|
| 452 |
-
))
|
| 453 |
-
|
| 454 |
-
fig.update_layout(
|
| 455 |
-
title=title,
|
| 456 |
-
xaxis_title='Time',
|
| 457 |
-
yaxis_title='Price (USDT)',
|
| 458 |
-
template='plotly_dark',
|
| 459 |
-
hovermode='x unified',
|
| 460 |
-
height=500
|
| 461 |
-
)
|
| 462 |
-
|
| 463 |
-
return fig
|
| 464 |
-
|
| 465 |
-
@staticmethod
|
| 466 |
-
def plot_candlestick(df, n_candles=100):
|
| 467 |
-
"""Plot candlestick chart"""
|
| 468 |
-
|
| 469 |
-
df = df.tail(n_candles).copy()
|
| 470 |
-
|
| 471 |
-
fig = make_subplots(
|
| 472 |
-
rows=2, cols=1,
|
| 473 |
-
shared_xaxes=True,
|
| 474 |
-
vertical_spacing=0.05,
|
| 475 |
-
subplot_titles=('Price', 'Volume'),
|
| 476 |
-
row_heights=[0.7, 0.3]
|
| 477 |
-
)
|
| 478 |
-
|
| 479 |
-
fig.add_trace(
|
| 480 |
-
go.Candlestick(
|
| 481 |
-
x=df['timestamp'],
|
| 482 |
-
open=df['open'],
|
| 483 |
-
high=df['high'],
|
| 484 |
-
low=df['low'],
|
| 485 |
-
close=df['close'],
|
| 486 |
-
name='OHLC'
|
| 487 |
-
),
|
| 488 |
-
row=1, col=1
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
colors = ['red' if row['close'] < row['open'] else 'green'
|
| 492 |
-
for idx, row in df.iterrows()]
|
| 493 |
-
|
| 494 |
-
fig.add_trace(
|
| 495 |
-
go.Bar(
|
| 496 |
-
x=df['timestamp'],
|
| 497 |
-
y=df['volume'],
|
| 498 |
-
name='Volume',
|
| 499 |
-
marker_color=colors
|
| 500 |
-
),
|
| 501 |
-
row=2, col=1
|
| 502 |
-
)
|
| 503 |
-
|
| 504 |
-
fig.update_layout(
|
| 505 |
-
title='BTC/USDT Chart',
|
| 506 |
-
template='plotly_dark',
|
| 507 |
-
xaxis_rangeslider_visible=False,
|
| 508 |
-
height=700
|
| 509 |
-
)
|
| 510 |
-
|
| 511 |
-
return fig
|
| 512 |
-
|
| 513 |
-
@staticmethod
|
| 514 |
-
def plot_feature_importance(model, feature_names, top_n=20):
|
| 515 |
-
"""Plot feature importance"""
|
| 516 |
-
|
| 517 |
-
if hasattr(model, 'feature_importances_'):
|
| 518 |
-
importances = model.feature_importances_
|
| 519 |
-
indices = np.argsort(importances)[-top_n:]
|
| 520 |
-
|
| 521 |
-
fig = go.Figure(go.Bar(
|
| 522 |
-
x=importances[indices],
|
| 523 |
-
y=[feature_names[i] for i in indices],
|
| 524 |
-
orientation='h',
|
| 525 |
-
marker_color='lightblue'
|
| 526 |
-
))
|
| 527 |
-
|
| 528 |
-
fig.update_layout(
|
| 529 |
-
title=f'Top {top_n} Feature Importances',
|
| 530 |
-
xaxis_title='Importance',
|
| 531 |
-
yaxis_title='Features',
|
| 532 |
-
template='plotly_dark',
|
| 533 |
-
height=600
|
| 534 |
-
)
|
| 535 |
-
|
| 536 |
-
return fig
|
| 537 |
-
|
| 538 |
-
return None
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
# ================================
|
| 542 |
-
# MAIN PIPELINE
|
| 543 |
-
# ================================
|
| 544 |
-
|
| 545 |
-
class BTCPredictionPipeline:
|
| 546 |
-
"""Main prediction pipeline"""
|
| 547 |
-
|
| 548 |
-
def __init__(self):
|
| 549 |
-
self.okx_client = OKXClient()
|
| 550 |
-
self.feature_engineer = FeatureEngineer()
|
| 551 |
-
self.ensemble_model = EnsemblePredictor()
|
| 552 |
-
self.visualizer = Visualizer()
|
| 553 |
-
self.raw_data = None
|
| 554 |
-
self.processed_data = None
|
| 555 |
-
|
| 556 |
-
def fetch_data(self, bar='1H', limit=300):
|
| 557 |
-
"""Fetch data from OKX"""
|
| 558 |
-
|
| 559 |
-
print(f"Fetching {limit} candles from OKX...")
|
| 560 |
-
df = self.okx_client.get_candlesticks(instId='BTC-USDT', bar=bar, limit=limit)
|
| 561 |
-
|
| 562 |
-
if df is not None:
|
| 563 |
-
self.raw_data = df
|
| 564 |
-
print(f"Fetched {len(df)} candles")
|
| 565 |
-
return df
|
| 566 |
else:
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
print("Engineering features...")
|
| 577 |
-
df = self.feature_engineer.add_technical_indicators(self.raw_data)
|
| 578 |
-
df = self.feature_engineer.add_lag_features(df, n_lags=5)
|
| 579 |
-
df = self.feature_engineer.add_time_features(df)
|
| 580 |
-
|
| 581 |
-
df = df.dropna()
|
| 582 |
-
self.processed_data = df
|
| 583 |
-
|
| 584 |
-
print(f"Features: {len(df.columns)}, Samples: {len(df)}")
|
| 585 |
-
|
| 586 |
-
return df
|
| 587 |
-
|
| 588 |
-
def train_model(self, test_size=0.2, val_size=0.1):
|
| 589 |
-
"""Train ensemble model"""
|
| 590 |
-
|
| 591 |
-
if self.processed_data is None:
|
| 592 |
-
raise ValueError("Features not prepared")
|
| 593 |
-
|
| 594 |
-
X, y = self.ensemble_model.prepare_data(self.processed_data)
|
| 595 |
-
|
| 596 |
-
X_temp, X_test, y_temp, y_test = train_test_split(
|
| 597 |
-
X, y, test_size=test_size, shuffle=False
|
| 598 |
-
)
|
| 599 |
-
|
| 600 |
-
X_train, X_val, y_train, y_val = train_test_split(
|
| 601 |
-
X_temp, y_temp, test_size=val_size/(1-test_size), shuffle=False
|
| 602 |
-
)
|
| 603 |
-
|
| 604 |
-
print(f"\nTrain: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")
|
| 605 |
-
|
| 606 |
-
print("\nTraining ensemble...")
|
| 607 |
-
self.ensemble_model.train(X_train, y_train, X_val, y_val)
|
| 608 |
-
|
| 609 |
-
print("\nEvaluating...")
|
| 610 |
-
metrics, predictions = self.ensemble_model.evaluate(X_test, y_test)
|
| 611 |
-
|
| 612 |
-
print("\n=== Ensemble Performance ===")
|
| 613 |
-
for metric_name, value in metrics['ensemble'].items():
|
| 614 |
-
print(f"{metric_name}: {value:.4f}")
|
| 615 |
-
|
| 616 |
-
return metrics, predictions, y_test
|
| 617 |
-
|
| 618 |
-
def predict_future(self, n_steps=24):
|
| 619 |
-
"""Predict future prices"""
|
| 620 |
-
|
| 621 |
-
if not self.ensemble_model.is_trained:
|
| 622 |
-
raise ValueError("Model not trained")
|
| 623 |
-
|
| 624 |
-
last_data = self.processed_data.iloc[-1:].copy()
|
| 625 |
-
X_last, _ = self.ensemble_model.prepare_data(last_data)
|
| 626 |
-
|
| 627 |
-
pred, _ = self.ensemble_model.predict(X_last)
|
| 628 |
-
|
| 629 |
-
last_time = self.processed_data['timestamp'].iloc[-1]
|
| 630 |
-
future_times = [last_time + timedelta(hours=i+1) for i in range(n_steps)]
|
| 631 |
-
|
| 632 |
-
predictions = [pred[0] * (1 + np.random.normal(0, 0.005)) for _ in range(n_steps)]
|
| 633 |
-
|
| 634 |
-
return future_times, predictions
|
| 635 |
-
|
| 636 |
-
# app.py - PARÇA 5/5
|
| 637 |
-
# ========================================================================
|
| 638 |
-
# Gradio Interface
|
| 639 |
-
# ========================================================================
|
| 640 |
-
|
| 641 |
-
# Global pipeline instance
|
| 642 |
-
pipeline = BTCPredictionPipeline()
|
| 643 |
-
training_complete = False
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
def fetch_data_ui(bar_size, num_candles):
|
| 647 |
-
"""Fetch data interface"""
|
| 648 |
-
try:
|
| 649 |
-
df = pipeline.fetch_data(bar=bar_size, limit=int(num_candles))
|
| 650 |
-
|
| 651 |
-
if df is not None:
|
| 652 |
-
info = f"✅ Successfully fetched {len(df)} candles\n\n"
|
| 653 |
-
info += f"Time range: {df['timestamp'].min()} to {df['timestamp'].max()}\n"
|
| 654 |
-
info += f"Price range: ${df['close'].min():.2f} - ${df['close'].max():.2f}\n"
|
| 655 |
-
info += f"Current price: ${df['close'].iloc[-1]:.2f}"
|
| 656 |
-
|
| 657 |
-
fig = pipeline.visualizer.plot_candlestick(df)
|
| 658 |
-
|
| 659 |
-
summary = df.tail(10)[['timestamp', 'open', 'high', 'low', 'close', 'volume']].copy()
|
| 660 |
-
summary['timestamp'] = summary['timestamp'].dt.strftime('%Y-%m-%d %H:%M')
|
| 661 |
-
|
| 662 |
-
return info, fig, summary
|
| 663 |
else:
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
try:
|
| 675 |
-
pipeline.prepare_features()
|
| 676 |
-
|
| 677 |
-
metrics, predictions, y_test = pipeline.train_model(
|
| 678 |
-
test_size=test_size,
|
| 679 |
-
val_size=val_size
|
| 680 |
-
)
|
| 681 |
-
|
| 682 |
-
training_complete = True
|
| 683 |
-
|
| 684 |
-
metrics_text = "=== ENSEMBLE MODEL PERFORMANCE ===\n\n"
|
| 685 |
-
for metric_name, value in metrics['ensemble'].items():
|
| 686 |
-
metrics_text += f"{metric_name}: {value:.4f}\n"
|
| 687 |
-
|
| 688 |
-
metrics_text += "\n\n=== INDIVIDUAL MODELS ===\n\n"
|
| 689 |
-
for model_name, model_metrics in metrics.items():
|
| 690 |
-
if model_name != 'ensemble':
|
| 691 |
-
metrics_text += f"\n{model_name.upper()}:\n"
|
| 692 |
-
for metric_name, value in model_metrics.items():
|
| 693 |
-
metrics_text += f" {metric_name}: {value:.4f}\n"
|
| 694 |
-
|
| 695 |
-
test_idx = len(pipeline.processed_data) - len(y_test)
|
| 696 |
-
test_timestamps = pipeline.processed_data['timestamp'].iloc[test_idx:].values
|
| 697 |
-
|
| 698 |
-
fig = pipeline.visualizer.plot_predictions(
|
| 699 |
-
y_test,
|
| 700 |
-
predictions,
|
| 701 |
-
test_timestamps,
|
| 702 |
-
"Test Set Predictions"
|
| 703 |
-
)
|
| 704 |
-
|
| 705 |
-
return metrics_text, fig, "✅ Training complete!"
|
| 706 |
-
|
| 707 |
-
except Exception as e:
|
| 708 |
-
return f"❌ Error: {str(e)}", None, "Training failed"
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
def predict_future_ui(n_hours):
|
| 712 |
-
"""Predict future interface"""
|
| 713 |
-
|
| 714 |
-
if not training_complete:
|
| 715 |
-
return "⚠️ Please train model first", None, None
|
| 716 |
-
|
| 717 |
-
try:
|
| 718 |
-
future_times, predictions = pipeline.predict_future(n_steps=int(n_hours))
|
| 719 |
-
|
| 720 |
-
pred_df = pd.DataFrame({
|
| 721 |
-
'Timestamp': [t.strftime('%Y-%m-%d %H:%M') for t in future_times],
|
| 722 |
-
'Predicted Price (USDT)': [f"${p:,.2f}" for p in predictions]
|
| 723 |
-
})
|
| 724 |
-
|
| 725 |
-
fig = go.Figure()
|
| 726 |
-
fig.add_trace(go.Scatter(
|
| 727 |
-
x=future_times,
|
| 728 |
-
y=predictions,
|
| 729 |
-
mode='lines+markers',
|
| 730 |
-
name='Predicted Price',
|
| 731 |
-
line=dict(color='green', width=3),
|
| 732 |
-
marker=dict(size=8)
|
| 733 |
-
))
|
| 734 |
-
|
| 735 |
-
fig.update_layout(
|
| 736 |
-
title=f'BTC/USDT Price Prediction - Next {n_hours} Hours',
|
| 737 |
-
xaxis_title='Time',
|
| 738 |
-
yaxis_title='Price (USDT)',
|
| 739 |
-
template='plotly_dark',
|
| 740 |
-
hovermode='x unified',
|
| 741 |
-
height=500
|
| 742 |
-
)
|
| 743 |
-
|
| 744 |
-
return pred_df, fig, f"✅ Predicted next {n_hours} hours"
|
| 745 |
-
|
| 746 |
-
except Exception as e:
|
| 747 |
-
return None, None, f"❌ Error: {str(e)}"
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
def get_current_price_ui():
|
| 751 |
-
"""Get current price from OKX"""
|
| 752 |
-
try:
|
| 753 |
-
ticker = pipeline.okx_client.get_ticker('BTC-USDT')
|
| 754 |
-
|
| 755 |
-
if ticker:
|
| 756 |
-
info = f"🔴 LIVE BTC/USDT PRICE\n\n"
|
| 757 |
-
info += f"Last Price: ${ticker['last']:,.2f}\n"
|
| 758 |
-
info += f"Bid: ${ticker['bid']:,.2f}\n"
|
| 759 |
-
info += f"Ask: ${ticker['ask']:,.2f}\n"
|
| 760 |
-
info += f"24h Volume: {ticker['volume_24h']:,.2f} BTC\n"
|
| 761 |
-
info += f"Updated: {ticker['timestamp'].strftime('%Y-%m-%d %H:%M:%S')}"
|
| 762 |
-
|
| 763 |
-
return info
|
| 764 |
else:
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
return
|
| 776 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
try:
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
try:
|
| 807 |
-
df =
|
| 808 |
-
|
| 809 |
-
fig = make_subplots(
|
| 810 |
-
rows=4, cols=1,
|
| 811 |
-
shared_xaxes=True,
|
| 812 |
-
vertical_spacing=0.05,
|
| 813 |
-
subplot_titles=('Price & MA', 'RSI', 'MACD', 'Volume'),
|
| 814 |
-
row_heights=[0.4, 0.2, 0.2, 0.2]
|
| 815 |
-
)
|
| 816 |
-
|
| 817 |
-
# Price and Moving Averages
|
| 818 |
-
fig.add_trace(
|
| 819 |
-
go.Scatter(x=df['timestamp'], y=df['close'],
|
| 820 |
-
name='Close', line=dict(color='white', width=2)),
|
| 821 |
-
row=1, col=1
|
| 822 |
-
)
|
| 823 |
-
fig.add_trace(
|
| 824 |
-
go.Scatter(x=df['timestamp'], y=df['sma_20'],
|
| 825 |
-
name='SMA 20', line=dict(color='orange', width=1)),
|
| 826 |
-
row=1, col=1
|
| 827 |
-
)
|
| 828 |
-
fig.add_trace(
|
| 829 |
-
go.Scatter(x=df['timestamp'], y=df['sma_50'],
|
| 830 |
-
name='SMA 50', line=dict(color='blue', width=1)),
|
| 831 |
-
row=1, col=1
|
| 832 |
-
)
|
| 833 |
-
|
| 834 |
-
# RSI
|
| 835 |
-
fig.add_trace(
|
| 836 |
-
go.Scatter(x=df['timestamp'], y=df['rsi_14'],
|
| 837 |
-
name='RSI', line=dict(color='purple', width=2)),
|
| 838 |
-
row=2, col=1
|
| 839 |
-
)
|
| 840 |
-
fig.add_hline(y=70, line_dash="dash", line_color="red", row=2, col=1)
|
| 841 |
-
fig.add_hline(y=30, line_dash="dash", line_color="green", row=2, col=1)
|
| 842 |
-
|
| 843 |
-
# MACD
|
| 844 |
-
fig.add_trace(
|
| 845 |
-
go.Scatter(x=df['timestamp'], y=df['macd'],
|
| 846 |
-
name='MACD', line=dict(color='blue', width=1)),
|
| 847 |
-
row=3, col=1
|
| 848 |
-
)
|
| 849 |
-
fig.add_trace(
|
| 850 |
-
go.Scatter(x=df['timestamp'], y=df['macd_signal'],
|
| 851 |
-
name='Signal', line=dict(color='red', width=1)),
|
| 852 |
-
row=3, col=1
|
| 853 |
-
)
|
| 854 |
-
fig.add_trace(
|
| 855 |
-
go.Bar(x=df['timestamp'], y=df['macd_diff'],
|
| 856 |
-
name='Histogram', marker_color='gray'),
|
| 857 |
-
row=3, col=1
|
| 858 |
-
)
|
| 859 |
-
|
| 860 |
-
# Volume
|
| 861 |
-
colors = ['red' if df.iloc[i]['close'] < df.iloc[i]['open'] else 'green'
|
| 862 |
-
for i in range(len(df))]
|
| 863 |
-
fig.add_trace(
|
| 864 |
-
go.Bar(x=df['timestamp'], y=df['volume'],
|
| 865 |
-
name='Volume', marker_color=colors),
|
| 866 |
-
row=4, col=1
|
| 867 |
-
)
|
| 868 |
-
|
| 869 |
-
fig.update_layout(
|
| 870 |
-
title='Market Technical Analysis',
|
| 871 |
-
template='plotly_dark',
|
| 872 |
-
height=900,
|
| 873 |
-
showlegend=True,
|
| 874 |
-
hovermode='x unified'
|
| 875 |
-
)
|
| 876 |
-
|
| 877 |
-
# Market summary
|
| 878 |
-
current_price = df['close'].iloc[-1]
|
| 879 |
-
rsi = df['rsi_14'].iloc[-1]
|
| 880 |
-
macd_signal = "Bullish" if df['macd_diff'].iloc[-1] > 0 else "Bearish"
|
| 881 |
-
|
| 882 |
-
summary = f"=== MARKET ANALYSIS ===\n\n"
|
| 883 |
-
summary += f"Current Price: ${current_price:,.2f}\n"
|
| 884 |
-
summary += f"RSI (14): {rsi:.2f} - "
|
| 885 |
-
|
| 886 |
-
if rsi > 70:
|
| 887 |
-
summary += "Overbought ⚠️\n"
|
| 888 |
-
elif rsi < 30:
|
| 889 |
-
summary += "Oversold ⚠️\n"
|
| 890 |
-
else:
|
| 891 |
-
summary += "Neutral ✅\n"
|
| 892 |
-
|
| 893 |
-
summary += f"MACD Signal: {macd_signal}\n"
|
| 894 |
-
summary += f"SMA 20: ${df['sma_20'].iloc[-1]:,.2f}\n"
|
| 895 |
-
summary += f"SMA 50: ${df['sma_50'].iloc[-1]:,.2f}\n"
|
| 896 |
-
summary += f"24h Change: {((current_price / df['close'].iloc[-24] - 1) * 100):.2f}%\n"
|
| 897 |
-
summary += f"Volatility (20): {df['volatility_20'].iloc[-1]:.6f}\n"
|
| 898 |
-
|
| 899 |
-
return fig, summary
|
| 900 |
-
|
| 901 |
except Exception as e:
|
| 902 |
-
return
|
| 903 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
|
| 905 |
-
# ================================
|
| 906 |
-
# GRADIO APP
|
| 907 |
-
# ================================
|
| 908 |
-
|
| 909 |
-
with gr.Blocks(title="OKX BTC/USDT Ensemble Predictor", theme=gr.themes.Soft()) as demo:
|
| 910 |
-
|
| 911 |
-
gr.Markdown("""
|
| 912 |
-
# 🚀 OKX BTC/USDT Ensemble Price Predictor
|
| 913 |
-
|
| 914 |
-
**Advanced Machine Learning System for Bitcoin Price Prediction**
|
| 915 |
-
|
| 916 |
-
This application uses an ensemble of 6 machine learning models:
|
| 917 |
-
- Random Forest
|
| 918 |
-
- Gradient Boosting
|
| 919 |
-
- AdaBoost
|
| 920 |
-
- Ridge Regression
|
| 921 |
-
- Lasso Regression
|
| 922 |
-
- Elastic Net
|
| 923 |
-
|
| 924 |
-
**Features:**
|
| 925 |
-
- Real-time data from OKX API
|
| 926 |
-
- 100+ technical indicators
|
| 927 |
-
- Weighted ensemble predictions
|
| 928 |
-
- Advanced visualization
|
| 929 |
-
""")
|
| 930 |
-
|
| 931 |
-
# TAB 1: DATA FETCHING
|
| 932 |
-
with gr.Tab("📊 Data Fetching"):
|
| 933 |
-
gr.Markdown("### Fetch Historical BTC/USDT Data from OKX")
|
| 934 |
-
|
| 935 |
-
with gr.Row():
|
| 936 |
-
bar_size = gr.Dropdown(
|
| 937 |
-
choices=['1m', '5m', '15m', '30m', '1H', '2H', '4H', '1D'],
|
| 938 |
-
value='1H',
|
| 939 |
-
label="Timeframe"
|
| 940 |
-
)
|
| 941 |
-
num_candles = gr.Slider(
|
| 942 |
-
minimum=100,
|
| 943 |
-
maximum=300,
|
| 944 |
-
value=300,
|
| 945 |
-
step=10,
|
| 946 |
-
label="Number of Candles"
|
| 947 |
-
)
|
| 948 |
-
|
| 949 |
-
fetch_btn = gr.Button("🔄 Fetch Data", variant="primary", size="lg")
|
| 950 |
-
|
| 951 |
with gr.Row():
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
train_btn.click(
|
| 990 |
-
fn=train_model_ui,
|
| 991 |
-
inputs=[test_size_slider, val_size_slider],
|
| 992 |
-
outputs=[train_metrics, train_plot, train_status]
|
| 993 |
-
)
|
| 994 |
-
|
| 995 |
-
# TAB 3: PREDICTIONS
|
| 996 |
-
with gr.Tab("🔮 Future Predictions"):
|
| 997 |
-
gr.Markdown("### Predict Future BTC/USDT Prices")
|
| 998 |
-
|
| 999 |
-
n_hours_slider = gr.Slider(
|
| 1000 |
-
minimum=1,
|
| 1001 |
-
maximum=72,
|
| 1002 |
-
value=24,
|
| 1003 |
-
step=1,
|
| 1004 |
-
label="Prediction Horizon (Hours)"
|
| 1005 |
-
)
|
| 1006 |
-
|
| 1007 |
-
predict_btn = gr.Button("🔮 Predict Future", variant="primary", size="lg")
|
| 1008 |
-
|
| 1009 |
-
predict_status = gr.Textbox(label="Prediction Status", lines=1)
|
| 1010 |
-
predict_table = gr.Dataframe(label="Predicted Prices")
|
| 1011 |
-
predict_plot = gr.Plot(label="Future Price Prediction")
|
| 1012 |
-
|
| 1013 |
-
predict_btn.click(
|
| 1014 |
-
fn=predict_future_ui,
|
| 1015 |
-
inputs=[n_hours_slider],
|
| 1016 |
-
outputs=[predict_table, predict_plot, predict_status]
|
| 1017 |
-
)
|
| 1018 |
-
|
| 1019 |
-
# TAB 4: LIVE PRICE
|
| 1020 |
-
with gr.Tab("💰 Live Price"):
|
| 1021 |
-
gr.Markdown("### Real-time BTC/USDT Price from OKX")
|
| 1022 |
-
|
| 1023 |
-
refresh_btn = gr.Button("🔄 Refresh Price", variant="primary", size="lg")
|
| 1024 |
-
|
| 1025 |
-
live_price_info = gr.Textbox(label="Current Market Data", lines=8)
|
| 1026 |
-
|
| 1027 |
-
refresh_btn.click(
|
| 1028 |
-
fn=get_current_price_ui,
|
| 1029 |
-
inputs=[],
|
| 1030 |
-
outputs=[live_price_info]
|
| 1031 |
-
)
|
| 1032 |
-
|
| 1033 |
-
# TAB 5: FEATURE IMPORTANCE
|
| 1034 |
-
with gr.Tab("📈 Feature Importance"):
|
| 1035 |
-
gr.Markdown("### Top Features Contributing to Predictions")
|
| 1036 |
-
|
| 1037 |
-
feature_btn = gr.Button("📊 Show Feature Importance", variant="primary", size="lg")
|
| 1038 |
-
|
| 1039 |
-
feature_plot = gr.Plot(label="Feature Importance Chart")
|
| 1040 |
-
feature_text = gr.Textbox(label="Top 30 Features", lines=35)
|
| 1041 |
-
|
| 1042 |
-
feature_btn.click(
|
| 1043 |
-
fn=show_feature_importance_ui,
|
| 1044 |
-
inputs=[],
|
| 1045 |
-
outputs=[feature_plot, feature_text]
|
| 1046 |
-
)
|
| 1047 |
-
|
| 1048 |
-
# TAB 6: MARKET ANALYSIS
|
| 1049 |
-
with gr.Tab("📉 Market Analysis"):
|
| 1050 |
-
gr.Markdown("### Technical Analysis Dashboard")
|
| 1051 |
-
|
| 1052 |
-
analyze_btn = gr.Button("📊 Analyze Market", variant="primary", size="lg")
|
| 1053 |
-
|
| 1054 |
-
analysis_plot = gr.Plot(label="Technical Indicators")
|
| 1055 |
-
analysis_summary = gr.Textbox(label="Market Summary", lines=12)
|
| 1056 |
-
|
| 1057 |
-
analyze_btn.click(
|
| 1058 |
-
fn=analyze_market_ui,
|
| 1059 |
-
inputs=[],
|
| 1060 |
-
outputs=[analysis_plot, analysis_summary]
|
| 1061 |
-
)
|
| 1062 |
-
|
| 1063 |
-
# TAB 7: ABOUT
|
| 1064 |
-
with gr.Tab("ℹ️ About"):
|
| 1065 |
-
gr.Markdown("""
|
| 1066 |
-
## About This Application
|
| 1067 |
-
|
| 1068 |
-
### Ensemble Model Architecture
|
| 1069 |
-
|
| 1070 |
-
This application uses a sophisticated ensemble learning approach combining:
|
| 1071 |
-
|
| 1072 |
-
1. **Random Forest** - Handles non-linear relationships and feature interactions
|
| 1073 |
-
2. **Gradient Boosting** - Sequential learning for complex patterns
|
| 1074 |
-
3. **AdaBoost** - Adaptive boosting for improved accuracy
|
| 1075 |
-
4. **Ridge Regression** - Linear model with L2 regularization
|
| 1076 |
-
5. **Lasso Regression** - Linear model with L1 regularization and feature selection
|
| 1077 |
-
6. **Elastic Net** - Combines L1 and L2 regularization
|
| 1078 |
-
|
| 1079 |
-
### Feature Engineering (100+ Features)
|
| 1080 |
-
|
| 1081 |
-
- **Price Features**: Returns, log returns, price ranges, candlestick patterns
|
| 1082 |
-
- **Moving Averages**: SMA and EMA (5, 10, 20, 50, 100 periods)
|
| 1083 |
-
- **Momentum Indicators**: MACD, RSI, ROC, Stochastic Oscillator
|
| 1084 |
-
- **Volatility Indicators**: ATR, Bollinger Bands, rolling volatility
|
| 1085 |
-
- **Volume Indicators**: OBV, volume ratios, volume-price trends
|
| 1086 |
-
- **Statistical Features**: Skewness, kurtosis, quantiles
|
| 1087 |
-
- **Lag Features**: Historical prices and volumes (1-5 periods)
|
| 1088 |
-
- **Time Features**: Hour, day, month with cyclical encoding
|
| 1089 |
-
|
| 1090 |
-
### Data Source
|
| 1091 |
-
|
| 1092 |
-
Real-time and historical data fetched from **OKX Exchange** via REST API:
|
| 1093 |
-
- Endpoint: `https://www.okx.com/api/v5/market/candles`
|
| 1094 |
-
- Instrument: BTC-USDT
|
| 1095 |
-
- Supported timeframes: 1m, 5m, 15m, 30m, 1H, 2H, 4H, 1D
|
| 1096 |
-
|
| 1097 |
-
### Model Training Process
|
| 1098 |
-
|
| 1099 |
-
1. **Data Collection**: Fetch historical OHLCV data from OKX
|
| 1100 |
-
2. **Feature Engineering**: Generate 100+ technical indicators
|
| 1101 |
-
3. **Data Preprocessing**: Handle missing values, normalize features
|
| 1102 |
-
4. **Train/Val/Test Split**: Time-series aware splitting
|
| 1103 |
-
5. **Model Training**: Train 6 models independently
|
| 1104 |
-
6. **Weight Optimization**: Calculate optimal ensemble weights based on validation performance
|
| 1105 |
-
7. **Evaluation**: Test on unseen data with multiple metrics
|
| 1106 |
-
|
| 1107 |
-
### Performance Metrics
|
| 1108 |
-
|
| 1109 |
-
- **MSE** (Mean Squared Error): Average squared prediction error
|
| 1110 |
-
- **RMSE** (Root Mean Squared Error): Square root of MSE, in price units
|
| 1111 |
-
- **MAE** (Mean Absolute Error): Average absolute prediction error
|
| 1112 |
-
- **R²** (R-squared): Proportion of variance explained
|
| 1113 |
-
- **MAPE** (Mean Absolute Percentage Error): Average percentage error
|
| 1114 |
-
|
| 1115 |
-
### Usage Instructions
|
| 1116 |
-
|
| 1117 |
-
1. **Fetch Data**: Go to "Data Fetching" tab and load historical data
|
| 1118 |
-
2. **Train Model**: Navigate to "Model Training" and train the ensemble
|
| 1119 |
-
3. **Make Predictions**: Use "Future Predictions" to forecast prices
|
| 1120 |
-
4. **Monitor Live**: Check "Live Price" for real-time market data
|
| 1121 |
-
5. **Analyze**: Explore "Feature Importance" and "Market Analysis"
|
| 1122 |
-
|
| 1123 |
-
### Limitations & Disclaimer
|
| 1124 |
-
|
| 1125 |
-
⚠️ **Important**: This tool is for educational and research purposes only.
|
| 1126 |
-
|
| 1127 |
-
- Cryptocurrency markets are highly volatile and unpredictable
|
| 1128 |
-
- Past performance does not guarantee future results
|
| 1129 |
-
- Model predictions should NOT be used as sole basis for trading decisions
|
| 1130 |
-
- Always conduct your own research and consult financial advisors
|
| 1131 |
-
- The authors are not responsible for any financial losses
|
| 1132 |
-
|
| 1133 |
-
### Technical Stack
|
| 1134 |
-
|
| 1135 |
-
- **Python 3.10+**
|
| 1136 |
-
- **Gradio**: Web interface
|
| 1137 |
-
- **Scikit-learn**: Machine learning models
|
| 1138 |
-
- **Pandas & NumPy**: Data manipulation
|
| 1139 |
-
- **Plotly**: Interactive visualizations
|
| 1140 |
-
- **Requests**: API communication
|
| 1141 |
-
|
| 1142 |
-
### Version
|
| 1143 |
-
|
| 1144 |
-
**v1.0.0** - Initial Release
|
| 1145 |
-
|
| 1146 |
-
---
|
| 1147 |
-
|
| 1148 |
-
Made with ❤️ for the crypto community
|
| 1149 |
-
|
| 1150 |
-
**GitHub**: [Your Repository Link]
|
| 1151 |
-
**Documentation**: [Your Docs Link]
|
| 1152 |
-
**Contact**: [Your Contact Info]
|
| 1153 |
-
""")
|
| 1154 |
-
|
| 1155 |
-
# ================================
|
| 1156 |
-
# LAUNCH APP
|
| 1157 |
-
# ================================
|
| 1158 |
-
|
| 1159 |
if __name__ == "__main__":
|
| 1160 |
-
|
| 1161 |
-
|
| 1162 |
-
server_port=7860,
|
| 1163 |
-
share=False,
|
| 1164 |
-
show_error=True
|
| 1165 |
-
)
|
|
|
|
| 1 |
+
İşte tek dosya `app.py`. Gradio (blank) arayüzü, OKX REST'ten BTC/USDT (spot) candle verisi çekme, önişleme, birkaç basit modelden (LightGBM, XGBoost, küçük PyTorch LSTM ve basit RandomForest) oluşan ensemble ile inference yapacak şekilde hazırlanmıştır. Eksik modeller varsa demo (dummy) modeller üretecek; gerçek eğitim için ek adımlar gerekir. Dosya, Spaces/Gradio üzerinde çalışacak şekilde tasarlandı.
|
| 2 |
+
|
| 3 |
+
python
|
| 4 |
+
# app.py
|
| 5 |
+
"""
|
| 6 |
+
Gradio (blank) tabanlı Hugging Face Space uygulaması.
|
| 7 |
+
- OKX REST API'den BTC/USDT (spot) candle verisi çeker
|
| 8 |
+
- Teknik göstergeler üretir
|
| 9 |
+
- Ensemble: LightGBM, XGBoost, RandomForest (sklearn) + küçük PyTorch LSTM
|
| 10 |
+
- Eğer pretrained model dosyaları yoksa küçük demo modeller oluşturur
|
| 11 |
+
- Outputs: tahmin (regresyon: next-close), model katkıları, grafikler
|
| 12 |
+
|
| 13 |
+
Not:
|
| 14 |
+
- requirements.txt'de aşağıdakiler olmalı:
|
| 15 |
+
gradio, pandas, numpy, requests, ta, scikit-learn, lightgbm, xgboost, torch, matplotlib
|
| 16 |
+
- Kullanıcı OKX API anahtarı gerekli değildir (public candles endpoint kullanılıyor).
|
| 17 |
+
- Bu dosya tek başına çalışır; ancak ağır paketler (lightgbm, xgboost, torch) Spaces ortamında kurulmadıysa hata verebilir.
|
| 18 |
+
"""
|
| 19 |
|
| 20 |
import os
|
| 21 |
+
import io
|
| 22 |
+
import time
|
| 23 |
+
import math
|
| 24 |
+
import json
|
| 25 |
+
import threading
|
| 26 |
+
from typing import Tuple, Dict, Any, List
|
| 27 |
+
|
| 28 |
import numpy as np
|
| 29 |
import pandas as pd
|
|
|
|
| 30 |
import requests
|
| 31 |
+
from datetime import datetime, timedelta, timezone
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# Visualization
|
| 34 |
+
import matplotlib
|
| 35 |
+
matplotlib.use("Agg")
|
| 36 |
+
import matplotlib.pyplot as plt
|
| 37 |
+
|
| 38 |
+
# Technical indicators
|
| 39 |
+
try:
|
| 40 |
+
import ta
|
| 41 |
+
except Exception:
|
| 42 |
+
# Minimal fallback implementations if ta isn't installed
|
| 43 |
+
ta = None
|
| 44 |
+
|
| 45 |
+
# ML libs
|
| 46 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 47 |
+
from sklearn.preprocessing import StandardScaler
|
| 48 |
+
from sklearn.pipeline import Pipeline
|
| 49 |
+
from sklearn.base import BaseEstimator, RegressorMixin
|
| 50 |
+
|
| 51 |
+
# Try import optional libs
|
| 52 |
+
HAS_LGB = True
|
| 53 |
+
HAS_XGB = True
|
| 54 |
+
HAS_TORCH = True
|
| 55 |
+
try:
|
| 56 |
+
import lightgbm as lgb
|
| 57 |
+
except Exception:
|
| 58 |
+
HAS_LGB = False
|
| 59 |
+
try:
|
| 60 |
+
import xgboost as xgb
|
| 61 |
+
except Exception:
|
| 62 |
+
HAS_XGB = False
|
| 63 |
+
try:
|
| 64 |
+
import torch
|
| 65 |
+
import torch.nn as nn
|
| 66 |
+
import torch.nn.functional as F
|
| 67 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 68 |
+
except Exception:
|
| 69 |
+
HAS_TORCH = False
|
| 70 |
+
|
| 71 |
+
# Gradio
|
| 72 |
+
import gradio as gr
|
| 73 |
|
| 74 |
+
# -------------------------
|
| 75 |
+
# Configuration/Constants
|
| 76 |
+
# -------------------------
|
| 77 |
+
OKX_BASE = "https://www.okx.com"
|
| 78 |
+
# Public candles: GET /api/v5/market/history-candles?instId=BTC-USDT-SWAP&bar=1m&limit=100
|
| 79 |
+
# We'll use spot: BTC-USDT
|
| 80 |
+
DEFAULT_INSTRUMENT = "BTC-USDT"
|
| 81 |
+
DEFAULT_BAR = "1m" # options: 1m, 3m, 5m, 15m, 1H etc.
|
| 82 |
+
DEFAULT_LIMIT = 500 # up to 1000 depending on endpoint
|
| 83 |
+
|
| 84 |
+
# Model filenames (in repo or persisted by training)
|
| 85 |
+
MODEL_DIR = "models"
|
| 86 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 87 |
+
LGB_MODEL_FILE = os.path.join(MODEL_DIR, "lgb_model.txt")
|
| 88 |
+
XGB_MODEL_FILE = os.path.join(MODEL_DIR, "xgb_model.json")
|
| 89 |
+
RF_MODEL_FILE = os.path.join(MODEL_DIR, "rf_model.pkl")
|
| 90 |
+
LSTM_MODEL_FILE = os.path.join(MODEL_DIR, "lstm_model.pt")
|
| 91 |
+
SCALER_FILE = os.path.join(MODEL_DIR, "scaler.npy") # save scaler mean/scale
|
| 92 |
+
|
| 93 |
+
# Thread-safe model cache
|
| 94 |
+
_MODEL_LOCK = threading.Lock()
|
| 95 |
+
_MODELS = {}
|
| 96 |
+
|
| 97 |
+
# -------------------------
|
| 98 |
+
# Utilities
|
| 99 |
+
# -------------------------
|
| 100 |
+
def now_iso():
|
| 101 |
+
return datetime.now(timezone.utc).isoformat()
|
| 102 |
+
|
| 103 |
+
def okx_candles(inst_id: str = DEFAULT_INSTRUMENT, bar: str = DEFAULT_BAR, limit: int = DEFAULT_LIMIT) -> pd.DataFrame:
|
| 104 |
+
"""
|
| 105 |
+
Fetch recent candle data from OKX public REST API.
|
| 106 |
+
Returns DataFrame with columns: ts, open, high, low, close, volume
|
| 107 |
+
ts in UTC datetime
|
| 108 |
+
"""
|
| 109 |
+
url = f"{OKX_BASE}/api/v5/market/history-candles"
|
| 110 |
+
params = {"instId": inst_id, "bar": bar, "limit": str(limit)}
|
| 111 |
+
resp = requests.get(url, params=params, timeout=15)
|
| 112 |
+
resp.raise_for_status()
|
| 113 |
+
data = resp.json()
|
| 114 |
+
|
| 115 |
+
if not data or data.get("code") not in (None, "0", 0):
|
| 116 |
+
# OKX returns "code": "0" on success sometimes; be permissive
|
| 117 |
+
# If structure unexpected, raise
|
| 118 |
+
# Try to parse anyway
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
cand = data.get("data", [])
|
| 122 |
+
if not cand:
|
| 123 |
+
# Possibly different field
|
| 124 |
+
raise RuntimeError("No candle data returned from OKX")
|
| 125 |
+
|
| 126 |
+
# OKX returns list of lists: [ts, open, high, low, close, volume, ...]
|
| 127 |
+
# timestamp in millis
|
| 128 |
+
rows = []
|
| 129 |
+
for c in cand:
|
| 130 |
+
# According to OKX docs: [ts, open, high, low, close, volume]
|
| 131 |
+
ts = int(c[0]) // 1000 if len(str(c[0])) > 10 else int(c[0])
|
| 132 |
+
dt = datetime.fromtimestamp(ts, tz=timezone.utc)
|
| 133 |
+
rows.append({
|
| 134 |
+
"ts": dt,
|
| 135 |
+
"open": float(c[1]),
|
| 136 |
+
"high": float(c[2]),
|
| 137 |
+
"low": float(c[3]),
|
| 138 |
+
"close": float(c[4]),
|
| 139 |
+
"volume": float(c[5])
|
| 140 |
})
|
| 141 |
+
df = pd.DataFrame(rows)
|
| 142 |
+
df = df.sort_values("ts").reset_index(drop=True)
|
| 143 |
+
return df
|
| 144 |
+
|
| 145 |
+
# Minimal TA indicators if `ta` package is not available
|
| 146 |
+
def add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
|
| 147 |
+
df = df.copy()
|
| 148 |
+
if ta is not None:
|
| 149 |
+
# Use ta to add common indicators
|
| 150 |
+
df["rsi"] = ta.momentum.RSIIndicator(df["close"], window=14, fillna=True).rsi()
|
| 151 |
+
df["ema12"] = ta.trend.EMAIndicator(df["close"], window=12, fillna=True).ema_indicator()
|
| 152 |
+
df["ema26"] = ta.trend.EMAIndicator(df["close"], window=26, fillna=True).ema_indicator()
|
| 153 |
+
macd = ta.trend.MACD(df["close"], window_slow=26, window_fast=12, window_sign=9, fillna=True)
|
| 154 |
+
df["macd"] = macd.macd()
|
| 155 |
+
df["macd_signal"] = macd.macd_signal()
|
| 156 |
+
df["bb_high"] = ta.volatility.BollingerBands(df["close"], window=20, fillna=True).bollinger_hband()
|
| 157 |
+
df["bb_low"] = ta.volatility.BollingerBands(df["close"], window=20, fillna=True).bollinger_lband()
|
| 158 |
+
df["atr"] = ta.volatility.AverageTrueRange(df["high"], df["low"], df["close"], window=14, fillna=True).average_true_range()
|
| 159 |
+
else:
|
| 160 |
+
# Fallback simple computations
|
| 161 |
+
df["rsi"] = simple_rsi(df["close"], window=14)
|
| 162 |
+
df["ema12"] = df["close"].ewm(span=12, adjust=False).mean()
|
| 163 |
+
df["ema26"] = df["close"].ewm(span=26, adjust=False).mean()
|
| 164 |
+
df["macd"] = df["ema12"] - df["ema26"]
|
| 165 |
+
df["macd_signal"] = df["macd"].ewm(span=9, adjust=False).mean()
|
| 166 |
+
df["bb_mid"] = df["close"].rolling(20).mean()
|
| 167 |
+
df["bb_std"] = df["close"].rolling(20).std()
|
| 168 |
+
df["bb_high"] = df["bb_mid"] + 2 * df["bb_std"]
|
| 169 |
+
df["bb_low"] = df["bb_mid"] - 2 * df["bb_std"]
|
| 170 |
+
df["atr"] = simple_atr(df, window=14)
|
| 171 |
+
# Fill na
|
| 172 |
+
df = df.fillna(method="bfill").fillna(method="ffill").fillna(0.0)
|
| 173 |
+
return df
|
| 174 |
+
|
| 175 |
+
def simple_rsi(series: pd.Series, window: int = 14) -> pd.Series:
|
| 176 |
+
delta = series.diff()
|
| 177 |
+
up = delta.clip(lower=0)
|
| 178 |
+
down = -1 * delta.clip(upper=0)
|
| 179 |
+
ma_up = up.ewm(alpha=1/window, adjust=False).mean()
|
| 180 |
+
ma_down = down.ewm(alpha=1/window, adjust=False).mean()
|
| 181 |
+
rs = ma_up / (ma_down + 1e-8)
|
| 182 |
+
rsi = 100 - (100 / (1 + rs))
|
| 183 |
+
return rsi.fillna(50.0)
|
| 184 |
+
|
| 185 |
+
def simple_atr(df: pd.DataFrame, window: int = 14) -> pd.Series:
|
| 186 |
+
high_low = df["high"] - df["low"]
|
| 187 |
+
high_close = (df["high"] - df["close"].shift()).abs()
|
| 188 |
+
low_close = (df["low"] - df["close"].shift()).abs()
|
| 189 |
+
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
| 190 |
+
atr = tr.ewm(span=window, adjust=False).mean()
|
| 191 |
+
return atr.fillna(0.0)
|
| 192 |
+
|
| 193 |
+
def create_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 194 |
+
df = df.copy()
|
| 195 |
+
df = add_technical_indicators(df)
|
| 196 |
+
# Returns features aligned to each row predicting next row's close
|
| 197 |
+
# Feature engineering: returns, log returns, vol, moving averages, ratios
|
| 198 |
+
df["return_1"] = df["close"].pct_change().fillna(0.0)
|
| 199 |
+
df["log_return_1"] = np.log1p(df["return_1"])
|
| 200 |
+
df["vol_5"] = df["close"].rolling(5).std().fillna(0.0)
|
| 201 |
+
df["vol_20"] = df["close"].rolling(20).std().fillna(0.0)
|
| 202 |
+
df["ma_5"] = df["close"].rolling(5).mean().fillna(method="bfill")
|
| 203 |
+
df["ma_20"] = df["close"].rolling(20).mean().fillna(method="bfill")
|
| 204 |
+
df["ma_50"] = df["close"].rolling(50).mean().fillna(method="bfill")
|
| 205 |
+
# ratio features
|
| 206 |
+
df["ma5_div_ma20"] = df["ma_5"] / (df["ma_20"] + 1e-9)
|
| 207 |
+
df["ema_diff"] = df["ema12"] - df["ema26"]
|
| 208 |
+
# time features
|
| 209 |
+
df["ts_unix"] = df["ts"].astype(np.int64) // 10**9
|
| 210 |
+
df["hour"] = df["ts"].dt.hour
|
| 211 |
+
df["minute"] = df["ts"].dt.minute
|
| 212 |
+
# fill remaining na
|
| 213 |
+
df = df.fillna(method="bfill").fillna(0.0)
|
| 214 |
+
return df
|
| 215 |
+
|
| 216 |
+
# -------------------------
|
| 217 |
+
# Model wrappers and helpers
|
| 218 |
+
# -------------------------
|
| 219 |
+
class DummyRegressor(BaseEstimator, RegressorMixin):
|
| 220 |
+
"""Simple mean predictor used as fallback."""
|
| 221 |
+
def fit(self, X, y):
|
| 222 |
+
self._mean = np.mean(y) if len(y) else 0.0
|
| 223 |
+
return self
|
| 224 |
+
def predict(self, X):
|
| 225 |
+
return np.full((X.shape[0],), getattr(self, "_mean", 0.0))
|
| 226 |
+
|
| 227 |
+
def save_numpy(obj: np.ndarray, path: str):
|
| 228 |
+
np.save(path, obj)
|
| 229 |
+
|
| 230 |
+
def load_numpy(path: str) -> np.ndarray:
|
| 231 |
+
return np.load(path)
|
| 232 |
+
|
| 233 |
+
def get_feature_columns() -> List[str]:
|
| 234 |
+
cols = [
|
| 235 |
+
"open","high","low","close","volume",
|
| 236 |
+
"rsi","ema12","ema26","macd","macd_signal","bb_high","bb_low","atr",
|
| 237 |
+
"return_1","log_return_1","vol_5","vol_20","ma_5","ma_20","ma_50",
|
| 238 |
+
"ma5_div_ma20","ema_diff","ts_unix","hour","minute"
|
| 239 |
+
]
|
| 240 |
+
return cols
|
| 241 |
+
|
| 242 |
+
# Model persistence helpers (light, simple)
|
| 243 |
+
def load_models() -> Dict[str, Any]:
|
| 244 |
+
"""
|
| 245 |
+
Try to load pretrained models from MODEL_DIR. If missing, create small demo models.
|
| 246 |
+
Returns dict of models and scaler.
|
| 247 |
+
"""
|
| 248 |
+
with _MODEL_LOCK:
|
| 249 |
+
if _MODELS:
|
| 250 |
+
return _MODELS
|
| 251 |
+
|
| 252 |
+
models = {}
|
| 253 |
+
scaler = None
|
| 254 |
+
|
| 255 |
+
# Try load scaler if exists
|
| 256 |
+
if os.path.exists(SCALER_FILE):
|
| 257 |
+
try:
|
| 258 |
+
sc = np.load(SCALER_FILE, allow_pickle=True).item()
|
| 259 |
+
scaler = StandardScaler()
|
| 260 |
+
scaler.mean_ = sc["mean"]
|
| 261 |
+
scaler.scale_ = sc["scale"]
|
| 262 |
+
scaler.n_features_in_ = sc["n_in"]
|
| 263 |
+
except Exception:
|
| 264 |
+
scaler = None
|
| 265 |
+
|
| 266 |
+
# RandomForest (sklearn)
|
| 267 |
try:
|
| 268 |
+
import joblib
|
| 269 |
+
if os.path.exists(RF_MODEL_FILE):
|
| 270 |
+
models["rf"] = joblib.load(RF_MODEL_FILE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
else:
|
| 272 |
+
raise FileNotFoundError
|
| 273 |
+
except Exception:
|
| 274 |
+
# create small RF demo
|
| 275 |
+
models["rf"] = RandomForestRegressor(n_estimators=10, random_state=42)
|
| 276 |
+
|
| 277 |
+
# LightGBM
|
| 278 |
+
if HAS_LGB and os.path.exists(LGB_MODEL_FILE):
|
| 279 |
+
try:
|
| 280 |
+
models["lgb"] = lgb.Booster(model_file=LGB_MODEL_FILE)
|
| 281 |
+
except Exception:
|
| 282 |
+
models["lgb"] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
else:
|
| 284 |
+
models["lgb"] = None if not HAS_LGB else None
|
| 285 |
+
|
| 286 |
+
# XGBoost
|
| 287 |
+
if HAS_XGB and os.path.exists(XGB_MODEL_FILE):
|
| 288 |
+
try:
|
| 289 |
+
models["xgb"] = xgb.Booster()
|
| 290 |
+
models["xgb"].load_model(XGB_MODEL_FILE)
|
| 291 |
+
except Exception:
|
| 292 |
+
models["xgb"] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
+
models["xgb"] = None
|
| 295 |
+
|
| 296 |
+
# LSTM / PyTorch
|
| 297 |
+
if HAS_TORCH and os.path.exists(LSTM_MODEL_FILE):
|
| 298 |
+
try:
|
| 299 |
+
lstm = torch.load(LSTM_MODEL_FILE, map_location=torch.device("cpu"))
|
| 300 |
+
models["lstm"] = lstm
|
| 301 |
+
except Exception:
|
| 302 |
+
models["lstm"] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
else:
|
| 304 |
+
models["lstm"] = None
|
| 305 |
+
|
| 306 |
+
# If scaler missing, create a dummy one later in pipeline when training; for inference create StandardScaler default
|
| 307 |
+
if scaler is None:
|
| 308 |
+
scaler = StandardScaler()
|
| 309 |
+
|
| 310 |
+
# Create an ensemble wrapper
|
| 311 |
+
models["scaler"] = scaler
|
| 312 |
+
|
| 313 |
+
_MODELS.update(models)
|
| 314 |
+
return _MODELS
|
| 315 |
+
|
| 316 |
+
def save_scaler(scaler: StandardScaler, path: str = SCALER_FILE):
|
| 317 |
+
obj = {"mean": scaler.mean_, "scale": scaler.scale_, "n_in": scaler.n_features_in_}
|
| 318 |
+
np.save(path, obj)
|
| 319 |
+
|
| 320 |
+
# -------------------------
|
| 321 |
+
# Inference logic
|
| 322 |
+
# -------------------------
|
| 323 |
+
def prepare_inference_features(df: pd.DataFrame) -> Tuple[np.ndarray, List[str], pd.DataFrame]:
|
| 324 |
+
"""
|
| 325 |
+
Takes raw candles df, returns (X, feature_cols, df_ready)
|
| 326 |
+
X is 2D array for model input, aligned so that each row predicts next close.
|
| 327 |
+
"""
|
| 328 |
+
df2 = create_features(df)
|
| 329 |
+
feat_cols = get_feature_columns()
|
| 330 |
+
# Ensure columns present
|
| 331 |
+
for c in feat_cols:
|
| 332 |
+
if c not in df2.columns:
|
| 333 |
+
df2[c] = 0.0
|
| 334 |
+
X = df2[feat_cols].values
|
| 335 |
+
return X, feat_cols, df2
|
| 336 |
+
|
| 337 |
+
def predict_ensemble(X: np.ndarray, models: Dict[str, Any]) -> Dict[str, Any]:
|
| 338 |
+
"""
|
| 339 |
+
Predict next-step close using ensemble of models.
|
| 340 |
+
Return dict:
|
| 341 |
+
- per_model_preds: {name: scalar_pred}
|
| 342 |
+
- ensemble_mean: float
|
| 343 |
+
- weighted: float (weights fallback equal)
|
| 344 |
+
"""
|
| 345 |
+
scaler = models.get("scaler", None)
|
| 346 |
+
if scaler is None:
|
| 347 |
+
scaler = StandardScaler()
|
| 348 |
+
# Use last row features to predict next
|
| 349 |
+
if X.ndim == 1:
|
| 350 |
+
X_row = X.reshape(1, -1)
|
| 351 |
+
else:
|
| 352 |
+
X_row = X[-1:, :]
|
| 353 |
+
# scale
|
| 354 |
try:
|
| 355 |
+
Xs = scaler.transform(X_row)
|
| 356 |
+
except Exception:
|
| 357 |
+
# If scaler not fitted, fit on X (fallback)
|
| 358 |
+
try:
|
| 359 |
+
scaler.fit(X)
|
| 360 |
+
save_scaler(scaler)
|
| 361 |
+
Xs = scaler.transform(X_row)
|
| 362 |
+
except Exception:
|
| 363 |
+
Xs = X_row
|
| 364 |
+
|
| 365 |
+
preds = {}
|
| 366 |
+
# RandomForest
|
| 367 |
+
rf = models.get("rf", None)
|
| 368 |
+
if rf is not None:
|
| 369 |
+
try:
|
| 370 |
+
p = rf.predict(Xs)[0]
|
| 371 |
+
except Exception:
|
| 372 |
+
p = float(np.nan)
|
| 373 |
+
else:
|
| 374 |
+
p = float(np.nan)
|
| 375 |
+
preds["rf"] = float(p)
|
| 376 |
+
|
| 377 |
+
# LightGBM
|
| 378 |
+
if HAS_LGB and models.get("lgb", None) is not None:
|
| 379 |
+
try:
|
| 380 |
+
dmat = lgb.Dataset(Xs, free_raw_data=False)
|
| 381 |
+
p = models["lgb"].predict(Xs)[0]
|
| 382 |
+
except Exception:
|
| 383 |
+
p = float(np.nan)
|
| 384 |
+
else:
|
| 385 |
+
p = float(np.nan)
|
| 386 |
+
preds["lgb"] = float(p)
|
| 387 |
+
|
| 388 |
+
# XGBoost
|
| 389 |
+
if HAS_XGB and models.get("xgb", None) is not None:
|
| 390 |
+
try:
|
| 391 |
+
dm = xgb.DMatrix(Xs)
|
| 392 |
+
p = models["xgb"].predict(dm)[0]
|
| 393 |
+
except Exception:
|
| 394 |
+
p = float(np.nan)
|
| 395 |
+
else:
|
| 396 |
+
p = float(np.nan)
|
| 397 |
+
preds["xgb"] = float(p)
|
| 398 |
+
|
| 399 |
+
# LSTM (PyTorch)
|
| 400 |
+
if HAS_TORCH and models.get("lstm", None) is not None:
|
| 401 |
+
try:
|
| 402 |
+
model = models["lstm"]
|
| 403 |
+
model.eval()
|
| 404 |
+
with torch.no_grad():
|
| 405 |
+
t = torch.tensor(X_row, dtype=torch.float32).unsqueeze(0) # shape (1,1,features) if expected
|
| 406 |
+
# try both (1,features) or (1,seq,features)
|
| 407 |
+
if t.dim() == 3:
|
| 408 |
+
out = model(t)
|
| 409 |
+
else:
|
| 410 |
+
# reshape to (1,1,features)
|
| 411 |
+
t2 = t.unsqueeze(1)
|
| 412 |
+
out = model(t2)
|
| 413 |
+
p = float(out.squeeze().cpu().numpy())
|
| 414 |
+
except Exception:
|
| 415 |
+
p = float(np.nan)
|
| 416 |
+
else:
|
| 417 |
+
p = float(np.nan)
|
| 418 |
+
preds["lstm"] = float(p)
|
| 419 |
+
|
| 420 |
+
# If models missing, fallback: use RF or mean of last price as naive
|
| 421 |
+
valid_preds = [v for v in preds.values() if not (math.isnan(v) or v is None)]
|
| 422 |
+
if not valid_preds:
|
| 423 |
+
# fallback naive next-close = last close
|
| 424 |
+
naive = float(X_row[0, get_feature_columns().index("close")])
|
| 425 |
+
ensemble_mean = naive
|
| 426 |
+
weighted = naive
|
| 427 |
+
else:
|
| 428 |
+
ensemble_mean = float(np.nanmean(valid_preds))
|
| 429 |
+
# Simple weighting: prefer models that exist; equal weight
|
| 430 |
+
weighted = ensemble_mean
|
| 431 |
+
|
| 432 |
+
return {
|
| 433 |
+
"per_model": preds,
|
| 434 |
+
"ensemble_mean": ensemble_mean,
|
| 435 |
+
"weighted": weighted
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
# -------------------------
|
| 439 |
+
# LSTM simple architecture (for demo)
|
| 440 |
+
# -------------------------
|
| 441 |
+
if HAS_TORCH:
|
| 442 |
+
class SimpleLSTM(nn.Module):
|
| 443 |
+
def __init__(self, input_size: int, hidden_size: int = 32, num_layers: int = 1):
|
| 444 |
+
super().__init__()
|
| 445 |
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
| 446 |
+
self.fc = nn.Linear(hidden_size, 1)
|
| 447 |
+
def forward(self, x):
|
| 448 |
+
# x: (batch, seq_len, input_size)
|
| 449 |
+
out, _ = self.lstm(x)
|
| 450 |
+
# take last time step
|
| 451 |
+
last = out[:, -1, :]
|
| 452 |
+
return self.fc(last)
|
| 453 |
+
|
| 454 |
+
# -------------------------
|
| 455 |
+
# Visualization helpers
|
| 456 |
+
# -------------------------
|
| 457 |
+
def plot_price_and_preds(df: pd.DataFrame, preds: Dict[str, Any]) -> bytes:
|
| 458 |
+
fig, ax = plt.subplots(figsize=(9,4))
|
| 459 |
+
ax.plot(df["ts"], df["close"], label="close", color="black", lw=1)
|
| 460 |
+
# mark last price and ensemble prediction
|
| 461 |
+
last_ts = df["ts"].iloc[-1]
|
| 462 |
+
last_close = df["close"].iloc[-1]
|
| 463 |
+
pred = preds.get("weighted", preds.get("ensemble_mean", last_close))
|
| 464 |
+
ax.scatter([last_ts + pd.Timedelta(seconds=1)], [pred], color="red", label="ensemble_pred")
|
| 465 |
+
ax.axhline(last_close, linestyle="--", color="gray", alpha=0.6)
|
| 466 |
+
ax.set_title("BTC/USDT close and ensemble prediction")
|
| 467 |
+
ax.set_xlabel("Time (UTC)")
|
| 468 |
+
ax.set_ylabel("Price")
|
| 469 |
+
ax.legend()
|
| 470 |
+
fig.tight_layout()
|
| 471 |
+
buf = io.BytesIO()
|
| 472 |
+
fig.savefig(buf, format="png")
|
| 473 |
+
plt.close(fig)
|
| 474 |
+
buf.seek(0)
|
| 475 |
+
return buf.read()
|
| 476 |
+
|
| 477 |
+
def plot_model_contributions(per_model: Dict[str, float]) -> bytes:
|
| 478 |
+
names = list(per_model.keys())
|
| 479 |
+
vals = [per_model[n] if (not math.isnan(per_model[n])) else 0.0 for n in names]
|
| 480 |
+
fig, ax = plt.subplots(figsize=(6,3))
|
| 481 |
+
ax.bar(names, vals, color=["#1f77b4","#ff7f0e","#2ca02c","#d62728"])
|
| 482 |
+
ax.set_title("Per-model predictions (abs values)")
|
| 483 |
+
ax.set_ylabel("Predicted price")
|
| 484 |
+
fig.tight_layout()
|
| 485 |
+
buf = io.BytesIO()
|
| 486 |
+
fig.savefig(buf, format="png")
|
| 487 |
+
plt.close(fig)
|
| 488 |
+
buf.seek(0)
|
| 489 |
+
return buf.read()
|
| 490 |
+
|
| 491 |
+
# -------------------------
|
| 492 |
+
# Gradio app components
|
| 493 |
+
# -------------------------
|
| 494 |
+
def inference_pipeline(inst_id: str = DEFAULT_INSTRUMENT,
|
| 495 |
+
bar: str = DEFAULT_BAR,
|
| 496 |
+
limit: int = DEFAULT_LIMIT,
|
| 497 |
+
show_plot: bool = True):
|
| 498 |
+
"""
|
| 499 |
+
High-level function called by Gradio. Returns JSON/dicts + image bytes for display.
|
| 500 |
+
"""
|
| 501 |
+
# Step 1: fetch candles
|
| 502 |
try:
|
| 503 |
+
df = okx_candles(inst_id=inst_id, bar=bar, limit=int(limit))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
except Exception as e:
|
| 505 |
+
return {"error": f"Failed to fetch candles: {e}"}
|
| 506 |
+
|
| 507 |
+
# Step 2: prepare features
|
| 508 |
+
X, feat_cols, df_ready = prepare_inference_features(df)
|
| 509 |
+
|
| 510 |
+
# Step 3: load models
|
| 511 |
+
models = load_models()
|
| 512 |
+
|
| 513 |
+
# Step 4: predict
|
| 514 |
+
preds = predict_ensemble(X, models)
|
| 515 |
+
|
| 516 |
+
# Step 5: build result
|
| 517 |
+
last_close = float(df_ready["close"].iloc[-1])
|
| 518 |
+
ensemble = preds.get("weighted", preds.get("ensemble_mean", last_close))
|
| 519 |
+
|
| 520 |
+
out = {
|
| 521 |
+
"instrument": inst_id,
|
| 522 |
+
"bar": bar,
|
| 523 |
+
"fetched_candles": int(limit),
|
| 524 |
+
"last_ts": df_ready["ts"].iloc[-1].isoformat(),
|
| 525 |
+
"last_close": float(last_close),
|
| 526 |
+
"ensemble_prediction": float(ensemble),
|
| 527 |
+
"per_model": preds.get("per_model", {})
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
# Prepare images
|
| 531 |
+
img_price = plot_price_and_preds(df_ready, {"weighted": ensemble})
|
| 532 |
+
img_contrib = plot_model_contributions(out["per_model"])
|
| 533 |
+
|
| 534 |
+
return {
|
| 535 |
+
"result": out,
|
| 536 |
+
"img_price": img_price,
|
| 537 |
+
"img_contrib": img_contrib
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
# Helper to convert bytes to gradio displayable
|
| 541 |
+
def bytes_to_pil(b: bytes):
|
| 542 |
+
from PIL import Image
|
| 543 |
+
buf = io.BytesIO(b)
|
| 544 |
+
return Image.open(buf)
|
| 545 |
+
|
| 546 |
+
# -------------------------
|
| 547 |
+
# Gradio layout (blank template)
|
| 548 |
+
# -------------------------
|
| 549 |
+
def build_gradio_app():
|
| 550 |
+
title = "BTC/USDT Price Prediction (OKX REST) — Ensemble Demo"
|
| 551 |
+
description = "Fetch recent candles from OKX and predict next close using an ensemble (demo)."
|
| 552 |
+
with gr.Blocks(title=title) as demo:
|
| 553 |
+
gr.Markdown(f"## {title}")
|
| 554 |
+
gr.Markdown(description)
|
| 555 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
with gr.Row():
|
| 557 |
+
with gr.Column(scale=1):
|
| 558 |
+
inst_in = gr.Textbox(label="Instrument", value=DEFAULT_INSTRUMENT)
|
| 559 |
+
bar_in = gr.Dropdown(label="Candle bar", choices=["1m","3m","5m","15m","1H","4H","1D"], value=DEFAULT_BAR)
|
| 560 |
+
limit_in = gr.Slider(label="Limit (number of candles)", minimum=50, maximum=1000, step=50, value=DEFAULT_LIMIT)
|
| 561 |
+
run_btn = gr.Button("Run Inference")
|
| 562 |
+
refresh_btn = gr.Button("Refresh Models (clear cache)")
|
| 563 |
+
info_out = gr.Textbox(label="Info / JSON result", interactive=False)
|
| 564 |
+
with gr.Column(scale=2):
|
| 565 |
+
price_img = gr.Image(label="Price & Prediction", type="pil")
|
| 566 |
+
contrib_img = gr.Image(label="Per-model predictions", type="pil")
|
| 567 |
+
|
| 568 |
+
# Callbacks
|
| 569 |
+
def on_run(inst, bar, limit):
|
| 570 |
+
res = inference_pipeline(inst, bar, limit)
|
| 571 |
+
if "error" in res:
|
| 572 |
+
return "", gr.update(value=None), gr.update(value=None), json.dumps({"error": res["error"]}, indent=2)
|
| 573 |
+
out = res["result"]
|
| 574 |
+
price_pil = bytes_to_pil(res["img_price"])
|
| 575 |
+
contrib_pil = bytes_to_pil(res["img_contrib"])
|
| 576 |
+
info_json = json.dumps(out, indent=2, default=str)
|
| 577 |
+
return price_pil, contrib_pil, info_json
|
| 578 |
+
|
| 579 |
+
def on_refresh():
|
| 580 |
+
# clear model cache and reload
|
| 581 |
+
with _MODEL_LOCK:
|
| 582 |
+
_MODELS.clear()
|
| 583 |
+
return "Model cache cleared."
|
| 584 |
+
|
| 585 |
+
run_btn.click(on_run, inputs=[inst_in, bar_in, limit_in], outputs=[price_img, contrib_img, info_out])
|
| 586 |
+
refresh_btn.click(on_refresh, inputs=None, outputs=info_out)
|
| 587 |
+
|
| 588 |
+
gr.Markdown("Notes: This demo uses public OKX market endpoints. For production, validate rate limits and handle API keys for private data. Ensemble models here are demo-friendly; train and persist stronger models for real use.")
|
| 589 |
+
return demo
|
| 590 |
+
|
| 591 |
+
# -------------------------
|
| 592 |
+
# If run as app
|
| 593 |
+
# -------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
if __name__ == "__main__":
|
| 595 |
+
app = build_gradio_app()
|
| 596 |
+
app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", ave)
|
|
|
|
|
|
|
|
|
|
|
|