Alvin3y1 commited on
Commit
2d2ed3b
·
verified ·
1 Parent(s): 2503fda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +465 -128
app.py CHANGED
@@ -6,207 +6,443 @@ import aiohttp
6
  import pandas as pd
7
  import numpy as np
8
  from aiohttp import web
9
- from sklearn.ensemble import RandomForestRegressor
 
 
 
10
 
11
  # --- CONFIGURATION ---
12
  SYMBOL_KRAKEN = "BTC/USD"
13
  PORT = 7860
14
  BROADCAST_RATE = 1.0
15
- PREDICTION_HORIZON = 100 # Predict next 100 candles
16
- MAX_HISTORY = 5000 # Store up to 5000 candles for training
17
- TRAIN_INTERVAL = 300 # Retrain model every 5 minutes
 
18
 
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  market_state = {
22
  "ohlc_history": [],
23
  "ready": False,
24
- "model": None,
 
25
  "last_training_time": 0,
26
  "last_price": 0,
27
- "price_change": 0
 
28
  }
29
 
30
  connected_clients = set()
31
 
 
 
 
 
 
 
 
 
 
32
  def calculate_indicators(candles):
33
- if len(candles) < 50:
 
34
  return None
35
 
36
- df = pd.DataFrame(candles)
37
  cols = ['open', 'high', 'low', 'close', 'volume']
38
  for c in cols:
39
- df[c] = df[c].astype(float)
 
 
 
 
40
 
41
- # --- Standard Indicators ---
42
- df['ema20'] = df['close'].ewm(span=20, adjust=False).mean()
43
- df['ema50'] = df['close'].ewm(span=50, adjust=False).mean()
 
 
 
 
 
44
 
45
- # Bollinger Bands
46
- df['std'] = df['close'].rolling(window=20).std()
47
- df['bb_upper'] = df['ema20'] + (df['std'] * 2)
48
- df['bb_lower'] = df['ema20'] - (df['std'] * 2)
 
49
 
50
- # RSI
51
- delta = df['close'].diff()
52
- gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
53
  loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
54
- rs = gain / loss
55
  df['rsi'] = 100 - (100 / (1 + rs))
 
 
 
 
 
56
 
57
- # MACD
58
- k = df['close'].ewm(span=12, adjust=False).mean()
59
- d = df['close'].ewm(span=26, adjust=False).mean()
60
- df['macd'] = k - d
61
  df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
62
  df['macd_hist'] = df['macd'] - df['macd_signal']
 
 
 
 
 
 
63
 
64
- # ATR
65
- df['tr0'] = abs(df['high'] - df['low'])
66
- df['tr1'] = abs(df['high'] - df['close'].shift())
67
- df['tr2'] = abs(df['low'] - df['close'].shift())
68
- df['tr'] = df[['tr0', 'tr1', 'tr2']].max(axis=1)
69
  df['atr'] = df['tr'].rolling(window=14).mean()
 
 
 
70
 
71
- # --- FEATURE ENGINEERING (Normalization) ---
72
- # We create features that represent % differences rather than raw prices
73
- # This helps the model learn patterns regardless of whether BTC is $20k or $100k
 
 
74
 
75
- # Distance from EMAs (Percentage)
76
- df['dist_ema20'] = (df['close'] - df['ema20']) / df['ema20']
77
- df['dist_ema50'] = (df['close'] - df['ema50']) / df['ema50']
78
 
79
- # Bollinger Band Width & Position
80
- df['bb_width'] = (df['bb_upper'] - df['bb_lower']) / df['ema20']
81
- df['bb_pos'] = (df['close'] - df['bb_lower']) / (df['bb_upper'] - df['bb_lower'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Volume Change
84
- df['vol_change'] = df['volume'].pct_change()
 
85
 
86
- # Log Returns (Momentum)
87
- df['log_ret'] = np.log(df['close'] / df['close'].shift(1))
 
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  return df
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def train_model(df):
92
- logging.info(f"Training ML Model on {len(df)} candles...")
 
 
 
93
 
94
- # Use normalized features for input
95
- feature_cols = [
96
- 'rsi', 'macd_hist', 'atr',
97
- 'dist_ema20', 'dist_ema50',
98
- 'bb_width', 'bb_pos',
99
- 'vol_change', 'log_ret'
100
- ]
101
 
102
- data = df.dropna().copy()
103
 
104
- # --- CREATE TARGETS (Percentage Change) ---
105
- targets = []
 
106
 
107
- # We want to predict the % return for the next 1 to N steps relative to CURRENT price
108
- for i in range(1, PREDICTION_HORIZON + 1):
109
- col_name = f'target_return_{i}'
110
- # Formula: (Price_Future - Price_Current) / Price_Current
111
- data[col_name] = (data['close'].shift(-i) - data['close']) / data['close']
112
- targets.append(col_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- data = data.dropna()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- if len(data) < 200:
117
- logging.warning("Not enough data to train model yet.")
118
- return None
119
 
120
- X = data[feature_cols].values
121
- y = data[targets].values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- # Increase estimators for better stability
124
- model = RandomForestRegressor(
125
- n_estimators=100,
126
- max_depth=15,
127
- min_samples_split=5,
128
- n_jobs=-1,
129
- random_state=42
130
- )
131
- model.fit(X, y)
 
 
 
 
 
 
 
132
 
133
- logging.info(f"Model Trained successfully.")
134
- return model
135
 
136
- def get_prediction(df, model):
137
- if model is None:
 
 
138
  return []
139
 
140
- feature_cols = [
141
- 'rsi', 'macd_hist', 'atr',
142
- 'dist_ema20', 'dist_ema50',
143
- 'bb_width', 'bb_pos',
144
- 'vol_change', 'log_ret'
145
- ]
146
-
147
- last_row = df.iloc[[-1]][feature_cols]
148
 
149
- if last_row.isnull().values.any():
 
 
 
150
  return []
151
-
152
- # The model predicts Percentage Returns
153
- predicted_returns = model.predict(last_row.values)[0]
154
 
155
- # Convert Percentage Returns back to Absolute Prices
156
- current_price = df.iloc[-1]['close']
157
- current_time = int(df.iloc[-1]['time'])
 
158
 
159
- pred_data = []
160
- for i, pct_change in enumerate(predicted_returns):
161
- # Reconstruct: Price = Current * (1 + Predicted_Return)
162
- future_price = current_price * (1 + pct_change)
163
 
164
- pred_data.append({
165
- "time": current_time + ((i + 1) * 60), # Add 60s for each step
166
- "value": float(future_price)
167
- })
168
-
169
- return pred_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  def process_market_data():
 
172
  if not market_state['ready'] or not market_state['ohlc_history']:
173
  return {"error": "Initializing..."}
174
 
175
  # 1. Calculate Indicators
176
  df = calculate_indicators(market_state['ohlc_history'])
177
- if df is None or len(df) < 50:
178
  return {"error": "Not enough data"}
179
 
180
- # 2. Train Model (Periodically)
181
- if market_state['model'] is None or (time.time() - market_state['last_training_time'] > TRAIN_INTERVAL):
 
 
 
 
 
 
 
182
  try:
183
- market_state['model'] = train_model(df)
184
- market_state['last_training_time'] = time.time()
 
 
 
185
  except Exception as e:
186
  logging.error(f"Training failed: {e}")
 
 
187
 
188
- # 3. Get Prediction
189
  predictions = []
190
  try:
191
- predictions = get_prediction(df, market_state['model'])
192
  except Exception as e:
193
  logging.error(f"Prediction failed: {e}")
194
 
195
- # 4. Prepare Data for Broadcast
196
- # Clean NaNs for JSON
197
  df_clean = df.replace([np.inf, -np.inf], np.nan)
198
  df_clean = df_clean.astype(object).where(pd.notnull(df_clean), None)
199
 
 
200
  last_close = float(df['close'].iloc[-1]) if len(df) > 0 else 0
201
- first_close = float(df['close'].iloc[0]) if len(df) > 0 else 0
202
  price_change = ((last_close - first_close) / first_close * 100) if first_close > 0 else 0
203
 
204
  market_state['last_price'] = last_close
205
  market_state['price_change'] = price_change
206
 
207
- # Only send last 500 candles to client to save bandwidth, but keep full history in memory
208
  display_data = df_clean.tail(500).to_dict('records')
209
- last_row = df.iloc[-1] if len(df) > 0 else {}
 
 
 
 
 
 
 
 
 
210
 
211
  return {
212
  "data": display_data,
@@ -214,14 +450,17 @@ def process_market_data():
214
  "stats": {
215
  "price": last_close,
216
  "change": round(price_change, 2),
217
- "rsi": round(float(last_row.get('rsi', 0)), 1) if pd.notna(last_row.get('rsi')) else 0,
218
- "macd": round(float(last_row.get('macd', 0)), 2) if pd.notna(last_row.get('macd')) else 0,
219
- "atr": round(float(last_row.get('atr', 0)), 2) if pd.notna(last_row.get('atr')) else 0,
220
- "volume": round(float(last_row.get('volume', 0)), 2) if pd.notna(last_row.get('volume')) else 0
 
 
221
  }
222
  }
223
 
224
- # --- FRONTEND HTML (No changes needed, handles price data perfectly) ---
 
225
  HTML_PAGE = """
226
  <!DOCTYPE html>
227
  <html lang="en">
@@ -280,6 +519,21 @@ HTML_PAGE = """
280
  color: #00ff88;
281
  }
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  .stats-row {
284
  display: flex;
285
  gap: 24px;
@@ -452,6 +706,15 @@ HTML_PAGE = """
452
  color: #bf5af2;
453
  z-index: 10;
454
  }
 
 
 
 
 
 
 
 
 
455
  </style>
456
  </head>
457
  <body>
@@ -459,6 +722,7 @@ HTML_PAGE = """
459
  <div class="logo-section">
460
  <div class="logo">QuantAI</div>
461
  <div class="symbol-badge">BTC/USD</div>
 
462
  </div>
463
 
464
  <div class="stats-row">
@@ -522,6 +786,7 @@ HTML_PAGE = """
522
  <span><div class="dot" style="background: #26a69a; opacity: 0.5"></div>Bollinger</span>
523
  </div>
524
  <div class="prediction-badge">AI Forecast: 100 candles</div>
 
525
  </div>
526
 
527
  <div id="volume-chart" class="chart-wrapper">
@@ -613,6 +878,21 @@ document.addEventListener('DOMContentLoaded', () => {
613
  crosshairMarkerVisible: false,
614
  title: 'Forecast'
615
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
  const volumeSeries = volChart.addHistogramSeries({
618
  priceFormat: { type: 'volume' },
@@ -627,6 +907,22 @@ document.addEventListener('DOMContentLoaded', () => {
627
  lineWidth: 2,
628
  priceScaleId: 'rsi'
629
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  oscChart.priceScale('rsi').applyOptions({
631
  scaleMargins: { top: 0.1, bottom: 0.1 }
632
  });
@@ -679,6 +975,19 @@ document.addEventListener('DOMContentLoaded', () => {
679
  rsiEl.className = 'stat-value ' + (rsiVal > 70 ? 'negative' : rsiVal < 30 ? 'positive' : 'neutral');
680
 
681
  document.getElementById('atr').textContent = stats.atr;
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  }
683
 
684
  if (lastData) {
@@ -760,7 +1069,13 @@ document.addEventListener('DOMContentLoaded', () => {
760
  if (volData.length > 0) volumeSeries.setData(volData);
761
 
762
  const rsiData = safeMap(d, 'rsi');
763
- if (rsiData.length > 0) rsi.setData(rsiData);
 
 
 
 
 
 
764
 
765
  const macdData = d
766
  .filter(x => x && x.time && x.macd_hist !== null && x.macd_hist !== undefined && !isNaN(x.macd_hist))
@@ -771,6 +1086,7 @@ document.addEventListener('DOMContentLoaded', () => {
771
  }));
772
  if (macdData.length > 0) macdHist.setData(macdData);
773
 
 
774
  if (payload.prediction && payload.prediction.length > 0) {
775
  const lastCandle = candleData[candleData.length - 1];
776
  const predData = [
@@ -778,6 +1094,18 @@ document.addEventListener('DOMContentLoaded', () => {
778
  ...payload.prediction.filter(p => p && p.time && p.value !== null && !isNaN(p.value))
779
  ];
780
  predLine.setData(predData);
 
 
 
 
 
 
 
 
 
 
 
 
781
  }
782
 
783
  updateStats(payload.stats, d[d.length - 1]);
@@ -808,10 +1136,11 @@ document.addEventListener('DOMContentLoaded', () => {
808
  </html>
809
  """
810
 
 
811
  async def fetch_initial_data():
 
812
  try:
813
  async with aiohttp.ClientSession() as session:
814
- # Although Kraken returns limited data, we set logic to accumulate it over time.
815
  url = "https://api.kraken.com/0/public/OHLC?pair=XBTUSD&interval=1"
816
  async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
817
  if response.status == 200:
@@ -838,7 +1167,9 @@ async def fetch_initial_data():
838
  logging.error(f"Initial data fetch error: {e}")
839
  return False
840
 
 
841
  async def kraken_rest_worker():
 
842
  await fetch_initial_data()
843
 
844
  while True:
@@ -861,26 +1192,22 @@ async def kraken_rest_worker():
861
  'close': float(c[4]),
862
  'volume': float(c[6])
863
  }
864
- for c in raw[-10:]
865
  ]
866
 
867
- # Intelligent Merge to keep history
868
  if market_state['ohlc_history']:
869
  existing_times = {c['time'] for c in market_state['ohlc_history']}
870
  for nc in new_candles:
871
  if nc['time'] in existing_times:
872
- # Update existing (in case close price changed)
873
  for i, ec in enumerate(market_state['ohlc_history']):
874
  if ec['time'] == nc['time']:
875
  market_state['ohlc_history'][i] = nc
876
  break
877
  else:
878
- # Append new
879
  market_state['ohlc_history'].append(nc)
880
 
881
  market_state['ohlc_history'].sort(key=lambda x: x['time'])
882
 
883
- # Keep MAX_HISTORY (5000)
884
  if len(market_state['ohlc_history']) > MAX_HISTORY:
885
  market_state['ohlc_history'] = market_state['ohlc_history'][-MAX_HISTORY:]
886
 
@@ -891,7 +1218,9 @@ async def kraken_rest_worker():
891
 
892
  await asyncio.sleep(5)
893
 
 
894
  async def broadcast_worker():
 
895
  while True:
896
  if connected_clients and market_state['ready']:
897
  payload = process_market_data()
@@ -906,7 +1235,9 @@ async def broadcast_worker():
906
  connected_clients.difference_update(disconnected)
907
  await asyncio.sleep(BROADCAST_RATE)
908
 
 
909
  async def websocket_handler(request):
 
910
  ws = web.WebSocketResponse()
911
  await ws.prepare(request)
912
  connected_clients.add(ws)
@@ -919,17 +1250,22 @@ async def websocket_handler(request):
919
  logging.info(f"Client disconnected. Total: {len(connected_clients)}")
920
  return ws
921
 
 
922
  async def handle_index(request):
923
  return web.Response(text=HTML_PAGE, content_type='text/html')
924
 
 
925
  async def handle_health(request):
926
  return web.json_response({
927
  "status": "ok",
928
  "ready": market_state['ready'],
929
  "candles": len(market_state['ohlc_history']),
930
- "clients": len(connected_clients)
 
 
931
  })
932
 
 
933
  async def main():
934
  app = web.Application()
935
  app.router.add_get('/', handle_index)
@@ -948,6 +1284,7 @@ async def main():
948
 
949
  await asyncio.Event().wait()
950
 
 
951
  if __name__ == "__main__":
952
  try:
953
  asyncio.run(main())
 
6
  import pandas as pd
7
  import numpy as np
8
  from aiohttp import web
9
+ from sklearn.ensemble import GradientBoostingRegressor
10
+ from sklearn.preprocessing import RobustScaler
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
 
14
  # --- CONFIGURATION ---
15
  SYMBOL_KRAKEN = "BTC/USD"
16
  PORT = 7860
17
  BROADCAST_RATE = 1.0
18
+ PREDICTION_HORIZON = 100
19
+ MAX_HISTORY = 5000
20
+ TRAIN_INTERVAL = 300
21
+ MIN_TRAINING_SAMPLES = 300
22
 
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
24
 
25
+ # Feature columns for ML model
26
+ FEATURE_COLS = [
27
+ 'rsi_norm', 'rsi_slope',
28
+ 'macd_hist_norm', 'macd_slope',
29
+ 'atr_pct',
30
+ 'dist_ema20', 'dist_ema50', 'ema_cross',
31
+ 'bb_width', 'bb_pos',
32
+ 'vol_zscore',
33
+ 'ret_1', 'ret_5', 'ret_10', 'ret_20',
34
+ 'volatility_ratio',
35
+ 'candle_body', 'upper_wick', 'lower_wick',
36
+ 'trend_strength'
37
+ ]
38
+
39
+ # Key horizons to predict (reduces noise vs predicting all 100)
40
+ KEY_HORIZONS = [1, 3, 5, 10, 20, 35, 50, 75, 100]
41
+
42
  market_state = {
43
  "ohlc_history": [],
44
  "ready": False,
45
+ "models": {}, # Dictionary of models for each horizon
46
+ "scaler": None,
47
  "last_training_time": 0,
48
  "last_price": 0,
49
+ "price_change": 0,
50
+ "training_metrics": {}
51
  }
52
 
53
  connected_clients = set()
54
 
55
+
56
+ def safe_divide(a, b, default=0.0):
57
+ """Safe division that handles zeros and NaN"""
58
+ with np.errstate(divide='ignore', invalid='ignore'):
59
+ result = np.where(b != 0, a / b, default)
60
+ result = np.where(np.isfinite(result), result, default)
61
+ return result
62
+
63
+
64
  def calculate_indicators(candles):
65
+ """Calculate technical indicators with robust normalization"""
66
+ if len(candles) < 60:
67
  return None
68
 
69
+ df = pd.DataFrame(candles).copy()
70
  cols = ['open', 'high', 'low', 'close', 'volume']
71
  for c in cols:
72
+ df[c] = pd.to_numeric(df[c], errors='coerce')
73
+
74
+ df = df.dropna(subset=['open', 'high', 'low', 'close'])
75
+ if len(df) < 60:
76
+ return None
77
 
78
+ close = df['close']
79
+ high = df['high']
80
+ low = df['low']
81
+ volume = df['volume'].fillna(0)
82
+
83
+ # --- EXPONENTIAL MOVING AVERAGES ---
84
+ df['ema20'] = close.ewm(span=20, adjust=False).mean()
85
+ df['ema50'] = close.ewm(span=50, adjust=False).mean()
86
 
87
+ # --- BOLLINGER BANDS ---
88
+ df['sma20'] = close.rolling(window=20).mean()
89
+ df['std20'] = close.rolling(window=20).std()
90
+ df['bb_upper'] = df['sma20'] + (df['std20'] * 2)
91
+ df['bb_lower'] = df['sma20'] - (df['std20'] * 2)
92
 
93
+ # --- RSI ---
94
+ delta = close.diff()
95
+ gain = delta.where(delta > 0, 0).rolling(window=14).mean()
96
  loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
97
+ rs = safe_divide(gain.values, loss.values, 1.0)
98
  df['rsi'] = 100 - (100 / (1 + rs))
99
+ df['rsi'] = df['rsi'].fillna(50).clip(0, 100)
100
+
101
+ # Normalized RSI (centered at 0, range -1 to 1)
102
+ df['rsi_norm'] = (df['rsi'] - 50) / 50
103
+ df['rsi_slope'] = df['rsi'].diff(5).fillna(0) / 50 # 5-period RSI change
104
 
105
+ # --- MACD ---
106
+ ema12 = close.ewm(span=12, adjust=False).mean()
107
+ ema26 = close.ewm(span=26, adjust=False).mean()
108
+ df['macd'] = ema12 - ema26
109
  df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
110
  df['macd_hist'] = df['macd'] - df['macd_signal']
111
+
112
+ # Normalize MACD by ATR to make it price-independent
113
+ atr_for_norm = close.rolling(20).std().replace(0, 1)
114
+ df['macd_hist_norm'] = df['macd_hist'] / atr_for_norm
115
+ df['macd_hist_norm'] = df['macd_hist_norm'].clip(-5, 5)
116
+ df['macd_slope'] = df['macd_hist_norm'].diff(3).fillna(0)
117
 
118
+ # --- ATR (Average True Range) ---
119
+ tr1 = abs(high - low)
120
+ tr2 = abs(high - close.shift())
121
+ tr3 = abs(low - close.shift())
122
+ df['tr'] = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
123
  df['atr'] = df['tr'].rolling(window=14).mean()
124
+
125
+ # ATR as percentage of price (volatility measure)
126
+ df['atr_pct'] = safe_divide(df['atr'].values, close.values) * 100
127
 
128
+ # --- NORMALIZED PRICE FEATURES ---
129
+
130
+ # Distance from EMAs (percentage)
131
+ df['dist_ema20'] = safe_divide((close - df['ema20']).values, df['ema20'].values) * 100
132
+ df['dist_ema50'] = safe_divide((close - df['ema50']).values, df['ema50'].values) * 100
133
 
134
+ # EMA cross strength
135
+ df['ema_cross'] = safe_divide((df['ema20'] - df['ema50']).values, df['ema50'].values) * 100
 
136
 
137
+ # --- BOLLINGER BAND FEATURES ---
138
+ bb_range = df['bb_upper'] - df['bb_lower']
139
+ bb_range_safe = bb_range.replace(0, np.nan).fillna(close * 0.01) # Fallback to 1% of price
140
+
141
+ df['bb_width'] = safe_divide(bb_range.values, df['sma20'].values) * 100
142
+ df['bb_pos'] = safe_divide((close - df['bb_lower']).values, bb_range_safe.values)
143
+ df['bb_pos'] = df['bb_pos'].clip(-0.5, 1.5).fillna(0.5) # Allow some overflow
144
+
145
+ # --- VOLUME FEATURES ---
146
+ vol_mean = volume.rolling(window=20).mean().replace(0, 1)
147
+ vol_std = volume.rolling(window=20).std().replace(0, 1)
148
+ df['vol_zscore'] = safe_divide((volume - vol_mean).values, vol_std.values)
149
+ df['vol_zscore'] = df['vol_zscore'].clip(-3, 3).fillna(0)
150
+
151
+ # --- RETURN FEATURES (momentum) ---
152
+ df['ret_1'] = close.pct_change(1).fillna(0) * 100
153
+ df['ret_5'] = close.pct_change(5).fillna(0) * 100
154
+ df['ret_10'] = close.pct_change(10).fillna(0) * 100
155
+ df['ret_20'] = close.pct_change(20).fillna(0) * 100
156
 
157
+ # Clip extreme returns
158
+ for col in ['ret_1', 'ret_5', 'ret_10', 'ret_20']:
159
+ df[col] = df[col].clip(-10, 10)
160
 
161
+ # --- VOLATILITY FEATURES ---
162
+ vol_short = df['ret_1'].rolling(5).std().fillna(0)
163
+ vol_long = df['ret_1'].rolling(20).std().replace(0, 1)
164
+ df['volatility_ratio'] = safe_divide(vol_short.values, vol_long.values).clip(0, 3)
165
 
166
+ # --- CANDLESTICK FEATURES ---
167
+ candle_range = (high - low).replace(0, 0.01)
168
+ df['candle_body'] = safe_divide((close - df['open']).values, candle_range.values)
169
+ df['upper_wick'] = safe_divide((high - pd.concat([close, df['open']], axis=1).max(axis=1)).values, candle_range.values)
170
+ df['lower_wick'] = safe_divide((pd.concat([close, df['open']], axis=1).min(axis=1) - low).values, candle_range.values)
171
+
172
+ # --- TREND STRENGTH ---
173
+ # Compare current price to 20-period high/low range
174
+ rolling_high = high.rolling(20).max()
175
+ rolling_low = low.rolling(20).min()
176
+ rolling_range = (rolling_high - rolling_low).replace(0, 1)
177
+ df['trend_strength'] = safe_divide((close - rolling_low).values, rolling_range.values) * 2 - 1 # -1 to 1
178
+
179
+ # Replace any remaining infinities or NaN
180
+ df = df.replace([np.inf, -np.inf], np.nan)
181
+
182
  return df
183
 
184
+
185
+ def prepare_training_data(df):
186
+ """Prepare features and multi-horizon targets for training"""
187
+ data = df.copy()
188
+
189
+ # Create target: future return at each key horizon
190
+ target_cols = []
191
+ for h in KEY_HORIZONS:
192
+ col_name = f'target_{h}'
193
+ future_price = data['close'].shift(-h)
194
+ current_price = data['close']
195
+ # Target is percentage return
196
+ data[col_name] = safe_divide((future_price - current_price).values, current_price.values) * 100
197
+ target_cols.append(col_name)
198
+
199
+ # Drop rows with NaN in features or targets
200
+ required_cols = FEATURE_COLS + target_cols
201
+ data = data.dropna(subset=required_cols)
202
+
203
+ if len(data) < MIN_TRAINING_SAMPLES:
204
+ return None, None
205
+
206
+ X = data[FEATURE_COLS].values
207
+ y_dict = {h: data[f'target_{h}'].values for h in KEY_HORIZONS}
208
+
209
+ return X, y_dict
210
+
211
+
212
  def train_model(df):
213
+ """Train separate models for each prediction horizon"""
214
+ logging.info(f"Training ML Models on {len(df)} candles...")
215
+
216
+ X, y_dict = prepare_training_data(df)
217
 
218
+ if X is None:
219
+ logging.warning("Not enough training data")
220
+ return None, None
 
 
 
 
221
 
222
+ logging.info(f"Training data: {len(X)} samples, {len(FEATURE_COLS)} features")
223
 
224
+ # Robust scaling handles outliers better than StandardScaler
225
+ scaler = RobustScaler()
226
+ X_scaled = scaler.fit_transform(X)
227
 
228
+ models = {}
229
+ metrics = {}
230
+
231
+ for h in KEY_HORIZONS:
232
+ y = y_dict[h]
233
+
234
+ # Gradient Boosting with regularization to prevent overfitting
235
+ model = GradientBoostingRegressor(
236
+ n_estimators=150,
237
+ max_depth=4,
238
+ learning_rate=0.05,
239
+ min_samples_split=30,
240
+ min_samples_leaf=15,
241
+ subsample=0.8,
242
+ max_features='sqrt',
243
+ validation_fraction=0.15,
244
+ n_iter_no_change=10,
245
+ random_state=42,
246
+ verbose=0
247
+ )
248
+
249
+ model.fit(X_scaled, y)
250
+ models[h] = model
251
 
252
+ # Calculate training R² score
253
+ train_score = model.score(X_scaled, y)
254
+ metrics[h] = {'r2': round(train_score, 3)}
255
+
256
+ logging.info(f" Horizon {h:3d}: R² = {train_score:.3f}")
257
+
258
+ # Log feature importance (from longest horizon model)
259
+ if 100 in models:
260
+ importance = dict(zip(FEATURE_COLS, models[100].feature_importances_))
261
+ top_5 = sorted(importance.items(), key=lambda x: x[1], reverse=True)[:5]
262
+ logging.info(f"Top features: {[f'{k}:{v:.3f}' for k,v in top_5]}")
263
+
264
+ market_state['training_metrics'] = metrics
265
+ logging.info("Model training complete")
266
+
267
+ return models, scaler
268
 
 
 
 
269
 
270
+ def interpolate_predictions(horizon_preds, target_horizon):
271
+ """Interpolate between key horizon predictions for smooth curve"""
272
+ horizons = sorted(horizon_preds.keys())
273
+
274
+ if target_horizon <= horizons[0]:
275
+ return horizon_preds[horizons[0]]
276
+ if target_horizon >= horizons[-1]:
277
+ return horizon_preds[horizons[-1]]
278
+
279
+ # Find surrounding horizons
280
+ lower_h = max([h for h in horizons if h <= target_horizon])
281
+ upper_h = min([h for h in horizons if h >= target_horizon])
282
+
283
+ if lower_h == upper_h:
284
+ return horizon_preds[lower_h]
285
+
286
+ # Cubic interpolation weight for smoother curves
287
+ t = (target_horizon - lower_h) / (upper_h - lower_h)
288
+ t_smooth = t * t * (3 - 2 * t) # Smoothstep function
289
+
290
+ return horizon_preds[lower_h] + (horizon_preds[upper_h] - horizon_preds[lower_h]) * t_smooth
291
 
292
+
293
+ def apply_trend_smoothing(predictions, window=5):
294
+ """Apply exponential moving average smoothing to predictions"""
295
+ if len(predictions) < window:
296
+ return predictions
297
+
298
+ smoothed = []
299
+ alpha = 2 / (window + 1)
300
+
301
+ # Initialize with first value
302
+ ema = predictions[0]
303
+ smoothed.append(ema)
304
+
305
+ for i in range(1, len(predictions)):
306
+ ema = alpha * predictions[i] + (1 - alpha) * ema
307
+ smoothed.append(ema)
308
 
309
+ return smoothed
 
310
 
311
+
312
+ def get_prediction(df, models, scaler):
313
+ """Generate price predictions for the next N candles"""
314
+ if not models or scaler is None:
315
  return []
316
 
317
+ # Check if we have valid features
318
+ last_row = df.iloc[-1:].copy()
 
 
 
 
 
 
319
 
320
+ # Validate features
321
+ missing_features = [col for col in FEATURE_COLS if col not in last_row.columns]
322
+ if missing_features:
323
+ logging.error(f"Missing features: {missing_features}")
324
  return []
 
 
 
325
 
326
+ feature_values = last_row[FEATURE_COLS]
327
+ if feature_values.isnull().values.any():
328
+ logging.warning("NaN in prediction features")
329
+ return []
330
 
331
+ try:
332
+ X = feature_values.values
333
+ X_scaled = scaler.transform(X)
 
334
 
335
+ current_price = float(df.iloc[-1]['close'])
336
+ current_time = int(df.iloc[-1]['time'])
337
+
338
+ # Get predictions at key horizons
339
+ horizon_preds = {}
340
+ for h in KEY_HORIZONS:
341
+ if h in models:
342
+ pred_return = models[h].predict(X_scaled)[0]
343
+ # Clip extreme predictions
344
+ pred_return = np.clip(pred_return, -15, 15) # Max ±15% move
345
+ horizon_preds[h] = pred_return
346
+
347
+ if not horizon_preds:
348
+ return []
349
+
350
+ # Interpolate for all time steps
351
+ raw_returns = []
352
+ for i in range(1, PREDICTION_HORIZON + 1):
353
+ pct_return = interpolate_predictions(horizon_preds, i)
354
+ raw_returns.append(pct_return)
355
+
356
+ # Apply trend smoothing
357
+ smoothed_returns = apply_trend_smoothing(raw_returns, window=7)
358
+
359
+ # Convert to prices with momentum continuation
360
+ predictions = []
361
+ prev_price = current_price
362
+
363
+ for i, pct_return in enumerate(smoothed_returns):
364
+ # Price = current * (1 + cumulative_return%)
365
+ future_price = current_price * (1 + pct_return / 100)
366
+
367
+ # Add slight momentum continuation (reduces jumps)
368
+ if i > 0:
369
+ momentum = (future_price - prev_price) * 0.1
370
+ future_price = future_price + momentum
371
+
372
+ predictions.append({
373
+ "time": current_time + ((i + 1) * 60),
374
+ "value": round(float(future_price), 2)
375
+ })
376
+ prev_price = future_price
377
+
378
+ return predictions
379
+
380
+ except Exception as e:
381
+ logging.error(f"Prediction error: {e}")
382
+ return []
383
+
384
 
385
  def process_market_data():
386
+ """Process market data and generate predictions"""
387
  if not market_state['ready'] or not market_state['ohlc_history']:
388
  return {"error": "Initializing..."}
389
 
390
  # 1. Calculate Indicators
391
  df = calculate_indicators(market_state['ohlc_history'])
392
+ if df is None or len(df) < 60:
393
  return {"error": "Not enough data"}
394
 
395
+ # 2. Train Model Periodically
396
+ current_time = time.time()
397
+ should_train = (
398
+ market_state['models'] is None or
399
+ len(market_state['models']) == 0 or
400
+ (current_time - market_state['last_training_time'] > TRAIN_INTERVAL)
401
+ )
402
+
403
+ if should_train:
404
  try:
405
+ models, scaler = train_model(df)
406
+ if models:
407
+ market_state['models'] = models
408
+ market_state['scaler'] = scaler
409
+ market_state['last_training_time'] = current_time
410
  except Exception as e:
411
  logging.error(f"Training failed: {e}")
412
+ import traceback
413
+ traceback.print_exc()
414
 
415
+ # 3. Generate Predictions
416
  predictions = []
417
  try:
418
+ predictions = get_prediction(df, market_state['models'], market_state['scaler'])
419
  except Exception as e:
420
  logging.error(f"Prediction failed: {e}")
421
 
422
+ # 4. Prepare Display Data
 
423
  df_clean = df.replace([np.inf, -np.inf], np.nan)
424
  df_clean = df_clean.astype(object).where(pd.notnull(df_clean), None)
425
 
426
+ # Calculate stats
427
  last_close = float(df['close'].iloc[-1]) if len(df) > 0 else 0
428
+ first_close = float(df['close'].iloc[0]) if len(df) > 0 else last_close
429
  price_change = ((last_close - first_close) / first_close * 100) if first_close > 0 else 0
430
 
431
  market_state['last_price'] = last_close
432
  market_state['price_change'] = price_change
433
 
434
+ # Only send last 500 candles to client
435
  display_data = df_clean.tail(500).to_dict('records')
436
+
437
+ # Extract last row stats safely
438
+ last_row = df.iloc[-1]
439
+
440
+ def safe_get(series, key, default=0):
441
+ try:
442
+ val = series[key] if key in series.index else default
443
+ return float(val) if pd.notna(val) and np.isfinite(val) else default
444
+ except:
445
+ return default
446
 
447
  return {
448
  "data": display_data,
 
450
  "stats": {
451
  "price": last_close,
452
  "change": round(price_change, 2),
453
+ "rsi": round(safe_get(last_row, 'rsi'), 1),
454
+ "macd": round(safe_get(last_row, 'macd'), 2),
455
+ "atr": round(safe_get(last_row, 'atr'), 2),
456
+ "volume": round(safe_get(last_row, 'volume'), 2),
457
+ "candles": len(market_state['ohlc_history']),
458
+ "model_ready": len(market_state.get('models', {})) > 0
459
  }
460
  }
461
 
462
+
463
+ # --- FRONTEND HTML ---
464
  HTML_PAGE = """
465
  <!DOCTYPE html>
466
  <html lang="en">
 
519
  color: #00ff88;
520
  }
521
 
522
+ .model-badge {
523
+ background: rgba(191, 90, 242, 0.1);
524
+ border: 1px solid rgba(191, 90, 242, 0.3);
525
+ padding: 4px 10px;
526
+ border-radius: 12px;
527
+ font-size: 11px;
528
+ color: #bf5af2;
529
+ }
530
+
531
+ .model-badge.ready {
532
+ background: rgba(0, 255, 136, 0.1);
533
+ border-color: rgba(0, 255, 136, 0.3);
534
+ color: #00ff88;
535
+ }
536
+
537
  .stats-row {
538
  display: flex;
539
  gap: 24px;
 
706
  color: #bf5af2;
707
  z-index: 10;
708
  }
709
+
710
+ .candle-count {
711
+ position: absolute;
712
+ bottom: 12px;
713
+ right: 16px;
714
+ font-size: 10px;
715
+ color: #444;
716
+ z-index: 10;
717
+ }
718
  </style>
719
  </head>
720
  <body>
 
722
  <div class="logo-section">
723
  <div class="logo">QuantAI</div>
724
  <div class="symbol-badge">BTC/USD</div>
725
+ <div id="model-status" class="model-badge">Model: Training...</div>
726
  </div>
727
 
728
  <div class="stats-row">
 
786
  <span><div class="dot" style="background: #26a69a; opacity: 0.5"></div>Bollinger</span>
787
  </div>
788
  <div class="prediction-badge">AI Forecast: 100 candles</div>
789
+ <div id="candle-count" class="candle-count">Candles: --</div>
790
  </div>
791
 
792
  <div id="volume-chart" class="chart-wrapper">
 
878
  crosshairMarkerVisible: false,
879
  title: 'Forecast'
880
  });
881
+
882
+ // Prediction confidence band (optional visual)
883
+ const predUpper = mainChart.addLineSeries({
884
+ color: 'rgba(191, 90, 242, 0.15)',
885
+ lineWidth: 1,
886
+ lineStyle: LightweightCharts.LineStyle.Dotted,
887
+ crosshairMarkerVisible: false
888
+ });
889
+
890
+ const predLower = mainChart.addLineSeries({
891
+ color: 'rgba(191, 90, 242, 0.15)',
892
+ lineWidth: 1,
893
+ lineStyle: LightweightCharts.LineStyle.Dotted,
894
+ crosshairMarkerVisible: false
895
+ });
896
 
897
  const volumeSeries = volChart.addHistogramSeries({
898
  priceFormat: { type: 'volume' },
 
907
  lineWidth: 2,
908
  priceScaleId: 'rsi'
909
  });
910
+
911
+ // RSI overbought/oversold lines
912
+ const rsiUpper = oscChart.addLineSeries({
913
+ color: 'rgba(239, 83, 80, 0.3)',
914
+ lineWidth: 1,
915
+ lineStyle: LightweightCharts.LineStyle.Dashed,
916
+ priceScaleId: 'rsi'
917
+ });
918
+
919
+ const rsiLower = oscChart.addLineSeries({
920
+ color: 'rgba(38, 166, 154, 0.3)',
921
+ lineWidth: 1,
922
+ lineStyle: LightweightCharts.LineStyle.Dashed,
923
+ priceScaleId: 'rsi'
924
+ });
925
+
926
  oscChart.priceScale('rsi').applyOptions({
927
  scaleMargins: { top: 0.1, bottom: 0.1 }
928
  });
 
975
  rsiEl.className = 'stat-value ' + (rsiVal > 70 ? 'negative' : rsiVal < 30 ? 'positive' : 'neutral');
976
 
977
  document.getElementById('atr').textContent = stats.atr;
978
+
979
+ // Update model status
980
+ const modelBadge = document.getElementById('model-status');
981
+ if (stats.model_ready) {
982
+ modelBadge.textContent = 'Model: Active';
983
+ modelBadge.className = 'model-badge ready';
984
+ } else {
985
+ modelBadge.textContent = 'Model: Training...';
986
+ modelBadge.className = 'model-badge';
987
+ }
988
+
989
+ // Update candle count
990
+ document.getElementById('candle-count').textContent = 'Candles: ' + (stats.candles || '--');
991
  }
992
 
993
  if (lastData) {
 
1069
  if (volData.length > 0) volumeSeries.setData(volData);
1070
 
1071
  const rsiData = safeMap(d, 'rsi');
1072
+ if (rsiData.length > 0) {
1073
+ rsi.setData(rsiData);
1074
+ // Set RSI reference lines
1075
+ const times = rsiData.map(x => x.time);
1076
+ rsiUpper.setData(times.map(t => ({time: t, value: 70})));
1077
+ rsiLower.setData(times.map(t => ({time: t, value: 30})));
1078
+ }
1079
 
1080
  const macdData = d
1081
  .filter(x => x && x.time && x.macd_hist !== null && x.macd_hist !== undefined && !isNaN(x.macd_hist))
 
1086
  }));
1087
  if (macdData.length > 0) macdHist.setData(macdData);
1088
 
1089
+ // Handle predictions with confidence bands
1090
  if (payload.prediction && payload.prediction.length > 0) {
1091
  const lastCandle = candleData[candleData.length - 1];
1092
  const predData = [
 
1094
  ...payload.prediction.filter(p => p && p.time && p.value !== null && !isNaN(p.value))
1095
  ];
1096
  predLine.setData(predData);
1097
+
1098
+ // Add confidence bands (±1% expanding over time)
1099
+ const upperBand = predData.map((p, i) => ({
1100
+ time: p.time,
1101
+ value: p.value * (1 + 0.002 * Math.sqrt(i))
1102
+ }));
1103
+ const lowerBand = predData.map((p, i) => ({
1104
+ time: p.time,
1105
+ value: p.value * (1 - 0.002 * Math.sqrt(i))
1106
+ }));
1107
+ predUpper.setData(upperBand);
1108
+ predLower.setData(lowerBand);
1109
  }
1110
 
1111
  updateStats(payload.stats, d[d.length - 1]);
 
1136
  </html>
1137
  """
1138
 
1139
+
1140
  async def fetch_initial_data():
1141
+ """Fetch initial OHLC data from Kraken"""
1142
  try:
1143
  async with aiohttp.ClientSession() as session:
 
1144
  url = "https://api.kraken.com/0/public/OHLC?pair=XBTUSD&interval=1"
1145
  async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
1146
  if response.status == 200:
 
1167
  logging.error(f"Initial data fetch error: {e}")
1168
  return False
1169
 
1170
+
1171
  async def kraken_rest_worker():
1172
+ """Background worker to fetch and update OHLC data"""
1173
  await fetch_initial_data()
1174
 
1175
  while True:
 
1192
  'close': float(c[4]),
1193
  'volume': float(c[6])
1194
  }
1195
+ for c in raw[-20:] # Get last 20 candles for merging
1196
  ]
1197
 
 
1198
  if market_state['ohlc_history']:
1199
  existing_times = {c['time'] for c in market_state['ohlc_history']}
1200
  for nc in new_candles:
1201
  if nc['time'] in existing_times:
 
1202
  for i, ec in enumerate(market_state['ohlc_history']):
1203
  if ec['time'] == nc['time']:
1204
  market_state['ohlc_history'][i] = nc
1205
  break
1206
  else:
 
1207
  market_state['ohlc_history'].append(nc)
1208
 
1209
  market_state['ohlc_history'].sort(key=lambda x: x['time'])
1210
 
 
1211
  if len(market_state['ohlc_history']) > MAX_HISTORY:
1212
  market_state['ohlc_history'] = market_state['ohlc_history'][-MAX_HISTORY:]
1213
 
 
1218
 
1219
  await asyncio.sleep(5)
1220
 
1221
+
1222
  async def broadcast_worker():
1223
+ """Broadcast market data to connected clients"""
1224
  while True:
1225
  if connected_clients and market_state['ready']:
1226
  payload = process_market_data()
 
1235
  connected_clients.difference_update(disconnected)
1236
  await asyncio.sleep(BROADCAST_RATE)
1237
 
1238
+
1239
  async def websocket_handler(request):
1240
+ """Handle WebSocket connections"""
1241
  ws = web.WebSocketResponse()
1242
  await ws.prepare(request)
1243
  connected_clients.add(ws)
 
1250
  logging.info(f"Client disconnected. Total: {len(connected_clients)}")
1251
  return ws
1252
 
1253
+
1254
  async def handle_index(request):
1255
  return web.Response(text=HTML_PAGE, content_type='text/html')
1256
 
1257
+
1258
  async def handle_health(request):
1259
  return web.json_response({
1260
  "status": "ok",
1261
  "ready": market_state['ready'],
1262
  "candles": len(market_state['ohlc_history']),
1263
+ "clients": len(connected_clients),
1264
+ "model_ready": len(market_state.get('models', {})) > 0,
1265
+ "training_metrics": market_state.get('training_metrics', {})
1266
  })
1267
 
1268
+
1269
  async def main():
1270
  app = web.Application()
1271
  app.router.add_get('/', handle_index)
 
1284
 
1285
  await asyncio.Event().wait()
1286
 
1287
+
1288
  if __name__ == "__main__":
1289
  try:
1290
  asyncio.run(main())