Spaces:
Sleeping
Sleeping
Commit ·
1867a74
1
Parent(s): a32c351
Fix SARIMA and LSTM deployment issues
Browse files
app.py
CHANGED
|
@@ -9,46 +9,41 @@ from sklearn.metrics import mean_absolute_error, mean_squared_error
|
|
| 9 |
# Load the dataset
|
| 10 |
webtraffic_data = pd.read_csv("webtraffic.csv")
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
webtraffic_data['Datetime'] = pd.date_range(start='2023-01-01', periods=len(webtraffic_data), freq='H')
|
| 17 |
-
|
| 18 |
-
# Split the data into train/test for evaluation
|
| 19 |
train_size = int(len(webtraffic_data) * 0.8)
|
| 20 |
-
test_size = len(webtraffic_data) - train_size
|
| 21 |
train_data = webtraffic_data.iloc[:train_size]
|
| 22 |
test_data = webtraffic_data.iloc[train_size:]
|
| 23 |
|
| 24 |
-
# Load
|
| 25 |
sarima_model = joblib.load("sarima_model.pkl") # SARIMA model
|
| 26 |
lstm_model = tf.keras.models.load_model("lstm_model.keras") # LSTM model
|
| 27 |
|
| 28 |
-
# Initialize
|
| 29 |
-
future_periods = len(test_data)
|
| 30 |
-
|
| 31 |
-
# Generate predictions for SARIMA
|
| 32 |
-
sarima_predictions = sarima_model.forecast(steps=future_periods)
|
| 33 |
-
|
| 34 |
-
# Prepare data for LSTM predictions
|
| 35 |
from sklearn.preprocessing import MinMaxScaler
|
| 36 |
|
| 37 |
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
| 38 |
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
| 39 |
|
| 40 |
-
# Fit
|
| 41 |
X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
|
| 42 |
y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
|
| 43 |
|
| 44 |
-
# Scale test data
|
| 45 |
X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1))
|
| 46 |
y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
|
| 47 |
|
| 48 |
-
# Reshape data for LSTM
|
| 49 |
X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, X_test_scaled.shape[1]))
|
| 50 |
|
| 51 |
-
#
|
|
|
|
|
|
|
|
|
|
| 52 |
lstm_predictions_scaled = lstm_model.predict(X_test_lstm)
|
| 53 |
lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled).flatten()
|
| 54 |
|
|
@@ -59,19 +54,18 @@ future_predictions = pd.DataFrame({
|
|
| 59 |
"LSTM_Predicted": lstm_predictions
|
| 60 |
})
|
| 61 |
|
| 62 |
-
# Calculate metrics
|
| 63 |
-
|
| 64 |
-
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
| 69 |
-
# Function to generate
|
| 70 |
def generate_plot(model):
|
| 71 |
"""Generate plot based on the selected model."""
|
| 72 |
plt.figure(figsize=(15, 6))
|
| 73 |
-
|
| 74 |
-
plt.plot(actual_dates, test_data['Sessions'], label='Actual Traffic', color='black', linestyle='dotted', linewidth=2)
|
| 75 |
|
| 76 |
if model == "SARIMA":
|
| 77 |
plt.plot(future_predictions['Datetime'], future_predictions['SARIMA_Predicted'], label='SARIMA Predicted', color='blue', linewidth=2)
|
|
@@ -89,41 +83,41 @@ def generate_plot(model):
|
|
| 89 |
plt.close()
|
| 90 |
return plot_path
|
| 91 |
|
| 92 |
-
# Function to display metrics
|
| 93 |
def display_metrics():
|
| 94 |
-
"""Generate
|
| 95 |
metrics = {
|
| 96 |
"Model": ["SARIMA", "LSTM"],
|
| 97 |
-
"Mean Absolute Error (MAE)": [
|
| 98 |
-
"Root Mean Squared Error (RMSE)": [
|
| 99 |
}
|
| 100 |
return pd.DataFrame(metrics)
|
| 101 |
|
| 102 |
# Gradio interface function
|
| 103 |
def dashboard_interface(model="SARIMA"):
|
| 104 |
"""Generate plot and metrics for the selected model."""
|
| 105 |
-
plot_path = generate_plot(model)
|
| 106 |
-
metrics_df = display_metrics()
|
| 107 |
return plot_path, metrics_df.to_string()
|
| 108 |
|
| 109 |
-
# Build the Gradio
|
| 110 |
with gr.Blocks() as dashboard:
|
| 111 |
-
gr.Markdown("##
|
| 112 |
-
gr.Markdown("
|
| 113 |
|
| 114 |
# Dropdown for model selection
|
| 115 |
model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA")
|
| 116 |
|
| 117 |
# Outputs: Plot and Metrics
|
| 118 |
plot_output = gr.Image(label="Prediction Plot")
|
| 119 |
-
metrics_output = gr.Textbox(label="Metrics", lines=
|
| 120 |
|
| 121 |
# Button to update dashboard
|
| 122 |
gr.Button("Update Dashboard").click(
|
| 123 |
-
fn=dashboard_interface,
|
| 124 |
-
inputs=[model_selection],
|
| 125 |
-
outputs=[plot_output, metrics_output]
|
| 126 |
)
|
| 127 |
|
| 128 |
-
# Launch the
|
| 129 |
dashboard.launch()
|
|
|
|
| 9 |
# Load the dataset
|
| 10 |
webtraffic_data = pd.read_csv("webtraffic.csv")
|
| 11 |
|
| 12 |
+
# Convert 'Hour Index' to datetime
|
| 13 |
+
start_date = pd.Timestamp("2024-01-01 00:00:00")
|
| 14 |
+
webtraffic_data['Datetime'] = start_date + pd.to_timedelta(webtraffic_data['Hour Index'], unit='h')
|
| 15 |
+
webtraffic_data.drop(columns=['Hour Index'], inplace=True)
|
| 16 |
|
| 17 |
+
# Split the data into train/test
|
|
|
|
|
|
|
|
|
|
| 18 |
train_size = int(len(webtraffic_data) * 0.8)
|
|
|
|
| 19 |
train_data = webtraffic_data.iloc[:train_size]
|
| 20 |
test_data = webtraffic_data.iloc[train_size:]
|
| 21 |
|
| 22 |
+
# Load pre-trained models
|
| 23 |
sarima_model = joblib.load("sarima_model.pkl") # SARIMA model
|
| 24 |
lstm_model = tf.keras.models.load_model("lstm_model.keras") # LSTM model
|
| 25 |
|
| 26 |
+
# Initialize scalers and scale the data for LSTM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
from sklearn.preprocessing import MinMaxScaler
|
| 28 |
|
| 29 |
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
| 30 |
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
| 31 |
|
| 32 |
+
# Fit scalers on the training data
|
| 33 |
X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
|
| 34 |
y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
|
| 35 |
|
| 36 |
+
# Scale the test data
|
| 37 |
X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1))
|
| 38 |
y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
|
| 39 |
|
| 40 |
+
# Reshape test data for LSTM
|
| 41 |
X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, X_test_scaled.shape[1]))
|
| 42 |
|
| 43 |
+
# Generate predictions for SARIMA
|
| 44 |
+
sarima_predictions = sarima_model.predict(start=len(train_data), end=len(webtraffic_data) - 1)
|
| 45 |
+
|
| 46 |
+
# Generate predictions for LSTM
|
| 47 |
lstm_predictions_scaled = lstm_model.predict(X_test_lstm)
|
| 48 |
lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled).flatten()
|
| 49 |
|
|
|
|
| 54 |
"LSTM_Predicted": lstm_predictions
|
| 55 |
})
|
| 56 |
|
| 57 |
+
# Calculate metrics
|
| 58 |
+
mae_sarima = mean_absolute_error(test_data['Sessions'], sarima_predictions)
|
| 59 |
+
rmse_sarima = mean_squared_error(test_data['Sessions'], sarima_predictions, squared=False)
|
| 60 |
|
| 61 |
+
mae_lstm = mean_absolute_error(test_data['Sessions'], lstm_predictions)
|
| 62 |
+
rmse_lstm = mean_squared_error(test_data['Sessions'], lstm_predictions, squared=False)
|
| 63 |
|
| 64 |
+
# Function to generate plots
|
| 65 |
def generate_plot(model):
|
| 66 |
"""Generate plot based on the selected model."""
|
| 67 |
plt.figure(figsize=(15, 6))
|
| 68 |
+
plt.plot(test_data['Datetime'], test_data['Sessions'], label='Actual Traffic', color='black', linestyle='dotted', linewidth=2)
|
|
|
|
| 69 |
|
| 70 |
if model == "SARIMA":
|
| 71 |
plt.plot(future_predictions['Datetime'], future_predictions['SARIMA_Predicted'], label='SARIMA Predicted', color='blue', linewidth=2)
|
|
|
|
| 83 |
plt.close()
|
| 84 |
return plot_path
|
| 85 |
|
| 86 |
+
# Function to display metrics
|
| 87 |
def display_metrics():
|
| 88 |
+
"""Generate metrics for both models."""
|
| 89 |
metrics = {
|
| 90 |
"Model": ["SARIMA", "LSTM"],
|
| 91 |
+
"Mean Absolute Error (MAE)": [mae_sarima, mae_lstm],
|
| 92 |
+
"Root Mean Squared Error (RMSE)": [rmse_sarima, rmse_lstm]
|
| 93 |
}
|
| 94 |
return pd.DataFrame(metrics)
|
| 95 |
|
| 96 |
# Gradio interface function
|
| 97 |
def dashboard_interface(model="SARIMA"):
|
| 98 |
"""Generate plot and metrics for the selected model."""
|
| 99 |
+
plot_path = generate_plot(model)
|
| 100 |
+
metrics_df = display_metrics()
|
| 101 |
return plot_path, metrics_df.to_string()
|
| 102 |
|
| 103 |
+
# Build the Gradio dashboard
|
| 104 |
with gr.Blocks() as dashboard:
|
| 105 |
+
gr.Markdown("## Web Traffic Prediction Dashboard")
|
| 106 |
+
gr.Markdown("Select a model to view its predictions and performance metrics.")
|
| 107 |
|
| 108 |
# Dropdown for model selection
|
| 109 |
model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA")
|
| 110 |
|
| 111 |
# Outputs: Plot and Metrics
|
| 112 |
plot_output = gr.Image(label="Prediction Plot")
|
| 113 |
+
metrics_output = gr.Textbox(label="Metrics", lines=10)
|
| 114 |
|
| 115 |
# Button to update dashboard
|
| 116 |
gr.Button("Update Dashboard").click(
|
| 117 |
+
fn=dashboard_interface,
|
| 118 |
+
inputs=[model_selection],
|
| 119 |
+
outputs=[plot_output, metrics_output]
|
| 120 |
)
|
| 121 |
|
| 122 |
+
# Launch the dashboard
|
| 123 |
dashboard.launch()
|