danielthatu12's picture
Fixed app.py file
8ef1b90
"""
app.py – StockBuddy Flask API
=================================
LIGHTWEIGHT CHANGES vs original:
[OPT-A] Removed the startup TF validation model (was creating & running a test
LSTM on every cold start – wastes ~10 s and ~100 MB RAM on free tier).
Replaced with a simple tf.constant() smoke-test.
[OPT-B] PORT is now read from the PORT environment variable so the server
works on Render (sets $PORT automatically) and Hugging Face Spaces
(expects port 7860) without code changes.
[OPT-C] time_step updated to 30 throughout (was 45) to match the lighter model.
All REST API routes are unchanged from the original.
"""
from flask import Flask, request, jsonify
from flask_cors import CORS
import numpy as np
import pandas as pd
import os
import threading
import time
from datetime import datetime, timedelta
import json
import model as stock_model
import sys
import requests
import traceback
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.callbacks import Callback
import tensorflow as tf
import xgboost as xgb
app = Flask(__name__)
CORS(app)
def read_json_dict_from_request():
"""
Parse JSON request body without relying on request.json alone.
Hugging Face / some proxies or clients occasionally send bodies that trip
Flask's default parser (BOM, odd encoding, or malformed JSON). This returns
(dict, None) on success or (None, error_message) on failure, and logs a
short preview for server-side debugging.
"""
raw = request.get_data(cache=True)
if not raw:
return None, "Empty request body"
mimetype = (request.mimetype or "").lower()
if mimetype and mimetype != "application/json":
# Spring/Java normally sends application/json; log odd types for debugging
print(f"POST /api/predict unexpected mimetype={request.mimetype!r} content_type={request.content_type!r}")
try:
text = raw.decode("utf-8-sig")
except UnicodeDecodeError as e:
return None, f"Body is not valid UTF-8: {e}"
text = text.strip()
if not text:
return None, "Empty body after decode"
try:
data = json.loads(text)
except json.JSONDecodeError as e:
preview = text[:400].replace("\r", " ").replace("\n", " ")
print(
f"POST /api/predict JSONDecodeError: {e.msg} (line {e.lineno}, col {e.colno}); "
f"len={len(text)} preview={preview!r}"
)
return None, (
f"Invalid JSON ({e.msg} at line {e.lineno}, column {e.colno}). "
"Send strict JSON with double-quoted keys, e.g. "
'{"userId":"...","symbol":"AAPL","daysAhead":5}. '
f"Body preview: {preview!r}"
)
if not isinstance(data, dict):
return None, f"JSON root must be an object, got {type(data).__name__}"
return data, None
@app.route("/", methods=["GET"])
def home():
return jsonify({"status": "running", "message": "StockBuddy API is live!"})
# [OPT-A] Lightweight TF smoke-test instead of building & running a full LSTM
def validate_tensorflow():
"""Quick TensorFlow sanity-check (no model created, no GPU required)."""
try:
print("TensorFlow version:", tf.__version__)
# A tiny constant operation is enough to confirm TF is importable and
# the runtime works. Full model creation is deferred to the first
# prediction request so the cold-start is fast on free-tier hosts.
_ = tf.constant([1.0, 2.0, 3.0])
gpus = tf.config.list_physical_devices("GPU")
if gpus:
msg = f"GPU available ({len(gpus)} device(s)) – running in GPU mode."
else:
msg = "No GPU detected – running in CPU mode (expected on free tier)."
print(f"TensorFlow OK: {msg}")
return True, msg
except Exception as e:
print(f"TensorFlow validation failed: {e}")
return False, f"TensorFlow error: {e}"
# Run smoke-test at startup
tf_status, tf_message = validate_tensorflow()
if not tf_status:
print(f"WARNING: {tf_message}")
else:
print(f"TensorFlow validation: {tf_message}")
# Dictionary to store running prediction tasks
prediction_tasks = {}
class PredictionTask:
def __init__(self, user_id, symbol, days_ahead):
self.user_id = user_id
self.symbol = symbol
self.days_ahead = days_ahead
self.progress = 0
self.status = "pending"
self.result = None
self.sentiment_result = None
self.thread = None
self.stop_requested = False
self.stop_acknowledged = False
# Unique task ID: millisecond timestamp + random hex suffix
timestamp = int(time.time() * 1000)
random_suffix = os.urandom(4).hex()
self.task_id = f"{user_id}_{symbol}_{timestamp}_{random_suffix}"
def run(self):
self.thread = threading.Thread(target=self._run_prediction)
self.thread.daemon = True
self.thread.start()
return self.task_id
def is_stop_requested(self):
"""Callback for model training loops to poll stop flag."""
if self.stop_requested and not self.stop_acknowledged:
self.stop_acknowledged = True
self.status = "stopped"
return True
return self.stop_requested
def _run_prediction(self):
try:
print(f"Starting prediction for {self.symbol} (task: {self.task_id})")
self.status = "running"
self.progress = 10
# ── Fetch historical data ────────────────────────────────────────
print(f"Fetching historical data for {self.symbol}...")
try:
data = stock_model.fetch_stock_data(self.symbol, outputsize="compact")
print(f"Fetched {len(data)} rows for {self.symbol}")
except Exception as e:
print(f"Data fetch error: {e}")
self.status = "failed"
self.result = {"error": f"Could not fetch data for {self.symbol}: {e}"}
return
if data is None:
self.status = "failed"
self.result = {"error": f"Could not fetch data for {self.symbol}"}
return
if self.stop_requested:
self.status = "stopped"; return
if len(data) < 60:
self.status = "failed"
self.result = {"error": f"Insufficient data for {self.symbol} "
f"(got {len(data)}, need ≥60)"}
return
# ── Extract last actual close ────────────────────────────────────
try:
if isinstance(data, pd.DataFrame) and "Close" in data.columns:
last_actual_close = float(data["Close"].iloc[-1])
last_date = data.index[-1]
else:
last_actual_close = float(data.iloc[-1, 0])
last_date = data.index[-1]
print(f"Latest close for {self.symbol}: "
f"${last_actual_close:.2f} on {last_date.strftime('%Y-%m-%d')}")
except Exception as e:
self.status = "failed"
self.result = {"error": f"Error reading price data: {e}"}
return
self.progress = 20
if self.stop_requested:
self.status = "stopped"; return
# ── Sentiment analysis ───────────────────────────────────────────
try:
print(f"Fetching news for {self.symbol}...")
headlines = stock_model.fetch_finnhub_news(self.symbol)
print(f"Got {len(headlines)} headlines")
self.progress = 30
if self.stop_requested:
self.status = "stopped"; return
sentiment_results, sentiment_totals = \
stock_model.analyze_sentiment(headlines)
sentiment_summary = stock_model.generate_sentiment_summary(
sentiment_totals, headlines, self.symbol)
self.sentiment_result = {
"totals": sentiment_totals,
"summary": sentiment_summary,
}
except Exception as e:
print(f"Sentiment error (non-fatal): {e}")
self.sentiment_result = {
"totals": {"positive": 0, "negative": 0, "neutral": 0},
"summary": f"Unable to analyse sentiment: {e}",
}
self.progress = 40
if self.stop_requested:
self.status = "stopped"; return
# ── Preprocess data ──────────────────────────────────────────────
try:
print("Preprocessing data...")
scaled_data, scaler = stock_model.preprocess_data(data)
# [OPT-C] time_step 45 → 30
time_step = 30
X, y = stock_model.create_sequences(scaled_data, time_step)
print(f"Sequences: X={X.shape}, y={y.shape}")
except Exception as e:
self.status = "failed"
self.result = {"error": f"Preprocessing failed: {e}"}
return
if len(X) == 0:
self.status = "failed"
self.result = {"error": f"Could not create training sequences for {self.symbol}"}
return
self.progress = 50
if self.stop_requested:
self.status = "stopped"; return
# ── Train LSTM ───────────────────────────────────────────────────
try:
train_size = int(len(X) * 0.8)
if train_size == 0:
self.status = "failed"
self.result = {"error": "Not enough data to split for training"}
return
X_train, y_train = X[:train_size], y[:train_size]
self.progress = 55
print(f"Training LSTM with {len(X_train)} samples...")
lstm_model = stock_model.train_lstm(
X_train, y_train, time_step, self.is_stop_requested)
except Exception as e:
self.status = "failed"
self.result = {"error": f"LSTM training failed: {e}"}
return
if self.stop_requested:
self.status = "stopped"; return
self.progress = 75
if self.stop_requested:
self.status = "stopped"; return
# ── Train XGBoost on residuals ───────────────────────────────────
try:
print("Calculating residuals for XGBoost...")
lstm_preds = lstm_model.predict(X_train, verbose=0).flatten()
residuals = y_train - lstm_preds
xgb_model = stock_model.train_xgboost(
X_train.reshape(X_train.shape[0], -1),
residuals,
self.is_stop_requested,
)
if self.stop_requested or xgb_model is None:
self.status = "stopped"; return
except Exception as e:
print(f"XGBoost training error (non-fatal): {e}")
xgb_model = None
self.progress = 90
if self.stop_requested:
self.status = "stopped"; return
# ── Generate predictions ─────────────────────────────────────────
try:
print(f"Generating {self.days_ahead}-day predictions...")
predictions = stock_model.predict_stock_price(
lstm_model, xgb_model, scaled_data, scaler,
time_step, self.days_ahead, self.is_stop_requested,
)
if self.stop_requested or predictions is None:
self.status = "stopped"; return
except Exception as e:
self.status = "failed"
self.result = {"error": f"Prediction generation failed: {e}"}
return
self.progress = 95
if self.stop_requested:
self.status = "stopped"; return
# ── Build future trading-day dates ───────────────────────────────
future_dates = []
for i in range(1, self.days_ahead + 1):
if self.stop_requested:
break
next_date = last_date + timedelta(days=i)
while next_date.weekday() > 4:
next_date += timedelta(days=1)
future_dates.append(next_date)
if self.stop_requested:
self.status = "stopped"; return
# Deduplicate dates
unique_future_dates = []
seen_dates = set()
for date in future_dates:
ds = date.strftime("%Y-%m-%d")
if ds not in seen_dates:
seen_dates.add(ds)
unique_future_dates.append(date)
# Pad if needed
while (len(unique_future_dates) < len(predictions)
and not self.stop_requested):
next_date = unique_future_dates[-1] + timedelta(days=1)
while next_date.weekday() > 4:
next_date += timedelta(days=1)
ds = next_date.strftime("%Y-%m-%d")
if ds not in seen_dates:
unique_future_dates.append(next_date)
seen_dates.add(ds)
if self.stop_requested:
self.status = "stopped"; return
unique_future_dates = unique_future_dates[: len(predictions)]
# ── Assemble result payload ──────────────────────────────────────
prediction_data = []
for i in range(min(len(unique_future_dates), len(predictions))):
predicted_price = float(predictions[i][0])
percent_change = (
(predicted_price - last_actual_close) / last_actual_close * 100
)
prediction_data.append({
"date": unique_future_dates[i].strftime("%Y-%m-%d"),
"price": round(predicted_price, 2),
"change": round(percent_change, 2),
})
self.result = {
"symbol": self.symbol,
"lastActualClose": {
"date": last_date.strftime("%Y-%m-%d"),
"price": round(last_actual_close, 2),
},
"predictions": prediction_data,
"sentiment": self.sentiment_result,
"tableDisplay": True,
}
self.progress = 100
self.status = "completed"
print(f"Prediction complete for {self.symbol}")
except Exception as e:
self.status = "failed"
self.result = {"error": str(e)}
print(f"Prediction task error: {e}")
traceback.print_exc()
# =============================================================================
# REST API ROUTES
# (all routes are identical to the original – no frontend changes needed)
# =============================================================================
@app.route("/api/predict", methods=["POST"])
def start_prediction():
try:
data, parse_err = read_json_dict_from_request()
if parse_err:
return jsonify({"error": "Invalid or missing JSON body", "details": parse_err}), 400
print(f"POST /api/predict body keys={list(data.keys())}")
user_id = data.get("userId")
symbol = data.get("symbol")
days_ahead = int(data.get("daysAhead", 5))
if not user_id or not symbol:
return jsonify({"error": "Missing required parameters (userId or symbol)"}), 400
if not isinstance(symbol, str) or len(symbol) > 10:
return jsonify({"error": f"Invalid symbol format: {symbol}"}), 400
if not tf_status:
return jsonify({
"error": f"Prediction service unavailable: {tf_message}",
"tf_status": tf_message,
}), 503
task = PredictionTask(user_id, symbol, days_ahead)
task_id = task.run()
prediction_tasks[task_id] = task
return jsonify({
"taskId": task_id,
"status": "pending",
"message": f"Prediction started for {symbol}",
})
except ValueError as e:
return jsonify({"error": str(e)}), 400
except Exception as e:
print(f"Critical error starting prediction: {e}")
traceback.print_exc()
return jsonify({"error": "Failed to start prediction", "details": str(e)}), 500
@app.route("/api/predict/status/<task_id>", methods=["GET"])
def prediction_status(task_id):
try:
task = prediction_tasks.get(task_id)
if not task:
return jsonify({"error": "Task not found"}), 404
try:
if task.status == "completed" and task.result:
if isinstance(task.result, dict):
if "predictions" in task.result and isinstance(
task.result["predictions"], list):
for pred in task.result["predictions"]:
if (not isinstance(pred, dict)
or "date" not in pred
or "price" not in pred):
task.status = "failed"
task.result = {"error": "Malformed prediction data"}
break
else:
task.status = "failed"
task.result = {"error": "Missing prediction data"}
else:
task.status = "failed"
task.result = {"error": "Invalid result format"}
return jsonify({
"taskId": task_id,
"status": task.status,
"progress": task.progress,
"result": task.result if task.status == "completed" else None,
})
except Exception as e:
print(f"Error generating status response: {e}")
return jsonify({
"taskId": task_id,
"status": "error",
"progress": task.progress,
"error": str(e),
})
except Exception as e:
print(f"Critical error in prediction status: {e}")
return jsonify({"taskId": task_id, "status": "error",
"error": "Server error"}), 500
@app.route("/api/predict/stop/<task_id>", methods=["POST"])
def stop_prediction(task_id):
task = prediction_tasks.get(task_id)
if not task:
return jsonify({"error": "Task not found"}), 404
task.stop_requested = True
if task.thread and task.thread.is_alive():
task.status = "stopping"
print(f"Stop requested for task {task_id} ({task.symbol})")
stop_wait_start = time.time()
while time.time() - stop_wait_start < 2:
if task.stop_acknowledged:
task.status = "stopped"
break
time.sleep(0.1)
else:
task.status = "stopped"
return jsonify({
"taskId": task_id,
"status": task.status,
"symbol": task.symbol,
"progress": task.progress,
"stopRequested": task.stop_requested,
"stopAcknowledged": task.stop_acknowledged,
})
@app.route("/api/predict/sentiment/<symbol>", methods=["GET"])
def get_sentiment(symbol):
try:
headlines = stock_model.fetch_finnhub_news(symbol)
sentiment_results, sentiment_totals = \
stock_model.analyze_sentiment(headlines)
sentiment_summary = stock_model.generate_sentiment_summary(
sentiment_totals, headlines, symbol)
return jsonify({
"symbol": symbol,
"sentiment": {
"totals": sentiment_totals,
"summary": sentiment_summary,
"period": 28,
},
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/diagnose", methods=["GET"])
def diagnose():
"""Diagnostic endpoint – checks environment, APIs and model primitives."""
try:
env_info = {
"python_version": sys.version,
"tensorflow_version": tf.__version__,
"numpy_version": np.__version__,
"pandas_version": pd.__version__,
"xgboost_version": xgb.__version__,
}
api_status = {}
try:
url = "https://www.alphavantage.co/query"
params = {
"function": "TIME_SERIES_DAILY",
"symbol": "AAPL",
"apikey": stock_model.ALPHAVANTAGE_API_KEY,
"outputsize": "compact",
"datatype": "json",
}
resp = requests.get(url, params=params)
rj = resp.json()
api_status["alpha_vantage"] = {
"status_code": resp.status_code,
"has_data": "Time Series (Daily)" in rj,
"error": rj.get("Error Message") or rj.get("Note")
if "Time Series (Daily)" not in rj else None,
}
except Exception as e:
api_status["alpha_vantage"] = {"error": str(e)}
try:
headers = {"X-Finnhub-Token": stock_model.FINNHUB_API_KEY}
resp = requests.get(
"https://finnhub.io/api/v1/news?category=general",
headers=headers)
api_status["finnhub"] = {
"status_code": resp.status_code,
"has_data": len(resp.json()) > 0,
"error": None if resp.status_code == 200 else str(resp.text),
}
except Exception as e:
api_status["finnhub"] = {"error": str(e)}
model_status = {}
try:
test_data = np.random.rand(100, 6) # 6 features (OPT-2)
test_scaler = MinMaxScaler()
test_data[:, 0] = test_scaler.fit_transform(
np.arange(100).reshape(-1, 1)).flatten()
X, y = stock_model.create_sequences(test_data, time_step=30)
model_status["sequence_creation"] = {
"success": len(X) > 0,
"X_shape": str(X.shape),
"y_shape": str(y.shape),
}
except Exception as e:
model_status["error"] = str(e)
return jsonify({
"timestamp": datetime.now().isoformat(),
"status": "OK",
"environment": env_info,
"api_status": api_status,
"model_status": model_status,
})
except Exception as e:
return jsonify({"status": "ERROR", "error": str(e)}), 500
if __name__ == "__main__":
# [OPT-B] Read port from environment variable so the same binary works on:
# • Render (sets $PORT automatically, usually 10000)
# • Hugging Face (expects 7860)
# • Local dev (falls back to 5001)
port = int(os.environ.get("PORT", 5001))
print(f"Starting StockBuddy API on port {port}")
app.run(host="0.0.0.0", port=port)