Prajwal3009 commited on
Commit
e5cf81c
·
verified ·
1 Parent(s): 55a0897

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -22
app.py CHANGED
@@ -55,13 +55,32 @@ def calculate_metrics(actual, predicted):
55
 
56
  # Function to plot ACF and PACF
57
  def plot_acf_pacf(series):
58
- fig, ax = plt.subplots(1, 2, figsize=(12, 6))
59
- plot_acf(series, lags=20, ax=ax[0])
60
- plot_pacf(series, lags=20, ax=ax[1])
61
- ax[0].set_title('Auto-Correlation Function (ACF)')
62
- ax[1].set_title('Partial Auto-Correlation Function (PACF)')
63
- st.pyplot(fig)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # Forecasting button
66
  if st.button("Forecast"):
67
  # Prepare the historical data for comparison
@@ -149,20 +168,7 @@ if st.button("Forecast"):
149
  ax.set_ylabel('Sales Quantity')
150
  ax.legend()
151
  st.pyplot(fig)
152
- import matplotlib.pyplot as plt
153
- import numpy as np
154
- import pandas as pd
155
- import pickle
156
- import os
157
- import streamlit as st
158
- from sklearn.metrics import mean_absolute_error, mean_squared_error
159
- from statsmodels.tsa.arima.model import ARIMA
160
 
161
- # Function to calculate evaluation metrics
162
- def calculate_metrics(actual, predicted):
163
- mae = mean_absolute_error(actual, predicted)
164
- rmse = np.sqrt(mean_squared_error(actual, predicted))
165
- return mae, rmse
166
 
167
  if st.button("Compare Models"):
168
  # Load and forecast with ARIMA
@@ -204,7 +210,7 @@ if st.button("Compare Models"):
204
  axs[0, 0].legend()
205
 
206
  # Calculate and display ARIMA metrics
207
- arima_mae, arima_rmse = calculate_metrics(product_sales_series[-15:], st.session_state.arima_forecast)
208
  st.write(f'**ARIMA Model Metrics:**\n- MAE: {arima_mae:.2f}\n- RMSE: {arima_rmse:.2f}')
209
 
210
  # Plot Decision Tree forecast
@@ -216,7 +222,7 @@ if st.button("Compare Models"):
216
  axs[0, 1].legend()
217
 
218
  # Calculate and display Decision Tree metrics
219
- dt_mae, dt_rmse = calculate_metrics(product_sales_series[-15:], st.session_state.dt_forecast)
220
  st.write(f'**Decision Tree Model Metrics:**\n- MAE: {dt_mae:.2f}\n- RMSE: {dt_rmse:.2f}')
221
 
222
  # Plot XGBoost forecast
@@ -228,7 +234,7 @@ if st.button("Compare Models"):
228
  axs[1, 0].legend()
229
 
230
  # Calculate and display XGBoost metrics
231
- xgb_mae, xgb_rmse = calculate_metrics(product_sales_series[-15:], st.session_state.xgb_forecast)
232
  st.write(f'**XGBoost Model Metrics:**\n- MAE: {xgb_mae:.2f}\n- RMSE: {xgb_rmse:.2f}')
233
 
234
  # Hide the last subplot (bottom right) if not needed
 
55
 
56
  # Function to plot ACF and PACF
57
  def plot_acf_pacf(series):
58
+ try:
59
+ # Ensure the series has enough data points for the desired number of lags
60
+ if len(series) < 20:
61
+ st.write("The series is too short to generate ACF and PACF plots with 20 lags.")
62
+ return # Skip plotting
 
63
 
64
+ fig, ax = plt.subplots(1, 2, figsize=(12, 6))
65
+
66
+ # Plot ACF and PACF with error handling for shape mismatch
67
+ plot_acf(series, lags=20, ax=ax[0])
68
+ plot_pacf(series, lags=20, ax=ax[1])
69
+
70
+ # Set titles for the plots
71
+ ax[0].set_title('Auto-Correlation Function (ACF)')
72
+ ax[1].set_title('Partial Auto-Correlation Function (PACF)')
73
+
74
+ # Display the plots using Streamlit
75
+ st.pyplot(fig)
76
+
77
+ except ValueError as e:
78
+ st.error(f"An error occurred while plotting ACF and PACF: {e}")
79
+
80
+ def calculate_metrics2(actual, predicted):
81
+ mae = mean_absolute_error(actual, predicted)
82
+ rmse = np.sqrt(mean_squared_error(actual, predicted))
83
+ return mae, rmse
84
  # Forecasting button
85
  if st.button("Forecast"):
86
  # Prepare the historical data for comparison
 
168
  ax.set_ylabel('Sales Quantity')
169
  ax.legend()
170
  st.pyplot(fig)
 
 
 
 
 
 
 
 
171
 
 
 
 
 
 
172
 
173
  if st.button("Compare Models"):
174
  # Load and forecast with ARIMA
 
210
  axs[0, 0].legend()
211
 
212
  # Calculate and display ARIMA metrics
213
+ arima_mae, arima_rmse = calculate_metrics2(product_sales_series[-15:], st.session_state.arima_forecast)
214
  st.write(f'**ARIMA Model Metrics:**\n- MAE: {arima_mae:.2f}\n- RMSE: {arima_rmse:.2f}')
215
 
216
  # Plot Decision Tree forecast
 
222
  axs[0, 1].legend()
223
 
224
  # Calculate and display Decision Tree metrics
225
+ dt_mae, dt_rmse = calculate_metrics2(product_sales_series[-15:], st.session_state.dt_forecast)
226
  st.write(f'**Decision Tree Model Metrics:**\n- MAE: {dt_mae:.2f}\n- RMSE: {dt_rmse:.2f}')
227
 
228
  # Plot XGBoost forecast
 
234
  axs[1, 0].legend()
235
 
236
  # Calculate and display XGBoost metrics
237
+ xgb_mae, xgb_rmse = calculate_metrics2(product_sales_series[-15:], st.session_state.xgb_forecast)
238
  st.write(f'**XGBoost Model Metrics:**\n- MAE: {xgb_mae:.2f}\n- RMSE: {xgb_rmse:.2f}')
239
 
240
  # Hide the last subplot (bottom right) if not needed