aromidvar1355 commited on
Commit
5141802
Β·
verified Β·
1 Parent(s): 1737f2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -78
app.py CHANGED
@@ -1,82 +1,106 @@
1
  import gradio as gr
2
  import pandas as pd
3
- import numpy as np
4
- from core.data import preprocess_data
5
- from core.trainer import train_model
6
  from core.plot import plot_forecast, plot_metrics
7
- from core.utils import plot_loss_curve
8
- from core.model_runner import get_model
9
-
10
- def forecast_pipeline(data, model_type, window_size, forecast_steps, epochs):
11
- try:
12
- df = pd.read_csv(data.name) if hasattr(data, 'name') else pd.read_csv("sample_data.csv")
13
- df = df.dropna()
14
-
15
- X_train, y_train, X_test, y_test, scaler = preprocess_data(df, window_size, forecast_steps)
16
-
17
- model = get_model(model_type, input_shape=X_train.shape[1:])
18
- history, predictions = train_model(model, X_train, y_train, X_test, y_test, epochs)
19
-
20
- # Inverse transform predictions and actuals
21
- predicted = scaler.inverse_transform(predictions)
22
- actual = scaler.inverse_transform(y_test)
23
-
24
- # Metrics
25
- mse = np.mean((predicted - actual)**2)
26
- mae = np.mean(np.abs(predicted - actual))
27
- r2 = 1 - (np.sum((predicted - actual) ** 2) / np.sum((actual - np.mean(actual)) ** 2))
28
-
29
- result = {
30
- "actual": actual.flatten(),
31
- "predicted": predicted.flatten(),
32
- "loss": history.history['loss'],
33
- "val_loss": history.history.get('val_loss', []),
34
- "metrics": {
35
- "MSE": mse,
36
- "MAE": mae,
37
- "RΒ²": r2
38
- }
39
- }
40
-
41
- return (
42
- plot_forecast(result),
43
- plot_loss_curve(result),
44
- plot_metrics(result),
45
- result["metrics"]
46
- )
47
- except Exception as e:
48
- return f"Error: {str(e)}", None, None, None
49
-
50
- # Gradio interface
51
- with gr.Blocks(title="πŸ“ˆ MarketPredict: Forecasting Dashboard") as demo:
52
- gr.Markdown("# 🧠 MarketPredict: Advanced Time Series Forecasting")
53
- gr.Markdown("Upload a CSV with time-series data or use the sample. Select model and parameters to begin training and see predictions.")
54
-
55
- with gr.Row():
56
- with gr.Column():
57
- file_input = gr.File(label="Upload CSV", file_types=[".csv"])
58
- model_choice = gr.Dropdown(
59
- choices=["LSTM", "GRU", "CNN", "Transformer", "MLP"],
60
- label="Select Model",
61
- value="LSTM"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
- window_slider = gr.Slider(5, 60, value=30, step=5, label="Window Size")
64
- forecast_slider = gr.Slider(1, 20, value=5, step=1, label="Forecast Steps")
65
- epoch_slider = gr.Slider(5, 200, value=50, step=5, label="Epochs")
66
-
67
- run_btn = gr.Button("πŸš€ Run Forecast")
68
-
69
- with gr.Column():
70
- forecast_plot = gr.Plot(label="πŸ“‰ Forecast vs Actual")
71
- loss_plot = gr.Plot(label="πŸ“‰ Training Loss Curve")
72
- metrics_plot = gr.Plot(label="πŸ“Š Error Metrics")
73
- metrics_out = gr.JSON(label="πŸ“Œ Evaluation Metrics")
74
-
75
- run_btn.click(
76
- fn=forecast_pipeline,
77
- inputs=[file_input, model_choice, window_slider, forecast_slider, epoch_slider],
78
- outputs=[forecast_plot, loss_plot, metrics_plot, metrics_out]
79
- )
80
-
81
- if __name__ == "__main__":
82
- demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
+ from core.data import get_data
4
+ from core.model_runner import run_model
 
5
  from core.plot import plot_forecast, plot_metrics
6
+ from config import AVAILABLE_MODELS, DEFAULT_TICKERS, DEFAULT_PARAMS
7
+
8
+
9
+ def main_interface():
10
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
11
+ gr.Markdown("# πŸ“ˆ AI Forecasting Studio")
12
+ gr.Markdown("Upload time series data or fetch from Yahoo Finance and forecast with deep learning models (LSTM, GRU, Transformer, etc).")
13
+
14
+ with gr.Row():
15
+ with gr.Column(scale=1):
16
+ data_src = gr.Radio(["Yahoo Finance", "Upload CSV"], label="Data Source", value="Yahoo Finance")
17
+
18
+ ticker = gr.Dropdown(choices=DEFAULT_TICKERS, label="Ticker Symbol", value="BTC-USD", visible=True)
19
+ file_upload = gr.File(label="Upload CSV", visible=False, file_types=[".csv"])
20
+
21
+ start_date = gr.Textbox(label="Start Date (YYYY-MM-DD)", value="2022-01-01")
22
+ end_date = gr.Textbox(label="End Date (YYYY-MM-DD)", value="2023-12-31")
23
+ horizon = gr.Slider(minimum=10, maximum=200, step=1, label="Forecast Horizon (Days)", value=30)
24
+
25
+ gr.Markdown("## βš™οΈ Model Configuration")
26
+
27
+ model = gr.Dropdown(choices=AVAILABLE_MODELS, label="Model", value="LSTM")
28
+ hidden_units = gr.Slider(minimum=8, maximum=512, step=8, label="Hidden Units", value=64)
29
+ n_layers = gr.Slider(minimum=1, maximum=5, step=1, label="Number of Layers", value=2)
30
+ epochs = gr.Slider(minimum=10, maximum=300, step=10, label="Epochs", value=100)
31
+ learning_rate = gr.Slider(minimum=1e-5, maximum=0.01, step=1e-5, label="Learning Rate", value=0.001)
32
+ window_size = gr.Slider(minimum=5, maximum=90, step=1, label="Window Size", value=30)
33
+ test_split = gr.Slider(minimum=0.05, maximum=0.5, step=0.01, label="Test Split (Fraction)", value=0.2)
34
+
35
+ run_btn = gr.Button("πŸš€ Train & Predict")
36
+ status = gr.Textbox(label="Status", lines=3, interactive=False)
37
+
38
+ with gr.Column(scale=2):
39
+ forecast_plot = gr.Plot(label="πŸ“Š Forecast Results")
40
+ error_plot = gr.Plot(label="πŸ“‰ Backtest / Error Analysis")
41
+
42
+ # === Backend Logic ===
43
+
44
+ def run_pipeline(data_src, ticker, file_upload, start_date, end_date, horizon,
45
+ model, hidden_units, n_layers, epochs, learning_rate,
46
+ window_size, test_split):
47
+
48
+ try:
49
+ start_dt = pd.to_datetime(start_date)
50
+ end_dt = pd.to_datetime(end_date)
51
+
52
+ if start_dt >= end_dt:
53
+ return None, None, "❌ End date must be after start date."
54
+
55
+ df = get_data(data_src, ticker, file_upload, start_date, end_date)
56
+ if df is None or df.empty:
57
+ return None, None, "❌ Failed to load or parse dataset."
58
+
59
+ result = run_model(
60
+ df=df,
61
+ model_name=model,
62
+ horizon=horizon,
63
+ hidden_units=hidden_units,
64
+ n_layers=n_layers,
65
+ epochs=epochs,
66
+ learning_rate=learning_rate,
67
+ window_size=window_size,
68
+ test_split=test_split
69
+ )
70
+
71
+ forecast_fig = plot_forecast(result)
72
+ error_fig = plot_metrics(result)
73
+
74
+ return forecast_fig, error_fig, "βœ… Forecast complete!"
75
+ except Exception as e:
76
+ return None, None, f"❌ Error occurred: {str(e)}"
77
+
78
+ def toggle_data_input(src):
79
+ return (
80
+ gr.update(visible=(src == "Yahoo Finance")),
81
+ gr.update(visible=(src == "Upload CSV"))
82
  )
83
+
84
+ # === Event Bindings ===
85
+ run_btn.click(
86
+ fn=run_pipeline,
87
+ inputs=[
88
+ data_src, ticker, file_upload,
89
+ start_date, end_date, horizon,
90
+ model, hidden_units, n_layers, epochs,
91
+ learning_rate, window_size, test_split
92
+ ],
93
+ outputs=[forecast_plot, error_plot, status]
94
+ )
95
+
96
+ data_src.change(
97
+ fn=toggle_data_input,
98
+ inputs=[data_src],
99
+ outputs=[ticker, file_upload]
100
+ )
101
+
102
+ return app
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main_interface().launch()