Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import joblib | |
| from sklearn.metrics import mean_absolute_error, mean_squared_error | |
| from math import sqrt | |
| # Step 1: Load the Dataset | |
| print("Loading Dataset...") | |
| data_file = "webtraffic.csv" | |
| try: | |
| webtraffic_data = pd.read_csv(data_file) | |
| print("Dataset loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| exit() | |
| # Step 2: Ensure 'Datetime' column exists or create it | |
| if "Datetime" not in webtraffic_data.columns: | |
| print("Datetime column missing. Creating from 'Hour Index'.") | |
| start_date = pd.Timestamp("2024-01-01 00:00:00") | |
| webtraffic_data["Datetime"] = start_date + pd.to_timedelta(webtraffic_data["Hour Index"], unit="h") | |
| else: | |
| webtraffic_data["Datetime"] = pd.to_datetime(webtraffic_data["Datetime"]) | |
| webtraffic_data.sort_values("Datetime", inplace=True) | |
| # Step 3: Load SARIMA Model | |
| print("Loading SARIMA Model...") | |
| try: | |
| sarima_model = joblib.load("sarima_model.pkl") | |
| print("SARIMA model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading SARIMA model: {e}") | |
| exit() | |
| # Step 4: Define Functions for Gradio Dashboard | |
| future_periods = 48 # Number of hours to predict | |
| def generate_sarima_plot(): | |
| """Generate SARIMA predictions and return a detailed plot with metrics.""" | |
| try: | |
| # Generate future dates for predictions | |
| future_dates = pd.date_range( | |
| start=webtraffic_data["Datetime"].iloc[-1], | |
| periods=future_periods + 1, | |
| freq="H" | |
| )[1:] | |
| # Generate SARIMA predictions | |
| sarima_predictions = sarima_model.predict(n_periods=future_periods) | |
| # Extract actual data for the last 'future_periods' hours | |
| actual_sessions = webtraffic_data["Sessions"].iloc[-future_periods:].values | |
| # Calculate metrics | |
| mae_sarima = mean_absolute_error(actual_sessions, sarima_predictions[:len(actual_sessions)]) | |
| rmse_sarima = sqrt(mean_squared_error(actual_sessions, sarima_predictions[:len(actual_sessions)])) | |
| # Combine predictions into a DataFrame for plotting | |
| future_predictions = pd.DataFrame({ | |
| "Datetime": future_dates, | |
| "SARIMA_Predicted": sarima_predictions | |
| }) | |
| # Plot Actual Traffic vs SARIMA Predictions | |
| plt.figure(figsize=(15, 6)) | |
| plt.plot( | |
| webtraffic_data["Datetime"], | |
| webtraffic_data["Sessions"], | |
| label="Actual Traffic", | |
| color="black", | |
| linestyle="dotted", | |
| linewidth=2, | |
| ) | |
| plt.plot( | |
| future_predictions["Datetime"], | |
| future_predictions["SARIMA_Predicted"], | |
| label="SARIMA Predicted", | |
| color="blue", | |
| linewidth=2, | |
| ) | |
| plt.title("SARIMA Predictions vs Actual Traffic", fontsize=16) | |
| plt.xlabel("Datetime", fontsize=12) | |
| plt.ylabel("Sessions", fontsize=12) | |
| plt.legend(loc="upper left") | |
| plt.grid(True) | |
| plt.tight_layout() | |
| # Save the plot | |
| plot_path = "sarima_prediction_plot.png" | |
| plt.savefig(plot_path) | |
| plt.close() | |
| # Return plot path and metrics | |
| metrics = f""" | |
| SARIMA Model Metrics: | |
| - Mean Absolute Error (MAE): {mae_sarima:.2f} | |
| - Root Mean Squared Error (RMSE): {rmse_sarima:.2f} | |
| """ | |
| return plot_path, metrics | |
| except Exception as e: | |
| print(f"Error generating SARIMA plot: {e}") | |
| return None, "Error in generating output. Please check the data and model." | |
| def generate_zoomed_plot(): | |
| """Generate a zoomed-in SARIMA prediction plot.""" | |
| try: | |
| # Generate future dates for predictions | |
| future_dates = pd.date_range( | |
| start=webtraffic_data["Datetime"].iloc[-1], | |
| periods=future_periods + 1, | |
| freq="H" | |
| )[1:] | |
| # Generate SARIMA predictions | |
| sarima_predictions = sarima_model.predict(n_periods=future_periods) | |
| # Combine predictions into a DataFrame for plotting | |
| future_predictions = pd.DataFrame({ | |
| "Datetime": future_dates, | |
| "SARIMA_Predicted": sarima_predictions | |
| }) | |
| # Zoomed-in view of the plot (recent data only) | |
| plt.figure(figsize=(15, 6)) | |
| plt.plot( | |
| webtraffic_data["Datetime"].iloc[-future_periods:], | |
| webtraffic_data["Sessions"].iloc[-future_periods:], | |
| label="Actual Traffic (Zoomed)", | |
| color="black", | |
| linestyle="dotted", | |
| linewidth=2, | |
| ) | |
| plt.plot( | |
| future_predictions["Datetime"], | |
| future_predictions["SARIMA_Predicted"], | |
| label="SARIMA Predicted (Zoomed)", | |
| color="green", | |
| linewidth=2, | |
| ) | |
| plt.title("Zoomed-In SARIMA Predictions vs Actual Traffic", fontsize=16) | |
| plt.xlabel("Datetime", fontsize=12) | |
| plt.ylabel("Sessions", fontsize=12) | |
| plt.legend(loc="upper left") | |
| plt.grid(True) | |
| plt.tight_layout() | |
| # Save the zoomed plot | |
| zoomed_plot_path = "sarima_zoomed_plot.png" | |
| plt.savefig(zoomed_plot_path) | |
| plt.close() | |
| return zoomed_plot_path | |
| except Exception as e: | |
| print(f"Error generating zoomed plot: {e}") | |
| return None | |
| # Step 5: Gradio Dashboard with Two Tiles and Metrics | |
| with gr.Blocks() as dashboard: | |
| gr.Markdown("## Enhanced SARIMA Web Traffic Prediction Dashboard") | |
| gr.Markdown("This dashboard includes SARIMA predictions, performance metrics, and a zoomed-in view of recent data.") | |
| # Outputs: Main Plot and Metrics | |
| plot_output = gr.Image(label="SARIMA Prediction Plot") | |
| metrics_output = gr.Textbox(label="Model Metrics", lines=6) | |
| # Outputs: Zoomed Plot | |
| zoomed_plot_output = gr.Image(label="Zoomed-In Prediction Plot") | |
| # Button to Generate Results | |
| gr.Button("Generate Predictions").click( | |
| fn=generate_sarima_plot, | |
| inputs=[], | |
| outputs=[plot_output, metrics_output], | |
| ) | |
| gr.Button("Generate Zoomed-In Plot").click( | |
| fn=generate_zoomed_plot, | |
| inputs=[], | |
| outputs=[zoomed_plot_output], | |
| ) | |
| # Launch the Gradio Dashboard | |
| if __name__ == "__main__": | |
| print("\nLaunching Enhanced Gradio Dashboard...") | |
| dashboard.launch() | |