Spaces:
Running
Running
| import gradio as gr | |
| import tensorflow as tf | |
| import joblib | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import yfinance as yf | |
| from huggingface_hub import hf_hub_download | |
| # --- 1. Download Model and Scalers from Hugging Face Hub --- | |
| # This is better than manually uploading them. The Space will fetch them automatically. | |
| MODEL_REPO = "munem420/stock-forecaster-lstm" | |
| MODEL_FILENAME = "model_lstm.h5" | |
| SCALER_FILENAME = "scalers.joblib" | |
| print("--- Downloading model and scalers ---") | |
| try: | |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) | |
| scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME) | |
| print("β Files downloaded successfully.") | |
| except Exception as e: | |
| print(f"β Error downloading files: {e}") | |
| model_path, scalers_path = None, None | |
| # --- 2. Load the Model and Scalers --- | |
| loaded_model_lstm = None | |
| loaded_scalers = None | |
| if model_path and os.path.exists(model_path): | |
| try: | |
| loaded_model_lstm = tf.keras.models.load_model(model_path) | |
| print("β Model loaded successfully.") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| if scalers_path and os.path.exists(scalers_path): | |
| try: | |
| loaded_scalers = joblib.load(scalers_path) | |
| print("β Scalers loaded successfully.") | |
| except Exception as e: | |
| print(f"β Error loading scalers: {e}") | |
| # This dictionary is part of the original model's logic. | |
| # A more robust solution would fetch this dynamically or store it better. | |
| ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'} | |
| def get_ticker_from_input(input_name): | |
| # Simplified version for this app | |
| return input_name.upper() | |
| # --- 3. The Main Forecasting Function (Adapted from your code) --- | |
| def forecast_stock(input_name, model, scalers_dict, input_width=60): | |
| if not model or not scalers_dict: | |
| return "Error: Model or scalers not loaded." | |
| ticker = get_ticker_from_input(input_name) | |
| if not ticker: | |
| return "Error: Invalid stock ticker." | |
| print(f"\n--- Generating forecast for {ticker} ---") | |
| # Fetch recent data using yfinance | |
| try: | |
| data_df = yf.download(ticker, period="1y", progress=False) | |
| if data_df.empty: | |
| return f"Error: No data found for ticker {ticker}. It may be delisted or invalid." | |
| except Exception as e: | |
| return f"Error fetching data for {ticker}: {e}" | |
| if len(data_df) < input_width: | |
| return f"Error: Not enough historical data for {ticker}. Need {input_width} days, but only have {len(data_df)}." | |
| recent_data = data_df.tail(input_width) | |
| close_prices = recent_data['Close'].values.reshape(-1, 1) | |
| # Note: The original scalers were trained on specific stocks. | |
| # Using a scaler for a different stock (e.g., AAPL) on a new ticker might not be accurate. | |
| # For this example, we'll try to find a matching scaler or default to a common one. | |
| scaler = scalers_dict.get(ticker) | |
| if not scaler: | |
| print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.") | |
| scaler = scalers_dict.get('ZURVY') | |
| if not scaler: | |
| return "Error: Default scaler 'ZURVY' not found." | |
| scaled_data = scaler.transform(close_prices) | |
| X_pred = scaled_data.reshape(1, input_width, 1) | |
| prediction_scaled = model.predict(X_pred, verbose=0)[0][0] | |
| prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0] | |
| last_close = recent_data['Close'].iloc[-1] | |
| result = ( | |
| f"Last known close for {ticker}: ${last_close:.2f}\n" | |
| f"Predicted next day's close price: ${prediction_actual:.2f}" | |
| ) | |
| print(result) | |
| return result | |
| # --- 4. Create the Gradio Interface --- | |
| # We create a simple function that Gradio can expose as an API endpoint. | |
| def predict_api(ticker_symbol): | |
| return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers) | |
| # We use a dummy Gradio interface because we only need its backend API capabilities. | |
| # The `gr.Blocks()` allows us to run the server without displaying a UI. | |
| with gr.Blocks() as app: | |
| gr.Markdown("This is the backend for the React Stock Forecaster App.") | |
| # This creates an API endpoint at `/run/predict` | |
| ticker_input = gr.Textbox(label="Stock Ticker", visible=False) | |
| output_text = gr.Textbox(label="Forecast", visible=False) | |
| # The Gradio API function must be tied to an event | |
| # We will call this endpoint from our React app. | |
| ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict") | |
| # --- 5. Mount and Serve the React App's Static Files --- | |
| # Before running this, you must build your React app using `npm run build`. | |
| # This will create a `build` directory with static files. | |
| # Gradio will serve the `index.html` from this directory. | |
| app.mount_static_directory("./build") | |
| # Launch the server | |
| app.launch() |