StockPredict / app.py
aromidvar1355's picture
Update app.py
fbaddf2 verified
import gradio as gr
import pandas as pd
from core.data import load_data
from core.model_runner import get_model
from core.plot import plot_forecast, plot_metrics_r2, plot_metrics_errors, plot_loss_curve, plot_future_forecast
from config import AVAILABLE_MODELS, DEFAULT_TICKERS
def main_interface():
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("# ๐Ÿ“ˆ AI Forecasting Studio")
with gr.Row():
with gr.Column(scale=1):
data_src = gr.Radio(["Yahoo Finance", "Upload CSV"], label="Data Source", value="Yahoo Finance")
ticker = gr.Dropdown(choices=DEFAULT_TICKERS, label="Ticker", value="BTC-USD")
file_upload = gr.File(label="Upload CSV", visible=False, file_types=[".csv"])
start_date = gr.Textbox(label="Start Date (YYYY-MM-DD)", value="2022-01-01")
end_date = gr.Textbox(label="End Date (YYYY-MM-DD)", value="2023-12-31")
horizon = gr.Slider(1, 15, step=1, label="Forecast Days", value=1)
gr.Markdown("## โš™๏ธ Model Settings")
model = gr.Dropdown(choices=AVAILABLE_MODELS, label="Model", value="LSTM")
hidden_units = gr.Slider(8, 512, label="Hidden Units", value=64)
n_layers = gr.Slider(1, 5, step=1, label="# Hidden Layers", value=2)
epochs = gr.Slider(10, 300, label="Epochs", value=100)
learning_rate = gr.Slider(1e-5, 0.01, label="Learning Rate", value=0.001)
beta1 = gr.Slider(0.8, 0.95, label="AdamW Beta1", value=0.9, step=0.01) # Added
beta2 = gr.Slider(0.9, 0.999, label="AdamW Beta2", value=0.999, step=0.001) # Added
weight_decay = gr.Slider(0.0, 0.1, label="Weight Decay", value=0.01, step=0.001) # Added
dropout = gr.Slider(0.0, 0.3, label="Drop Out", value=0.2)
window_size = gr.Slider(5, 90, label="Window Size", value=30)
test_split = gr.Slider(0.05, 0.5, label="Test Split", value=0.2)
run_btn = gr.Button("๐Ÿš€ Train & Predict")
status = gr.Textbox(label="Status", interactive=False, lines=2)
with gr.Column(scale=2):
backtest_plot = gr.Plot(label="๐Ÿ“Š Backtesting: Actual vs Forecast")
future_plot = gr.Plot(label="๐Ÿ”ฎ Future Forecast")
future_table = gr.Dataframe(label="๐Ÿ“‹ Future Predictions")
r2_plot = gr.Plot(label="๐Ÿ“‰ Rยฒ and MAPE Metrics") # Updated
error_plot = gr.Plot(label="๐Ÿ“‰ RMSE and MAE Metrics") # Updated
loss_plot = gr.Plot(label="๐Ÿ“ˆ Training Loss Curve")
def run_pipeline(data_src, ticker, file_upload, start_date, end_date, horizon, model,
hidden_units, n_layers, epochs, learning_rate, beta1, beta2, weight_decay,
dropout, window_size, test_split):
try:
pd.to_datetime(start_date)
pd.to_datetime(end_date)
source_key = "csv" if data_src == "Upload CSV" else "yahoo"
df = load_data(data_src=source_key, ticker=ticker, file_upload=file_upload, start=start_date, end=end_date)
if df is None or df.empty:
return None, None, None, None, None, None, "โŒ Failed to load data. Please check input."
result = get_model(
df=df,
model_name=model,
horizon=horizon,
hidden_units=hidden_units,
n_layers=n_layers,
epochs=epochs,
learning_rate=learning_rate,
beta1=beta1, # Added
beta2=beta2, # Added
weight_decay=weight_decay, # Added
dropout=dropout,
window_size=window_size,
test_split=test_split
)
forecast_plot = plot_forecast(result)
future_plot = plot_future_forecast(df, result)
r2_plot = plot_metrics_r2(result) # Updated
error_plot = plot_metrics_errors(result) # Updated
loss_plot = plot_loss_curve(result)
msg = "โœ… Done."
if "latest_prediction" in result:
last_date = df['Date'].iloc[-1]
future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=horizon, freq='B')
future_df = pd.DataFrame({'Date': future_dates, 'Predicted Value': result["latest_prediction"]})
msg += f" Next predicted value(s): {[f'{val:.2f}' for val in result['latest_prediction']]}"
else:
future_df = pd.DataFrame()
return forecast_plot, future_plot, future_df, r2_plot, error_plot, loss_plot, msg
except Exception as e:
return None, None, None, None, None, None, f"โŒ Error: {str(e)}"
run_btn.click(
fn=run_pipeline,
inputs=[
data_src, ticker, file_upload,
start_date, end_date, horizon, model,
hidden_units, n_layers, epochs, learning_rate,
beta1, beta2, weight_decay, dropout, window_size, test_split
],
outputs=[backtest_plot, future_plot, future_table, r2_plot, error_plot, loss_plot, status]
)
def toggle_file(src):
return gr.update(visible=(src == "Upload CSV"))
data_src.change(fn=toggle_file, inputs=[data_src], outputs=[file_upload])
return app
if __name__ == '__main__':
main_interface().launch()