Update app.py
Browse files
app.py
CHANGED
|
@@ -8,12 +8,12 @@ import numpy as np
|
|
| 8 |
from aiohttp import web
|
| 9 |
import websockets
|
| 10 |
from sklearn.ensemble import RandomForestRegressor
|
| 11 |
-
from sklearn.model_selection import train_test_split
|
| 12 |
|
|
|
|
| 13 |
SYMBOL_KRAKEN = "BTC/USD"
|
| 14 |
PORT = 7860
|
| 15 |
-
BROADCAST_RATE = 1.0
|
| 16 |
-
PREDICTION_HORIZON = 100
|
| 17 |
|
| 18 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
|
| 19 |
|
|
@@ -83,22 +83,26 @@ def train_model(df):
|
|
| 83 |
logging.info("Training ML Model...")
|
| 84 |
|
| 85 |
# 1. Prepare Features (X)
|
| 86 |
-
# We use the indicators we calculated as features
|
| 87 |
feature_cols = ['close', 'ema', 'bb_upper', 'bb_lower', 'rsi', 'macd', 'stoch_k', 'atr', 'obv', 'vwap']
|
| 88 |
|
| 89 |
-
#
|
| 90 |
data = df.dropna().copy()
|
| 91 |
|
| 92 |
-
# Create Targets (y)
|
| 93 |
-
#
|
| 94 |
-
|
| 95 |
targets = []
|
|
|
|
| 96 |
for i in range(1, PREDICTION_HORIZON + 1):
|
| 97 |
col_name = f'target_{i}'
|
| 98 |
-
|
| 99 |
targets.append(col_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
# Drop
|
| 102 |
data = data.dropna()
|
| 103 |
|
| 104 |
if len(data) < 100:
|
|
@@ -109,7 +113,6 @@ def train_model(df):
|
|
| 109 |
y = data[targets].values
|
| 110 |
|
| 111 |
# Train Random Forest
|
| 112 |
-
# n_estimators=50 is kept low for speed in this demo. Increase for accuracy.
|
| 113 |
model = RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1, random_state=42)
|
| 114 |
model.fit(X, y)
|
| 115 |
|
|
@@ -127,14 +130,14 @@ def get_prediction(df, model):
|
|
| 127 |
if last_row.isnull().values.any(): return []
|
| 128 |
|
| 129 |
# Predict
|
| 130 |
-
prediction = model.predict(last_row.values)[0]
|
| 131 |
|
| 132 |
# Format for frontend
|
| 133 |
current_time = int(df.iloc[-1]['time'])
|
| 134 |
pred_data = []
|
| 135 |
for i, price in enumerate(prediction):
|
| 136 |
pred_data.append({
|
| 137 |
-
"time": current_time + ((i + 1) * 60),
|
| 138 |
"value": float(price)
|
| 139 |
})
|
| 140 |
|
|
@@ -149,7 +152,7 @@ def process_market_data():
|
|
| 149 |
if df is None or len(df) < 50: return {"error": "Not enough data"}
|
| 150 |
|
| 151 |
# 2. Train Model (Periodically)
|
| 152 |
-
# Train initially or every 15 minutes
|
| 153 |
if market_state['model'] is None or (time.time() - market_state['last_training_time'] > 900):
|
| 154 |
market_state['model'] = train_model(df)
|
| 155 |
market_state['last_training_time'] = time.time()
|
|
@@ -222,7 +225,6 @@ HTML_PAGE = f"""
|
|
| 222 |
const mainChart = LightweightCharts.createChart(mainEl, commonOpts);
|
| 223 |
const candles = mainChart.addCandlestickSeries({{ upColor: '#00ff9d', downColor: '#ff3b3b', borderVisible: false }});
|
| 224 |
const ema = mainChart.addLineSeries({{ color: '#2962FF', lineWidth: 1 }});
|
| 225 |
-
// Prediction Line
|
| 226 |
const predLine = mainChart.addLineSeries({{ color: '#bf5af2', lineWidth: 2, lineStyle: 2, title: 'AI Forecast' }});
|
| 227 |
|
| 228 |
const oscChart = LightweightCharts.createChart(oscEl, commonOpts);
|
|
@@ -258,7 +260,6 @@ HTML_PAGE = f"""
|
|
| 258 |
ema.setData(mapData('ema'));
|
| 259 |
rsi.setData(mapData('rsi'));
|
| 260 |
|
| 261 |
-
// Prediction Data (Future)
|
| 262 |
if(payload.prediction && payload.prediction.length > 0) {{
|
| 263 |
predLine.setData(payload.prediction);
|
| 264 |
}}
|
|
@@ -283,8 +284,6 @@ HTML_PAGE = f"""
|
|
| 283 |
async def kraken_worker():
|
| 284 |
global market_state
|
| 285 |
try:
|
| 286 |
-
# 1. REST Snapshot (Get MORE data for training)
|
| 287 |
-
# 720 is roughly Kraken's max per request for 1m timeframe
|
| 288 |
async with aiohttp.ClientSession() as session:
|
| 289 |
url = "https://api.kraken.com/0/public/OHLC?pair=XBTUSD&interval=1"
|
| 290 |
async with session.get(url) as response:
|
|
@@ -294,7 +293,6 @@ async def kraken_worker():
|
|
| 294 |
for key in data['result']:
|
| 295 |
if key != 'last':
|
| 296 |
raw = data['result'][key]
|
| 297 |
-
# Keep last 720 candles for training
|
| 298 |
market_state['ohlc_history'] = [
|
| 299 |
{
|
| 300 |
'time': int(c[0]),
|
|
@@ -311,7 +309,7 @@ async def kraken_worker():
|
|
| 311 |
except Exception as e:
|
| 312 |
logging.error(f"Init Error: {e}")
|
| 313 |
|
| 314 |
-
#
|
| 315 |
while True:
|
| 316 |
try:
|
| 317 |
async with websockets.connect("wss://ws.kraken.com/v2") as ws:
|
|
@@ -347,7 +345,6 @@ async def kraken_worker():
|
|
| 347 |
'close': price,
|
| 348 |
'volume': vol
|
| 349 |
})
|
| 350 |
-
# Keep buffer size larger for ML
|
| 351 |
if len(market_state['ohlc_history']) > 800:
|
| 352 |
market_state['ohlc_history'].pop(0)
|
| 353 |
except: pass
|
|
@@ -379,7 +376,6 @@ async def kraken_worker():
|
|
| 379 |
async def broadcast_worker():
|
| 380 |
while True:
|
| 381 |
if connected_clients and market_state['ready']:
|
| 382 |
-
# This calculation might take 100-200ms depending on CPU
|
| 383 |
payload = process_market_data()
|
| 384 |
if payload and "data" in payload:
|
| 385 |
msg = json.dumps(payload)
|
|
|
|
| 8 |
from aiohttp import web
|
| 9 |
import websockets
|
| 10 |
from sklearn.ensemble import RandomForestRegressor
|
|
|
|
| 11 |
|
| 12 |
+
# Configuration
|
| 13 |
SYMBOL_KRAKEN = "BTC/USD"
|
| 14 |
PORT = 7860
|
| 15 |
+
BROADCAST_RATE = 1.0
|
| 16 |
+
PREDICTION_HORIZON = 100
|
| 17 |
|
| 18 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
|
| 19 |
|
|
|
|
| 83 |
logging.info("Training ML Model...")
|
| 84 |
|
| 85 |
# 1. Prepare Features (X)
|
|
|
|
| 86 |
feature_cols = ['close', 'ema', 'bb_upper', 'bb_lower', 'rsi', 'macd', 'stoch_k', 'atr', 'obv', 'vwap']
|
| 87 |
|
| 88 |
+
# Create a clean copy for training data
|
| 89 |
data = df.dropna().copy()
|
| 90 |
|
| 91 |
+
# 2. Create Targets (y) - OPTIMIZED to fix Fragmentation Warning
|
| 92 |
+
# Instead of adding columns in a loop, we create a dict and concat once
|
| 93 |
+
future_shifts = {}
|
| 94 |
targets = []
|
| 95 |
+
|
| 96 |
for i in range(1, PREDICTION_HORIZON + 1):
|
| 97 |
col_name = f'target_{i}'
|
| 98 |
+
future_shifts[col_name] = data['close'].shift(-i)
|
| 99 |
targets.append(col_name)
|
| 100 |
+
|
| 101 |
+
# Concatenate all target columns at once
|
| 102 |
+
target_df = pd.DataFrame(future_shifts, index=data.index)
|
| 103 |
+
data = pd.concat([data, target_df], axis=1)
|
| 104 |
|
| 105 |
+
# Drop rows where we don't have future data (the last 100 rows)
|
| 106 |
data = data.dropna()
|
| 107 |
|
| 108 |
if len(data) < 100:
|
|
|
|
| 113 |
y = data[targets].values
|
| 114 |
|
| 115 |
# Train Random Forest
|
|
|
|
| 116 |
model = RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1, random_state=42)
|
| 117 |
model.fit(X, y)
|
| 118 |
|
|
|
|
| 130 |
if last_row.isnull().values.any(): return []
|
| 131 |
|
| 132 |
# Predict
|
| 133 |
+
prediction = model.predict(last_row.values)[0]
|
| 134 |
|
| 135 |
# Format for frontend
|
| 136 |
current_time = int(df.iloc[-1]['time'])
|
| 137 |
pred_data = []
|
| 138 |
for i, price in enumerate(prediction):
|
| 139 |
pred_data.append({
|
| 140 |
+
"time": current_time + ((i + 1) * 60),
|
| 141 |
"value": float(price)
|
| 142 |
})
|
| 143 |
|
|
|
|
| 152 |
if df is None or len(df) < 50: return {"error": "Not enough data"}
|
| 153 |
|
| 154 |
# 2. Train Model (Periodically)
|
| 155 |
+
# Train initially or every 15 minutes (900 seconds)
|
| 156 |
if market_state['model'] is None or (time.time() - market_state['last_training_time'] > 900):
|
| 157 |
market_state['model'] = train_model(df)
|
| 158 |
market_state['last_training_time'] = time.time()
|
|
|
|
| 225 |
const mainChart = LightweightCharts.createChart(mainEl, commonOpts);
|
| 226 |
const candles = mainChart.addCandlestickSeries({{ upColor: '#00ff9d', downColor: '#ff3b3b', borderVisible: false }});
|
| 227 |
const ema = mainChart.addLineSeries({{ color: '#2962FF', lineWidth: 1 }});
|
|
|
|
| 228 |
const predLine = mainChart.addLineSeries({{ color: '#bf5af2', lineWidth: 2, lineStyle: 2, title: 'AI Forecast' }});
|
| 229 |
|
| 230 |
const oscChart = LightweightCharts.createChart(oscEl, commonOpts);
|
|
|
|
| 260 |
ema.setData(mapData('ema'));
|
| 261 |
rsi.setData(mapData('rsi'));
|
| 262 |
|
|
|
|
| 263 |
if(payload.prediction && payload.prediction.length > 0) {{
|
| 264 |
predLine.setData(payload.prediction);
|
| 265 |
}}
|
|
|
|
| 284 |
async def kraken_worker():
|
| 285 |
global market_state
|
| 286 |
try:
|
|
|
|
|
|
|
| 287 |
async with aiohttp.ClientSession() as session:
|
| 288 |
url = "https://api.kraken.com/0/public/OHLC?pair=XBTUSD&interval=1"
|
| 289 |
async with session.get(url) as response:
|
|
|
|
| 293 |
for key in data['result']:
|
| 294 |
if key != 'last':
|
| 295 |
raw = data['result'][key]
|
|
|
|
| 296 |
market_state['ohlc_history'] = [
|
| 297 |
{
|
| 298 |
'time': int(c[0]),
|
|
|
|
| 309 |
except Exception as e:
|
| 310 |
logging.error(f"Init Error: {e}")
|
| 311 |
|
| 312 |
+
# WebSocket Stream
|
| 313 |
while True:
|
| 314 |
try:
|
| 315 |
async with websockets.connect("wss://ws.kraken.com/v2") as ws:
|
|
|
|
| 345 |
'close': price,
|
| 346 |
'volume': vol
|
| 347 |
})
|
|
|
|
| 348 |
if len(market_state['ohlc_history']) > 800:
|
| 349 |
market_state['ohlc_history'].pop(0)
|
| 350 |
except: pass
|
|
|
|
| 376 |
async def broadcast_worker():
|
| 377 |
while True:
|
| 378 |
if connected_clients and market_state['ready']:
|
|
|
|
| 379 |
payload = process_market_data()
|
| 380 |
if payload and "data" in payload:
|
| 381 |
msg = json.dumps(payload)
|