Alvin3y1 commited on
Commit
ad70443
·
verified ·
1 Parent(s): e96c872

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -26
app.py CHANGED
@@ -8,7 +8,7 @@ import pandas as pd
8
  import numpy as np
9
  from aiohttp import web
10
  from sklearn.ensemble import RandomForestRegressor
11
- from sklearn.metrics import mean_squared_error
12
 
13
  SYMBOL_KRAKEN = "BTC/USD"
14
  PORT = 7860
@@ -30,6 +30,7 @@ market_state = {
30
  }
31
 
32
  connected_clients = set()
 
33
 
34
  def calculate_indicators(candles):
35
  if len(candles) < 100:
@@ -72,7 +73,11 @@ def calculate_indicators(candles):
72
  df['vol_change'] = df['volume'].pct_change()
73
  df['log_ret'] = np.log(df['close'] / df['close'].shift(1))
74
 
75
- for lag in [1, 2, 3]:
 
 
 
 
76
  df[f'rsi_lag{lag}'] = df['rsi'].shift(lag)
77
  df[f'macd_hist_lag{lag}'] = df['macd_hist'].shift(lag)
78
  df[f'log_ret_lag{lag}'] = df['log_ret'].shift(lag)
@@ -88,32 +93,38 @@ def train_model(df):
88
  'dist_ema20', 'dist_ema50',
89
  'bb_width', 'bb_pos',
90
  'vol_change', 'log_ret',
91
- 'rsi_lag1', 'rsi_lag2', 'rsi_lag3',
92
- 'macd_hist_lag1', 'macd_hist_lag2', 'macd_hist_lag3',
93
- 'log_ret_lag1', 'log_ret_lag2', 'log_ret_lag3',
94
- 'vol_change_lag1', 'vol_change_lag2', 'vol_change_lag3'
95
  ]
96
 
 
 
 
 
 
 
97
  data = df.dropna().copy()
98
- targets = []
 
 
99
 
100
  for i in range(1, PREDICTION_HORIZON + 1):
101
  col_name = f'target_return_{i}'
102
- data[col_name] = (data['close'].shift(-i) - data['close']) / data['close']
103
- targets.append(col_name)
104
-
105
- data = data.dropna()
 
106
 
107
  if len(data) < 200:
108
  return None, None
109
 
110
  X = data[feature_cols].values
111
- y = data[targets].values
112
 
113
  model = RandomForestRegressor(
114
- n_estimators=150,
115
- max_depth=20,
116
- min_samples_split=4,
117
  min_samples_leaf=2,
118
  max_features='sqrt',
119
  n_jobs=-1,
@@ -136,12 +147,15 @@ def get_prediction(df, model, residual_std):
136
  'dist_ema20', 'dist_ema50',
137
  'bb_width', 'bb_pos',
138
  'vol_change', 'log_ret',
139
- 'rsi_lag1', 'rsi_lag2', 'rsi_lag3',
140
- 'macd_hist_lag1', 'macd_hist_lag2', 'macd_hist_lag3',
141
- 'log_ret_lag1', 'log_ret_lag2', 'log_ret_lag3',
142
- 'vol_change_lag1', 'vol_change_lag2', 'vol_change_lag3'
143
  ]
144
 
 
 
 
 
 
 
145
  last_row = df.iloc[[-1]][feature_cols]
146
 
147
  if last_row.isnull().values.any():
@@ -171,7 +185,7 @@ def get_prediction(df, model, residual_std):
171
 
172
  return pred_data
173
 
174
- def process_market_data():
175
  if not market_state['ready'] or not market_state['ohlc_history']:
176
  return {"error": "Initializing..."}
177
 
@@ -181,7 +195,8 @@ def process_market_data():
181
 
182
  if market_state['model'] is None or (time.time() - market_state['last_training_time'] > TRAIN_INTERVAL):
183
  try:
184
- model, res_std = train_model(df)
 
185
  if model is not None:
186
  market_state['model'] = model
187
  market_state['model_residuals'] = res_std
@@ -196,7 +211,8 @@ def process_market_data():
196
  logging.error(f"Prediction failed: {e}")
197
 
198
  df_clean = df.replace([np.inf, -np.inf], np.nan)
199
- df_clean = df_clean.astype(object).where(pd.notnull(df_clean), None)
 
200
 
201
  last_close = float(df['close'].iloc[-1]) if len(df) > 0 else 0
202
  first_close = float(df['close'].iloc[0]) if len(df) > 0 else 0
@@ -479,7 +495,7 @@ document.addEventListener('DOMContentLoaded', () => {
479
  document.getElementById('ema-val').textContent = lastData.ema20 ? lastData.ema20.toFixed(2) : '--';
480
  document.getElementById('bb-upper').textContent = lastData.bb_upper ? lastData.bb_upper.toFixed(2) : '--';
481
  document.getElementById('bb-lower').textContent = lastData.bb_lower ? lastData.bb_lower.toFixed(2) : '--';
482
- const macdVal = lastData.macd;
483
  const macdEl = document.getElementById('macd-val');
484
  if (macdVal !== null && macdVal !== undefined) {
485
  macdEl.textContent = macdVal.toFixed(2);
@@ -633,16 +649,18 @@ async def kraken_rest_worker():
633
  async def broadcast_worker():
634
  while True:
635
  if connected_clients and market_state['ready']:
636
- payload = process_market_data()
637
  if payload and "data" in payload:
638
  msg = json.dumps(payload)
 
639
  disconnected = set()
640
- for ws in connected_clients:
641
  try:
642
  await ws.send_str(msg)
643
  except Exception:
644
  disconnected.add(ws)
645
- connected_clients.difference_update(disconnected)
 
646
  await asyncio.sleep(BROADCAST_RATE)
647
 
648
  async def websocket_handler(request):
 
8
  import numpy as np
9
  from aiohttp import web
10
  from sklearn.ensemble import RandomForestRegressor
11
+ from concurrent.futures import ThreadPoolExecutor
12
 
13
  SYMBOL_KRAKEN = "BTC/USD"
14
  PORT = 7860
 
30
  }
31
 
32
  connected_clients = set()
33
+ executor = ThreadPoolExecutor(max_workers=1)
34
 
35
  def calculate_indicators(candles):
36
  if len(candles) < 100:
 
73
  df['vol_change'] = df['volume'].pct_change()
74
  df['log_ret'] = np.log(df['close'] / df['close'].shift(1))
75
 
76
+ df['datetime'] = pd.to_datetime(df['time'], unit='s')
77
+ df['hour_sin'] = np.sin(2 * np.pi * df['datetime'].dt.hour / 24)
78
+ df['hour_cos'] = np.cos(2 * np.pi * df['datetime'].dt.hour / 24)
79
+
80
+ for lag in [1, 2, 3, 5, 8]:
81
  df[f'rsi_lag{lag}'] = df['rsi'].shift(lag)
82
  df[f'macd_hist_lag{lag}'] = df['macd_hist'].shift(lag)
83
  df[f'log_ret_lag{lag}'] = df['log_ret'].shift(lag)
 
93
  'dist_ema20', 'dist_ema50',
94
  'bb_width', 'bb_pos',
95
  'vol_change', 'log_ret',
96
+ 'hour_sin', 'hour_cos'
 
 
 
97
  ]
98
 
99
+ for lag in [1, 2, 3, 5, 8]:
100
+ feature_cols.extend([
101
+ f'rsi_lag{lag}', f'macd_hist_lag{lag}',
102
+ f'log_ret_lag{lag}', f'vol_change_lag{lag}'
103
+ ])
104
+
105
  data = df.dropna().copy()
106
+
107
+ target_cols_dict = {}
108
+ target_names = []
109
 
110
  for i in range(1, PREDICTION_HORIZON + 1):
111
  col_name = f'target_return_{i}'
112
+ target_cols_dict[col_name] = (data['close'].shift(-i) - data['close']) / data['close']
113
+ target_names.append(col_name)
114
+
115
+ targets_df = pd.DataFrame(target_cols_dict, index=data.index)
116
+ data = pd.concat([data, targets_df], axis=1).dropna()
117
 
118
  if len(data) < 200:
119
  return None, None
120
 
121
  X = data[feature_cols].values
122
+ y = data[target_names].values
123
 
124
  model = RandomForestRegressor(
125
+ n_estimators=200,
126
+ max_depth=25,
127
+ min_samples_split=5,
128
  min_samples_leaf=2,
129
  max_features='sqrt',
130
  n_jobs=-1,
 
147
  'dist_ema20', 'dist_ema50',
148
  'bb_width', 'bb_pos',
149
  'vol_change', 'log_ret',
150
+ 'hour_sin', 'hour_cos'
 
 
 
151
  ]
152
 
153
+ for lag in [1, 2, 3, 5, 8]:
154
+ feature_cols.extend([
155
+ f'rsi_lag{lag}', f'macd_hist_lag{lag}',
156
+ f'log_ret_lag{lag}', f'vol_change_lag{lag}'
157
+ ])
158
+
159
  last_row = df.iloc[[-1]][feature_cols]
160
 
161
  if last_row.isnull().values.any():
 
185
 
186
  return pred_data
187
 
188
+ async def process_market_data():
189
  if not market_state['ready'] or not market_state['ohlc_history']:
190
  return {"error": "Initializing..."}
191
 
 
195
 
196
  if market_state['model'] is None or (time.time() - market_state['last_training_time'] > TRAIN_INTERVAL):
197
  try:
198
+ loop = asyncio.get_running_loop()
199
+ model, res_std = await loop.run_in_executor(executor, train_model, df)
200
  if model is not None:
201
  market_state['model'] = model
202
  market_state['model_residuals'] = res_std
 
211
  logging.error(f"Prediction failed: {e}")
212
 
213
  df_clean = df.replace([np.inf, -np.inf], np.nan)
214
+ cols_to_keep = ['time', 'open', 'high', 'low', 'close', 'volume', 'ema20', 'bb_upper', 'bb_lower', 'rsi', 'macd_hist']
215
+ df_clean = df_clean[cols_to_keep].where(pd.notnull(df_clean), None)
216
 
217
  last_close = float(df['close'].iloc[-1]) if len(df) > 0 else 0
218
  first_close = float(df['close'].iloc[0]) if len(df) > 0 else 0
 
495
  document.getElementById('ema-val').textContent = lastData.ema20 ? lastData.ema20.toFixed(2) : '--';
496
  document.getElementById('bb-upper').textContent = lastData.bb_upper ? lastData.bb_upper.toFixed(2) : '--';
497
  document.getElementById('bb-lower').textContent = lastData.bb_lower ? lastData.bb_lower.toFixed(2) : '--';
498
+ const macdVal = lastData.macd_hist;
499
  const macdEl = document.getElementById('macd-val');
500
  if (macdVal !== null && macdVal !== undefined) {
501
  macdEl.textContent = macdVal.toFixed(2);
 
649
  async def broadcast_worker():
650
  while True:
651
  if connected_clients and market_state['ready']:
652
+ payload = await process_market_data()
653
  if payload and "data" in payload:
654
  msg = json.dumps(payload)
655
+ current_clients = connected_clients.copy()
656
  disconnected = set()
657
+ for ws in current_clients:
658
  try:
659
  await ws.send_str(msg)
660
  except Exception:
661
  disconnected.add(ws)
662
+ if disconnected:
663
+ connected_clients.difference_update(disconnected)
664
  await asyncio.sleep(BROADCAST_RATE)
665
 
666
  async def websocket_handler(request):