Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from core.data import load_data | |
| from core.model_runner import get_model | |
| from core.plot import plot_forecast, plot_metrics_r2, plot_metrics_errors, plot_loss_curve, plot_future_forecast | |
| from config import AVAILABLE_MODELS, DEFAULT_TICKERS | |
| def main_interface(): | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# ๐ AI Forecasting Studio") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| data_src = gr.Radio(["Yahoo Finance", "Upload CSV"], label="Data Source", value="Yahoo Finance") | |
| ticker = gr.Dropdown(choices=DEFAULT_TICKERS, label="Ticker", value="BTC-USD") | |
| file_upload = gr.File(label="Upload CSV", visible=False, file_types=[".csv"]) | |
| start_date = gr.Textbox(label="Start Date (YYYY-MM-DD)", value="2022-01-01") | |
| end_date = gr.Textbox(label="End Date (YYYY-MM-DD)", value="2023-12-31") | |
| horizon = gr.Slider(1, 15, step=1, label="Forecast Days", value=1) | |
| gr.Markdown("## โ๏ธ Model Settings") | |
| model = gr.Dropdown(choices=AVAILABLE_MODELS, label="Model", value="LSTM") | |
| hidden_units = gr.Slider(8, 512, label="Hidden Units", value=64) | |
| n_layers = gr.Slider(1, 5, step=1, label="# Hidden Layers", value=2) | |
| epochs = gr.Slider(10, 300, label="Epochs", value=100) | |
| learning_rate = gr.Slider(1e-5, 0.01, label="Learning Rate", value=0.001) | |
| beta1 = gr.Slider(0.8, 0.95, label="AdamW Beta1", value=0.9, step=0.01) # Added | |
| beta2 = gr.Slider(0.9, 0.999, label="AdamW Beta2", value=0.999, step=0.001) # Added | |
| weight_decay = gr.Slider(0.0, 0.1, label="Weight Decay", value=0.01, step=0.001) # Added | |
| dropout = gr.Slider(0.0, 0.3, label="Drop Out", value=0.2) | |
| window_size = gr.Slider(5, 90, label="Window Size", value=30) | |
| test_split = gr.Slider(0.05, 0.5, label="Test Split", value=0.2) | |
| run_btn = gr.Button("๐ Train & Predict") | |
| status = gr.Textbox(label="Status", interactive=False, lines=2) | |
| with gr.Column(scale=2): | |
| backtest_plot = gr.Plot(label="๐ Backtesting: Actual vs Forecast") | |
| future_plot = gr.Plot(label="๐ฎ Future Forecast") | |
| future_table = gr.Dataframe(label="๐ Future Predictions") | |
| r2_plot = gr.Plot(label="๐ Rยฒ and MAPE Metrics") # Updated | |
| error_plot = gr.Plot(label="๐ RMSE and MAE Metrics") # Updated | |
| loss_plot = gr.Plot(label="๐ Training Loss Curve") | |
| def run_pipeline(data_src, ticker, file_upload, start_date, end_date, horizon, model, | |
| hidden_units, n_layers, epochs, learning_rate, beta1, beta2, weight_decay, | |
| dropout, window_size, test_split): | |
| try: | |
| pd.to_datetime(start_date) | |
| pd.to_datetime(end_date) | |
| source_key = "csv" if data_src == "Upload CSV" else "yahoo" | |
| df = load_data(data_src=source_key, ticker=ticker, file_upload=file_upload, start=start_date, end=end_date) | |
| if df is None or df.empty: | |
| return None, None, None, None, None, None, "โ Failed to load data. Please check input." | |
| result = get_model( | |
| df=df, | |
| model_name=model, | |
| horizon=horizon, | |
| hidden_units=hidden_units, | |
| n_layers=n_layers, | |
| epochs=epochs, | |
| learning_rate=learning_rate, | |
| beta1=beta1, # Added | |
| beta2=beta2, # Added | |
| weight_decay=weight_decay, # Added | |
| dropout=dropout, | |
| window_size=window_size, | |
| test_split=test_split | |
| ) | |
| forecast_plot = plot_forecast(result) | |
| future_plot = plot_future_forecast(df, result) | |
| r2_plot = plot_metrics_r2(result) # Updated | |
| error_plot = plot_metrics_errors(result) # Updated | |
| loss_plot = plot_loss_curve(result) | |
| msg = "โ Done." | |
| if "latest_prediction" in result: | |
| last_date = df['Date'].iloc[-1] | |
| future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=horizon, freq='B') | |
| future_df = pd.DataFrame({'Date': future_dates, 'Predicted Value': result["latest_prediction"]}) | |
| msg += f" Next predicted value(s): {[f'{val:.2f}' for val in result['latest_prediction']]}" | |
| else: | |
| future_df = pd.DataFrame() | |
| return forecast_plot, future_plot, future_df, r2_plot, error_plot, loss_plot, msg | |
| except Exception as e: | |
| return None, None, None, None, None, None, f"โ Error: {str(e)}" | |
| run_btn.click( | |
| fn=run_pipeline, | |
| inputs=[ | |
| data_src, ticker, file_upload, | |
| start_date, end_date, horizon, model, | |
| hidden_units, n_layers, epochs, learning_rate, | |
| beta1, beta2, weight_decay, dropout, window_size, test_split | |
| ], | |
| outputs=[backtest_plot, future_plot, future_table, r2_plot, error_plot, loss_plot, status] | |
| ) | |
| def toggle_file(src): | |
| return gr.update(visible=(src == "Upload CSV")) | |
| data_src.change(fn=toggle_file, inputs=[data_src], outputs=[file_upload]) | |
| return app | |
| if __name__ == '__main__': | |
| main_interface().launch() |