Spaces:
Running
Running
File size: 3,395 Bytes
77fb042 cd6bba8 6d7a718 77fb042 6d7a718 4b9dc02 6d7a718 05f5d59 6d7a718 4b9dc02 6d7a718 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import gradio as gr
import tensorflow as tf
import joblib
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
MODEL_REPO = "munem420/stock-forecaster-lstm"
MODEL_FILENAME = "model_lstm.h5"
SCALER_FILENAME = "scalers.joblib"
print("--- Downloading model and scalers ---")
try:
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
print("β
Files downloaded successfully.")
except Exception as e:
print(f"β Error downloading files: {e}")
model_path, scalers_path = None, None
loaded_model_lstm = None
loaded_scalers = None
if model_path and os.path.exists(model_path):
try:
loaded_model_lstm = tf.keras.models.load_model(
model_path,
custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
)
print("β
Model loaded successfully.")
except Exception as e:
print(f"β Error loading model: {e}")
if scalers_path and os.path.exists(scalers_path):
try:
loaded_scalers = joblib.load(scalers_path)
print("β
Scalers loaded successfully.")
except Exception as e:
print(f"β Error loading scalers: {e}")
ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
def get_ticker_from_input(input_name):
return input_name.upper()
def forecast_stock(input_name, model, scalers_dict, input_width=60):
if not model or not scalers_dict:
return "Error: Model or scalers not loaded. The backend may have failed to start."
ticker = get_ticker_from_input(input_name)
if not ticker:
return "Error: Invalid stock ticker."
print(f"\n--- Generating forecast for {ticker} ---")
if len(data_df) < input_width:
return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
recent_data = data_df.tail(input_width)
close_prices = recent_data['Close'].values.reshape(-input, 1)
scaler = scalers_dict.get(ticker)
if not scaler:
print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
scaler = scalers_dict.get('ZURVY')
if not scaler:
return "Error: Default scaler 'ZURVY' not found."
scaled_data = scaler.transform(close_prices)
X_pred = scaled_data.reshape(1, input_width, 1)
prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
last_close = recent_data['Close'].iloc[-1]
result = (
f"Last known close for {ticker}: ${last_close:.2f}\n"
f"Predicted next day's close price: ${prediction_actual:.2f}"
)
print(result)
return result
def predict_api(ticker_symbol):
return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
with gr.Blocks() as app:
gr.Markdown("This is the backend for the React Stock Forecaster App.")
ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
output_text = gr.Textbox(label="Forecast", visible=False)
ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
app = gr.mount_static_directory(app, "build")
if __name__ == "__main__":
app.launch() |