File size: 5,362 Bytes
17ef946
 
 
 
 
 
 
 
 
eeb9a7a
 
17ef946
 
eeb9a7a
17ef946
eeb9a7a
17ef946
 
eeb9a7a
17ef946
 
 
 
 
eeb9a7a
 
17ef946
 
eeb9a7a
17ef946
 
 
 
 
 
eeb9a7a
17ef946
eeb9a7a
17ef946
 
 
 
 
 
eeb9a7a
 
 
 
 
 
 
 
 
 
 
 
17ef946
eeb9a7a
17ef946
 
eeb9a7a
 
17ef946
eeb9a7a
 
17ef946
eeb9a7a
17ef946
eeb9a7a
17ef946
 
eeb9a7a
17ef946
eeb9a7a
17ef946
 
 
eeb9a7a
 
 
17ef946
 
eeb9a7a
17ef946
eeb9a7a
17ef946
eeb9a7a
 
 
 
17ef946
eeb9a7a
 
 
 
17ef946
eeb9a7a
17ef946
eeb9a7a
 
 
 
17ef946
eeb9a7a
 
 
 
 
 
 
 
 
 
 
 
17ef946
 
eeb9a7a
 
 
 
 
 
 
 
17ef946
eeb9a7a
 
17ef946
 
eeb9a7a
17ef946
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import gradio as gr
import tensorflow as tf
import joblib
import numpy as np
import pandas as pd
import yfinance as yf
from huggingface_hub import hf_hub_download

# --- 0. Force CPU-only mode for TensorFlow ---
# This prevents TensorFlow from trying to allocate GPU memory on a CPU-only instance.
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

# --- 1. Define Constants and Download Model/Scalers from Hugging Face Hub ---
MODEL_REPO = "munem420/stock-forecaster-lstm"
MODEL_FILENAME = "model_lstm.keras"
SCALER_FILENAME = "scalers.joblib"

print("--- Downloading model and scalers from Hugging Face Hub ---")
try:
    model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
    scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
    print("βœ… Files downloaded successfully.")
except Exception as e:
    print(f"❌ Critical Error: Could not download files from the Hub. {e}")
    # Set paths to None so the app knows that loading failed.
    model_path, scalers_path = None, None

# --- 2. Load the Model and Scalers into Memory ---
loaded_model_lstm = None
loaded_scalers = None

if model_path and os.path.exists(model_path):
    try:
        loaded_model_lstm = tf.keras.models.load_model(model_path)
        print("βœ… TensorFlow model loaded successfully.")
    except Exception as e:
        print(f"❌ Critical Error: Could not load the TensorFlow model. {e}")

if scalers_path and os.path.exists(scalers_path):
    try:
        loaded_scalers = joblib.load(scalers_path)
        print("βœ… Scalers loaded successfully.")
    except Exception as e:
        print(f"❌ Critical Error: Could not load the scalers file. {e}")

# --- 3. The Core Forecasting Function ---
def forecast_stock(input_name: str, input_width: int = 60) -> str:
    """
    Fetches live stock data, preprocesses it, and returns a prediction string.
    """
    # Fail fast if the model/scalers didn't load during startup
    if not loaded_model_lstm or not loaded_scalers:
        return "Error: Model or scalers are not loaded. The backend may have failed to start correctly. Check the Space logs."

    ticker = input_name.strip().upper()
    if not ticker:
        return "Error: Please enter a stock ticker."

    print(f"\n--- Generating forecast for {ticker} ---")
    
    # Fetch recent data using yfinance
    try:
        # Fetch more than needed to ensure we have enough valid trading days
        data_df = yf.download(ticker, period="200d", progress=False)
        if data_df.empty:
             return f"Error: No data found for ticker '{ticker}'. It may be delisted or an invalid symbol."
    except Exception as e:
        return f"Error fetching data for '{ticker}': {e}"

    if len(data_df) < input_width:
        return f"Error: Not enough historical data for {ticker}. Need {input_width} days, but only found {len(data_df)}."

    # Prepare the data for the model
    recent_data = data_df.tail(input_width)
    close_prices = recent_data['Close'].values.reshape(-1, 1)

    # Find the correct scaler. The original model was trained on specific stocks.
    # We try to find a matching scaler, otherwise, we use a default as a fallback.
    scaler = loaded_scalers.get(ticker)
    if not scaler:
        print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
        scaler = loaded_scalers.get('ZURVY')
        if not scaler:
            return "Error: Critical failure. The default 'ZURVY' scaler could not be found."

    # Scale the data and make a prediction
    try:
        scaled_data = scaler.transform(close_prices)
        X_pred = scaled_data.reshape(1, input_width, 1) # Reshape for LSTM: [batch, timesteps, features]

        prediction_scaled = loaded_model_lstm.predict(X_pred, verbose=0)[0][0]
        prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
    except Exception as e:
        return f"An error occurred during model prediction: {e}"

    # Format the final result
    last_close = recent_data['Close'].iloc[-1]
    result_str = (
        f"Forecast for: {ticker}\n"
        f"Last Close Price: ${last_close:.2f}\n"
        f"Predicted Next Day's Close: ${prediction_actual:.2f}"
    )
    print(result_str)
    return result_str

# --- 4. Create the Gradio Interface and API Endpoint ---
def predict_api(ticker_symbol: str) -> str:
    """A simple wrapper for the main forecast function to be exposed as an API."""
    return forecast_stock(ticker_symbol)

with gr.Blocks(title="Stock Forecaster Backend") as app:
    gr.Markdown("## Stock Forecaster Backend\nThis Gradio app serves the API for the React frontend.")
    
    # These components are not visible but are required to create the API endpoint
    ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
    output_text = gr.Textbox(label="Forecast", visible=False)
    
    # This creates the API endpoint at /run/predict
    ticker_input.submit(
        fn=predict_api, 
        inputs=[ticker_input], 
        outputs=[output_text], 
        api_name="predict"
    )

# --- 5. Mount the static React build directory to be served ---
# This requires a recent version of Gradio (e.g., 4.x), specified in README.md
app = gr.mount_static_directory(app, "build")

# --- 6. Launch the Gradio App ---
if __name__ == "__main__":
    app.launch()