Alvin3y1 commited on
Commit
a4227fa
·
verified ·
1 Parent(s): aad8eb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -43
app.py CHANGED
@@ -9,7 +9,7 @@ from aiohttp import web
9
  import websockets
10
  from sklearn.ensemble import RandomForestRegressor
11
 
12
- # Configuration
13
  SYMBOL_KRAKEN = "BTC/USD"
14
  PORT = 7860
15
  BROADCAST_RATE = 1.0
@@ -82,14 +82,12 @@ def calculate_indicators(candles):
82
  def train_model(df):
83
  logging.info("Training ML Model...")
84
 
85
- # 1. Prepare Features (X)
86
  feature_cols = ['close', 'ema', 'bb_upper', 'bb_lower', 'rsi', 'macd', 'stoch_k', 'atr', 'obv', 'vwap']
87
 
88
- # Create a clean copy for training data
89
  data = df.dropna().copy()
90
 
91
- # 2. Create Targets (y) - OPTIMIZED to fix Fragmentation Warning
92
- # Instead of adding columns in a loop, we create a dict and concat once
93
  future_shifts = {}
94
  targets = []
95
 
@@ -98,11 +96,9 @@ def train_model(df):
98
  future_shifts[col_name] = data['close'].shift(-i)
99
  targets.append(col_name)
100
 
101
- # Concatenate all target columns at once
102
  target_df = pd.DataFrame(future_shifts, index=data.index)
103
  data = pd.concat([data, target_df], axis=1)
104
 
105
- # Drop rows where we don't have future data (the last 100 rows)
106
  data = data.dropna()
107
 
108
  if len(data) < 100:
@@ -122,17 +118,13 @@ def train_model(df):
122
  def get_prediction(df, model):
123
  if model is None: return []
124
 
125
- # Get the very last row of data (current market state)
126
  feature_cols = ['close', 'ema', 'bb_upper', 'bb_lower', 'rsi', 'macd', 'stoch_k', 'atr', 'obv', 'vwap']
127
  last_row = df.iloc[[-1]][feature_cols]
128
 
129
- # Check for NaNs
130
  if last_row.isnull().values.any(): return []
131
 
132
- # Predict
133
  prediction = model.predict(last_row.values)[0]
134
 
135
- # Format for frontend
136
  current_time = int(df.iloc[-1]['time'])
137
  pred_data = []
138
  for i, price in enumerate(prediction):
@@ -152,16 +144,24 @@ def process_market_data():
152
  if df is None or len(df) < 50: return {"error": "Not enough data"}
153
 
154
  # 2. Train Model (Periodically)
155
- # Train initially or every 15 minutes (900 seconds)
156
  if market_state['model'] is None or (time.time() - market_state['last_training_time'] > 900):
157
- market_state['model'] = train_model(df)
158
- market_state['last_training_time'] = time.time()
 
 
 
159
 
160
  # 3. Get Prediction
161
- predictions = get_prediction(df, market_state['model'])
 
 
 
 
162
 
163
- # 4. Prepare JSON
164
- full_data = df.where(pd.notnull(df), None).to_dict('records')
 
 
165
 
166
  return {
167
  "data": full_data,
@@ -190,21 +190,21 @@ HTML_PAGE = f"""
190
  <body>
191
  <div class="header">
192
  <span style="color:#00e676">{SYMBOL_KRAKEN} + Random Forest (Next 100 Candles)</span>
193
- <span id="clock" style="color:#888">Initializing...</span>
194
  </div>
195
 
196
  <div id="charts-container">
197
  <div id="main-chart" class="chart-row">
198
  <div class="legend">
199
  <span class="l-item" style="color:#00ff9d">Price</span>
200
- <span class="l-item" style="color:#bf5af2">AI Prediction</span>
201
  <span class="l-item" style="color:#2962FF">EMA</span>
202
  </div>
203
  </div>
204
  <div id="osc-chart" class="chart-row">
205
  <div class="legend">
206
  <span class="l-item" style="color:#9C27B0">RSI</span>
207
- <span class="l-item" style="color:#00BCD4">MACD</span>
208
  </div>
209
  </div>
210
  </div>
@@ -222,16 +222,18 @@ HTML_PAGE = f"""
222
  crosshair: {{ mode: 1 }}
223
  }};
224
 
 
225
  const mainChart = LightweightCharts.createChart(mainEl, commonOpts);
226
  const candles = mainChart.addCandlestickSeries({{ upColor: '#00ff9d', downColor: '#ff3b3b', borderVisible: false }});
227
  const ema = mainChart.addLineSeries({{ color: '#2962FF', lineWidth: 1 }});
228
- const predLine = mainChart.addLineSeries({{ color: '#bf5af2', lineWidth: 2, lineStyle: 2, title: 'AI Forecast' }});
229
 
230
  const oscChart = LightweightCharts.createChart(oscEl, commonOpts);
231
  const rsi = oscChart.addLineSeries({{ color: '#9C27B0', lineWidth: 1 }});
232
  const macdHist = oscChart.addHistogramSeries({{ priceScaleId: 'macd', color: '#2962FF' }});
233
  oscChart.priceScale('macd').applyOptions({{ scaleMargins: {{ top: 0.8, bottom: 0 }} }});
234
 
 
235
  new ResizeObserver(entries => {{
236
  for (let e of entries) {{
237
  if(e.target === mainEl) mainChart.applyOptions({{ width: e.contentRect.width, height: e.contentRect.height }});
@@ -239,6 +241,7 @@ HTML_PAGE = f"""
239
  }}
240
  }}).observe(document.body);
241
 
 
242
  function syncCharts(source, targets) {{
243
  source.timeScale().subscribeVisibleLogicalRangeChange(range => {{
244
  targets.forEach(t => t.timeScale().setVisibleLogicalRange(range));
@@ -247,33 +250,50 @@ HTML_PAGE = f"""
247
  syncCharts(mainChart, [oscChart]);
248
  syncCharts(oscChart, [mainChart]);
249
 
 
250
  function connect() {{
251
  const ws = new WebSocket((location.protocol === 'https:' ? 'wss' : 'ws') + '://' + location.host + '/ws');
 
252
  ws.onmessage = (e) => {{
253
- const payload = JSON.parse(e.data);
254
- if (!payload.data) return;
255
-
256
- const d = payload.data;
257
- const mapData = (key) => d.map(x => ({{ time: x.time, value: x[key] }})).filter(x => x.value !== null);
258
-
259
- candles.setData(d.map(x => ({{ time: x.time, open: x.open, high: x.high, low: x.low, close: x.close }})));
260
- ema.setData(mapData('ema'));
261
- rsi.setData(mapData('rsi'));
262
-
263
- if(payload.prediction && payload.prediction.length > 0) {{
264
- predLine.setData(payload.prediction);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  }}
266
-
267
- macdHist.setData(d.map(x => ({{
268
- time: x.time,
269
- value: x.macd_hist || 0,
270
- color: (x.macd_hist||0) >= 0 ? '#26a69a' : '#ef5350'
271
- }})));
272
-
273
- document.getElementById('clock').innerText = new Date().toISOString().split('T')[1].split('.')[0] + ' UTC';
274
  }};
275
- ws.onclose = () => setTimeout(connect, 2000);
 
 
 
 
276
  }}
 
277
  connect();
278
  }});
279
  </script>
@@ -309,7 +329,6 @@ async def kraken_worker():
309
  except Exception as e:
310
  logging.error(f"Init Error: {e}")
311
 
312
- # WebSocket Stream
313
  while True:
314
  try:
315
  async with websockets.connect("wss://ws.kraken.com/v2") as ws:
 
9
  import websockets
10
  from sklearn.ensemble import RandomForestRegressor
11
 
12
+ # --- Configuration ---
13
  SYMBOL_KRAKEN = "BTC/USD"
14
  PORT = 7860
15
  BROADCAST_RATE = 1.0
 
82
  def train_model(df):
83
  logging.info("Training ML Model...")
84
 
 
85
  feature_cols = ['close', 'ema', 'bb_upper', 'bb_lower', 'rsi', 'macd', 'stoch_k', 'atr', 'obv', 'vwap']
86
 
87
+ # Clean data for training
88
  data = df.dropna().copy()
89
 
90
+ # Create Targets efficiently (fix for fragmentation warning)
 
91
  future_shifts = {}
92
  targets = []
93
 
 
96
  future_shifts[col_name] = data['close'].shift(-i)
97
  targets.append(col_name)
98
 
 
99
  target_df = pd.DataFrame(future_shifts, index=data.index)
100
  data = pd.concat([data, target_df], axis=1)
101
 
 
102
  data = data.dropna()
103
 
104
  if len(data) < 100:
 
118
  def get_prediction(df, model):
119
  if model is None: return []
120
 
 
121
  feature_cols = ['close', 'ema', 'bb_upper', 'bb_lower', 'rsi', 'macd', 'stoch_k', 'atr', 'obv', 'vwap']
122
  last_row = df.iloc[[-1]][feature_cols]
123
 
 
124
  if last_row.isnull().values.any(): return []
125
 
 
126
  prediction = model.predict(last_row.values)[0]
127
 
 
128
  current_time = int(df.iloc[-1]['time'])
129
  pred_data = []
130
  for i, price in enumerate(prediction):
 
144
  if df is None or len(df) < 50: return {"error": "Not enough data"}
145
 
146
  # 2. Train Model (Periodically)
 
147
  if market_state['model'] is None or (time.time() - market_state['last_training_time'] > 900):
148
+ try:
149
+ market_state['model'] = train_model(df)
150
+ market_state['last_training_time'] = time.time()
151
+ except Exception as e:
152
+ logging.error(f"Training failed: {e}")
153
 
154
  # 3. Get Prediction
155
+ predictions = []
156
+ try:
157
+ predictions = get_prediction(df, market_state['model'])
158
+ except Exception as e:
159
+ logging.error(f"Prediction failed: {e}")
160
 
161
+ # 4. Clean Data for JSON (Remove Infinity/NaN)
162
+ # This prevents the "blank graph" issue caused by invalid JSON
163
+ df_clean = df.replace([np.inf, -np.inf], np.nan)
164
+ full_data = df_clean.where(pd.notnull(df_clean), None).to_dict('records')
165
 
166
  return {
167
  "data": full_data,
 
190
  <body>
191
  <div class="header">
192
  <span style="color:#00e676">{SYMBOL_KRAKEN} + Random Forest (Next 100 Candles)</span>
193
+ <span id="clock" style="color:#888">Connecting...</span>
194
  </div>
195
 
196
  <div id="charts-container">
197
  <div id="main-chart" class="chart-row">
198
  <div class="legend">
199
  <span class="l-item" style="color:#00ff9d">Price</span>
200
+ <span class="l-item" style="color:#bf5af2">AI Forecast</span>
201
  <span class="l-item" style="color:#2962FF">EMA</span>
202
  </div>
203
  </div>
204
  <div id="osc-chart" class="chart-row">
205
  <div class="legend">
206
  <span class="l-item" style="color:#9C27B0">RSI</span>
207
+ <span class="l-item" style="color:#26a69a">MACD</span>
208
  </div>
209
  </div>
210
  </div>
 
222
  crosshair: {{ mode: 1 }}
223
  }};
224
 
225
+ // 1. Initialize Charts
226
  const mainChart = LightweightCharts.createChart(mainEl, commonOpts);
227
  const candles = mainChart.addCandlestickSeries({{ upColor: '#00ff9d', downColor: '#ff3b3b', borderVisible: false }});
228
  const ema = mainChart.addLineSeries({{ color: '#2962FF', lineWidth: 1 }});
229
+ const predLine = mainChart.addLineSeries({{ color: '#bf5af2', lineWidth: 2, lineStyle: 2 }});
230
 
231
  const oscChart = LightweightCharts.createChart(oscEl, commonOpts);
232
  const rsi = oscChart.addLineSeries({{ color: '#9C27B0', lineWidth: 1 }});
233
  const macdHist = oscChart.addHistogramSeries({{ priceScaleId: 'macd', color: '#2962FF' }});
234
  oscChart.priceScale('macd').applyOptions({{ scaleMargins: {{ top: 0.8, bottom: 0 }} }});
235
 
236
+ // 2. Responsive Resize
237
  new ResizeObserver(entries => {{
238
  for (let e of entries) {{
239
  if(e.target === mainEl) mainChart.applyOptions({{ width: e.contentRect.width, height: e.contentRect.height }});
 
241
  }}
242
  }}).observe(document.body);
243
 
244
+ // 3. Sync Time Scales
245
  function syncCharts(source, targets) {{
246
  source.timeScale().subscribeVisibleLogicalRangeChange(range => {{
247
  targets.forEach(t => t.timeScale().setVisibleLogicalRange(range));
 
250
  syncCharts(mainChart, [oscChart]);
251
  syncCharts(oscChart, [mainChart]);
252
 
253
+ // 4. WebSocket Logic
254
  function connect() {{
255
  const ws = new WebSocket((location.protocol === 'https:' ? 'wss' : 'ws') + '://' + location.host + '/ws');
256
+
257
  ws.onmessage = (e) => {{
258
+ try {{
259
+ const payload = JSON.parse(e.data);
260
+ if (!payload.data) return;
261
+
262
+ const d = payload.data;
263
+
264
+ // Helper to map safely (avoids undefined/null crashes)
265
+ const mapData = (key) => d
266
+ .map(x => ({{ time: x.time, value: x[key] }}))
267
+ .filter(x => x.value !== null && x.value !== undefined);
268
+
269
+ // Set Data
270
+ candles.setData(d.map(x => ({{ time: x.time, open: x.open, high: x.high, low: x.low, close: x.close }})));
271
+ ema.setData(mapData('ema'));
272
+ rsi.setData(mapData('rsi'));
273
+
274
+ if(payload.prediction && payload.prediction.length > 0) {{
275
+ predLine.setData(payload.prediction);
276
+ }}
277
+
278
+ macdHist.setData(d.map(x => ({{
279
+ time: x.time,
280
+ value: x.macd_hist || 0,
281
+ color: (x.macd_hist||0) >= 0 ? '#26a69a' : '#ef5350'
282
+ }})));
283
+
284
+ // Update Clock
285
+ document.getElementById('clock').innerText = new Date().toISOString().split('T')[1].split('.')[0] + ' UTC';
286
+ }} catch (err) {{
287
+ console.error("Chart Render Error:", err);
288
  }}
 
 
 
 
 
 
 
 
289
  }};
290
+
291
+ ws.onclose = () => {{
292
+ document.getElementById('clock').innerText = "Disconnected. Retrying...";
293
+ setTimeout(connect, 2000);
294
+ }};
295
  }}
296
+
297
  connect();
298
  }});
299
  </script>
 
329
  except Exception as e:
330
  logging.error(f"Init Error: {e}")
331
 
 
332
  while True:
333
  try:
334
  async with websockets.connect("wss://ws.kraken.com/v2") as ws: