danielthatu12 commited on
Commit
5866ad0
·
verified ·
1 Parent(s): 97bc428

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +569 -567
app.py CHANGED
@@ -1,567 +1,569 @@
1
- """
2
- app.py – StockBuddy Flask API
3
- =================================
4
- LIGHTWEIGHT CHANGES vs original:
5
- [OPT-A] Removed the startup TF validation model (was creating & running a test
6
- LSTM on every cold start – wastes ~10 s and ~100 MB RAM on free tier).
7
- Replaced with a simple tf.constant() smoke-test.
8
- [OPT-B] PORT is now read from the PORT environment variable so the server
9
- works on Render (sets $PORT automatically) and Hugging Face Spaces
10
- (expects port 7860) without code changes.
11
- [OPT-C] time_step updated to 30 throughout (was 45) to match the lighter model.
12
- All REST API routes are unchanged from the original.
13
- """
14
-
15
- from flask import Flask, request, jsonify
16
- from flask_cors import CORS
17
- import numpy as np
18
- import pandas as pd
19
- import os
20
- import threading
21
- import time
22
- from datetime import datetime, timedelta
23
- import json
24
- import model as stock_model
25
- import sys
26
- import requests
27
- import traceback
28
- from sklearn.preprocessing import MinMaxScaler
29
- from tensorflow.keras.models import Sequential
30
- from tensorflow.keras.layers import LSTM, Dense, Dropout
31
- from tensorflow.keras.callbacks import Callback
32
- import tensorflow as tf
33
- import xgboost as xgb
34
-
35
- app = Flask(__name__)
36
- CORS(app)
37
-
38
-
39
- @app.route("/", methods=["GET"])
40
- def home():
41
- return jsonify({"status": "running", "message": "StockBuddy API is live!"})
42
-
43
- # [OPT-A] Lightweight TF smoke-test instead of building & running a full LSTM
44
- def validate_tensorflow():
45
- """Quick TensorFlow sanity-check (no model created, no GPU required)."""
46
- try:
47
- print("TensorFlow version:", tf.__version__)
48
- # A tiny constant operation is enough to confirm TF is importable and
49
- # the runtime works. Full model creation is deferred to the first
50
- # prediction request so the cold-start is fast on free-tier hosts.
51
- _ = tf.constant([1.0, 2.0, 3.0])
52
- gpus = tf.config.list_physical_devices("GPU")
53
- if gpus:
54
- msg = f"GPU available ({len(gpus)} device(s)) – running in GPU mode."
55
- else:
56
- msg = "No GPU detected – running in CPU mode (expected on free tier)."
57
- print(f"TensorFlow OK: {msg}")
58
- return True, msg
59
- except Exception as e:
60
- print(f"TensorFlow validation failed: {e}")
61
- return False, f"TensorFlow error: {e}"
62
-
63
-
64
- # Run smoke-test at startup
65
- tf_status, tf_message = validate_tensorflow()
66
- if not tf_status:
67
- print(f"WARNING: {tf_message}")
68
- else:
69
- print(f"TensorFlow validation: {tf_message}")
70
-
71
- # Dictionary to store running prediction tasks
72
- prediction_tasks = {}
73
-
74
-
75
- class PredictionTask:
76
- def __init__(self, user_id, symbol, days_ahead):
77
- self.user_id = user_id
78
- self.symbol = symbol
79
- self.days_ahead = days_ahead
80
- self.progress = 0
81
- self.status = "pending"
82
- self.result = None
83
- self.sentiment_result = None
84
- self.thread = None
85
- self.stop_requested = False
86
- self.stop_acknowledged = False
87
- # Unique task ID: millisecond timestamp + random hex suffix
88
- timestamp = int(time.time() * 1000)
89
- random_suffix = os.urandom(4).hex()
90
- self.task_id = f"{user_id}_{symbol}_{timestamp}_{random_suffix}"
91
-
92
- def run(self):
93
- self.thread = threading.Thread(target=self._run_prediction)
94
- self.thread.daemon = True
95
- self.thread.start()
96
- return self.task_id
97
-
98
- def is_stop_requested(self):
99
- """Callback for model training loops to poll stop flag."""
100
- if self.stop_requested and not self.stop_acknowledged:
101
- self.stop_acknowledged = True
102
- self.status = "stopped"
103
- return True
104
- return self.stop_requested
105
-
106
- def _run_prediction(self):
107
- try:
108
- print(f"Starting prediction for {self.symbol} (task: {self.task_id})")
109
- self.status = "running"
110
- self.progress = 10
111
-
112
- # ── Fetch historical data ────────────────────────────────────────
113
- print(f"Fetching historical data for {self.symbol}...")
114
- try:
115
- data = stock_model.fetch_stock_data(self.symbol, outputsize="compact")
116
- print(f"Fetched {len(data)} rows for {self.symbol}")
117
- except Exception as e:
118
- error_msg = str(e)
119
- print(f"\n[ERROR] {error_msg}\n")
120
- self.status = "failed"
121
- self.result = {"error": error_msg}
122
- return
123
-
124
- if data is None:
125
- self.status = "failed"
126
- self.result = {"error": f"Could not fetch data for {self.symbol}"}
127
- return
128
-
129
- if self.stop_requested:
130
- self.status = "stopped"; return
131
-
132
- if len(data) < 60:
133
- self.status = "failed"
134
- self.result = {"error": f"Insufficient data for {self.symbol} "
135
- f"(got {len(data)}, need ≥60)"}
136
- return
137
-
138
- # ── Extract last actual close ────────────────────────────────────
139
- try:
140
- if isinstance(data, pd.DataFrame) and "Close" in data.columns:
141
- last_actual_close = float(data["Close"].iloc[-1])
142
- last_date = data.index[-1]
143
- else:
144
- last_actual_close = float(data.iloc[-1, 0])
145
- last_date = data.index[-1]
146
- print(f"Latest close for {self.symbol}: "
147
- f"${last_actual_close:.2f} on {last_date.strftime('%Y-%m-%d')}")
148
- except Exception as e:
149
- self.status = "failed"
150
- self.result = {"error": f"Error reading price data: {e}"}
151
- return
152
-
153
- self.progress = 20
154
- if self.stop_requested:
155
- self.status = "stopped"; return
156
-
157
- # ── Sentiment analysis ───────────────────────────────────────────
158
- try:
159
- print(f"Fetching news for {self.symbol}...")
160
- headlines = stock_model.fetch_finnhub_news(self.symbol)
161
- print(f"Got {len(headlines)} headlines")
162
- self.progress = 30
163
- if self.stop_requested:
164
- self.status = "stopped"; return
165
-
166
- sentiment_results, sentiment_totals = \
167
- stock_model.analyze_sentiment(headlines)
168
- sentiment_summary = stock_model.generate_sentiment_summary(
169
- sentiment_totals, headlines, self.symbol)
170
- self.sentiment_result = {
171
- "totals": sentiment_totals,
172
- "summary": sentiment_summary,
173
- }
174
- except Exception as e:
175
- print(f"Sentiment error (non-fatal): {e}")
176
- self.sentiment_result = {
177
- "totals": {"positive": 0, "negative": 0, "neutral": 0},
178
- "summary": f"Unable to analyse sentiment: {e}",
179
- }
180
-
181
- self.progress = 40
182
- if self.stop_requested:
183
- self.status = "stopped"; return
184
-
185
- # ── Preprocess data ──────────────────────────────────────────────
186
- try:
187
- print("Preprocessing data...")
188
- scaled_data, scaler = stock_model.preprocess_data(data)
189
-
190
- # [OPT-C] time_step 45 → 30
191
- time_step = 30
192
- X, y = stock_model.create_sequences(scaled_data, time_step)
193
- print(f"Sequences: X={X.shape}, y={y.shape}")
194
- except Exception as e:
195
- self.status = "failed"
196
- self.result = {"error": f"Preprocessing failed: {e}"}
197
- return
198
-
199
- if len(X) == 0:
200
- self.status = "failed"
201
- self.result = {"error": f"Could not create training sequences for {self.symbol}"}
202
- return
203
-
204
- self.progress = 50
205
- if self.stop_requested:
206
- self.status = "stopped"; return
207
-
208
- # ── Train LSTM ───────────────────────────────────────────────────
209
- try:
210
- train_size = int(len(X) * 0.8)
211
- if train_size == 0:
212
- self.status = "failed"
213
- self.result = {"error": "Not enough data to split for training"}
214
- return
215
-
216
- X_train, y_train = X[:train_size], y[:train_size]
217
- self.progress = 55
218
- print(f"Training LSTM with {len(X_train)} samples...")
219
- lstm_model = stock_model.train_lstm(
220
- X_train, y_train, time_step, self.is_stop_requested)
221
- except Exception as e:
222
- self.status = "failed"
223
- self.result = {"error": f"LSTM training failed: {e}"}
224
- return
225
-
226
- if self.stop_requested:
227
- self.status = "stopped"; return
228
-
229
- self.progress = 75
230
- if self.stop_requested:
231
- self.status = "stopped"; return
232
-
233
- # ── Train XGBoost on residuals ───────────────────────────────────
234
- try:
235
- print("Calculating residuals for XGBoost...")
236
- lstm_preds = lstm_model.predict(X_train, verbose=0).flatten()
237
- residuals = y_train - lstm_preds
238
- xgb_model = stock_model.train_xgboost(
239
- X_train.reshape(X_train.shape[0], -1),
240
- residuals,
241
- self.is_stop_requested,
242
- )
243
- if self.stop_requested or xgb_model is None:
244
- self.status = "stopped"; return
245
- except Exception as e:
246
- print(f"XGBoost training error (non-fatal): {e}")
247
- xgb_model = None
248
-
249
- self.progress = 90
250
- if self.stop_requested:
251
- self.status = "stopped"; return
252
-
253
- # ── Generate predictions ─────────────────────────────────────────
254
- try:
255
- print(f"Generating {self.days_ahead}-day predictions...")
256
- predictions = stock_model.predict_stock_price(
257
- lstm_model, xgb_model, scaled_data, scaler,
258
- time_step, self.days_ahead, self.is_stop_requested,
259
- )
260
- if self.stop_requested or predictions is None:
261
- self.status = "stopped"; return
262
- except Exception as e:
263
- self.status = "failed"
264
- self.result = {"error": f"Prediction generation failed: {e}"}
265
- return
266
-
267
- self.progress = 95
268
- if self.stop_requested:
269
- self.status = "stopped"; return
270
-
271
- # ── Build future trading-day dates ───────────────────────────────
272
- future_dates = []
273
- for i in range(1, self.days_ahead + 1):
274
- if self.stop_requested:
275
- break
276
- next_date = last_date + timedelta(days=i)
277
- while next_date.weekday() > 4:
278
- next_date += timedelta(days=1)
279
- future_dates.append(next_date)
280
-
281
- if self.stop_requested:
282
- self.status = "stopped"; return
283
-
284
- # Deduplicate dates
285
- unique_future_dates = []
286
- seen_dates = set()
287
- for date in future_dates:
288
- ds = date.strftime("%Y-%m-%d")
289
- if ds not in seen_dates:
290
- seen_dates.add(ds)
291
- unique_future_dates.append(date)
292
-
293
- # Pad if needed
294
- while (len(unique_future_dates) < len(predictions)
295
- and not self.stop_requested):
296
- next_date = unique_future_dates[-1] + timedelta(days=1)
297
- while next_date.weekday() > 4:
298
- next_date += timedelta(days=1)
299
- ds = next_date.strftime("%Y-%m-%d")
300
- if ds not in seen_dates:
301
- unique_future_dates.append(next_date)
302
- seen_dates.add(ds)
303
-
304
- if self.stop_requested:
305
- self.status = "stopped"; return
306
-
307
- unique_future_dates = unique_future_dates[: len(predictions)]
308
-
309
- # ── Assemble result payload ──────────────────────────────────────
310
- prediction_data = []
311
- for i in range(min(len(unique_future_dates), len(predictions))):
312
- predicted_price = float(predictions[i][0])
313
- percent_change = (
314
- (predicted_price - last_actual_close) / last_actual_close * 100
315
- )
316
- prediction_data.append({
317
- "date": unique_future_dates[i].strftime("%Y-%m-%d"),
318
- "price": round(predicted_price, 2),
319
- "change": round(percent_change, 2),
320
- })
321
-
322
- self.result = {
323
- "symbol": self.symbol,
324
- "lastActualClose": {
325
- "date": last_date.strftime("%Y-%m-%d"),
326
- "price": round(last_actual_close, 2),
327
- },
328
- "predictions": prediction_data,
329
- "sentiment": self.sentiment_result,
330
- "tableDisplay": True,
331
- }
332
- self.progress = 100
333
- self.status = "completed"
334
- print(f"Prediction complete for {self.symbol}")
335
-
336
- except Exception as e:
337
- error_msg = str(e)
338
- self.status = "failed"
339
- self.result = {"error": error_msg}
340
- print(f"\n[ERROR] {error_msg}\n")
341
- traceback.print_exc()
342
-
343
-
344
- # =============================================================================
345
- # REST API ROUTES
346
- # (all routes are identical to the original – no frontend changes needed)
347
- # =============================================================================
348
-
349
- @app.route("/api/predict", methods=["POST"])
350
- def start_prediction():
351
- try:
352
- data = request.json
353
- print(f"POST /api/predict body={data}")
354
-
355
- if not data:
356
- return jsonify({"error": "Invalid or missing request body"}), 400
357
-
358
- user_id = data.get("userId")
359
- symbol = data.get("symbol")
360
- days_ahead = int(data.get("daysAhead", 5))
361
-
362
- if not user_id or not symbol:
363
- return jsonify({"error": "Missing required parameters (userId or symbol)"}), 400
364
-
365
- if not isinstance(symbol, str) or len(symbol) > 10:
366
- return jsonify({"error": f"Invalid symbol format: {symbol}"}), 400
367
-
368
- if not tf_status:
369
- return jsonify({
370
- "error": f"Prediction service unavailable: {tf_message}",
371
- "tf_status": tf_message,
372
- }), 503
373
-
374
- task = PredictionTask(user_id, symbol, days_ahead)
375
- task_id = task.run()
376
- prediction_tasks[task_id] = task
377
-
378
- return jsonify({
379
- "taskId": task_id,
380
- "status": "pending",
381
- "message": f"Prediction started for {symbol}",
382
- })
383
- except ValueError as e:
384
- return jsonify({"error": str(e)}), 400
385
- except Exception as e:
386
- print(f"Critical error starting prediction: {e}")
387
- traceback.print_exc()
388
- return jsonify({"error": "Failed to start prediction", "details": str(e)}), 500
389
-
390
-
391
- @app.route("/api/predict/status/<task_id>", methods=["GET"])
392
- def prediction_status(task_id):
393
- try:
394
- task = prediction_tasks.get(task_id)
395
- if not task:
396
- return jsonify({"error": "Task not found"}), 404
397
-
398
- try:
399
- if task.status == "completed" and task.result:
400
- if isinstance(task.result, dict):
401
- if "predictions" in task.result and isinstance(
402
- task.result["predictions"], list):
403
- for pred in task.result["predictions"]:
404
- if (not isinstance(pred, dict)
405
- or "date" not in pred
406
- or "price" not in pred):
407
- task.status = "failed"
408
- task.result = {"error": "Malformed prediction data"}
409
- break
410
- else:
411
- task.status = "failed"
412
- task.result = {"error": "Missing prediction data"}
413
- else:
414
- task.status = "failed"
415
- task.result = {"error": "Invalid result format"}
416
-
417
- return jsonify({
418
- "taskId": task_id,
419
- "status": task.status,
420
- "progress": task.progress,
421
- "result": task.result if task.status in ["completed", "failed"] else None,
422
- "error": task.result.get("error") if (task.status == "failed" and task.result and isinstance(task.result, dict)) else None
423
- })
424
- except Exception as e:
425
- print(f"Error generating status response: {e}")
426
- return jsonify({
427
- "taskId": task_id,
428
- "status": "error",
429
- "progress": task.progress,
430
- "error": str(e),
431
- })
432
- except Exception as e:
433
- print(f"Critical error in prediction status: {e}")
434
- return jsonify({"taskId": task_id, "status": "error",
435
- "error": "Server error"}), 500
436
-
437
-
438
- @app.route("/api/predict/stop/<task_id>", methods=["POST"])
439
- def stop_prediction(task_id):
440
- task = prediction_tasks.get(task_id)
441
- if not task:
442
- return jsonify({"error": "Task not found"}), 404
443
-
444
- task.stop_requested = True
445
-
446
- if task.thread and task.thread.is_alive():
447
- task.status = "stopping"
448
- print(f"Stop requested for task {task_id} ({task.symbol})")
449
- stop_wait_start = time.time()
450
- while time.time() - stop_wait_start < 2:
451
- if task.stop_acknowledged:
452
- task.status = "stopped"
453
- break
454
- time.sleep(0.1)
455
- else:
456
- task.status = "stopped"
457
-
458
- return jsonify({
459
- "taskId": task_id,
460
- "status": task.status,
461
- "symbol": task.symbol,
462
- "progress": task.progress,
463
- "stopRequested": task.stop_requested,
464
- "stopAcknowledged": task.stop_acknowledged,
465
- })
466
-
467
-
468
- @app.route("/api/predict/sentiment/<symbol>", methods=["GET"])
469
- def get_sentiment(symbol):
470
- try:
471
- headlines = stock_model.fetch_finnhub_news(symbol)
472
- sentiment_results, sentiment_totals = \
473
- stock_model.analyze_sentiment(headlines)
474
- sentiment_summary = stock_model.generate_sentiment_summary(
475
- sentiment_totals, headlines, symbol)
476
- return jsonify({
477
- "symbol": symbol,
478
- "sentiment": {
479
- "totals": sentiment_totals,
480
- "summary": sentiment_summary,
481
- "period": 28,
482
- },
483
- })
484
- except Exception as e:
485
- return jsonify({"error": str(e)}), 500
486
-
487
-
488
- @app.route("/api/diagnose", methods=["GET"])
489
- def diagnose():
490
- """Diagnostic endpoint – checks environment, APIs and model primitives."""
491
- try:
492
- env_info = {
493
- "python_version": sys.version,
494
- "tensorflow_version": tf.__version__,
495
- "numpy_version": np.__version__,
496
- "pandas_version": pd.__version__,
497
- "xgboost_version": xgb.__version__,
498
- }
499
-
500
- api_status = {}
501
- try:
502
- url = "https://www.alphavantage.co/query"
503
- params = {
504
- "function": "TIME_SERIES_DAILY",
505
- "symbol": "AAPL",
506
- "apikey": stock_model.ALPHAVANTAGE_API_KEY,
507
- "outputsize": "compact",
508
- "datatype": "json",
509
- }
510
- resp = requests.get(url, params=params)
511
- rj = resp.json()
512
- api_status["alpha_vantage"] = {
513
- "status_code": resp.status_code,
514
- "has_data": "Time Series (Daily)" in rj,
515
- "error": rj.get("Error Message") or rj.get("Note") or rj.get("Information")
516
- if "Time Series (Daily)" not in rj else None,
517
- }
518
- except Exception as e:
519
- api_status["alpha_vantage"] = {"error": str(e)}
520
-
521
- try:
522
- headers = {"X-Finnhub-Token": stock_model.FINNHUB_API_KEY}
523
- resp = requests.get(
524
- "https://finnhub.io/api/v1/news?category=general",
525
- headers=headers)
526
- api_status["finnhub"] = {
527
- "status_code": resp.status_code,
528
- "has_data": len(resp.json()) > 0,
529
- "error": None if resp.status_code == 200 else str(resp.text),
530
- }
531
- except Exception as e:
532
- api_status["finnhub"] = {"error": str(e)}
533
-
534
- model_status = {}
535
- try:
536
- test_data = np.random.rand(100, 6) # 6 features (OPT-2)
537
- test_scaler = MinMaxScaler()
538
- test_data[:, 0] = test_scaler.fit_transform(
539
- np.arange(100).reshape(-1, 1)).flatten()
540
- X, y = stock_model.create_sequences(test_data, time_step=30)
541
- model_status["sequence_creation"] = {
542
- "success": len(X) > 0,
543
- "X_shape": str(X.shape),
544
- "y_shape": str(y.shape),
545
- }
546
- except Exception as e:
547
- model_status["error"] = str(e)
548
-
549
- return jsonify({
550
- "timestamp": datetime.now().isoformat(),
551
- "status": "OK",
552
- "environment": env_info,
553
- "api_status": api_status,
554
- "model_status": model_status,
555
- })
556
- except Exception as e:
557
- return jsonify({"status": "ERROR", "error": str(e)}), 500
558
-
559
-
560
- if __name__ == "__main__":
561
- # [OPT-B] Read port from environment variable so the same binary works on:
562
- # • Render (sets $PORT automatically, usually 10000)
563
- # Hugging Face (expects 7860)
564
- # • Local dev (falls back to 5001)
565
- port = int(os.environ.get("PORT", 5001))
566
- print(f"Starting StockBuddy API on port {port}")
567
- app.run(host="0.0.0.0", port=port)
 
 
 
1
+ """
2
+ app.py – StockBuddy Flask API
3
+ =================================
4
+ LIGHTWEIGHT CHANGES vs original:
5
+ [OPT-A] Removed the startup TF validation model (was creating & running a test
6
+ LSTM on every cold start – wastes ~10 s and ~100 MB RAM on free tier).
7
+ Replaced with a simple tf.constant() smoke-test.
8
+ [OPT-B] PORT is now read from the PORT environment variable so the server
9
+ works on Render (sets $PORT automatically) and Hugging Face Spaces
10
+ (expects port 7860) without code changes.
11
+ [OPT-C] time_step updated to 30 throughout (was 45) to match the lighter model.
12
+ All REST API routes are unchanged from the original.
13
+ """
14
+
15
+ from flask import Flask, request, jsonify
16
+ from flask_cors import CORS
17
+ import numpy as np
18
+ import pandas as pd
19
+ import os
20
+ import threading
21
+ import time
22
+ from datetime import datetime, timedelta
23
+ import json
24
+ import model as stock_model
25
+ import sys
26
+ import requests
27
+ import traceback
28
+ from sklearn.preprocessing import MinMaxScaler
29
+ from tensorflow.keras.models import Sequential
30
+ from tensorflow.keras.layers import LSTM, Dense, Dropout
31
+ from tensorflow.keras.callbacks import Callback
32
+ import tensorflow as tf
33
+ import xgboost as xgb
34
+
35
+ app = Flask(__name__)
36
+ CORS(app)
37
+
38
+
39
+ @app.route("/", methods=["GET"])
40
+ def home():
41
+ return jsonify({"status": "running", "message": "StockBuddy API is live!"})
42
+
43
+ # [OPT-A] Lightweight TF smoke-test instead of building & running a full LSTM
44
+ def validate_tensorflow():
45
+ """Quick TensorFlow sanity-check (no model created, no GPU required)."""
46
+ try:
47
+ print("TensorFlow version:", tf.__version__)
48
+ # A tiny constant operation is enough to confirm TF is importable and
49
+ # the runtime works. Full model creation is deferred to the first
50
+ # prediction request so the cold-start is fast on free-tier hosts.
51
+ _ = tf.constant([1.0, 2.0, 3.0])
52
+ gpus = tf.config.list_physical_devices("GPU")
53
+ if gpus:
54
+ msg = f"GPU available ({len(gpus)} device(s)) – running in GPU mode."
55
+ else:
56
+ msg = "No GPU detected – running in CPU mode (expected on free tier)."
57
+ print(f"TensorFlow OK: {msg}")
58
+ return True, msg
59
+ except Exception as e:
60
+ print(f"TensorFlow validation failed: {e}")
61
+ return False, f"TensorFlow error: {e}"
62
+
63
+
64
+ # Run smoke-test at startup
65
+ tf_status, tf_message = validate_tensorflow()
66
+ if not tf_status:
67
+ print(f"WARNING: {tf_message}")
68
+ else:
69
+ print(f"TensorFlow validation: {tf_message}")
70
+
71
+ # Dictionary to store running prediction tasks
72
+ prediction_tasks = {}
73
+
74
+
75
+ class PredictionTask:
76
+ def __init__(self, user_id, symbol, days_ahead):
77
+ self.user_id = user_id
78
+ self.symbol = symbol
79
+ self.days_ahead = days_ahead
80
+ self.progress = 0
81
+ self.status = "pending"
82
+ self.result = None
83
+ self.sentiment_result = None
84
+ self.thread = None
85
+ self.stop_requested = False
86
+ self.stop_acknowledged = False
87
+ # Unique task ID: millisecond timestamp + random hex suffix
88
+ timestamp = int(time.time() * 1000)
89
+ random_suffix = os.urandom(4).hex()
90
+ self.task_id = f"{user_id}_{symbol}_{timestamp}_{random_suffix}"
91
+
92
+ def run(self):
93
+ self.thread = threading.Thread(target=self._run_prediction)
94
+ self.thread.daemon = True
95
+ self.thread.start()
96
+ return self.task_id
97
+
98
+ def is_stop_requested(self):
99
+ """Callback for model training loops to poll stop flag."""
100
+ if self.stop_requested and not self.stop_acknowledged:
101
+ self.stop_acknowledged = True
102
+ self.status = "stopped"
103
+ return True
104
+ return self.stop_requested
105
+
106
+ def _run_prediction(self):
107
+ try:
108
+ print(f"Starting prediction for {self.symbol} (task: {self.task_id})")
109
+ self.status = "running"
110
+ self.progress = 10
111
+
112
+ # ── Fetch historical data ────────────────────────────────────────
113
+ print(f"Fetching historical data for {self.symbol}...")
114
+ try:
115
+ data = stock_model.fetch_stock_data(self.symbol, outputsize="compact")
116
+ print(f"Fetched {len(data)} rows for {self.symbol}")
117
+ except Exception as e:
118
+ error_msg = str(e)
119
+ print(f"\n[ERROR] {error_msg}\n")
120
+ self.status = "failed"
121
+ self.result = {"error": error_msg}
122
+ return
123
+
124
+ if data is None:
125
+ self.status = "failed"
126
+ self.result = {"error": f"Could not fetch data for {self.symbol}"}
127
+ return
128
+
129
+ if self.stop_requested:
130
+ self.status = "stopped"; return
131
+
132
+ if len(data) < 60:
133
+ self.status = "failed"
134
+ self.result = {"error": f"Insufficient data for {self.symbol} "
135
+ f"(got {len(data)}, need ≥60)"}
136
+ return
137
+
138
+ # ── Extract last actual close ────────────────────────────────────
139
+ try:
140
+ if isinstance(data, pd.DataFrame) and "Close" in data.columns:
141
+ last_actual_close = float(data["Close"].iloc[-1])
142
+ last_date = data.index[-1]
143
+ else:
144
+ last_actual_close = float(data.iloc[-1, 0])
145
+ last_date = data.index[-1]
146
+ print(f"Latest close for {self.symbol}: "
147
+ f"${last_actual_close:.2f} on {last_date.strftime('%Y-%m-%d')}")
148
+ except Exception as e:
149
+ self.status = "failed"
150
+ self.result = {"error": f"Error reading price data: {e}"}
151
+ return
152
+
153
+ self.progress = 20
154
+ if self.stop_requested:
155
+ self.status = "stopped"; return
156
+
157
+ # ── Sentiment analysis ───────────────────────────────────────────
158
+ try:
159
+ print(f"Fetching news for {self.symbol}...")
160
+ headlines = stock_model.fetch_finnhub_news(self.symbol)
161
+ print(f"Got {len(headlines)} headlines")
162
+ self.progress = 30
163
+ if self.stop_requested:
164
+ self.status = "stopped"; return
165
+
166
+ sentiment_results, sentiment_totals = \
167
+ stock_model.analyze_sentiment(headlines)
168
+ sentiment_summary = stock_model.generate_sentiment_summary(
169
+ sentiment_totals, headlines, self.symbol)
170
+ self.sentiment_result = {
171
+ "totals": sentiment_totals,
172
+ "summary": sentiment_summary,
173
+ }
174
+ except Exception as e:
175
+ print(f"Sentiment error (non-fatal): {e}")
176
+ self.sentiment_result = {
177
+ "totals": {"positive": 0, "negative": 0, "neutral": 0},
178
+ "summary": f"Unable to analyse sentiment: {e}",
179
+ }
180
+
181
+ self.progress = 40
182
+ if self.stop_requested:
183
+ self.status = "stopped"; return
184
+
185
+ # ── Preprocess data ──────────────────────────────────────────────
186
+ try:
187
+ print("Preprocessing data...")
188
+ scaled_data, scaler = stock_model.preprocess_data(data)
189
+
190
+ # [OPT-C] time_step 45 → 30
191
+ time_step = 30
192
+ X, y = stock_model.create_sequences(scaled_data, time_step)
193
+ print(f"Sequences: X={X.shape}, y={y.shape}")
194
+ except Exception as e:
195
+ self.status = "failed"
196
+ self.result = {"error": f"Preprocessing failed: {e}"}
197
+ return
198
+
199
+ if len(X) == 0:
200
+ self.status = "failed"
201
+ self.result = {"error": f"Could not create training sequences for {self.symbol}"}
202
+ return
203
+
204
+ self.progress = 50
205
+ if self.stop_requested:
206
+ self.status = "stopped"; return
207
+
208
+ # ── Train LSTM ───────────────────────────────────────────────────
209
+ try:
210
+ train_size = int(len(X) * 0.8)
211
+ if train_size == 0:
212
+ self.status = "failed"
213
+ self.result = {"error": "Not enough data to split for training"}
214
+ return
215
+
216
+ X_train, y_train = X[:train_size], y[:train_size]
217
+ self.progress = 55
218
+ print(f"Training LSTM with {len(X_train)} samples...")
219
+ lstm_model = stock_model.train_lstm(
220
+ X_train, y_train, time_step, self.is_stop_requested)
221
+ except Exception as e:
222
+ self.status = "failed"
223
+ self.result = {"error": f"LSTM training failed: {e}"}
224
+ return
225
+
226
+ if self.stop_requested:
227
+ self.status = "stopped"; return
228
+
229
+ self.progress = 75
230
+ if self.stop_requested:
231
+ self.status = "stopped"; return
232
+
233
+ # ── Train XGBoost on residuals ───────────────────────────────────
234
+ try:
235
+ print("Calculating residuals for XGBoost...")
236
+ lstm_preds = lstm_model.predict(X_train, verbose=0).flatten()
237
+ residuals = y_train - lstm_preds
238
+ xgb_model = stock_model.train_xgboost(
239
+ X_train.reshape(X_train.shape[0], -1),
240
+ residuals,
241
+ self.is_stop_requested,
242
+ )
243
+ if self.stop_requested or xgb_model is None:
244
+ self.status = "stopped"; return
245
+ except Exception as e:
246
+ print(f"XGBoost training error (non-fatal): {e}")
247
+ xgb_model = None
248
+
249
+ self.progress = 90
250
+ if self.stop_requested:
251
+ self.status = "stopped"; return
252
+
253
+ # ── Generate predictions ─────────────────────────────────────────
254
+ try:
255
+ print(f"Generating {self.days_ahead}-day predictions...")
256
+ predictions = stock_model.predict_stock_price(
257
+ lstm_model, xgb_model, scaled_data, scaler,
258
+ time_step, self.days_ahead, self.is_stop_requested,
259
+ )
260
+ if self.stop_requested or predictions is None:
261
+ self.status = "stopped"; return
262
+ except Exception as e:
263
+ self.status = "failed"
264
+ self.result = {"error": f"Prediction generation failed: {e}"}
265
+ return
266
+
267
+ self.progress = 95
268
+ if self.stop_requested:
269
+ self.status = "stopped"; return
270
+
271
+ # ── Build future trading-day dates ───────────────────────────────
272
+ future_dates = []
273
+ for i in range(1, self.days_ahead + 1):
274
+ if self.stop_requested:
275
+ break
276
+ next_date = last_date + timedelta(days=i)
277
+ while next_date.weekday() > 4:
278
+ next_date += timedelta(days=1)
279
+ future_dates.append(next_date)
280
+
281
+ if self.stop_requested:
282
+ self.status = "stopped"; return
283
+
284
+ # Deduplicate dates
285
+ unique_future_dates = []
286
+ seen_dates = set()
287
+ for date in future_dates:
288
+ ds = date.strftime("%Y-%m-%d")
289
+ if ds not in seen_dates:
290
+ seen_dates.add(ds)
291
+ unique_future_dates.append(date)
292
+
293
+ # Pad if needed
294
+ while (len(unique_future_dates) < len(predictions)
295
+ and not self.stop_requested):
296
+ next_date = unique_future_dates[-1] + timedelta(days=1)
297
+ while next_date.weekday() > 4:
298
+ next_date += timedelta(days=1)
299
+ ds = next_date.strftime("%Y-%m-%d")
300
+ if ds not in seen_dates:
301
+ unique_future_dates.append(next_date)
302
+ seen_dates.add(ds)
303
+
304
+ if self.stop_requested:
305
+ self.status = "stopped"; return
306
+
307
+ unique_future_dates = unique_future_dates[: len(predictions)]
308
+
309
+ # ── Assemble result payload ──────────────────────────────────────
310
+ prediction_data = []
311
+ for i in range(min(len(unique_future_dates), len(predictions))):
312
+ predicted_price = float(predictions[i][0])
313
+ percent_change = (
314
+ (predicted_price - last_actual_close) / last_actual_close * 100
315
+ )
316
+ prediction_data.append({
317
+ "date": unique_future_dates[i].strftime("%Y-%m-%d"),
318
+ "price": round(predicted_price, 2),
319
+ "change": round(percent_change, 2),
320
+ })
321
+
322
+ self.result = {
323
+ "symbol": self.symbol,
324
+ "lastActualClose": {
325
+ "date": last_date.strftime("%Y-%m-%d"),
326
+ "price": round(last_actual_close, 2),
327
+ },
328
+ "predictions": prediction_data,
329
+ "sentiment": self.sentiment_result,
330
+ "tableDisplay": True,
331
+ }
332
+ self.progress = 100
333
+ self.status = "completed"
334
+ print(f"Prediction complete for {self.symbol}")
335
+
336
+ except Exception as e:
337
+ error_msg = str(e)
338
+ self.status = "failed"
339
+ self.result = {"error": error_msg}
340
+ print(f"\n[ERROR] {error_msg}\n")
341
+ traceback.print_exc()
342
+
343
+
344
+ # =============================================================================
345
+ # REST API ROUTES
346
+ # (all routes are identical to the original – no frontend changes needed)
347
+ # =============================================================================
348
+
349
+ @app.route("/api/predict", methods=["POST"])
350
+ def start_prediction():
351
+ try:
352
+ data = request.json
353
+ print(f"POST /api/predict body={data}")
354
+
355
+ if not data:
356
+ return jsonify({"error": "Invalid or missing request body"}), 400
357
+
358
+ user_id = data.get("userId")
359
+ symbol = data.get("symbol")
360
+ days_ahead = int(data.get("daysAhead", 5))
361
+
362
+ if not user_id or not symbol:
363
+ return jsonify({"error": "Missing required parameters (userId or symbol)"}), 400
364
+
365
+ if not isinstance(symbol, str) or len(symbol) > 10:
366
+ return jsonify({"error": f"Invalid symbol format: {symbol}"}), 400
367
+
368
+ if not tf_status:
369
+ return jsonify({
370
+ "error": f"Prediction service unavailable: {tf_message}",
371
+ "tf_status": tf_message,
372
+ }), 503
373
+
374
+ task = PredictionTask(user_id, symbol, days_ahead)
375
+ task_id = task.run()
376
+ prediction_tasks[task_id] = task
377
+
378
+ return jsonify({
379
+ "taskId": task_id,
380
+ "status": "pending",
381
+ "message": f"Prediction started for {symbol}",
382
+ })
383
+ except ValueError as e:
384
+ return jsonify({"error": str(e)}), 400
385
+ except Exception as e:
386
+ print(f"Critical error starting prediction: {e}")
387
+ traceback.print_exc()
388
+ return jsonify({"error": "Failed to start prediction", "details": str(e)}), 500
389
+
390
+
391
+ @app.route("/api/predict/status/<task_id>", methods=["GET"])
392
+ def prediction_status(task_id):
393
+ try:
394
+ task = prediction_tasks.get(task_id)
395
+ if not task:
396
+ return jsonify({"error": "Task not found"}), 404
397
+
398
+ try:
399
+ if task.status == "completed" and task.result:
400
+ if isinstance(task.result, dict):
401
+ if "predictions" in task.result and isinstance(
402
+ task.result["predictions"], list):
403
+ for pred in task.result["predictions"]:
404
+ if (not isinstance(pred, dict)
405
+ or "date" not in pred
406
+ or "price" not in pred):
407
+ task.status = "failed"
408
+ task.result = {"error": "Malformed prediction data"}
409
+ break
410
+ else:
411
+ task.status = "failed"
412
+ task.result = {"error": "Missing prediction data"}
413
+ else:
414
+ task.status = "failed"
415
+ task.result = {"error": "Invalid result format"}
416
+
417
+ return jsonify({
418
+ "taskId": task_id,
419
+ "status": task.status,
420
+ "progress": task.progress,
421
+ "result": task.result if task.status in ["completed", "failed"] else None,
422
+ "error": task.result.get("error") if (task.status == "failed" and task.result and isinstance(task.result, dict)) else None
423
+ })
424
+ except Exception as e:
425
+ print(f"Error generating status response: {e}")
426
+ return jsonify({
427
+ "taskId": task_id,
428
+ "status": "error",
429
+ "progress": task.progress,
430
+ "error": str(e),
431
+ })
432
+ except Exception as e:
433
+ print(f"Critical error in prediction status: {e}")
434
+ return jsonify({"taskId": task_id, "status": "error",
435
+ "error": "Server error"}), 500
436
+
437
+
438
+ @app.route("/api/predict/stop/<task_id>", methods=["POST"])
439
+ def stop_prediction(task_id):
440
+ task = prediction_tasks.get(task_id)
441
+ if not task:
442
+ return jsonify({"error": "Task not found"}), 404
443
+
444
+ task.stop_requested = True
445
+
446
+ if task.thread and task.thread.is_alive():
447
+ task.status = "stopping"
448
+ print(f"Stop requested for task {task_id} ({task.symbol})")
449
+ stop_wait_start = time.time()
450
+ while time.time() - stop_wait_start < 2:
451
+ if task.stop_acknowledged:
452
+ task.status = "stopped"
453
+ break
454
+ time.sleep(0.1)
455
+ else:
456
+ task.status = "stopped"
457
+
458
+ return jsonify({
459
+ "taskId": task_id,
460
+ "status": task.status,
461
+ "symbol": task.symbol,
462
+ "progress": task.progress,
463
+ "stopRequested": task.stop_requested,
464
+ "stopAcknowledged": task.stop_acknowledged,
465
+ })
466
+
467
+
468
+ @app.route("/api/predict/sentiment/<symbol>", methods=["GET"])
469
+ def get_sentiment(symbol):
470
+ try:
471
+ headlines = stock_model.fetch_finnhub_news(symbol)
472
+ sentiment_results, sentiment_totals = \
473
+ stock_model.analyze_sentiment(headlines)
474
+ sentiment_summary = stock_model.generate_sentiment_summary(
475
+ sentiment_totals, headlines, symbol)
476
+ return jsonify({
477
+ "symbol": symbol,
478
+ "sentiment": {
479
+ "totals": sentiment_totals,
480
+ "summary": sentiment_summary,
481
+ "period": 28,
482
+ },
483
+ })
484
+ except Exception as e:
485
+ return jsonify({"error": str(e)}), 500
486
+
487
+
488
+ @app.route("/api/diagnose", methods=["GET"])
489
+ def diagnose():
490
+ """Diagnostic endpoint – checks environment, APIs and model primitives."""
491
+ try:
492
+ env_info = {
493
+ "python_version": sys.version,
494
+ "tensorflow_version": tf.__version__,
495
+ "numpy_version": np.__version__,
496
+ "pandas_version": pd.__version__,
497
+ "xgboost_version": xgb.__version__,
498
+ }
499
+
500
+
501
+
502
+ api_status = {}
503
+ try:
504
+ url = "https://www.alphavantage.co/query"
505
+ params = {
506
+ "function": "TIME_SERIES_DAILY",
507
+ "symbol": "AAPL",
508
+ "apikey": stock_model.ALPHAVANTAGE_API_KEY,
509
+ "outputsize": "compact",
510
+ "datatype": "json",
511
+ }
512
+ resp = requests.get(url, params=params)
513
+ rj = resp.json()
514
+ api_status["alpha_vantage"] = {
515
+ "status_code": resp.status_code,
516
+ "has_data": "Time Series (Daily)" in rj,
517
+ "error": rj.get("Error Message") or rj.get("Note") or rj.get("Information")
518
+ if "Time Series (Daily)" not in rj else None,
519
+ }
520
+ except Exception as e:
521
+ api_status["alpha_vantage"] = {"error": str(e)}
522
+
523
+ try:
524
+ headers = {"X-Finnhub-Token": stock_model.FINNHUB_API_KEY}
525
+ resp = requests.get(
526
+ "https://finnhub.io/api/v1/news?category=general",
527
+ headers=headers)
528
+ api_status["finnhub"] = {
529
+ "status_code": resp.status_code,
530
+ "has_data": len(resp.json()) > 0,
531
+ "error": None if resp.status_code == 200 else str(resp.text),
532
+ }
533
+ except Exception as e:
534
+ api_status["finnhub"] = {"error": str(e)}
535
+
536
+ model_status = {}
537
+ try:
538
+ test_data = np.random.rand(100, 6) # 6 features (OPT-2)
539
+ test_scaler = MinMaxScaler()
540
+ test_data[:, 0] = test_scaler.fit_transform(
541
+ np.arange(100).reshape(-1, 1)).flatten()
542
+ X, y = stock_model.create_sequences(test_data, time_step=30)
543
+ model_status["sequence_creation"] = {
544
+ "success": len(X) > 0,
545
+ "X_shape": str(X.shape),
546
+ "y_shape": str(y.shape),
547
+ }
548
+ except Exception as e:
549
+ model_status["error"] = str(e)
550
+
551
+ return jsonify({
552
+ "timestamp": datetime.now().isoformat(),
553
+ "status": "OK",
554
+ "environment": env_info,
555
+ "api_status": api_status,
556
+ "model_status": model_status,
557
+ })
558
+ except Exception as e:
559
+ return jsonify({"status": "ERROR", "error": str(e)}), 500
560
+
561
+
562
+ if __name__ == "__main__":
563
+ # [OPT-B] Read port from environment variable so the same binary works on:
564
+ # • Render (sets $PORT automatically, usually 10000)
565
+ # • Hugging Face (expects 7860)
566
+ # • Local dev (falls back to 5001)
567
+ port = int(os.environ.get("PORT", 5001))
568
+ print(f"Starting StockBuddy API on port {port}")
569
+ app.run(host="0.0.0.0", port=port)