GoshawkVortexAI commited on
Commit
e973ba4
·
verified ·
1 Parent(s): 53c5b4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +576 -1145
app.py CHANGED
@@ -1,1165 +1,596 @@
1
-
2
- # app.py - PARÇA 1/5
3
- # ========================================================================
4
- # İmport ve OKX REST API Client
5
- # ========================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  import os
 
 
 
 
 
 
 
8
  import numpy as np
9
  import pandas as pd
10
- import gradio as gr
11
  import requests
12
- import json
13
- from datetime import datetime, timedelta
14
- import warnings
15
- warnings.filterwarnings('ignore')
16
-
17
- # Machine Learning
18
- from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, AdaBoostRegressor
19
- from sklearn.linear_model import Ridge, Lasso, ElasticNet
20
- from sklearn.svm import SVR
21
- from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
22
- from sklearn.model_selection import train_test_split
23
- from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
24
 
25
  # Visualization
26
- import plotly.graph_objects as go
27
- from plotly.subplots import make_subplots
28
-
29
-
30
- # ================================
31
- # OKX REST API CLIENT
32
- # ================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- class OKXClient:
35
- """OKX REST API Client for BTC/USDT data"""
36
-
37
- def __init__(self):
38
- self.base_url = "https://www.okx.com"
39
- self.session = requests.Session()
40
- self.session.headers.update({
41
- 'Content-Type': 'application/json',
42
- 'User-Agent': 'Mozilla/5.0'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  })
44
-
45
- def get_candlesticks(self, instId='BTC-USDT', bar='1H', limit=300):
46
- """
47
- Get candlestick data from OKX
48
-
49
- Args:
50
- instId: Instrument ID (default: BTC-USDT)
51
- bar: Bar size (1m, 5m, 15m, 1H, 4H, 1D)
52
- limit: Number of candles (max 300)
53
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  try:
55
- endpoint = f"{self.base_url}/api/v5/market/candles"
56
- params = {
57
- 'instId': instId,
58
- 'bar': bar,
59
- 'limit': str(limit)
60
- }
61
-
62
- response = self.session.get(endpoint, params=params, timeout=10)
63
-
64
- if response.status_code == 200:
65
- data = response.json()
66
-
67
- if data['code'] == '0':
68
- candles = data['data']
69
-
70
- df = pd.DataFrame(candles, columns=[
71
- 'timestamp', 'open', 'high', 'low', 'close',
72
- 'volume', 'volCcy', 'volCcyQuote', 'confirm'
73
- ])
74
-
75
- df['timestamp'] = pd.to_datetime(df['timestamp'].astype(float), unit='ms')
76
-
77
- for col in ['open', 'high', 'low', 'close', 'volume']:
78
- df[col] = df[col].astype(float)
79
-
80
- df = df.sort_values('timestamp').reset_index(drop=True)
81
-
82
- return df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
83
- else:
84
- print(f"API Error: {data['msg']}")
85
- return None
86
  else:
87
- print(f"HTTP Error: {response.status_code}")
88
- return None
89
-
90
- except Exception as e:
91
- print(f"Error fetching data: {str(e)}")
92
- return None
93
-
94
- def get_ticker(self, instId='BTC-USDT'):
95
- """Get current ticker data"""
96
- try:
97
- endpoint = f"{self.base_url}/api/v5/market/ticker"
98
- params = {'instId': instId}
99
-
100
- response = self.session.get(endpoint, params=params, timeout=10)
101
-
102
- if response.status_code == 200:
103
- data = response.json()
104
- if data['code'] == '0' and len(data['data']) > 0:
105
- ticker = data['data'][0]
106
- return {
107
- 'last': float(ticker['last']),
108
- 'bid': float(ticker['bidPx']),
109
- 'ask': float(ticker['askPx']),
110
- 'volume_24h': float(ticker['vol24h']),
111
- 'timestamp': datetime.now()
112
- }
113
- return None
114
-
115
- except Exception as e:
116
- print(f"Error fetching ticker: {str(e)}")
117
- return None
118
- # app.py - PARÇA 2/5
119
- # ========================================================================
120
- # Feature Engineering Module
121
- # ========================================================================
122
-
123
- class FeatureEngineer:
124
- """Advanced feature engineering for crypto price prediction"""
125
-
126
- @staticmethod
127
- def add_technical_indicators(df):
128
- """Add comprehensive technical indicators"""
129
- df = df.copy()
130
-
131
- # Basic features
132
- df['returns'] = df['close'].pct_change()
133
- df['log_returns'] = np.log(df['close'] / df['close'].shift(1))
134
- df['price_range'] = df['high'] - df['low']
135
- df['price_change'] = df['close'] - df['open']
136
- df['body'] = abs(df['close'] - df['open'])
137
- df['upper_shadow'] = df['high'] - df[['open', 'close']].max(axis=1)
138
- df['lower_shadow'] = df[['open', 'close']].min(axis=1) - df['low']
139
-
140
- # Moving Averages
141
- for window in [5, 10, 20, 50, 100]:
142
- df[f'sma_{window}'] = df['close'].rolling(window=window).mean()
143
- df[f'ema_{window}'] = df['close'].ewm(span=window, adjust=False).mean()
144
- df[f'price_to_sma_{window}'] = df['close'] / df[f'sma_{window}']
145
-
146
- # MACD
147
- exp1 = df['close'].ewm(span=12, adjust=False).mean()
148
- exp2 = df['close'].ewm(span=26, adjust=False).mean()
149
- df['macd'] = exp1 - exp2
150
- df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
151
- df['macd_diff'] = df['macd'] - df['macd_signal']
152
-
153
- # RSI
154
- for period in [14, 28]:
155
- delta = df['close'].diff()
156
- gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
157
- loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
158
- rs = gain / loss
159
- df[f'rsi_{period}'] = 100 - (100 / (1 + rs))
160
-
161
- # Bollinger Bands
162
- for window in [20, 50]:
163
- rolling_mean = df['close'].rolling(window=window).mean()
164
- rolling_std = df['close'].rolling(window=window).std()
165
- df[f'bb_upper_{window}'] = rolling_mean + (rolling_std * 2)
166
- df[f'bb_lower_{window}'] = rolling_mean - (rolling_std * 2)
167
- df[f'bb_width_{window}'] = df[f'bb_upper_{window}'] - df[f'bb_lower_{window}']
168
- df[f'bb_position_{window}'] = (df['close'] - df[f'bb_lower_{window}']) / df[f'bb_width_{window}']
169
-
170
- # ATR
171
- high_low = df['high'] - df['low']
172
- high_close = np.abs(df['high'] - df['close'].shift())
173
- low_close = np.abs(df['low'] - df['close'].shift())
174
- ranges = pd.concat([high_low, high_close, low_close], axis=1)
175
- true_range = np.max(ranges, axis=1)
176
- df['atr_14'] = true_range.rolling(14).mean()
177
-
178
- # Stochastic Oscillator
179
- low_14 = df['low'].rolling(window=14).min()
180
- high_14 = df['high'].rolling(window=14).max()
181
- df['stoch_k'] = 100 * ((df['close'] - low_14) / (high_14 - low_14))
182
- df['stoch_d'] = df['stoch_k'].rolling(window=3).mean()
183
-
184
- # Volume features
185
- df['volume_sma_20'] = df['volume'].rolling(window=20).mean()
186
- df['volume_ratio'] = df['volume'] / df['volume_sma_20']
187
- df['volume_price_trend'] = df['volume'] * df['returns']
188
-
189
- # OBV
190
- df['obv'] = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
191
-
192
- # Momentum
193
- for period in [5, 10, 20]:
194
- df[f'momentum_{period}'] = df['close'].diff(period)
195
- df[f'roc_{period}'] = df['close'].pct_change(period)
196
-
197
- # Volatility
198
- for window in [5, 10, 20, 30]:
199
- df[f'volatility_{window}'] = df['returns'].rolling(window=window).std()
200
-
201
- # Statistical features
202
- for window in [10, 20]:
203
- df[f'skew_{window}'] = df['returns'].rolling(window=window).skew()
204
- df[f'kurt_{window}'] = df['returns'].rolling(window=window).kurt()
205
-
206
- return df
207
-
208
- @staticmethod
209
- def add_lag_features(df, n_lags=5):
210
- """Add lagged features"""
211
- df = df.copy()
212
-
213
- for lag in range(1, n_lags + 1):
214
- df[f'close_lag_{lag}'] = df['close'].shift(lag)
215
- df[f'volume_lag_{lag}'] = df['volume'].shift(lag)
216
- df[f'returns_lag_{lag}'] = df['returns'].shift(lag)
217
-
218
- return df
219
-
220
- @staticmethod
221
- def add_time_features(df):
222
- """Add time-based features"""
223
- df = df.copy()
224
-
225
- df['hour'] = df['timestamp'].dt.hour
226
- df['day_of_week'] = df['timestamp'].dt.dayofweek
227
- df['day_of_month'] = df['timestamp'].dt.day
228
- df['month'] = df['timestamp'].dt.month
229
-
230
- # Cyclical encoding
231
- df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
232
- df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
233
- df['day_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
234
- df['day_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
235
-
236
- return df
237
-
238
-
239
- # app.py - PARÇA 3/5
240
- # ========================================================================
241
- # Ensemble Model
242
- # ========================================================================
243
-
244
- class EnsemblePredictor:
245
- """Advanced Ensemble Model for BTC/USDT prediction"""
246
-
247
- def __init__(self):
248
- self.models = {}
249
- self.weights = {}
250
- self.scalers = {}
251
- self.feature_columns = None
252
- self.is_trained = False
253
-
254
- def initialize_models(self):
255
- """Initialize all models"""
256
-
257
- self.models['random_forest'] = RandomForestRegressor(
258
- n_estimators=200,
259
- max_depth=15,
260
- min_samples_split=5,
261
- random_state=42,
262
- n_jobs=-1
263
- )
264
-
265
- self.models['gradient_boosting'] = GradientBoostingRegressor(
266
- n_estimators=200,
267
- learning_rate=0.05,
268
- max_depth=5,
269
- random_state=42
270
- )
271
-
272
- self.models['adaboost'] = AdaBoostRegressor(
273
- n_estimators=100,
274
- learning_rate=0.1,
275
- random_state=42
276
- )
277
-
278
- self.models['ridge'] = Ridge(alpha=1.0)
279
- self.models['lasso'] = Lasso(alpha=0.1, max_iter=2000)
280
- self.models['elastic_net'] = ElasticNet(alpha=0.1, l1_ratio=0.5, max_iter=2000)
281
-
282
- for model_name in self.models.keys():
283
- self.weights[model_name] = 1.0 / len(self.models)
284
-
285
- def prepare_data(self, df, target_col='close'):
286
- """Prepare data for training"""
287
-
288
- exclude_cols = ['timestamp', target_col]
289
- feature_cols = [col for col in df.columns if col not in exclude_cols]
290
-
291
- df = df.replace([np.inf, -np.inf], np.nan)
292
- df = df.fillna(method='ffill').fillna(method='bfill').fillna(0)
293
-
294
- X = df[feature_cols].values
295
- y = df[target_col].values
296
-
297
- self.feature_columns = feature_cols
298
-
299
- return X, y
300
-
301
- def train(self, X_train, y_train, X_val, y_val):
302
- """Train ensemble model"""
303
-
304
- self.initialize_models()
305
-
306
- self.scalers['standard'] = StandardScaler()
307
- self.scalers['robust'] = RobustScaler()
308
-
309
- X_train_standard = self.scalers['standard'].fit_transform(X_train)
310
- X_val_standard = self.scalers['standard'].transform(X_val)
311
-
312
- X_train_robust = self.scalers['robust'].fit_transform(X_train)
313
- X_val_robust = self.scalers['robust'].transform(X_val)
314
-
315
- predictions_val = {}
316
-
317
- print("Training Random Forest...")
318
- self.models['random_forest'].fit(X_train_standard, y_train)
319
- predictions_val['random_forest'] = self.models['random_forest'].predict(X_val_standard)
320
-
321
- print("Training Gradient Boosting...")
322
- self.models['gradient_boosting'].fit(X_train_standard, y_train)
323
- predictions_val['gradient_boosting'] = self.models['gradient_boosting'].predict(X_val_standard)
324
-
325
- print("Training AdaBoost...")
326
- self.models['adaboost'].fit(X_train_standard, y_train)
327
- predictions_val['adaboost'] = self.models['adaboost'].predict(X_val_standard)
328
-
329
- print("Training Ridge...")
330
- self.models['ridge'].fit(X_train_robust, y_train)
331
- predictions_val['ridge'] = self.models['ridge'].predict(X_val_robust)
332
-
333
- print("Training Lasso...")
334
- self.models['lasso'].fit(X_train_robust, y_train)
335
- predictions_val['lasso'] = self.models['lasso'].predict(X_val_robust)
336
-
337
- print("Training Elastic Net...")
338
- self.models['elastic_net'].fit(X_train_robust, y_train)
339
- predictions_val['elastic_net'] = self.models['elastic_net'].predict(X_val_robust)
340
-
341
- self.optimize_weights(predictions_val, y_val)
342
- self.is_trained = True
343
-
344
- return predictions_val
345
-
346
- def optimize_weights(self, predictions_val, y_val):
347
- """Optimize ensemble weights"""
348
-
349
- performances = {}
350
- for model_name, preds in predictions_val.items():
351
- mse = mean_squared_error(y_val, preds)
352
- performances[model_name] = 1.0 / (mse + 1e-10)
353
-
354
- total_performance = sum(performances.values())
355
- for model_name in performances:
356
- self.weights[model_name] = performances[model_name] / total_performance
357
-
358
- print("\n=== Optimized Weights ===")
359
- for model_name, weight in self.weights.items():
360
- print(f"{model_name}: {weight:.4f}")
361
-
362
- def predict(self, X):
363
- """Make ensemble predictions"""
364
-
365
- if not self.is_trained:
366
- raise ValueError("Model must be trained first")
367
-
368
- X_standard = self.scalers['standard'].transform(X)
369
- X_robust = self.scalers['robust'].transform(X)
370
-
371
- predictions = {}
372
- predictions['random_forest'] = self.models['random_forest'].predict(X_standard)
373
- predictions['gradient_boosting'] = self.models['gradient_boosting'].predict(X_standard)
374
- predictions['adaboost'] = self.models['adaboost'].predict(X_standard)
375
- predictions['ridge'] = self.models['ridge'].predict(X_robust)
376
- predictions['lasso'] = self.models['lasso'].predict(X_robust)
377
- predictions['elastic_net'] = self.models['elastic_net'].predict(X_robust)
378
-
379
- ensemble_pred = np.zeros(len(X))
380
- for model_name, preds in predictions.items():
381
- ensemble_pred += self.weights[model_name] * preds
382
-
383
- return ensemble_pred, predictions
384
-
385
- def evaluate(self, X_test, y_test):
386
- """Evaluate model"""
387
-
388
- ensemble_pred, individual_preds = self.predict(X_test)
389
-
390
- mse = mean_squared_error(y_test, ensemble_pred)
391
- mae = mean_absolute_error(y_test, ensemble_pred)
392
- rmse = np.sqrt(mse)
393
- r2 = r2_score(y_test, ensemble_pred)
394
- mape = np.mean(np.abs((y_test - ensemble_pred) / y_test)) * 100
395
-
396
- metrics = {
397
- 'ensemble': {
398
- 'MSE': mse,
399
- 'RMSE': rmse,
400
- 'MAE': mae,
401
- 'R2': r2,
402
- 'MAPE': mape
403
- }
404
- }
405
-
406
- for model_name, preds in individual_preds.items():
407
- mse_ind = mean_squared_error(y_test, preds)
408
- rmse_ind = np.sqrt(mse_ind)
409
- mae_ind = mean_absolute_error(y_test, preds)
410
- r2_ind = r2_score(y_test, preds)
411
-
412
- metrics[model_name] = {
413
- 'MSE': mse_ind,
414
- 'RMSE': rmse_ind,
415
- 'MAE': mae_ind,
416
- 'R2': r2_ind
417
- }
418
-
419
- return metrics, ensemble_pred
420
-
421
- # app.py - PARÇA 4/5
422
- # ========================================================================
423
- # Visualization ve Main Pipeline
424
- # ========================================================================
425
-
426
- class Visualizer:
427
- """Visualization utilities"""
428
-
429
- @staticmethod
430
- def plot_predictions(y_true, y_pred, timestamps=None, title="BTC/USDT Predictions"):
431
- """Plot actual vs predicted"""
432
-
433
- fig = go.Figure()
434
-
435
- if timestamps is None:
436
- timestamps = list(range(len(y_true)))
437
-
438
- fig.add_trace(go.Scatter(
439
- x=timestamps,
440
- y=y_true,
441
- mode='lines',
442
- name='Actual',
443
- line=dict(color='cyan', width=2)
444
- ))
445
-
446
- fig.add_trace(go.Scatter(
447
- x=timestamps,
448
- y=y_pred,
449
- mode='lines',
450
- name='Predicted',
451
- line=dict(color='magenta', width=2, dash='dash')
452
- ))
453
-
454
- fig.update_layout(
455
- title=title,
456
- xaxis_title='Time',
457
- yaxis_title='Price (USDT)',
458
- template='plotly_dark',
459
- hovermode='x unified',
460
- height=500
461
- )
462
-
463
- return fig
464
-
465
- @staticmethod
466
- def plot_candlestick(df, n_candles=100):
467
- """Plot candlestick chart"""
468
-
469
- df = df.tail(n_candles).copy()
470
-
471
- fig = make_subplots(
472
- rows=2, cols=1,
473
- shared_xaxes=True,
474
- vertical_spacing=0.05,
475
- subplot_titles=('Price', 'Volume'),
476
- row_heights=[0.7, 0.3]
477
- )
478
-
479
- fig.add_trace(
480
- go.Candlestick(
481
- x=df['timestamp'],
482
- open=df['open'],
483
- high=df['high'],
484
- low=df['low'],
485
- close=df['close'],
486
- name='OHLC'
487
- ),
488
- row=1, col=1
489
- )
490
-
491
- colors = ['red' if row['close'] < row['open'] else 'green'
492
- for idx, row in df.iterrows()]
493
-
494
- fig.add_trace(
495
- go.Bar(
496
- x=df['timestamp'],
497
- y=df['volume'],
498
- name='Volume',
499
- marker_color=colors
500
- ),
501
- row=2, col=1
502
- )
503
-
504
- fig.update_layout(
505
- title='BTC/USDT Chart',
506
- template='plotly_dark',
507
- xaxis_rangeslider_visible=False,
508
- height=700
509
- )
510
-
511
- return fig
512
-
513
- @staticmethod
514
- def plot_feature_importance(model, feature_names, top_n=20):
515
- """Plot feature importance"""
516
-
517
- if hasattr(model, 'feature_importances_'):
518
- importances = model.feature_importances_
519
- indices = np.argsort(importances)[-top_n:]
520
-
521
- fig = go.Figure(go.Bar(
522
- x=importances[indices],
523
- y=[feature_names[i] for i in indices],
524
- orientation='h',
525
- marker_color='lightblue'
526
- ))
527
-
528
- fig.update_layout(
529
- title=f'Top {top_n} Feature Importances',
530
- xaxis_title='Importance',
531
- yaxis_title='Features',
532
- template='plotly_dark',
533
- height=600
534
- )
535
-
536
- return fig
537
-
538
- return None
539
-
540
-
541
- # ================================
542
- # MAIN PIPELINE
543
- # ================================
544
-
545
- class BTCPredictionPipeline:
546
- """Main prediction pipeline"""
547
-
548
- def __init__(self):
549
- self.okx_client = OKXClient()
550
- self.feature_engineer = FeatureEngineer()
551
- self.ensemble_model = EnsemblePredictor()
552
- self.visualizer = Visualizer()
553
- self.raw_data = None
554
- self.processed_data = None
555
-
556
- def fetch_data(self, bar='1H', limit=300):
557
- """Fetch data from OKX"""
558
-
559
- print(f"Fetching {limit} candles from OKX...")
560
- df = self.okx_client.get_candlesticks(instId='BTC-USDT', bar=bar, limit=limit)
561
-
562
- if df is not None:
563
- self.raw_data = df
564
- print(f"Fetched {len(df)} candles")
565
- return df
566
  else:
567
- print("Failed to fetch data")
568
- return None
569
-
570
- def prepare_features(self):
571
- """Prepare features"""
572
-
573
- if self.raw_data is None:
574
- raise ValueError("No data available")
575
-
576
- print("Engineering features...")
577
- df = self.feature_engineer.add_technical_indicators(self.raw_data)
578
- df = self.feature_engineer.add_lag_features(df, n_lags=5)
579
- df = self.feature_engineer.add_time_features(df)
580
-
581
- df = df.dropna()
582
- self.processed_data = df
583
-
584
- print(f"Features: {len(df.columns)}, Samples: {len(df)}")
585
-
586
- return df
587
-
588
- def train_model(self, test_size=0.2, val_size=0.1):
589
- """Train ensemble model"""
590
-
591
- if self.processed_data is None:
592
- raise ValueError("Features not prepared")
593
-
594
- X, y = self.ensemble_model.prepare_data(self.processed_data)
595
-
596
- X_temp, X_test, y_temp, y_test = train_test_split(
597
- X, y, test_size=test_size, shuffle=False
598
- )
599
-
600
- X_train, X_val, y_train, y_val = train_test_split(
601
- X_temp, y_temp, test_size=val_size/(1-test_size), shuffle=False
602
- )
603
-
604
- print(f"\nTrain: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")
605
-
606
- print("\nTraining ensemble...")
607
- self.ensemble_model.train(X_train, y_train, X_val, y_val)
608
-
609
- print("\nEvaluating...")
610
- metrics, predictions = self.ensemble_model.evaluate(X_test, y_test)
611
-
612
- print("\n=== Ensemble Performance ===")
613
- for metric_name, value in metrics['ensemble'].items():
614
- print(f"{metric_name}: {value:.4f}")
615
-
616
- return metrics, predictions, y_test
617
-
618
- def predict_future(self, n_steps=24):
619
- """Predict future prices"""
620
-
621
- if not self.ensemble_model.is_trained:
622
- raise ValueError("Model not trained")
623
-
624
- last_data = self.processed_data.iloc[-1:].copy()
625
- X_last, _ = self.ensemble_model.prepare_data(last_data)
626
-
627
- pred, _ = self.ensemble_model.predict(X_last)
628
-
629
- last_time = self.processed_data['timestamp'].iloc[-1]
630
- future_times = [last_time + timedelta(hours=i+1) for i in range(n_steps)]
631
-
632
- predictions = [pred[0] * (1 + np.random.normal(0, 0.005)) for _ in range(n_steps)]
633
-
634
- return future_times, predictions
635
-
636
- # app.py - PARÇA 5/5
637
- # ========================================================================
638
- # Gradio Interface
639
- # ========================================================================
640
-
641
- # Global pipeline instance
642
- pipeline = BTCPredictionPipeline()
643
- training_complete = False
644
-
645
-
646
- def fetch_data_ui(bar_size, num_candles):
647
- """Fetch data interface"""
648
- try:
649
- df = pipeline.fetch_data(bar=bar_size, limit=int(num_candles))
650
-
651
- if df is not None:
652
- info = f"✅ Successfully fetched {len(df)} candles\n\n"
653
- info += f"Time range: {df['timestamp'].min()} to {df['timestamp'].max()}\n"
654
- info += f"Price range: ${df['close'].min():.2f} - ${df['close'].max():.2f}\n"
655
- info += f"Current price: ${df['close'].iloc[-1]:.2f}"
656
-
657
- fig = pipeline.visualizer.plot_candlestick(df)
658
-
659
- summary = df.tail(10)[['timestamp', 'open', 'high', 'low', 'close', 'volume']].copy()
660
- summary['timestamp'] = summary['timestamp'].dt.strftime('%Y-%m-%d %H:%M')
661
-
662
- return info, fig, summary
663
  else:
664
- return "❌ Failed to fetch data", None, None
665
-
666
- except Exception as e:
667
- return f"❌ Error: {str(e)}", None, None
668
-
669
-
670
- def train_model_ui(test_size, val_size):
671
- """Train model interface"""
672
- global training_complete
673
-
674
- try:
675
- pipeline.prepare_features()
676
-
677
- metrics, predictions, y_test = pipeline.train_model(
678
- test_size=test_size,
679
- val_size=val_size
680
- )
681
-
682
- training_complete = True
683
-
684
- metrics_text = "=== ENSEMBLE MODEL PERFORMANCE ===\n\n"
685
- for metric_name, value in metrics['ensemble'].items():
686
- metrics_text += f"{metric_name}: {value:.4f}\n"
687
-
688
- metrics_text += "\n\n=== INDIVIDUAL MODELS ===\n\n"
689
- for model_name, model_metrics in metrics.items():
690
- if model_name != 'ensemble':
691
- metrics_text += f"\n{model_name.upper()}:\n"
692
- for metric_name, value in model_metrics.items():
693
- metrics_text += f" {metric_name}: {value:.4f}\n"
694
-
695
- test_idx = len(pipeline.processed_data) - len(y_test)
696
- test_timestamps = pipeline.processed_data['timestamp'].iloc[test_idx:].values
697
-
698
- fig = pipeline.visualizer.plot_predictions(
699
- y_test,
700
- predictions,
701
- test_timestamps,
702
- "Test Set Predictions"
703
- )
704
-
705
- return metrics_text, fig, "✅ Training complete!"
706
-
707
- except Exception as e:
708
- return f"❌ Error: {str(e)}", None, "Training failed"
709
-
710
-
711
- def predict_future_ui(n_hours):
712
- """Predict future interface"""
713
-
714
- if not training_complete:
715
- return "⚠️ Please train model first", None, None
716
-
717
- try:
718
- future_times, predictions = pipeline.predict_future(n_steps=int(n_hours))
719
-
720
- pred_df = pd.DataFrame({
721
- 'Timestamp': [t.strftime('%Y-%m-%d %H:%M') for t in future_times],
722
- 'Predicted Price (USDT)': [f"${p:,.2f}" for p in predictions]
723
- })
724
-
725
- fig = go.Figure()
726
- fig.add_trace(go.Scatter(
727
- x=future_times,
728
- y=predictions,
729
- mode='lines+markers',
730
- name='Predicted Price',
731
- line=dict(color='green', width=3),
732
- marker=dict(size=8)
733
- ))
734
-
735
- fig.update_layout(
736
- title=f'BTC/USDT Price Prediction - Next {n_hours} Hours',
737
- xaxis_title='Time',
738
- yaxis_title='Price (USDT)',
739
- template='plotly_dark',
740
- hovermode='x unified',
741
- height=500
742
- )
743
-
744
- return pred_df, fig, f"✅ Predicted next {n_hours} hours"
745
-
746
- except Exception as e:
747
- return None, None, f"❌ Error: {str(e)}"
748
-
749
-
750
- def get_current_price_ui():
751
- """Get current price from OKX"""
752
- try:
753
- ticker = pipeline.okx_client.get_ticker('BTC-USDT')
754
-
755
- if ticker:
756
- info = f"🔴 LIVE BTC/USDT PRICE\n\n"
757
- info += f"Last Price: ${ticker['last']:,.2f}\n"
758
- info += f"Bid: ${ticker['bid']:,.2f}\n"
759
- info += f"Ask: ${ticker['ask']:,.2f}\n"
760
- info += f"24h Volume: {ticker['volume_24h']:,.2f} BTC\n"
761
- info += f"Updated: {ticker['timestamp'].strftime('%Y-%m-%d %H:%M:%S')}"
762
-
763
- return info
764
  else:
765
- return " Failed to fetch current price"
766
-
767
- except Exception as e:
768
- return f"❌ Error: {str(e)}"
769
-
770
-
771
- def show_feature_importance_ui():
772
- """Show feature importance"""
773
-
774
- if not training_complete:
775
- return None, "⚠️ Please train model first"
776
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777
  try:
778
- model = pipeline.ensemble_model.models['random_forest']
779
- feature_names = pipeline.ensemble_model.feature_columns
780
-
781
- fig = pipeline.visualizer.plot_feature_importance(
782
- model,
783
- feature_names,
784
- top_n=30
785
- )
786
-
787
- importances = model.feature_importances_
788
- indices = np.argsort(importances)[-30:]
789
-
790
- importance_text = "=== TOP 30 FEATURES ===\n\n"
791
- for i, idx in enumerate(reversed(indices), 1):
792
- importance_text += f"{i}. {feature_names[idx]}: {importances[idx]:.6f}\n"
793
-
794
- return fig, importance_text
795
-
796
- except Exception as e:
797
- return None, f"❌ Error: {str(e)}"
798
-
799
-
800
- def analyze_market_ui():
801
- """Market analysis interface"""
802
-
803
- if pipeline.processed_data is None:
804
- return None, "⚠️ Please load data first"
805
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
  try:
807
- df = pipeline.processed_data.tail(200)
808
-
809
- fig = make_subplots(
810
- rows=4, cols=1,
811
- shared_xaxes=True,
812
- vertical_spacing=0.05,
813
- subplot_titles=('Price & MA', 'RSI', 'MACD', 'Volume'),
814
- row_heights=[0.4, 0.2, 0.2, 0.2]
815
- )
816
-
817
- # Price and Moving Averages
818
- fig.add_trace(
819
- go.Scatter(x=df['timestamp'], y=df['close'],
820
- name='Close', line=dict(color='white', width=2)),
821
- row=1, col=1
822
- )
823
- fig.add_trace(
824
- go.Scatter(x=df['timestamp'], y=df['sma_20'],
825
- name='SMA 20', line=dict(color='orange', width=1)),
826
- row=1, col=1
827
- )
828
- fig.add_trace(
829
- go.Scatter(x=df['timestamp'], y=df['sma_50'],
830
- name='SMA 50', line=dict(color='blue', width=1)),
831
- row=1, col=1
832
- )
833
-
834
- # RSI
835
- fig.add_trace(
836
- go.Scatter(x=df['timestamp'], y=df['rsi_14'],
837
- name='RSI', line=dict(color='purple', width=2)),
838
- row=2, col=1
839
- )
840
- fig.add_hline(y=70, line_dash="dash", line_color="red", row=2, col=1)
841
- fig.add_hline(y=30, line_dash="dash", line_color="green", row=2, col=1)
842
-
843
- # MACD
844
- fig.add_trace(
845
- go.Scatter(x=df['timestamp'], y=df['macd'],
846
- name='MACD', line=dict(color='blue', width=1)),
847
- row=3, col=1
848
- )
849
- fig.add_trace(
850
- go.Scatter(x=df['timestamp'], y=df['macd_signal'],
851
- name='Signal', line=dict(color='red', width=1)),
852
- row=3, col=1
853
- )
854
- fig.add_trace(
855
- go.Bar(x=df['timestamp'], y=df['macd_diff'],
856
- name='Histogram', marker_color='gray'),
857
- row=3, col=1
858
- )
859
-
860
- # Volume
861
- colors = ['red' if df.iloc[i]['close'] < df.iloc[i]['open'] else 'green'
862
- for i in range(len(df))]
863
- fig.add_trace(
864
- go.Bar(x=df['timestamp'], y=df['volume'],
865
- name='Volume', marker_color=colors),
866
- row=4, col=1
867
- )
868
-
869
- fig.update_layout(
870
- title='Market Technical Analysis',
871
- template='plotly_dark',
872
- height=900,
873
- showlegend=True,
874
- hovermode='x unified'
875
- )
876
-
877
- # Market summary
878
- current_price = df['close'].iloc[-1]
879
- rsi = df['rsi_14'].iloc[-1]
880
- macd_signal = "Bullish" if df['macd_diff'].iloc[-1] > 0 else "Bearish"
881
-
882
- summary = f"=== MARKET ANALYSIS ===\n\n"
883
- summary += f"Current Price: ${current_price:,.2f}\n"
884
- summary += f"RSI (14): {rsi:.2f} - "
885
-
886
- if rsi > 70:
887
- summary += "Overbought ⚠️\n"
888
- elif rsi < 30:
889
- summary += "Oversold ⚠️\n"
890
- else:
891
- summary += "Neutral ✅\n"
892
-
893
- summary += f"MACD Signal: {macd_signal}\n"
894
- summary += f"SMA 20: ${df['sma_20'].iloc[-1]:,.2f}\n"
895
- summary += f"SMA 50: ${df['sma_50'].iloc[-1]:,.2f}\n"
896
- summary += f"24h Change: {((current_price / df['close'].iloc[-24] - 1) * 100):.2f}%\n"
897
- summary += f"Volatility (20): {df['volatility_20'].iloc[-1]:.6f}\n"
898
-
899
- return fig, summary
900
-
901
  except Exception as e:
902
- return None, f" Error: {str(e)}"
903
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
 
905
- # ================================
906
- # GRADIO APP
907
- # ================================
908
-
909
- with gr.Blocks(title="OKX BTC/USDT Ensemble Predictor", theme=gr.themes.Soft()) as demo:
910
-
911
- gr.Markdown("""
912
- # 🚀 OKX BTC/USDT Ensemble Price Predictor
913
-
914
- **Advanced Machine Learning System for Bitcoin Price Prediction**
915
-
916
- This application uses an ensemble of 6 machine learning models:
917
- - Random Forest
918
- - Gradient Boosting
919
- - AdaBoost
920
- - Ridge Regression
921
- - Lasso Regression
922
- - Elastic Net
923
-
924
- **Features:**
925
- - Real-time data from OKX API
926
- - 100+ technical indicators
927
- - Weighted ensemble predictions
928
- - Advanced visualization
929
- """)
930
-
931
- # TAB 1: DATA FETCHING
932
- with gr.Tab("📊 Data Fetching"):
933
- gr.Markdown("### Fetch Historical BTC/USDT Data from OKX")
934
-
935
- with gr.Row():
936
- bar_size = gr.Dropdown(
937
- choices=['1m', '5m', '15m', '30m', '1H', '2H', '4H', '1D'],
938
- value='1H',
939
- label="Timeframe"
940
- )
941
- num_candles = gr.Slider(
942
- minimum=100,
943
- maximum=300,
944
- value=300,
945
- step=10,
946
- label="Number of Candles"
947
- )
948
-
949
- fetch_btn = gr.Button("🔄 Fetch Data", variant="primary", size="lg")
950
-
951
  with gr.Row():
952
- data_info = gr.Textbox(label="Data Info", lines=6)
953
-
954
- data_chart = gr.Plot(label="Price Chart")
955
- data_table = gr.Dataframe(label="Latest Data (Last 10 Candles)")
956
-
957
- fetch_btn.click(
958
- fn=fetch_data_ui,
959
- inputs=[bar_size, num_candles],
960
- outputs=[data_info, data_chart, data_table]
961
- )
962
-
963
- # TAB 2: MODEL TRAINING
964
- with gr.Tab("🤖 Model Training"):
965
- gr.Markdown("### Train Ensemble Model")
966
-
967
- with gr.Row():
968
- test_size_slider = gr.Slider(
969
- minimum=0.1,
970
- maximum=0.3,
971
- value=0.2,
972
- step=0.05,
973
- label="Test Set Size"
974
- )
975
- val_size_slider = gr.Slider(
976
- minimum=0.05,
977
- maximum=0.2,
978
- value=0.1,
979
- step=0.05,
980
- label="Validation Set Size"
981
- )
982
-
983
- train_btn = gr.Button("🚀 Train Model", variant="primary", size="lg")
984
-
985
- train_status = gr.Textbox(label="Training Status", lines=1)
986
- train_metrics = gr.Textbox(label="Model Performance Metrics", lines=20)
987
- train_plot = gr.Plot(label="Predictions vs Actual")
988
-
989
- train_btn.click(
990
- fn=train_model_ui,
991
- inputs=[test_size_slider, val_size_slider],
992
- outputs=[train_metrics, train_plot, train_status]
993
- )
994
-
995
- # TAB 3: PREDICTIONS
996
- with gr.Tab("🔮 Future Predictions"):
997
- gr.Markdown("### Predict Future BTC/USDT Prices")
998
-
999
- n_hours_slider = gr.Slider(
1000
- minimum=1,
1001
- maximum=72,
1002
- value=24,
1003
- step=1,
1004
- label="Prediction Horizon (Hours)"
1005
- )
1006
-
1007
- predict_btn = gr.Button("🔮 Predict Future", variant="primary", size="lg")
1008
-
1009
- predict_status = gr.Textbox(label="Prediction Status", lines=1)
1010
- predict_table = gr.Dataframe(label="Predicted Prices")
1011
- predict_plot = gr.Plot(label="Future Price Prediction")
1012
-
1013
- predict_btn.click(
1014
- fn=predict_future_ui,
1015
- inputs=[n_hours_slider],
1016
- outputs=[predict_table, predict_plot, predict_status]
1017
- )
1018
-
1019
- # TAB 4: LIVE PRICE
1020
- with gr.Tab("💰 Live Price"):
1021
- gr.Markdown("### Real-time BTC/USDT Price from OKX")
1022
-
1023
- refresh_btn = gr.Button("🔄 Refresh Price", variant="primary", size="lg")
1024
-
1025
- live_price_info = gr.Textbox(label="Current Market Data", lines=8)
1026
-
1027
- refresh_btn.click(
1028
- fn=get_current_price_ui,
1029
- inputs=[],
1030
- outputs=[live_price_info]
1031
- )
1032
-
1033
- # TAB 5: FEATURE IMPORTANCE
1034
- with gr.Tab("📈 Feature Importance"):
1035
- gr.Markdown("### Top Features Contributing to Predictions")
1036
-
1037
- feature_btn = gr.Button("📊 Show Feature Importance", variant="primary", size="lg")
1038
-
1039
- feature_plot = gr.Plot(label="Feature Importance Chart")
1040
- feature_text = gr.Textbox(label="Top 30 Features", lines=35)
1041
-
1042
- feature_btn.click(
1043
- fn=show_feature_importance_ui,
1044
- inputs=[],
1045
- outputs=[feature_plot, feature_text]
1046
- )
1047
-
1048
- # TAB 6: MARKET ANALYSIS
1049
- with gr.Tab("📉 Market Analysis"):
1050
- gr.Markdown("### Technical Analysis Dashboard")
1051
-
1052
- analyze_btn = gr.Button("📊 Analyze Market", variant="primary", size="lg")
1053
-
1054
- analysis_plot = gr.Plot(label="Technical Indicators")
1055
- analysis_summary = gr.Textbox(label="Market Summary", lines=12)
1056
-
1057
- analyze_btn.click(
1058
- fn=analyze_market_ui,
1059
- inputs=[],
1060
- outputs=[analysis_plot, analysis_summary]
1061
- )
1062
-
1063
- # TAB 7: ABOUT
1064
- with gr.Tab("ℹ️ About"):
1065
- gr.Markdown("""
1066
- ## About This Application
1067
-
1068
- ### Ensemble Model Architecture
1069
-
1070
- This application uses a sophisticated ensemble learning approach combining:
1071
-
1072
- 1. **Random Forest** - Handles non-linear relationships and feature interactions
1073
- 2. **Gradient Boosting** - Sequential learning for complex patterns
1074
- 3. **AdaBoost** - Adaptive boosting for improved accuracy
1075
- 4. **Ridge Regression** - Linear model with L2 regularization
1076
- 5. **Lasso Regression** - Linear model with L1 regularization and feature selection
1077
- 6. **Elastic Net** - Combines L1 and L2 regularization
1078
-
1079
- ### Feature Engineering (100+ Features)
1080
-
1081
- - **Price Features**: Returns, log returns, price ranges, candlestick patterns
1082
- - **Moving Averages**: SMA and EMA (5, 10, 20, 50, 100 periods)
1083
- - **Momentum Indicators**: MACD, RSI, ROC, Stochastic Oscillator
1084
- - **Volatility Indicators**: ATR, Bollinger Bands, rolling volatility
1085
- - **Volume Indicators**: OBV, volume ratios, volume-price trends
1086
- - **Statistical Features**: Skewness, kurtosis, quantiles
1087
- - **Lag Features**: Historical prices and volumes (1-5 periods)
1088
- - **Time Features**: Hour, day, month with cyclical encoding
1089
-
1090
- ### Data Source
1091
-
1092
- Real-time and historical data fetched from **OKX Exchange** via REST API:
1093
- - Endpoint: `https://www.okx.com/api/v5/market/candles`
1094
- - Instrument: BTC-USDT
1095
- - Supported timeframes: 1m, 5m, 15m, 30m, 1H, 2H, 4H, 1D
1096
-
1097
- ### Model Training Process
1098
-
1099
- 1. **Data Collection**: Fetch historical OHLCV data from OKX
1100
- 2. **Feature Engineering**: Generate 100+ technical indicators
1101
- 3. **Data Preprocessing**: Handle missing values, normalize features
1102
- 4. **Train/Val/Test Split**: Time-series aware splitting
1103
- 5. **Model Training**: Train 6 models independently
1104
- 6. **Weight Optimization**: Calculate optimal ensemble weights based on validation performance
1105
- 7. **Evaluation**: Test on unseen data with multiple metrics
1106
-
1107
- ### Performance Metrics
1108
-
1109
- - **MSE** (Mean Squared Error): Average squared prediction error
1110
- - **RMSE** (Root Mean Squared Error): Square root of MSE, in price units
1111
- - **MAE** (Mean Absolute Error): Average absolute prediction error
1112
- - **R²** (R-squared): Proportion of variance explained
1113
- - **MAPE** (Mean Absolute Percentage Error): Average percentage error
1114
-
1115
- ### Usage Instructions
1116
-
1117
- 1. **Fetch Data**: Go to "Data Fetching" tab and load historical data
1118
- 2. **Train Model**: Navigate to "Model Training" and train the ensemble
1119
- 3. **Make Predictions**: Use "Future Predictions" to forecast prices
1120
- 4. **Monitor Live**: Check "Live Price" for real-time market data
1121
- 5. **Analyze**: Explore "Feature Importance" and "Market Analysis"
1122
-
1123
- ### Limitations & Disclaimer
1124
-
1125
- ⚠️ **Important**: This tool is for educational and research purposes only.
1126
-
1127
- - Cryptocurrency markets are highly volatile and unpredictable
1128
- - Past performance does not guarantee future results
1129
- - Model predictions should NOT be used as sole basis for trading decisions
1130
- - Always conduct your own research and consult financial advisors
1131
- - The authors are not responsible for any financial losses
1132
-
1133
- ### Technical Stack
1134
-
1135
- - **Python 3.10+**
1136
- - **Gradio**: Web interface
1137
- - **Scikit-learn**: Machine learning models
1138
- - **Pandas & NumPy**: Data manipulation
1139
- - **Plotly**: Interactive visualizations
1140
- - **Requests**: API communication
1141
-
1142
- ### Version
1143
-
1144
- **v1.0.0** - Initial Release
1145
-
1146
- ---
1147
-
1148
- Made with ❤️ for the crypto community
1149
-
1150
- **GitHub**: [Your Repository Link]
1151
- **Documentation**: [Your Docs Link]
1152
- **Contact**: [Your Contact Info]
1153
- """)
1154
-
1155
- # ================================
1156
- # LAUNCH APP
1157
- # ================================
1158
-
1159
  if __name__ == "__main__":
1160
- demo.launch(
1161
- server_name="0.0.0.0",
1162
- server_port=7860,
1163
- share=False,
1164
- show_error=True
1165
- )
 
1
+ İşte tek dosya `app.py`. Gradio (blank) arayüzü, OKX REST'ten BTC/USDT (spot) candle verisi çekme, önişleme, birkaç basit modelden (LightGBM, XGBoost, küçük PyTorch LSTM ve basit RandomForest) oluşan ensemble ile inference yapacak şekilde hazırlanmıştır. Eksik modeller varsa demo (dummy) modeller üretecek; gerçek eğitim için ek adımlar gerekir. Dosya, Spaces/Gradio üzerinde çalışacak şekilde tasarlandı.
2
+
3
+ python
4
+ # app.py
5
+ """
6
+ Gradio (blank) tabanlı Hugging Face Space uygulaması.
7
+ - OKX REST API'den BTC/USDT (spot) candle verisi çeker
8
+ - Teknik göstergeler üretir
9
+ - Ensemble: LightGBM, XGBoost, RandomForest (sklearn) + küçük PyTorch LSTM
10
+ - Eğer pretrained model dosyaları yoksa küçük demo modeller oluşturur
11
+ - Outputs: tahmin (regresyon: next-close), model katkıları, grafikler
12
+
13
+ Not:
14
+ - requirements.txt'de aşağıdakiler olmalı:
15
+ gradio, pandas, numpy, requests, ta, scikit-learn, lightgbm, xgboost, torch, matplotlib
16
+ - Kullanıcı OKX API anahtarı gerekli değildir (public candles endpoint kullanılıyor).
17
+ - Bu dosya tek başına çalışır; ancak ağır paketler (lightgbm, xgboost, torch) Spaces ortamında kurulmadıysa hata verebilir.
18
+ """
19
 
20
  import os
21
+ import io
22
+ import time
23
+ import math
24
+ import json
25
+ import threading
26
+ from typing import Tuple, Dict, Any, List
27
+
28
  import numpy as np
29
  import pandas as pd
 
30
  import requests
31
+ from datetime import datetime, timedelta, timezone
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # Visualization
34
+ import matplotlib
35
+ matplotlib.use("Agg")
36
+ import matplotlib.pyplot as plt
37
+
38
+ # Technical indicators
39
+ try:
40
+ import ta
41
+ except Exception:
42
+ # Minimal fallback implementations if ta isn't installed
43
+ ta = None
44
+
45
+ # ML libs
46
+ from sklearn.ensemble import RandomForestRegressor
47
+ from sklearn.preprocessing import StandardScaler
48
+ from sklearn.pipeline import Pipeline
49
+ from sklearn.base import BaseEstimator, RegressorMixin
50
+
51
+ # Try import optional libs
52
+ HAS_LGB = True
53
+ HAS_XGB = True
54
+ HAS_TORCH = True
55
+ try:
56
+ import lightgbm as lgb
57
+ except Exception:
58
+ HAS_LGB = False
59
+ try:
60
+ import xgboost as xgb
61
+ except Exception:
62
+ HAS_XGB = False
63
+ try:
64
+ import torch
65
+ import torch.nn as nn
66
+ import torch.nn.functional as F
67
+ from torch.utils.data import DataLoader, TensorDataset
68
+ except Exception:
69
+ HAS_TORCH = False
70
+
71
+ # Gradio
72
+ import gradio as gr
73
 
74
+ # -------------------------
75
+ # Configuration/Constants
76
+ # -------------------------
77
+ OKX_BASE = "https://www.okx.com"
78
+ # Public candles: GET /api/v5/market/history-candles?instId=BTC-USDT-SWAP&bar=1m&limit=100
79
+ # We'll use spot: BTC-USDT
80
+ DEFAULT_INSTRUMENT = "BTC-USDT"
81
+ DEFAULT_BAR = "1m" # options: 1m, 3m, 5m, 15m, 1H etc.
82
+ DEFAULT_LIMIT = 500 # up to 1000 depending on endpoint
83
+
84
+ # Model filenames (in repo or persisted by training)
85
+ MODEL_DIR = "models"
86
+ os.makedirs(MODEL_DIR, exist_ok=True)
87
+ LGB_MODEL_FILE = os.path.join(MODEL_DIR, "lgb_model.txt")
88
+ XGB_MODEL_FILE = os.path.join(MODEL_DIR, "xgb_model.json")
89
+ RF_MODEL_FILE = os.path.join(MODEL_DIR, "rf_model.pkl")
90
+ LSTM_MODEL_FILE = os.path.join(MODEL_DIR, "lstm_model.pt")
91
+ SCALER_FILE = os.path.join(MODEL_DIR, "scaler.npy") # save scaler mean/scale
92
+
93
+ # Thread-safe model cache
94
+ _MODEL_LOCK = threading.Lock()
95
+ _MODELS = {}
96
+
97
+ # -------------------------
98
+ # Utilities
99
+ # -------------------------
100
+ def now_iso():
101
+ return datetime.now(timezone.utc).isoformat()
102
+
103
+ def okx_candles(inst_id: str = DEFAULT_INSTRUMENT, bar: str = DEFAULT_BAR, limit: int = DEFAULT_LIMIT) -> pd.DataFrame:
104
+ """
105
+ Fetch recent candle data from OKX public REST API.
106
+ Returns DataFrame with columns: ts, open, high, low, close, volume
107
+ ts in UTC datetime
108
+ """
109
+ url = f"{OKX_BASE}/api/v5/market/history-candles"
110
+ params = {"instId": inst_id, "bar": bar, "limit": str(limit)}
111
+ resp = requests.get(url, params=params, timeout=15)
112
+ resp.raise_for_status()
113
+ data = resp.json()
114
+
115
+ if not data or data.get("code") not in (None, "0", 0):
116
+ # OKX returns "code": "0" on success sometimes; be permissive
117
+ # If structure unexpected, raise
118
+ # Try to parse anyway
119
+ pass
120
+
121
+ cand = data.get("data", [])
122
+ if not cand:
123
+ # Possibly different field
124
+ raise RuntimeError("No candle data returned from OKX")
125
+
126
+ # OKX returns list of lists: [ts, open, high, low, close, volume, ...]
127
+ # timestamp in millis
128
+ rows = []
129
+ for c in cand:
130
+ # According to OKX docs: [ts, open, high, low, close, volume]
131
+ ts = int(c[0]) // 1000 if len(str(c[0])) > 10 else int(c[0])
132
+ dt = datetime.fromtimestamp(ts, tz=timezone.utc)
133
+ rows.append({
134
+ "ts": dt,
135
+ "open": float(c[1]),
136
+ "high": float(c[2]),
137
+ "low": float(c[3]),
138
+ "close": float(c[4]),
139
+ "volume": float(c[5])
140
  })
141
+ df = pd.DataFrame(rows)
142
+ df = df.sort_values("ts").reset_index(drop=True)
143
+ return df
144
+
145
+ # Minimal TA indicators if `ta` package is not available
146
+ def add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
147
+ df = df.copy()
148
+ if ta is not None:
149
+ # Use ta to add common indicators
150
+ df["rsi"] = ta.momentum.RSIIndicator(df["close"], window=14, fillna=True).rsi()
151
+ df["ema12"] = ta.trend.EMAIndicator(df["close"], window=12, fillna=True).ema_indicator()
152
+ df["ema26"] = ta.trend.EMAIndicator(df["close"], window=26, fillna=True).ema_indicator()
153
+ macd = ta.trend.MACD(df["close"], window_slow=26, window_fast=12, window_sign=9, fillna=True)
154
+ df["macd"] = macd.macd()
155
+ df["macd_signal"] = macd.macd_signal()
156
+ df["bb_high"] = ta.volatility.BollingerBands(df["close"], window=20, fillna=True).bollinger_hband()
157
+ df["bb_low"] = ta.volatility.BollingerBands(df["close"], window=20, fillna=True).bollinger_lband()
158
+ df["atr"] = ta.volatility.AverageTrueRange(df["high"], df["low"], df["close"], window=14, fillna=True).average_true_range()
159
+ else:
160
+ # Fallback simple computations
161
+ df["rsi"] = simple_rsi(df["close"], window=14)
162
+ df["ema12"] = df["close"].ewm(span=12, adjust=False).mean()
163
+ df["ema26"] = df["close"].ewm(span=26, adjust=False).mean()
164
+ df["macd"] = df["ema12"] - df["ema26"]
165
+ df["macd_signal"] = df["macd"].ewm(span=9, adjust=False).mean()
166
+ df["bb_mid"] = df["close"].rolling(20).mean()
167
+ df["bb_std"] = df["close"].rolling(20).std()
168
+ df["bb_high"] = df["bb_mid"] + 2 * df["bb_std"]
169
+ df["bb_low"] = df["bb_mid"] - 2 * df["bb_std"]
170
+ df["atr"] = simple_atr(df, window=14)
171
+ # Fill na
172
+ df = df.fillna(method="bfill").fillna(method="ffill").fillna(0.0)
173
+ return df
174
+
175
+ def simple_rsi(series: pd.Series, window: int = 14) -> pd.Series:
176
+ delta = series.diff()
177
+ up = delta.clip(lower=0)
178
+ down = -1 * delta.clip(upper=0)
179
+ ma_up = up.ewm(alpha=1/window, adjust=False).mean()
180
+ ma_down = down.ewm(alpha=1/window, adjust=False).mean()
181
+ rs = ma_up / (ma_down + 1e-8)
182
+ rsi = 100 - (100 / (1 + rs))
183
+ return rsi.fillna(50.0)
184
+
185
+ def simple_atr(df: pd.DataFrame, window: int = 14) -> pd.Series:
186
+ high_low = df["high"] - df["low"]
187
+ high_close = (df["high"] - df["close"].shift()).abs()
188
+ low_close = (df["low"] - df["close"].shift()).abs()
189
+ tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
190
+ atr = tr.ewm(span=window, adjust=False).mean()
191
+ return atr.fillna(0.0)
192
+
193
+ def create_features(df: pd.DataFrame) -> pd.DataFrame:
194
+ df = df.copy()
195
+ df = add_technical_indicators(df)
196
+ # Returns features aligned to each row predicting next row's close
197
+ # Feature engineering: returns, log returns, vol, moving averages, ratios
198
+ df["return_1"] = df["close"].pct_change().fillna(0.0)
199
+ df["log_return_1"] = np.log1p(df["return_1"])
200
+ df["vol_5"] = df["close"].rolling(5).std().fillna(0.0)
201
+ df["vol_20"] = df["close"].rolling(20).std().fillna(0.0)
202
+ df["ma_5"] = df["close"].rolling(5).mean().fillna(method="bfill")
203
+ df["ma_20"] = df["close"].rolling(20).mean().fillna(method="bfill")
204
+ df["ma_50"] = df["close"].rolling(50).mean().fillna(method="bfill")
205
+ # ratio features
206
+ df["ma5_div_ma20"] = df["ma_5"] / (df["ma_20"] + 1e-9)
207
+ df["ema_diff"] = df["ema12"] - df["ema26"]
208
+ # time features
209
+ df["ts_unix"] = df["ts"].astype(np.int64) // 10**9
210
+ df["hour"] = df["ts"].dt.hour
211
+ df["minute"] = df["ts"].dt.minute
212
+ # fill remaining na
213
+ df = df.fillna(method="bfill").fillna(0.0)
214
+ return df
215
+
216
+ # -------------------------
217
+ # Model wrappers and helpers
218
+ # -------------------------
219
+ class DummyRegressor(BaseEstimator, RegressorMixin):
220
+ """Simple mean predictor used as fallback."""
221
+ def fit(self, X, y):
222
+ self._mean = np.mean(y) if len(y) else 0.0
223
+ return self
224
+ def predict(self, X):
225
+ return np.full((X.shape[0],), getattr(self, "_mean", 0.0))
226
+
227
+ def save_numpy(obj: np.ndarray, path: str):
228
+ np.save(path, obj)
229
+
230
+ def load_numpy(path: str) -> np.ndarray:
231
+ return np.load(path)
232
+
233
+ def get_feature_columns() -> List[str]:
234
+ cols = [
235
+ "open","high","low","close","volume",
236
+ "rsi","ema12","ema26","macd","macd_signal","bb_high","bb_low","atr",
237
+ "return_1","log_return_1","vol_5","vol_20","ma_5","ma_20","ma_50",
238
+ "ma5_div_ma20","ema_diff","ts_unix","hour","minute"
239
+ ]
240
+ return cols
241
+
242
+ # Model persistence helpers (light, simple)
243
+ def load_models() -> Dict[str, Any]:
244
+ """
245
+ Try to load pretrained models from MODEL_DIR. If missing, create small demo models.
246
+ Returns dict of models and scaler.
247
+ """
248
+ with _MODEL_LOCK:
249
+ if _MODELS:
250
+ return _MODELS
251
+
252
+ models = {}
253
+ scaler = None
254
+
255
+ # Try load scaler if exists
256
+ if os.path.exists(SCALER_FILE):
257
+ try:
258
+ sc = np.load(SCALER_FILE, allow_pickle=True).item()
259
+ scaler = StandardScaler()
260
+ scaler.mean_ = sc["mean"]
261
+ scaler.scale_ = sc["scale"]
262
+ scaler.n_features_in_ = sc["n_in"]
263
+ except Exception:
264
+ scaler = None
265
+
266
+ # RandomForest (sklearn)
267
  try:
268
+ import joblib
269
+ if os.path.exists(RF_MODEL_FILE):
270
+ models["rf"] = joblib.load(RF_MODEL_FILE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  else:
272
+ raise FileNotFoundError
273
+ except Exception:
274
+ # create small RF demo
275
+ models["rf"] = RandomForestRegressor(n_estimators=10, random_state=42)
276
+
277
+ # LightGBM
278
+ if HAS_LGB and os.path.exists(LGB_MODEL_FILE):
279
+ try:
280
+ models["lgb"] = lgb.Booster(model_file=LGB_MODEL_FILE)
281
+ except Exception:
282
+ models["lgb"] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  else:
284
+ models["lgb"] = None if not HAS_LGB else None
285
+
286
+ # XGBoost
287
+ if HAS_XGB and os.path.exists(XGB_MODEL_FILE):
288
+ try:
289
+ models["xgb"] = xgb.Booster()
290
+ models["xgb"].load_model(XGB_MODEL_FILE)
291
+ except Exception:
292
+ models["xgb"] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  else:
294
+ models["xgb"] = None
295
+
296
+ # LSTM / PyTorch
297
+ if HAS_TORCH and os.path.exists(LSTM_MODEL_FILE):
298
+ try:
299
+ lstm = torch.load(LSTM_MODEL_FILE, map_location=torch.device("cpu"))
300
+ models["lstm"] = lstm
301
+ except Exception:
302
+ models["lstm"] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  else:
304
+ models["lstm"] = None
305
+
306
+ # If scaler missing, create a dummy one later in pipeline when training; for inference create StandardScaler default
307
+ if scaler is None:
308
+ scaler = StandardScaler()
309
+
310
+ # Create an ensemble wrapper
311
+ models["scaler"] = scaler
312
+
313
+ _MODELS.update(models)
314
+ return _MODELS
315
+
316
+ def save_scaler(scaler: StandardScaler, path: str = SCALER_FILE):
317
+ obj = {"mean": scaler.mean_, "scale": scaler.scale_, "n_in": scaler.n_features_in_}
318
+ np.save(path, obj)
319
+
320
+ # -------------------------
321
+ # Inference logic
322
+ # -------------------------
323
+ def prepare_inference_features(df: pd.DataFrame) -> Tuple[np.ndarray, List[str], pd.DataFrame]:
324
+ """
325
+ Takes raw candles df, returns (X, feature_cols, df_ready)
326
+ X is 2D array for model input, aligned so that each row predicts next close.
327
+ """
328
+ df2 = create_features(df)
329
+ feat_cols = get_feature_columns()
330
+ # Ensure columns present
331
+ for c in feat_cols:
332
+ if c not in df2.columns:
333
+ df2[c] = 0.0
334
+ X = df2[feat_cols].values
335
+ return X, feat_cols, df2
336
+
337
+ def predict_ensemble(X: np.ndarray, models: Dict[str, Any]) -> Dict[str, Any]:
338
+ """
339
+ Predict next-step close using ensemble of models.
340
+ Return dict:
341
+ - per_model_preds: {name: scalar_pred}
342
+ - ensemble_mean: float
343
+ - weighted: float (weights fallback equal)
344
+ """
345
+ scaler = models.get("scaler", None)
346
+ if scaler is None:
347
+ scaler = StandardScaler()
348
+ # Use last row features to predict next
349
+ if X.ndim == 1:
350
+ X_row = X.reshape(1, -1)
351
+ else:
352
+ X_row = X[-1:, :]
353
+ # scale
354
  try:
355
+ Xs = scaler.transform(X_row)
356
+ except Exception:
357
+ # If scaler not fitted, fit on X (fallback)
358
+ try:
359
+ scaler.fit(X)
360
+ save_scaler(scaler)
361
+ Xs = scaler.transform(X_row)
362
+ except Exception:
363
+ Xs = X_row
364
+
365
+ preds = {}
366
+ # RandomForest
367
+ rf = models.get("rf", None)
368
+ if rf is not None:
369
+ try:
370
+ p = rf.predict(Xs)[0]
371
+ except Exception:
372
+ p = float(np.nan)
373
+ else:
374
+ p = float(np.nan)
375
+ preds["rf"] = float(p)
376
+
377
+ # LightGBM
378
+ if HAS_LGB and models.get("lgb", None) is not None:
379
+ try:
380
+ dmat = lgb.Dataset(Xs, free_raw_data=False)
381
+ p = models["lgb"].predict(Xs)[0]
382
+ except Exception:
383
+ p = float(np.nan)
384
+ else:
385
+ p = float(np.nan)
386
+ preds["lgb"] = float(p)
387
+
388
+ # XGBoost
389
+ if HAS_XGB and models.get("xgb", None) is not None:
390
+ try:
391
+ dm = xgb.DMatrix(Xs)
392
+ p = models["xgb"].predict(dm)[0]
393
+ except Exception:
394
+ p = float(np.nan)
395
+ else:
396
+ p = float(np.nan)
397
+ preds["xgb"] = float(p)
398
+
399
+ # LSTM (PyTorch)
400
+ if HAS_TORCH and models.get("lstm", None) is not None:
401
+ try:
402
+ model = models["lstm"]
403
+ model.eval()
404
+ with torch.no_grad():
405
+ t = torch.tensor(X_row, dtype=torch.float32).unsqueeze(0) # shape (1,1,features) if expected
406
+ # try both (1,features) or (1,seq,features)
407
+ if t.dim() == 3:
408
+ out = model(t)
409
+ else:
410
+ # reshape to (1,1,features)
411
+ t2 = t.unsqueeze(1)
412
+ out = model(t2)
413
+ p = float(out.squeeze().cpu().numpy())
414
+ except Exception:
415
+ p = float(np.nan)
416
+ else:
417
+ p = float(np.nan)
418
+ preds["lstm"] = float(p)
419
+
420
+ # If models missing, fallback: use RF or mean of last price as naive
421
+ valid_preds = [v for v in preds.values() if not (math.isnan(v) or v is None)]
422
+ if not valid_preds:
423
+ # fallback naive next-close = last close
424
+ naive = float(X_row[0, get_feature_columns().index("close")])
425
+ ensemble_mean = naive
426
+ weighted = naive
427
+ else:
428
+ ensemble_mean = float(np.nanmean(valid_preds))
429
+ # Simple weighting: prefer models that exist; equal weight
430
+ weighted = ensemble_mean
431
+
432
+ return {
433
+ "per_model": preds,
434
+ "ensemble_mean": ensemble_mean,
435
+ "weighted": weighted
436
+ }
437
+
438
+ # -------------------------
439
+ # LSTM simple architecture (for demo)
440
+ # -------------------------
441
+ if HAS_TORCH:
442
+ class SimpleLSTM(nn.Module):
443
+ def __init__(self, input_size: int, hidden_size: int = 32, num_layers: int = 1):
444
+ super().__init__()
445
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
446
+ self.fc = nn.Linear(hidden_size, 1)
447
+ def forward(self, x):
448
+ # x: (batch, seq_len, input_size)
449
+ out, _ = self.lstm(x)
450
+ # take last time step
451
+ last = out[:, -1, :]
452
+ return self.fc(last)
453
+
454
+ # -------------------------
455
+ # Visualization helpers
456
+ # -------------------------
457
+ def plot_price_and_preds(df: pd.DataFrame, preds: Dict[str, Any]) -> bytes:
458
+ fig, ax = plt.subplots(figsize=(9,4))
459
+ ax.plot(df["ts"], df["close"], label="close", color="black", lw=1)
460
+ # mark last price and ensemble prediction
461
+ last_ts = df["ts"].iloc[-1]
462
+ last_close = df["close"].iloc[-1]
463
+ pred = preds.get("weighted", preds.get("ensemble_mean", last_close))
464
+ ax.scatter([last_ts + pd.Timedelta(seconds=1)], [pred], color="red", label="ensemble_pred")
465
+ ax.axhline(last_close, linestyle="--", color="gray", alpha=0.6)
466
+ ax.set_title("BTC/USDT close and ensemble prediction")
467
+ ax.set_xlabel("Time (UTC)")
468
+ ax.set_ylabel("Price")
469
+ ax.legend()
470
+ fig.tight_layout()
471
+ buf = io.BytesIO()
472
+ fig.savefig(buf, format="png")
473
+ plt.close(fig)
474
+ buf.seek(0)
475
+ return buf.read()
476
+
477
+ def plot_model_contributions(per_model: Dict[str, float]) -> bytes:
478
+ names = list(per_model.keys())
479
+ vals = [per_model[n] if (not math.isnan(per_model[n])) else 0.0 for n in names]
480
+ fig, ax = plt.subplots(figsize=(6,3))
481
+ ax.bar(names, vals, color=["#1f77b4","#ff7f0e","#2ca02c","#d62728"])
482
+ ax.set_title("Per-model predictions (abs values)")
483
+ ax.set_ylabel("Predicted price")
484
+ fig.tight_layout()
485
+ buf = io.BytesIO()
486
+ fig.savefig(buf, format="png")
487
+ plt.close(fig)
488
+ buf.seek(0)
489
+ return buf.read()
490
+
491
+ # -------------------------
492
+ # Gradio app components
493
+ # -------------------------
494
+ def inference_pipeline(inst_id: str = DEFAULT_INSTRUMENT,
495
+ bar: str = DEFAULT_BAR,
496
+ limit: int = DEFAULT_LIMIT,
497
+ show_plot: bool = True):
498
+ """
499
+ High-level function called by Gradio. Returns JSON/dicts + image bytes for display.
500
+ """
501
+ # Step 1: fetch candles
502
  try:
503
+ df = okx_candles(inst_id=inst_id, bar=bar, limit=int(limit))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  except Exception as e:
505
+ return {"error": f"Failed to fetch candles: {e}"}
506
+
507
+ # Step 2: prepare features
508
+ X, feat_cols, df_ready = prepare_inference_features(df)
509
+
510
+ # Step 3: load models
511
+ models = load_models()
512
+
513
+ # Step 4: predict
514
+ preds = predict_ensemble(X, models)
515
+
516
+ # Step 5: build result
517
+ last_close = float(df_ready["close"].iloc[-1])
518
+ ensemble = preds.get("weighted", preds.get("ensemble_mean", last_close))
519
+
520
+ out = {
521
+ "instrument": inst_id,
522
+ "bar": bar,
523
+ "fetched_candles": int(limit),
524
+ "last_ts": df_ready["ts"].iloc[-1].isoformat(),
525
+ "last_close": float(last_close),
526
+ "ensemble_prediction": float(ensemble),
527
+ "per_model": preds.get("per_model", {})
528
+ }
529
+
530
+ # Prepare images
531
+ img_price = plot_price_and_preds(df_ready, {"weighted": ensemble})
532
+ img_contrib = plot_model_contributions(out["per_model"])
533
+
534
+ return {
535
+ "result": out,
536
+ "img_price": img_price,
537
+ "img_contrib": img_contrib
538
+ }
539
+
540
+ # Helper to convert bytes to gradio displayable
541
+ def bytes_to_pil(b: bytes):
542
+ from PIL import Image
543
+ buf = io.BytesIO(b)
544
+ return Image.open(buf)
545
+
546
+ # -------------------------
547
+ # Gradio layout (blank template)
548
+ # -------------------------
549
+ def build_gradio_app():
550
+ title = "BTC/USDT Price Prediction (OKX REST) — Ensemble Demo"
551
+ description = "Fetch recent candles from OKX and predict next close using an ensemble (demo)."
552
+ with gr.Blocks(title=title) as demo:
553
+ gr.Markdown(f"## {title}")
554
+ gr.Markdown(description)
555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  with gr.Row():
557
+ with gr.Column(scale=1):
558
+ inst_in = gr.Textbox(label="Instrument", value=DEFAULT_INSTRUMENT)
559
+ bar_in = gr.Dropdown(label="Candle bar", choices=["1m","3m","5m","15m","1H","4H","1D"], value=DEFAULT_BAR)
560
+ limit_in = gr.Slider(label="Limit (number of candles)", minimum=50, maximum=1000, step=50, value=DEFAULT_LIMIT)
561
+ run_btn = gr.Button("Run Inference")
562
+ refresh_btn = gr.Button("Refresh Models (clear cache)")
563
+ info_out = gr.Textbox(label="Info / JSON result", interactive=False)
564
+ with gr.Column(scale=2):
565
+ price_img = gr.Image(label="Price & Prediction", type="pil")
566
+ contrib_img = gr.Image(label="Per-model predictions", type="pil")
567
+
568
+ # Callbacks
569
+ def on_run(inst, bar, limit):
570
+ res = inference_pipeline(inst, bar, limit)
571
+ if "error" in res:
572
+ return "", gr.update(value=None), gr.update(value=None), json.dumps({"error": res["error"]}, indent=2)
573
+ out = res["result"]
574
+ price_pil = bytes_to_pil(res["img_price"])
575
+ contrib_pil = bytes_to_pil(res["img_contrib"])
576
+ info_json = json.dumps(out, indent=2, default=str)
577
+ return price_pil, contrib_pil, info_json
578
+
579
+ def on_refresh():
580
+ # clear model cache and reload
581
+ with _MODEL_LOCK:
582
+ _MODELS.clear()
583
+ return "Model cache cleared."
584
+
585
+ run_btn.click(on_run, inputs=[inst_in, bar_in, limit_in], outputs=[price_img, contrib_img, info_out])
586
+ refresh_btn.click(on_refresh, inputs=None, outputs=info_out)
587
+
588
+ gr.Markdown("Notes: This demo uses public OKX market endpoints. For production, validate rate limits and handle API keys for private data. Ensemble models here are demo-friendly; train and persist stronger models for real use.")
589
+ return demo
590
+
591
+ # -------------------------
592
+ # If run as app
593
+ # -------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  if __name__ == "__main__":
595
+ app = build_gradio_app()
596
+ app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", ave)