File size: 3,395 Bytes
77fb042
cd6bba8
6d7a718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77fb042
6d7a718
 
 
 
 
 
 
 
 
 
 
 
 
 
4b9dc02
 
6d7a718
 
 
 
 
05f5d59
6d7a718
 
 
 
 
 
 
 
 
 
4b9dc02
6d7a718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import gradio as gr
import tensorflow as tf
import joblib
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'



MODEL_REPO = "munem420/stock-forecaster-lstm"
MODEL_FILENAME = "model_lstm.h5"
SCALER_FILENAME = "scalers.joblib"



print("--- Downloading model and scalers ---")
try:
    model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
    scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
    print("βœ… Files downloaded successfully.")
except Exception as e:
    print(f"❌ Error downloading files: {e}")
    model_path, scalers_path = None, None

loaded_model_lstm = None
loaded_scalers = None


if model_path and os.path.exists(model_path):
    try:
        loaded_model_lstm = tf.keras.models.load_model(
            model_path,
            custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
        )
        print("βœ… Model loaded successfully.")
    except Exception as e:
        print(f"❌ Error loading model: {e}")

if scalers_path and os.path.exists(scalers_path):
    try:
        loaded_scalers = joblib.load(scalers_path)
        print("βœ… Scalers loaded successfully.")
    except Exception as e:
        print(f"❌ Error loading scalers: {e}")

ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}










def get_ticker_from_input(input_name):
    return input_name.upper()






def forecast_stock(input_name, model, scalers_dict, input_width=60):
    if not model or not scalers_dict:
        return "Error: Model or scalers not loaded. The backend may have failed to start."
    ticker = get_ticker_from_input(input_name)
    if not ticker:
        return "Error: Invalid stock ticker."
    print(f"\n--- Generating forecast for {ticker} ---")



    
    if len(data_df) < input_width:
        return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
    recent_data = data_df.tail(input_width)
    close_prices = recent_data['Close'].values.reshape(-input, 1)
    scaler = scalers_dict.get(ticker)
    if not scaler:
        print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
        scaler = scalers_dict.get('ZURVY')
        if not scaler:
            return "Error: Default scaler 'ZURVY' not found."
    scaled_data = scaler.transform(close_prices)
    X_pred = scaled_data.reshape(1, input_width, 1)
    prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
    prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
    last_close = recent_data['Close'].iloc[-1]
    result = (
        f"Last known close for {ticker}: ${last_close:.2f}\n"
        f"Predicted next day's close price: ${prediction_actual:.2f}"
    )
    print(result)
    return result

def predict_api(ticker_symbol):
    return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)

with gr.Blocks() as app:
    gr.Markdown("This is the backend for the React Stock Forecaster App.")
    ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
    output_text = gr.Textbox(label="Forecast", visible=False)
    ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")

app = gr.mount_static_directory(app, "build")








if __name__ == "__main__":
    app.launch()