Spaces:
Running
Running
| """ | |
| app.py – StockBuddy Flask API | |
| ================================= | |
| LIGHTWEIGHT CHANGES vs original: | |
| [OPT-A] Removed the startup TF validation model (was creating & running a test | |
| LSTM on every cold start – wastes ~10 s and ~100 MB RAM on free tier). | |
| Replaced with a simple tf.constant() smoke-test. | |
| [OPT-B] PORT is now read from the PORT environment variable so the server | |
| works on Render (sets $PORT automatically) and Hugging Face Spaces | |
| (expects port 7860) without code changes. | |
| [OPT-C] time_step updated to 30 throughout (was 45) to match the lighter model. | |
| All REST API routes are unchanged from the original. | |
| """ | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import numpy as np | |
| import pandas as pd | |
| import os | |
| import threading | |
| import time | |
| from datetime import datetime, timedelta | |
| import json | |
| import model as stock_model | |
| import sys | |
| import requests | |
| import traceback | |
| from sklearn.preprocessing import MinMaxScaler | |
| from tensorflow.keras.models import Sequential | |
| from tensorflow.keras.layers import LSTM, Dense, Dropout | |
| from tensorflow.keras.callbacks import Callback | |
| import tensorflow as tf | |
| import xgboost as xgb | |
| app = Flask(__name__) | |
| CORS(app) | |
| def read_json_dict_from_request(): | |
| """ | |
| Parse JSON request body without relying on request.json alone. | |
| Hugging Face / some proxies or clients occasionally send bodies that trip | |
| Flask's default parser (BOM, odd encoding, or malformed JSON). This returns | |
| (dict, None) on success or (None, error_message) on failure, and logs a | |
| short preview for server-side debugging. | |
| """ | |
| raw = request.get_data(cache=True) | |
| if not raw: | |
| return None, "Empty request body" | |
| mimetype = (request.mimetype or "").lower() | |
| if mimetype and mimetype != "application/json": | |
| # Spring/Java normally sends application/json; log odd types for debugging | |
| print(f"POST /api/predict unexpected mimetype={request.mimetype!r} content_type={request.content_type!r}") | |
| try: | |
| text = raw.decode("utf-8-sig") | |
| except UnicodeDecodeError as e: | |
| return None, f"Body is not valid UTF-8: {e}" | |
| text = text.strip() | |
| if not text: | |
| return None, "Empty body after decode" | |
| try: | |
| data = json.loads(text) | |
| except json.JSONDecodeError as e: | |
| preview = text[:400].replace("\r", " ").replace("\n", " ") | |
| print( | |
| f"POST /api/predict JSONDecodeError: {e.msg} (line {e.lineno}, col {e.colno}); " | |
| f"len={len(text)} preview={preview!r}" | |
| ) | |
| return None, ( | |
| f"Invalid JSON ({e.msg} at line {e.lineno}, column {e.colno}). " | |
| "Send strict JSON with double-quoted keys, e.g. " | |
| '{"userId":"...","symbol":"AAPL","daysAhead":5}. ' | |
| f"Body preview: {preview!r}" | |
| ) | |
| if not isinstance(data, dict): | |
| return None, f"JSON root must be an object, got {type(data).__name__}" | |
| return data, None | |
| def home(): | |
| return jsonify({"status": "running", "message": "StockBuddy API is live!"}) | |
| # [OPT-A] Lightweight TF smoke-test instead of building & running a full LSTM | |
| def validate_tensorflow(): | |
| """Quick TensorFlow sanity-check (no model created, no GPU required).""" | |
| try: | |
| print("TensorFlow version:", tf.__version__) | |
| # A tiny constant operation is enough to confirm TF is importable and | |
| # the runtime works. Full model creation is deferred to the first | |
| # prediction request so the cold-start is fast on free-tier hosts. | |
| _ = tf.constant([1.0, 2.0, 3.0]) | |
| gpus = tf.config.list_physical_devices("GPU") | |
| if gpus: | |
| msg = f"GPU available ({len(gpus)} device(s)) – running in GPU mode." | |
| else: | |
| msg = "No GPU detected – running in CPU mode (expected on free tier)." | |
| print(f"TensorFlow OK: {msg}") | |
| return True, msg | |
| except Exception as e: | |
| print(f"TensorFlow validation failed: {e}") | |
| return False, f"TensorFlow error: {e}" | |
| # Run smoke-test at startup | |
| tf_status, tf_message = validate_tensorflow() | |
| if not tf_status: | |
| print(f"WARNING: {tf_message}") | |
| else: | |
| print(f"TensorFlow validation: {tf_message}") | |
| # Dictionary to store running prediction tasks | |
| prediction_tasks = {} | |
| class PredictionTask: | |
| def __init__(self, user_id, symbol, days_ahead): | |
| self.user_id = user_id | |
| self.symbol = symbol | |
| self.days_ahead = days_ahead | |
| self.progress = 0 | |
| self.status = "pending" | |
| self.result = None | |
| self.sentiment_result = None | |
| self.thread = None | |
| self.stop_requested = False | |
| self.stop_acknowledged = False | |
| # Unique task ID: millisecond timestamp + random hex suffix | |
| timestamp = int(time.time() * 1000) | |
| random_suffix = os.urandom(4).hex() | |
| self.task_id = f"{user_id}_{symbol}_{timestamp}_{random_suffix}" | |
| def run(self): | |
| self.thread = threading.Thread(target=self._run_prediction) | |
| self.thread.daemon = True | |
| self.thread.start() | |
| return self.task_id | |
| def is_stop_requested(self): | |
| """Callback for model training loops to poll stop flag.""" | |
| if self.stop_requested and not self.stop_acknowledged: | |
| self.stop_acknowledged = True | |
| self.status = "stopped" | |
| return True | |
| return self.stop_requested | |
| def _run_prediction(self): | |
| try: | |
| print(f"Starting prediction for {self.symbol} (task: {self.task_id})") | |
| self.status = "running" | |
| self.progress = 10 | |
| # ── Fetch historical data ──────────────────────────────────────── | |
| print(f"Fetching historical data for {self.symbol}...") | |
| try: | |
| data = stock_model.fetch_stock_data(self.symbol, outputsize="compact") | |
| print(f"Fetched {len(data)} rows for {self.symbol}") | |
| except Exception as e: | |
| print(f"Data fetch error: {e}") | |
| self.status = "failed" | |
| self.result = {"error": f"Could not fetch data for {self.symbol}: {e}"} | |
| return | |
| if data is None: | |
| self.status = "failed" | |
| self.result = {"error": f"Could not fetch data for {self.symbol}"} | |
| return | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| if len(data) < 60: | |
| self.status = "failed" | |
| self.result = {"error": f"Insufficient data for {self.symbol} " | |
| f"(got {len(data)}, need ≥60)"} | |
| return | |
| # ── Extract last actual close ──────────────────────────────────── | |
| try: | |
| if isinstance(data, pd.DataFrame) and "Close" in data.columns: | |
| last_actual_close = float(data["Close"].iloc[-1]) | |
| last_date = data.index[-1] | |
| else: | |
| last_actual_close = float(data.iloc[-1, 0]) | |
| last_date = data.index[-1] | |
| print(f"Latest close for {self.symbol}: " | |
| f"${last_actual_close:.2f} on {last_date.strftime('%Y-%m-%d')}") | |
| except Exception as e: | |
| self.status = "failed" | |
| self.result = {"error": f"Error reading price data: {e}"} | |
| return | |
| self.progress = 20 | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| # ── Sentiment analysis ─────────────────────────────────────────── | |
| try: | |
| print(f"Fetching news for {self.symbol}...") | |
| headlines = stock_model.fetch_finnhub_news(self.symbol) | |
| print(f"Got {len(headlines)} headlines") | |
| self.progress = 30 | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| sentiment_results, sentiment_totals = \ | |
| stock_model.analyze_sentiment(headlines) | |
| sentiment_summary = stock_model.generate_sentiment_summary( | |
| sentiment_totals, headlines, self.symbol) | |
| self.sentiment_result = { | |
| "totals": sentiment_totals, | |
| "summary": sentiment_summary, | |
| } | |
| except Exception as e: | |
| print(f"Sentiment error (non-fatal): {e}") | |
| self.sentiment_result = { | |
| "totals": {"positive": 0, "negative": 0, "neutral": 0}, | |
| "summary": f"Unable to analyse sentiment: {e}", | |
| } | |
| self.progress = 40 | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| # ── Preprocess data ────────────────────────────────────────────── | |
| try: | |
| print("Preprocessing data...") | |
| scaled_data, scaler = stock_model.preprocess_data(data) | |
| # [OPT-C] time_step 45 → 30 | |
| time_step = 30 | |
| X, y = stock_model.create_sequences(scaled_data, time_step) | |
| print(f"Sequences: X={X.shape}, y={y.shape}") | |
| except Exception as e: | |
| self.status = "failed" | |
| self.result = {"error": f"Preprocessing failed: {e}"} | |
| return | |
| if len(X) == 0: | |
| self.status = "failed" | |
| self.result = {"error": f"Could not create training sequences for {self.symbol}"} | |
| return | |
| self.progress = 50 | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| # ── Train LSTM ─────────────────────────────────────────────────── | |
| try: | |
| train_size = int(len(X) * 0.8) | |
| if train_size == 0: | |
| self.status = "failed" | |
| self.result = {"error": "Not enough data to split for training"} | |
| return | |
| X_train, y_train = X[:train_size], y[:train_size] | |
| self.progress = 55 | |
| print(f"Training LSTM with {len(X_train)} samples...") | |
| lstm_model = stock_model.train_lstm( | |
| X_train, y_train, time_step, self.is_stop_requested) | |
| except Exception as e: | |
| self.status = "failed" | |
| self.result = {"error": f"LSTM training failed: {e}"} | |
| return | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| self.progress = 75 | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| # ── Train XGBoost on residuals ─────────────────────────────────── | |
| try: | |
| print("Calculating residuals for XGBoost...") | |
| lstm_preds = lstm_model.predict(X_train, verbose=0).flatten() | |
| residuals = y_train - lstm_preds | |
| xgb_model = stock_model.train_xgboost( | |
| X_train.reshape(X_train.shape[0], -1), | |
| residuals, | |
| self.is_stop_requested, | |
| ) | |
| if self.stop_requested or xgb_model is None: | |
| self.status = "stopped"; return | |
| except Exception as e: | |
| print(f"XGBoost training error (non-fatal): {e}") | |
| xgb_model = None | |
| self.progress = 90 | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| # ── Generate predictions ───────────────────────────────────────── | |
| try: | |
| print(f"Generating {self.days_ahead}-day predictions...") | |
| predictions = stock_model.predict_stock_price( | |
| lstm_model, xgb_model, scaled_data, scaler, | |
| time_step, self.days_ahead, self.is_stop_requested, | |
| ) | |
| if self.stop_requested or predictions is None: | |
| self.status = "stopped"; return | |
| except Exception as e: | |
| self.status = "failed" | |
| self.result = {"error": f"Prediction generation failed: {e}"} | |
| return | |
| self.progress = 95 | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| # ── Build future trading-day dates ─────────────────────────────── | |
| future_dates = [] | |
| for i in range(1, self.days_ahead + 1): | |
| if self.stop_requested: | |
| break | |
| next_date = last_date + timedelta(days=i) | |
| while next_date.weekday() > 4: | |
| next_date += timedelta(days=1) | |
| future_dates.append(next_date) | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| # Deduplicate dates | |
| unique_future_dates = [] | |
| seen_dates = set() | |
| for date in future_dates: | |
| ds = date.strftime("%Y-%m-%d") | |
| if ds not in seen_dates: | |
| seen_dates.add(ds) | |
| unique_future_dates.append(date) | |
| # Pad if needed | |
| while (len(unique_future_dates) < len(predictions) | |
| and not self.stop_requested): | |
| next_date = unique_future_dates[-1] + timedelta(days=1) | |
| while next_date.weekday() > 4: | |
| next_date += timedelta(days=1) | |
| ds = next_date.strftime("%Y-%m-%d") | |
| if ds not in seen_dates: | |
| unique_future_dates.append(next_date) | |
| seen_dates.add(ds) | |
| if self.stop_requested: | |
| self.status = "stopped"; return | |
| unique_future_dates = unique_future_dates[: len(predictions)] | |
| # ── Assemble result payload ────────────────────────────────────── | |
| prediction_data = [] | |
| for i in range(min(len(unique_future_dates), len(predictions))): | |
| predicted_price = float(predictions[i][0]) | |
| percent_change = ( | |
| (predicted_price - last_actual_close) / last_actual_close * 100 | |
| ) | |
| prediction_data.append({ | |
| "date": unique_future_dates[i].strftime("%Y-%m-%d"), | |
| "price": round(predicted_price, 2), | |
| "change": round(percent_change, 2), | |
| }) | |
| self.result = { | |
| "symbol": self.symbol, | |
| "lastActualClose": { | |
| "date": last_date.strftime("%Y-%m-%d"), | |
| "price": round(last_actual_close, 2), | |
| }, | |
| "predictions": prediction_data, | |
| "sentiment": self.sentiment_result, | |
| "tableDisplay": True, | |
| } | |
| self.progress = 100 | |
| self.status = "completed" | |
| print(f"Prediction complete for {self.symbol}") | |
| except Exception as e: | |
| self.status = "failed" | |
| self.result = {"error": str(e)} | |
| print(f"Prediction task error: {e}") | |
| traceback.print_exc() | |
| # ============================================================================= | |
| # REST API ROUTES | |
| # (all routes are identical to the original – no frontend changes needed) | |
| # ============================================================================= | |
| def start_prediction(): | |
| try: | |
| data, parse_err = read_json_dict_from_request() | |
| if parse_err: | |
| return jsonify({"error": "Invalid or missing JSON body", "details": parse_err}), 400 | |
| print(f"POST /api/predict body keys={list(data.keys())}") | |
| user_id = data.get("userId") | |
| symbol = data.get("symbol") | |
| days_ahead = int(data.get("daysAhead", 5)) | |
| if not user_id or not symbol: | |
| return jsonify({"error": "Missing required parameters (userId or symbol)"}), 400 | |
| if not isinstance(symbol, str) or len(symbol) > 10: | |
| return jsonify({"error": f"Invalid symbol format: {symbol}"}), 400 | |
| if not tf_status: | |
| return jsonify({ | |
| "error": f"Prediction service unavailable: {tf_message}", | |
| "tf_status": tf_message, | |
| }), 503 | |
| task = PredictionTask(user_id, symbol, days_ahead) | |
| task_id = task.run() | |
| prediction_tasks[task_id] = task | |
| return jsonify({ | |
| "taskId": task_id, | |
| "status": "pending", | |
| "message": f"Prediction started for {symbol}", | |
| }) | |
| except ValueError as e: | |
| return jsonify({"error": str(e)}), 400 | |
| except Exception as e: | |
| print(f"Critical error starting prediction: {e}") | |
| traceback.print_exc() | |
| return jsonify({"error": "Failed to start prediction", "details": str(e)}), 500 | |
| def prediction_status(task_id): | |
| try: | |
| task = prediction_tasks.get(task_id) | |
| if not task: | |
| return jsonify({"error": "Task not found"}), 404 | |
| try: | |
| if task.status == "completed" and task.result: | |
| if isinstance(task.result, dict): | |
| if "predictions" in task.result and isinstance( | |
| task.result["predictions"], list): | |
| for pred in task.result["predictions"]: | |
| if (not isinstance(pred, dict) | |
| or "date" not in pred | |
| or "price" not in pred): | |
| task.status = "failed" | |
| task.result = {"error": "Malformed prediction data"} | |
| break | |
| else: | |
| task.status = "failed" | |
| task.result = {"error": "Missing prediction data"} | |
| else: | |
| task.status = "failed" | |
| task.result = {"error": "Invalid result format"} | |
| return jsonify({ | |
| "taskId": task_id, | |
| "status": task.status, | |
| "progress": task.progress, | |
| "result": task.result if task.status == "completed" else None, | |
| }) | |
| except Exception as e: | |
| print(f"Error generating status response: {e}") | |
| return jsonify({ | |
| "taskId": task_id, | |
| "status": "error", | |
| "progress": task.progress, | |
| "error": str(e), | |
| }) | |
| except Exception as e: | |
| print(f"Critical error in prediction status: {e}") | |
| return jsonify({"taskId": task_id, "status": "error", | |
| "error": "Server error"}), 500 | |
| def stop_prediction(task_id): | |
| task = prediction_tasks.get(task_id) | |
| if not task: | |
| return jsonify({"error": "Task not found"}), 404 | |
| task.stop_requested = True | |
| if task.thread and task.thread.is_alive(): | |
| task.status = "stopping" | |
| print(f"Stop requested for task {task_id} ({task.symbol})") | |
| stop_wait_start = time.time() | |
| while time.time() - stop_wait_start < 2: | |
| if task.stop_acknowledged: | |
| task.status = "stopped" | |
| break | |
| time.sleep(0.1) | |
| else: | |
| task.status = "stopped" | |
| return jsonify({ | |
| "taskId": task_id, | |
| "status": task.status, | |
| "symbol": task.symbol, | |
| "progress": task.progress, | |
| "stopRequested": task.stop_requested, | |
| "stopAcknowledged": task.stop_acknowledged, | |
| }) | |
| def get_sentiment(symbol): | |
| try: | |
| headlines = stock_model.fetch_finnhub_news(symbol) | |
| sentiment_results, sentiment_totals = \ | |
| stock_model.analyze_sentiment(headlines) | |
| sentiment_summary = stock_model.generate_sentiment_summary( | |
| sentiment_totals, headlines, symbol) | |
| return jsonify({ | |
| "symbol": symbol, | |
| "sentiment": { | |
| "totals": sentiment_totals, | |
| "summary": sentiment_summary, | |
| "period": 28, | |
| }, | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def diagnose(): | |
| """Diagnostic endpoint – checks environment, APIs and model primitives.""" | |
| try: | |
| env_info = { | |
| "python_version": sys.version, | |
| "tensorflow_version": tf.__version__, | |
| "numpy_version": np.__version__, | |
| "pandas_version": pd.__version__, | |
| "xgboost_version": xgb.__version__, | |
| } | |
| api_status = {} | |
| try: | |
| url = "https://www.alphavantage.co/query" | |
| params = { | |
| "function": "TIME_SERIES_DAILY", | |
| "symbol": "AAPL", | |
| "apikey": stock_model.ALPHAVANTAGE_API_KEY, | |
| "outputsize": "compact", | |
| "datatype": "json", | |
| } | |
| resp = requests.get(url, params=params) | |
| rj = resp.json() | |
| api_status["alpha_vantage"] = { | |
| "status_code": resp.status_code, | |
| "has_data": "Time Series (Daily)" in rj, | |
| "error": rj.get("Error Message") or rj.get("Note") | |
| if "Time Series (Daily)" not in rj else None, | |
| } | |
| except Exception as e: | |
| api_status["alpha_vantage"] = {"error": str(e)} | |
| try: | |
| headers = {"X-Finnhub-Token": stock_model.FINNHUB_API_KEY} | |
| resp = requests.get( | |
| "https://finnhub.io/api/v1/news?category=general", | |
| headers=headers) | |
| api_status["finnhub"] = { | |
| "status_code": resp.status_code, | |
| "has_data": len(resp.json()) > 0, | |
| "error": None if resp.status_code == 200 else str(resp.text), | |
| } | |
| except Exception as e: | |
| api_status["finnhub"] = {"error": str(e)} | |
| model_status = {} | |
| try: | |
| test_data = np.random.rand(100, 6) # 6 features (OPT-2) | |
| test_scaler = MinMaxScaler() | |
| test_data[:, 0] = test_scaler.fit_transform( | |
| np.arange(100).reshape(-1, 1)).flatten() | |
| X, y = stock_model.create_sequences(test_data, time_step=30) | |
| model_status["sequence_creation"] = { | |
| "success": len(X) > 0, | |
| "X_shape": str(X.shape), | |
| "y_shape": str(y.shape), | |
| } | |
| except Exception as e: | |
| model_status["error"] = str(e) | |
| return jsonify({ | |
| "timestamp": datetime.now().isoformat(), | |
| "status": "OK", | |
| "environment": env_info, | |
| "api_status": api_status, | |
| "model_status": model_status, | |
| }) | |
| except Exception as e: | |
| return jsonify({"status": "ERROR", "error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| # [OPT-B] Read port from environment variable so the same binary works on: | |
| # • Render (sets $PORT automatically, usually 10000) | |
| # • Hugging Face (expects 7860) | |
| # • Local dev (falls back to 5001) | |
| port = int(os.environ.get("PORT", 5001)) | |
| print(f"Starting StockBuddy API on port {port}") | |
| app.run(host="0.0.0.0", port=port) | |