Spaces:
Running
Running
| import os | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import tensorflow as tf | |
| from fuzzywuzzy import process | |
| import gradio as gr | |
| # --- Load data --- | |
| DATA_PATH = "/kaggle/working/all_stocks_long.csv" # update if needed | |
| MODEL_DIR = "/kaggle/working/stock-forecaster-lstm" | |
| MODEL_FILE = "model_lstm.h5" | |
| SCALER_FILE = "scalers.joblib" | |
| combined_fe = pd.read_csv(DATA_PATH, parse_dates=['date']) | |
| combined_fe['ticker'] = combined_fe['ticker'].str.upper() | |
| # --- Load model --- | |
| model_path = os.path.join(MODEL_DIR, MODEL_FILE) | |
| scaler_path = os.path.join(MODEL_DIR, SCALER_FILE) | |
| loaded_model = tf.keras.models.load_model(model_path) | |
| loaded_scalers = joblib.load(scaler_path) | |
| # --- Build ticker <-> company mappings --- | |
| top_tickers = combined_fe['ticker'].unique() | |
| ticker_to_name = {t: t for t in top_tickers} # can update with real names if available | |
| name_to_ticker = {v: k for k,v in ticker_to_name.items()} | |
| # --- Prediction helpers --- | |
| def get_ticker_from_input(input_str): | |
| if input_str.upper() in ticker_to_name: | |
| return input_str.upper() | |
| if input_str in name_to_ticker: | |
| return name_to_ticker[input_str] | |
| best_match, score = process.extractOne(input_str, name_to_ticker.keys()) | |
| if score > 80: | |
| return name_to_ticker[best_match] | |
| return None | |
| def make_windows(series, input_width=60, horizon=1): | |
| arr = series.values.astype(np.float32) | |
| X, y = [], [] | |
| for i in range(input_width, len(arr)-horizon+1): | |
| X.append(arr[i-input_width:i]) | |
| y.append(arr[i + (horizon-1)]) | |
| return np.array(X), np.array(y) | |
| def forecast_stock(input_name): | |
| ticker = get_ticker_from_input(input_name) | |
| if not ticker: | |
| return f"Could not find ticker for '{input_name}'" | |
| recent_data = combined_fe[combined_fe['ticker']==ticker].sort_values('date').tail(60) | |
| if len(recent_data) < 60: | |
| return f"Not enough historical data for {ticker}." | |
| close_prices = recent_data['close'].values.reshape(-1, 1) | |
| scaler = loaded_scalers[ticker] | |
| scaled_data = scaler.transform(close_prices) | |
| X_pred = scaled_data.reshape(1, 60, 1) | |
| prediction_scaled = loaded_model.predict(X_pred, verbose=0)[0][0] | |
| prediction_actual = scaler.inverse_transform([[prediction_scaled]])[0][0] | |
| return round(prediction_actual, 2) | |
| # --- Gradio Interface --- | |
| iface = gr.Interface( | |
| fn=forecast_stock, | |
| inputs=gr.Textbox(label="Enter Ticker or Company Name"), | |
| outputs=gr.Textbox(label="Predicted Next Day Close"), | |
| title="Stock Price Forecaster (LSTM)", | |
| description="Enter a stock ticker or company name to predict the next day's close price." | |
| ) | |
| iface.launch() | |