Alvin3y1 commited on
Commit
aad8eb4
·
verified ·
1 Parent(s): 6db9908

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -22
app.py CHANGED
@@ -8,12 +8,12 @@ import numpy as np
8
  from aiohttp import web
9
  import websockets
10
  from sklearn.ensemble import RandomForestRegressor
11
- from sklearn.model_selection import train_test_split
12
 
 
13
  SYMBOL_KRAKEN = "BTC/USD"
14
  PORT = 7860
15
- BROADCAST_RATE = 1.0 # Slowed down slightly for ML inference time
16
- PREDICTION_HORIZON = 100 # Predict next 100 minutes
17
 
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
19
 
@@ -83,22 +83,26 @@ def train_model(df):
83
  logging.info("Training ML Model...")
84
 
85
  # 1. Prepare Features (X)
86
- # We use the indicators we calculated as features
87
  feature_cols = ['close', 'ema', 'bb_upper', 'bb_lower', 'rsi', 'macd', 'stoch_k', 'atr', 'obv', 'vwap']
88
 
89
- # Drop NaNs created by indicators
90
  data = df.dropna().copy()
91
 
92
- # Create Targets (y)
93
- # y is the Close price shifted backwards by 1 to 100 steps
94
- # If t is now, we want to predict Close at t+1, t+2... t+100
95
  targets = []
 
96
  for i in range(1, PREDICTION_HORIZON + 1):
97
  col_name = f'target_{i}'
98
- data[col_name] = data['close'].shift(-i)
99
  targets.append(col_name)
 
 
 
 
100
 
101
- # Drop the end rows where we don't have future data yet
102
  data = data.dropna()
103
 
104
  if len(data) < 100:
@@ -109,7 +113,6 @@ def train_model(df):
109
  y = data[targets].values
110
 
111
  # Train Random Forest
112
- # n_estimators=50 is kept low for speed in this demo. Increase for accuracy.
113
  model = RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1, random_state=42)
114
  model.fit(X, y)
115
 
@@ -127,14 +130,14 @@ def get_prediction(df, model):
127
  if last_row.isnull().values.any(): return []
128
 
129
  # Predict
130
- prediction = model.predict(last_row.values)[0] # Returns array of 100 values
131
 
132
  # Format for frontend
133
  current_time = int(df.iloc[-1]['time'])
134
  pred_data = []
135
  for i, price in enumerate(prediction):
136
  pred_data.append({
137
- "time": current_time + ((i + 1) * 60), # Add minutes
138
  "value": float(price)
139
  })
140
 
@@ -149,7 +152,7 @@ def process_market_data():
149
  if df is None or len(df) < 50: return {"error": "Not enough data"}
150
 
151
  # 2. Train Model (Periodically)
152
- # Train initially or every 15 minutes to adapt to new trends
153
  if market_state['model'] is None or (time.time() - market_state['last_training_time'] > 900):
154
  market_state['model'] = train_model(df)
155
  market_state['last_training_time'] = time.time()
@@ -222,7 +225,6 @@ HTML_PAGE = f"""
222
  const mainChart = LightweightCharts.createChart(mainEl, commonOpts);
223
  const candles = mainChart.addCandlestickSeries({{ upColor: '#00ff9d', downColor: '#ff3b3b', borderVisible: false }});
224
  const ema = mainChart.addLineSeries({{ color: '#2962FF', lineWidth: 1 }});
225
- // Prediction Line
226
  const predLine = mainChart.addLineSeries({{ color: '#bf5af2', lineWidth: 2, lineStyle: 2, title: 'AI Forecast' }});
227
 
228
  const oscChart = LightweightCharts.createChart(oscEl, commonOpts);
@@ -258,7 +260,6 @@ HTML_PAGE = f"""
258
  ema.setData(mapData('ema'));
259
  rsi.setData(mapData('rsi'));
260
 
261
- // Prediction Data (Future)
262
  if(payload.prediction && payload.prediction.length > 0) {{
263
  predLine.setData(payload.prediction);
264
  }}
@@ -283,8 +284,6 @@ HTML_PAGE = f"""
283
  async def kraken_worker():
284
  global market_state
285
  try:
286
- # 1. REST Snapshot (Get MORE data for training)
287
- # 720 is roughly Kraken's max per request for 1m timeframe
288
  async with aiohttp.ClientSession() as session:
289
  url = "https://api.kraken.com/0/public/OHLC?pair=XBTUSD&interval=1"
290
  async with session.get(url) as response:
@@ -294,7 +293,6 @@ async def kraken_worker():
294
  for key in data['result']:
295
  if key != 'last':
296
  raw = data['result'][key]
297
- # Keep last 720 candles for training
298
  market_state['ohlc_history'] = [
299
  {
300
  'time': int(c[0]),
@@ -311,7 +309,7 @@ async def kraken_worker():
311
  except Exception as e:
312
  logging.error(f"Init Error: {e}")
313
 
314
- # 2. WebSocket Stream
315
  while True:
316
  try:
317
  async with websockets.connect("wss://ws.kraken.com/v2") as ws:
@@ -347,7 +345,6 @@ async def kraken_worker():
347
  'close': price,
348
  'volume': vol
349
  })
350
- # Keep buffer size larger for ML
351
  if len(market_state['ohlc_history']) > 800:
352
  market_state['ohlc_history'].pop(0)
353
  except: pass
@@ -379,7 +376,6 @@ async def kraken_worker():
379
  async def broadcast_worker():
380
  while True:
381
  if connected_clients and market_state['ready']:
382
- # This calculation might take 100-200ms depending on CPU
383
  payload = process_market_data()
384
  if payload and "data" in payload:
385
  msg = json.dumps(payload)
 
8
  from aiohttp import web
9
  import websockets
10
  from sklearn.ensemble import RandomForestRegressor
 
11
 
12
+ # Configuration
13
  SYMBOL_KRAKEN = "BTC/USD"
14
  PORT = 7860
15
+ BROADCAST_RATE = 1.0
16
+ PREDICTION_HORIZON = 100
17
 
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
19
 
 
83
  logging.info("Training ML Model...")
84
 
85
  # 1. Prepare Features (X)
 
86
  feature_cols = ['close', 'ema', 'bb_upper', 'bb_lower', 'rsi', 'macd', 'stoch_k', 'atr', 'obv', 'vwap']
87
 
88
+ # Create a clean copy for training data
89
  data = df.dropna().copy()
90
 
91
+ # 2. Create Targets (y) - OPTIMIZED to fix Fragmentation Warning
92
+ # Instead of adding columns in a loop, we create a dict and concat once
93
+ future_shifts = {}
94
  targets = []
95
+
96
  for i in range(1, PREDICTION_HORIZON + 1):
97
  col_name = f'target_{i}'
98
+ future_shifts[col_name] = data['close'].shift(-i)
99
  targets.append(col_name)
100
+
101
+ # Concatenate all target columns at once
102
+ target_df = pd.DataFrame(future_shifts, index=data.index)
103
+ data = pd.concat([data, target_df], axis=1)
104
 
105
+ # Drop rows where we don't have future data (the last 100 rows)
106
  data = data.dropna()
107
 
108
  if len(data) < 100:
 
113
  y = data[targets].values
114
 
115
  # Train Random Forest
 
116
  model = RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1, random_state=42)
117
  model.fit(X, y)
118
 
 
130
  if last_row.isnull().values.any(): return []
131
 
132
  # Predict
133
+ prediction = model.predict(last_row.values)[0]
134
 
135
  # Format for frontend
136
  current_time = int(df.iloc[-1]['time'])
137
  pred_data = []
138
  for i, price in enumerate(prediction):
139
  pred_data.append({
140
+ "time": current_time + ((i + 1) * 60),
141
  "value": float(price)
142
  })
143
 
 
152
  if df is None or len(df) < 50: return {"error": "Not enough data"}
153
 
154
  # 2. Train Model (Periodically)
155
+ # Train initially or every 15 minutes (900 seconds)
156
  if market_state['model'] is None or (time.time() - market_state['last_training_time'] > 900):
157
  market_state['model'] = train_model(df)
158
  market_state['last_training_time'] = time.time()
 
225
  const mainChart = LightweightCharts.createChart(mainEl, commonOpts);
226
  const candles = mainChart.addCandlestickSeries({{ upColor: '#00ff9d', downColor: '#ff3b3b', borderVisible: false }});
227
  const ema = mainChart.addLineSeries({{ color: '#2962FF', lineWidth: 1 }});
 
228
  const predLine = mainChart.addLineSeries({{ color: '#bf5af2', lineWidth: 2, lineStyle: 2, title: 'AI Forecast' }});
229
 
230
  const oscChart = LightweightCharts.createChart(oscEl, commonOpts);
 
260
  ema.setData(mapData('ema'));
261
  rsi.setData(mapData('rsi'));
262
 
 
263
  if(payload.prediction && payload.prediction.length > 0) {{
264
  predLine.setData(payload.prediction);
265
  }}
 
284
  async def kraken_worker():
285
  global market_state
286
  try:
 
 
287
  async with aiohttp.ClientSession() as session:
288
  url = "https://api.kraken.com/0/public/OHLC?pair=XBTUSD&interval=1"
289
  async with session.get(url) as response:
 
293
  for key in data['result']:
294
  if key != 'last':
295
  raw = data['result'][key]
 
296
  market_state['ohlc_history'] = [
297
  {
298
  'time': int(c[0]),
 
309
  except Exception as e:
310
  logging.error(f"Init Error: {e}")
311
 
312
+ # WebSocket Stream
313
  while True:
314
  try:
315
  async with websockets.connect("wss://ws.kraken.com/v2") as ws:
 
345
  'close': price,
346
  'volume': vol
347
  })
 
348
  if len(market_state['ohlc_history']) > 800:
349
  market_state['ohlc_history'].pop(0)
350
  except: pass
 
376
  async def broadcast_worker():
377
  while True:
378
  if connected_clients and market_state['ready']:
 
379
  payload = process_market_data()
380
  if payload and "data" in payload:
381
  msg = json.dumps(payload)