Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,91 +1,76 @@
|
|
| 1 |
import os
|
| 2 |
-
import gradio as gr
|
| 3 |
-
import tensorflow as tf
|
| 4 |
-
import joblib
|
| 5 |
-
import numpy as np
|
| 6 |
import pandas as pd
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
|
| 18 |
-
scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
|
| 19 |
-
print("✅ Files downloaded successfully.")
|
| 20 |
-
except Exception as e:
|
| 21 |
-
print(f"❌ Error downloading files: {e}")
|
| 22 |
-
model_path, scalers_path = None, None
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
loaded_model_lstm = tf.keras.models.load_model(
|
| 30 |
-
model_path,
|
| 31 |
-
custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
|
| 32 |
-
)
|
| 33 |
-
print("✅ Model loaded successfully.")
|
| 34 |
-
except Exception as e:
|
| 35 |
-
print(f"❌ Error loading model: {e}")
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
except Exception as e:
|
| 42 |
-
print(f"❌ Error loading scalers: {e}")
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
def
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
def forecast_stock(input_name
|
| 50 |
-
if not model or not scalers_dict:
|
| 51 |
-
return "Error: Model or scalers not loaded. The backend may have failed to start."
|
| 52 |
ticker = get_ticker_from_input(input_name)
|
| 53 |
if not ticker:
|
| 54 |
-
return "
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
recent_data = data_df.tail(input_width)
|
| 60 |
-
close_prices = recent_data['Close'].values.reshape(-input, 1)
|
| 61 |
-
scaler = scalers_dict.get(ticker)
|
| 62 |
-
if not scaler:
|
| 63 |
-
print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
|
| 64 |
-
scaler = scalers_dict.get('ZURVY')
|
| 65 |
-
if not scaler:
|
| 66 |
-
return "Error: Default scaler 'ZURVY' not found."
|
| 67 |
scaled_data = scaler.transform(close_prices)
|
| 68 |
-
X_pred = scaled_data.reshape(1,
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
f"Last known close for {ticker}: ${last_close:.2f}\n"
|
| 74 |
-
f"Predicted next day's close price: ${prediction_actual:.2f}"
|
| 75 |
-
)
|
| 76 |
-
print(result)
|
| 77 |
-
return result
|
| 78 |
-
|
| 79 |
-
def predict_api(ticker_symbol):
|
| 80 |
-
return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
|
| 81 |
-
|
| 82 |
-
with gr.Blocks() as app:
|
| 83 |
-
gr.Markdown("This is the backend for the React Stock Forecaster App.")
|
| 84 |
-
ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
|
| 85 |
-
output_text = gr.Textbox(label="Forecast", visible=False)
|
| 86 |
-
ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
|
| 87 |
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
|
| 91 |
-
app.launch()
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import joblib
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
from fuzzywuzzy import process
|
| 7 |
+
import gradio as gr
|
| 8 |
|
| 9 |
+
# --- Load data ---
|
| 10 |
+
DATA_PATH = "/kaggle/working/all_stocks_long.csv" # update if needed
|
| 11 |
+
MODEL_DIR = "/kaggle/working/stock-forecaster-lstm"
|
| 12 |
+
MODEL_FILE = "model_lstm.h5"
|
| 13 |
+
SCALER_FILE = "scalers.joblib"
|
| 14 |
|
| 15 |
+
combined_fe = pd.read_csv(DATA_PATH, parse_dates=['date'])
|
| 16 |
+
combined_fe['ticker'] = combined_fe['ticker'].str.upper()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
# --- Load model ---
|
| 19 |
+
model_path = os.path.join(MODEL_DIR, MODEL_FILE)
|
| 20 |
+
scaler_path = os.path.join(MODEL_DIR, SCALER_FILE)
|
| 21 |
|
| 22 |
+
loaded_model = tf.keras.models.load_model(model_path)
|
| 23 |
+
loaded_scalers = joblib.load(scaler_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
# --- Build ticker <-> company mappings ---
|
| 26 |
+
top_tickers = combined_fe['ticker'].unique()
|
| 27 |
+
ticker_to_name = {t: t for t in top_tickers} # can update with real names if available
|
| 28 |
+
name_to_ticker = {v: k for k,v in ticker_to_name.items()}
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
# --- Prediction helpers ---
|
| 31 |
+
def get_ticker_from_input(input_str):
|
| 32 |
+
if input_str.upper() in ticker_to_name:
|
| 33 |
+
return input_str.upper()
|
| 34 |
+
if input_str in name_to_ticker:
|
| 35 |
+
return name_to_ticker[input_str]
|
| 36 |
+
best_match, score = process.extractOne(input_str, name_to_ticker.keys())
|
| 37 |
+
if score > 80:
|
| 38 |
+
return name_to_ticker[best_match]
|
| 39 |
+
return None
|
| 40 |
|
| 41 |
+
def make_windows(series, input_width=60, horizon=1):
|
| 42 |
+
arr = series.values.astype(np.float32)
|
| 43 |
+
X, y = [], []
|
| 44 |
+
for i in range(input_width, len(arr)-horizon+1):
|
| 45 |
+
X.append(arr[i-input_width:i])
|
| 46 |
+
y.append(arr[i + (horizon-1)])
|
| 47 |
+
return np.array(X), np.array(y)
|
| 48 |
|
| 49 |
+
def forecast_stock(input_name):
|
|
|
|
|
|
|
| 50 |
ticker = get_ticker_from_input(input_name)
|
| 51 |
if not ticker:
|
| 52 |
+
return f"Could not find ticker for '{input_name}'"
|
| 53 |
+
|
| 54 |
+
recent_data = combined_fe[combined_fe['ticker']==ticker].sort_values('date').tail(60)
|
| 55 |
+
if len(recent_data) < 60:
|
| 56 |
+
return f"Not enough historical data for {ticker}."
|
| 57 |
|
| 58 |
+
close_prices = recent_data['close'].values.reshape(-1, 1)
|
| 59 |
+
scaler = loaded_scalers[ticker]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
scaled_data = scaler.transform(close_prices)
|
| 61 |
+
X_pred = scaled_data.reshape(1, 60, 1)
|
| 62 |
+
|
| 63 |
+
prediction_scaled = loaded_model.predict(X_pred, verbose=0)[0][0]
|
| 64 |
+
prediction_actual = scaler.inverse_transform([[prediction_scaled]])[0][0]
|
| 65 |
+
return round(prediction_actual, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
# --- Gradio Interface ---
|
| 68 |
+
iface = gr.Interface(
|
| 69 |
+
fn=forecast_stock,
|
| 70 |
+
inputs=gr.Textbox(label="Enter Ticker or Company Name"),
|
| 71 |
+
outputs=gr.Textbox(label="Predicted Next Day Close"),
|
| 72 |
+
title="Stock Price Forecaster (LSTM)",
|
| 73 |
+
description="Enter a stock ticker or company name to predict the next day's close price."
|
| 74 |
+
)
|
| 75 |
|
| 76 |
+
iface.launch()
|
|
|