|
|
"""
|
|
|
Demand Prediction System - Streamlit Dashboard
|
|
|
|
|
|
Interactive dashboard for visualizing sales trends and making demand predictions.
|
|
|
"""
|
|
|
|
|
|
import streamlit as st
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
import seaborn as sns
|
|
|
import joblib
|
|
|
import json
|
|
|
from datetime import datetime, timedelta, date as dt_date
|
|
|
import os
|
|
|
import warnings
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
|
|
|
st.set_page_config(
|
|
|
page_title="Demand Prediction Dashboard",
|
|
|
page_icon="📊",
|
|
|
layout="wide",
|
|
|
initial_sidebar_state="expanded"
|
|
|
)
|
|
|
|
|
|
|
|
|
st.markdown("""
|
|
|
<style>
|
|
|
.main-header {
|
|
|
font-size: 2.5rem;
|
|
|
font-weight: bold;
|
|
|
color: #1f77b4;
|
|
|
text-align: center;
|
|
|
margin-bottom: 2rem;
|
|
|
}
|
|
|
.metric-card {
|
|
|
background-color: #f0f2f6;
|
|
|
padding: 1rem;
|
|
|
border-radius: 0.5rem;
|
|
|
margin: 0.5rem 0;
|
|
|
}
|
|
|
.stButton>button {
|
|
|
width: 100%;
|
|
|
background-color: #1f77b4;
|
|
|
color: white;
|
|
|
}
|
|
|
</style>
|
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
DATA_PATH = 'data/sales.csv'
|
|
|
MODEL_DIR = 'models'
|
|
|
MODEL_PATH = f'{MODEL_DIR}/best_model.joblib'
|
|
|
TS_MODEL_PATH = f'{MODEL_DIR}/best_timeseries_model.joblib'
|
|
|
PREPROCESSING_PATH = f'{MODEL_DIR}/preprocessing.joblib'
|
|
|
ALL_MODELS_METADATA_PATH = f'{MODEL_DIR}/all_models_metadata.json'
|
|
|
|
|
|
|
|
|
@st.cache_data
|
|
|
def load_data():
|
|
|
"""Load sales data with caching."""
|
|
|
if os.path.exists(DATA_PATH):
|
|
|
df = pd.read_csv(DATA_PATH)
|
|
|
df['date'] = pd.to_datetime(df['date'])
|
|
|
return df
|
|
|
return None
|
|
|
|
|
|
|
|
|
@st.cache_resource
|
|
|
def load_models():
|
|
|
"""Load trained models with caching."""
|
|
|
models = {
|
|
|
'ml_model': None,
|
|
|
'ts_model': None,
|
|
|
'preprocessing': None,
|
|
|
'model_name': None,
|
|
|
'is_timeseries': False,
|
|
|
'metadata': None
|
|
|
}
|
|
|
|
|
|
|
|
|
if os.path.exists(ALL_MODELS_METADATA_PATH):
|
|
|
with open(ALL_MODELS_METADATA_PATH, 'r') as f:
|
|
|
models['metadata'] = json.load(f)
|
|
|
models['model_name'] = models['metadata'].get('best_model', 'Unknown')
|
|
|
models['is_timeseries'] = models['model_name'] in ['ARIMA', 'Prophet']
|
|
|
|
|
|
|
|
|
if os.path.exists(MODEL_PATH):
|
|
|
try:
|
|
|
models['ml_model'] = joblib.load(MODEL_PATH)
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
|
|
|
if os.path.exists(TS_MODEL_PATH):
|
|
|
try:
|
|
|
models['ts_model'] = joblib.load(TS_MODEL_PATH)
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
|
|
|
if os.path.exists(PREPROCESSING_PATH):
|
|
|
try:
|
|
|
models['preprocessing'] = joblib.load(PREPROCESSING_PATH)
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
return models
|
|
|
|
|
|
|
|
|
def prepare_features_ml(product_id, date, price, discount, category, preprocessing_data):
|
|
|
"""Prepare features for ML model prediction."""
|
|
|
if preprocessing_data is None:
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(date, dt_date):
|
|
|
date = pd.Timestamp(date)
|
|
|
elif not isinstance(date, pd.Timestamp):
|
|
|
date = pd.to_datetime(date)
|
|
|
|
|
|
|
|
|
day = date.day
|
|
|
month = date.month
|
|
|
day_of_week = date.weekday()
|
|
|
weekend = 1 if day_of_week >= 5 else 0
|
|
|
year = date.year
|
|
|
quarter = date.quarter
|
|
|
|
|
|
|
|
|
category_encoder = preprocessing_data['encoders']['category']
|
|
|
product_encoder = preprocessing_data['encoders']['product_id']
|
|
|
|
|
|
try:
|
|
|
category_encoded = category_encoder.transform([category])[0]
|
|
|
except ValueError:
|
|
|
category_encoded = 0
|
|
|
|
|
|
try:
|
|
|
product_id_encoded = product_encoder.transform([product_id])[0]
|
|
|
except ValueError:
|
|
|
product_id_encoded = product_encoder.transform([product_encoder.classes_[0]])[0]
|
|
|
|
|
|
|
|
|
feature_dict = {
|
|
|
'price': price,
|
|
|
'discount': discount,
|
|
|
'day': day,
|
|
|
'month': month,
|
|
|
'day_of_week': day_of_week,
|
|
|
'weekend': weekend,
|
|
|
'year': year,
|
|
|
'quarter': quarter,
|
|
|
'category_encoded': category_encoded,
|
|
|
'product_id_encoded': product_id_encoded
|
|
|
}
|
|
|
|
|
|
|
|
|
feature_names = preprocessing_data['feature_names']
|
|
|
features = np.array([[feature_dict[name] for name in feature_names]])
|
|
|
|
|
|
|
|
|
scaler = preprocessing_data['scaler']
|
|
|
features_scaled = scaler.transform(features)
|
|
|
|
|
|
return features_scaled
|
|
|
|
|
|
|
|
|
def predict_ml(product_id, date, price, discount, category, model, preprocessing_data):
|
|
|
"""Make prediction using ML model."""
|
|
|
features = prepare_features_ml(product_id, date, price, discount, category, preprocessing_data)
|
|
|
if features is None:
|
|
|
return None
|
|
|
prediction = model.predict(features)[0]
|
|
|
return max(0, prediction)
|
|
|
|
|
|
|
|
|
def predict_timeseries(date, model, model_name):
|
|
|
"""Make prediction using time-series model."""
|
|
|
|
|
|
if isinstance(date, dt_date):
|
|
|
date = pd.Timestamp(date)
|
|
|
elif not isinstance(date, pd.Timestamp):
|
|
|
date = pd.to_datetime(date)
|
|
|
|
|
|
if model_name == 'ARIMA':
|
|
|
try:
|
|
|
forecast = model.forecast(steps=1)
|
|
|
prediction = forecast[0] if hasattr(forecast, '__iter__') else forecast
|
|
|
return max(0, prediction)
|
|
|
except:
|
|
|
return None
|
|
|
|
|
|
elif model_name == 'Prophet':
|
|
|
try:
|
|
|
future = pd.DataFrame({'ds': [date]})
|
|
|
forecast = model.predict(future)
|
|
|
prediction = forecast['yhat'].iloc[0]
|
|
|
return max(0, prediction)
|
|
|
except:
|
|
|
return None
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Main dashboard function."""
|
|
|
|
|
|
|
|
|
st.markdown('<h1 class="main-header">📊 Demand Prediction Dashboard</h1>', unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
df = load_data()
|
|
|
if df is None:
|
|
|
st.error("❌ Sales data not found. Please run generate_dataset.py first.")
|
|
|
return
|
|
|
|
|
|
|
|
|
models = load_models()
|
|
|
|
|
|
|
|
|
with st.sidebar:
|
|
|
st.header("⚙️ Navigation")
|
|
|
page = st.radio(
|
|
|
"Select Page",
|
|
|
["📈 Sales Trends", "🔮 Demand Prediction", "📊 Model Comparison"],
|
|
|
index=0
|
|
|
)
|
|
|
|
|
|
st.markdown("---")
|
|
|
st.header("ℹ️ Information")
|
|
|
if models['metadata']:
|
|
|
best_model = models['metadata'].get('best_model', 'Unknown')
|
|
|
st.info(f"**Best Model:** {best_model}")
|
|
|
if best_model in models['metadata'].get('all_models', {}):
|
|
|
metrics = models['metadata']['all_models'][best_model]
|
|
|
st.metric("R2 Score", f"{metrics.get('r2', 0):.4f}")
|
|
|
|
|
|
|
|
|
if page == "📈 Sales Trends":
|
|
|
show_sales_trends(df)
|
|
|
elif page == "🔮 Demand Prediction":
|
|
|
show_prediction_interface(df, models)
|
|
|
elif page == "📊 Model Comparison":
|
|
|
show_model_comparison(models)
|
|
|
|
|
|
|
|
|
def show_sales_trends(df):
|
|
|
"""Display sales trends visualizations."""
|
|
|
st.header("📈 Sales Trends Analysis")
|
|
|
|
|
|
|
|
|
col1, col2, col3 = st.columns(3)
|
|
|
|
|
|
with col1:
|
|
|
categories = ['All'] + sorted(df['category'].unique().tolist())
|
|
|
selected_category = st.selectbox("Select Category", categories)
|
|
|
|
|
|
with col2:
|
|
|
products = ['All'] + sorted(df['product_id'].unique().tolist())
|
|
|
selected_product = st.selectbox("Select Product", products)
|
|
|
|
|
|
with col3:
|
|
|
date_range = st.date_input(
|
|
|
"Select Date Range",
|
|
|
value=(df['date'].min(), df['date'].max()),
|
|
|
min_value=df['date'].min(),
|
|
|
max_value=df['date'].max()
|
|
|
)
|
|
|
|
|
|
|
|
|
filtered_df = df.copy()
|
|
|
|
|
|
if selected_category != 'All':
|
|
|
filtered_df = filtered_df[filtered_df['category'] == selected_category]
|
|
|
|
|
|
if selected_product != 'All':
|
|
|
filtered_df = filtered_df[filtered_df['product_id'] == int(selected_product)]
|
|
|
|
|
|
if isinstance(date_range, tuple) and len(date_range) == 2:
|
|
|
filtered_df = filtered_df[
|
|
|
(filtered_df['date'] >= pd.to_datetime(date_range[0])) &
|
|
|
(filtered_df['date'] <= pd.to_datetime(date_range[1]))
|
|
|
]
|
|
|
|
|
|
if len(filtered_df) == 0:
|
|
|
st.warning("No data available for selected filters.")
|
|
|
return
|
|
|
|
|
|
|
|
|
tab1, tab2, tab3, tab4 = st.tabs(["📅 Daily Trends", "📆 Monthly Trends", "📦 Category Analysis", "💰 Price vs Demand"])
|
|
|
|
|
|
with tab1:
|
|
|
st.subheader("Daily Sales Trends")
|
|
|
daily_sales = filtered_df.groupby('date')['sales_quantity'].sum().reset_index()
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(14, 6))
|
|
|
ax.plot(daily_sales['date'], daily_sales['sales_quantity'], linewidth=2, alpha=0.7)
|
|
|
ax.set_title('Total Daily Sales Quantity', fontsize=16, fontweight='bold')
|
|
|
ax.set_xlabel('Date', fontsize=12)
|
|
|
ax.set_ylabel('Sales Quantity', fontsize=12)
|
|
|
ax.grid(True, alpha=0.3)
|
|
|
plt.xticks(rotation=45)
|
|
|
plt.tight_layout()
|
|
|
st.pyplot(fig)
|
|
|
|
|
|
|
|
|
col1, col2, col3, col4 = st.columns(4)
|
|
|
with col1:
|
|
|
st.metric("Total Sales", f"{daily_sales['sales_quantity'].sum():,.0f}")
|
|
|
with col2:
|
|
|
st.metric("Average Daily", f"{daily_sales['sales_quantity'].mean():.1f}")
|
|
|
with col3:
|
|
|
st.metric("Max Daily", f"{daily_sales['sales_quantity'].max():,.0f}")
|
|
|
with col4:
|
|
|
st.metric("Min Daily", f"{daily_sales['sales_quantity'].min():,.0f}")
|
|
|
|
|
|
with tab2:
|
|
|
st.subheader("Monthly Sales Trends")
|
|
|
filtered_df['month_year'] = filtered_df['date'].dt.to_period('M')
|
|
|
monthly_sales = filtered_df.groupby('month_year')['sales_quantity'].sum().reset_index()
|
|
|
monthly_sales['month_year'] = monthly_sales['month_year'].astype(str)
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(14, 6))
|
|
|
ax.bar(range(len(monthly_sales)), monthly_sales['sales_quantity'], alpha=0.7, color='steelblue')
|
|
|
ax.set_title('Monthly Sales Quantity', fontsize=16, fontweight='bold')
|
|
|
ax.set_xlabel('Month', fontsize=12)
|
|
|
ax.set_ylabel('Sales Quantity', fontsize=12)
|
|
|
ax.set_xticks(range(len(monthly_sales)))
|
|
|
ax.set_xticklabels(monthly_sales['month_year'], rotation=45, ha='right')
|
|
|
ax.grid(True, alpha=0.3, axis='y')
|
|
|
plt.tight_layout()
|
|
|
st.pyplot(fig)
|
|
|
|
|
|
with tab3:
|
|
|
st.subheader("Sales by Category")
|
|
|
category_sales = filtered_df.groupby('category')['sales_quantity'].sum().sort_values(ascending=False)
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
|
category_sales.plot(kind='barh', ax=ax, color='coral', alpha=0.7)
|
|
|
ax.set_title('Total Sales by Category', fontsize=16, fontweight='bold')
|
|
|
ax.set_xlabel('Total Sales Quantity', fontsize=12)
|
|
|
ax.set_ylabel('Category', fontsize=12)
|
|
|
ax.grid(True, alpha=0.3, axis='x')
|
|
|
plt.tight_layout()
|
|
|
st.pyplot(fig)
|
|
|
|
|
|
|
|
|
category_stats = filtered_df.groupby('category').agg({
|
|
|
'sales_quantity': ['sum', 'mean', 'std', 'min', 'max']
|
|
|
}).round(2)
|
|
|
category_stats.columns = ['Total', 'Average', 'Std Dev', 'Min', 'Max']
|
|
|
st.dataframe(category_stats, use_container_width=True)
|
|
|
|
|
|
with tab4:
|
|
|
st.subheader("Price vs Demand Relationship")
|
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
|
scatter = ax.scatter(filtered_df['price'], filtered_df['sales_quantity'],
|
|
|
c=filtered_df['discount'], cmap='viridis', alpha=0.6, s=50)
|
|
|
ax.set_title('Price vs Sales Quantity (colored by discount)', fontsize=16, fontweight='bold')
|
|
|
ax.set_xlabel('Price', fontsize=12)
|
|
|
ax.set_ylabel('Sales Quantity', fontsize=12)
|
|
|
ax.grid(True, alpha=0.3)
|
|
|
plt.colorbar(scatter, ax=ax, label='Discount %')
|
|
|
plt.tight_layout()
|
|
|
st.pyplot(fig)
|
|
|
|
|
|
|
|
|
correlation = filtered_df['price'].corr(filtered_df['sales_quantity'])
|
|
|
st.metric("Price-Demand Correlation", f"{correlation:.3f}")
|
|
|
|
|
|
|
|
|
def show_prediction_interface(df, models):
|
|
|
"""Display interactive prediction interface."""
|
|
|
st.header("🔮 Demand Prediction")
|
|
|
|
|
|
|
|
|
if models['ml_model'] is None and models['ts_model'] is None:
|
|
|
st.error("❌ No trained models found. Please run train_model.py first.")
|
|
|
return
|
|
|
|
|
|
|
|
|
model_type = st.radio(
|
|
|
"Select Model Type",
|
|
|
["Auto (Best Model)", "Machine Learning", "Time-Series"],
|
|
|
horizontal=True
|
|
|
)
|
|
|
|
|
|
st.markdown("---")
|
|
|
|
|
|
if model_type == "Time-Series" or (model_type == "Auto (Best Model)" and models['is_timeseries']):
|
|
|
|
|
|
st.subheader("Overall Daily Demand Prediction")
|
|
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
with col1:
|
|
|
prediction_date = st.date_input(
|
|
|
"Select Date for Prediction",
|
|
|
value=datetime.now().date() + timedelta(days=30),
|
|
|
min_value=df['date'].max().date() + timedelta(days=1)
|
|
|
)
|
|
|
|
|
|
with col2:
|
|
|
st.write("")
|
|
|
st.write("")
|
|
|
|
|
|
if st.button("🔮 Predict Demand", type="primary"):
|
|
|
if models['ts_model'] is None:
|
|
|
st.error("Time-series model not available.")
|
|
|
else:
|
|
|
with st.spinner("Making prediction..."):
|
|
|
prediction = predict_timeseries(
|
|
|
prediction_date,
|
|
|
models['ts_model'],
|
|
|
models['model_name']
|
|
|
)
|
|
|
|
|
|
if prediction is not None:
|
|
|
st.success(f"✅ Prediction Complete!")
|
|
|
|
|
|
col1, col2, col3 = st.columns(3)
|
|
|
with col1:
|
|
|
st.metric("Predicted Daily Demand", f"{prediction:,.0f} units")
|
|
|
with col2:
|
|
|
day_name = pd.to_datetime(prediction_date).strftime('%A')
|
|
|
st.metric("Day of Week", day_name)
|
|
|
with col3:
|
|
|
is_weekend = "Yes" if pd.to_datetime(prediction_date).weekday() >= 5 else "No"
|
|
|
st.metric("Weekend", is_weekend)
|
|
|
|
|
|
st.info("💡 This prediction represents the total daily demand across all products.")
|
|
|
else:
|
|
|
st.error("Failed to make prediction.")
|
|
|
|
|
|
else:
|
|
|
|
|
|
st.subheader("Product-Specific Demand Prediction")
|
|
|
|
|
|
|
|
|
categories = sorted(df['category'].unique().tolist())
|
|
|
products = sorted(df['product_id'].unique().tolist())
|
|
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
|
|
|
with col1:
|
|
|
selected_category = st.selectbox("Select Category", categories)
|
|
|
selected_product = st.selectbox("Select Product ID", products)
|
|
|
prediction_date = st.date_input(
|
|
|
"Select Date for Prediction",
|
|
|
value=datetime.now().date() + timedelta(days=30),
|
|
|
min_value=df['date'].max().date() + timedelta(days=1)
|
|
|
)
|
|
|
|
|
|
with col2:
|
|
|
price = st.number_input(
|
|
|
"Product Price ($)",
|
|
|
min_value=0.01,
|
|
|
value=100.0,
|
|
|
step=1.0,
|
|
|
format="%.2f"
|
|
|
)
|
|
|
discount = st.slider(
|
|
|
"Discount (%)",
|
|
|
min_value=0,
|
|
|
max_value=100,
|
|
|
value=0,
|
|
|
step=5
|
|
|
)
|
|
|
|
|
|
|
|
|
product_data = df[df['product_id'] == selected_product]
|
|
|
if len(product_data) > 0:
|
|
|
with st.expander("📊 Product Statistics"):
|
|
|
col1, col2, col3, col4 = st.columns(4)
|
|
|
with col1:
|
|
|
st.metric("Avg Price", f"${product_data['price'].mean():.2f}")
|
|
|
with col2:
|
|
|
st.metric("Avg Sales", f"{product_data['sales_quantity'].mean():.1f}")
|
|
|
with col3:
|
|
|
st.metric("Total Sales", f"{product_data['sales_quantity'].sum():,.0f}")
|
|
|
with col4:
|
|
|
st.metric("Category", selected_category)
|
|
|
|
|
|
if st.button("🔮 Predict Demand", type="primary"):
|
|
|
if models['ml_model'] is None or models['preprocessing'] is None:
|
|
|
st.error("ML model or preprocessing not available.")
|
|
|
else:
|
|
|
with st.spinner("Making prediction..."):
|
|
|
prediction = predict_ml(
|
|
|
selected_product,
|
|
|
prediction_date,
|
|
|
price,
|
|
|
discount,
|
|
|
selected_category,
|
|
|
models['ml_model'],
|
|
|
models['preprocessing']
|
|
|
)
|
|
|
|
|
|
if prediction is not None:
|
|
|
st.success(f"✅ Prediction Complete!")
|
|
|
|
|
|
col1, col2, col3, col4 = st.columns(4)
|
|
|
with col1:
|
|
|
st.metric("Predicted Demand", f"{prediction:,.0f} units")
|
|
|
with col2:
|
|
|
st.metric("Price", f"${price:.2f}")
|
|
|
with col3:
|
|
|
st.metric("Discount", f"{discount}%")
|
|
|
with col4:
|
|
|
day_name = pd.to_datetime(prediction_date).strftime('%A')
|
|
|
st.metric("Day", day_name)
|
|
|
|
|
|
|
|
|
st.markdown("### 📈 Prediction Insights")
|
|
|
date_obj = pd.to_datetime(prediction_date)
|
|
|
is_weekend = date_obj.weekday() >= 5
|
|
|
month = date_obj.month
|
|
|
|
|
|
insights = []
|
|
|
if is_weekend:
|
|
|
insights.append("📅 Weekend - typically higher demand")
|
|
|
if month in [11, 12]:
|
|
|
insights.append("🎄 Holiday season - peak sales period")
|
|
|
if discount > 0:
|
|
|
insights.append(f"💰 {discount}% discount - may increase demand")
|
|
|
|
|
|
if insights:
|
|
|
for insight in insights:
|
|
|
st.info(insight)
|
|
|
else:
|
|
|
st.error("Failed to make prediction.")
|
|
|
|
|
|
|
|
|
def show_model_comparison(models):
|
|
|
"""Display model comparison."""
|
|
|
st.header("📊 Model Comparison")
|
|
|
|
|
|
if models['metadata'] is None:
|
|
|
st.warning("Model metadata not available. Please run train_model.py first.")
|
|
|
return
|
|
|
|
|
|
metadata = models['metadata']
|
|
|
all_models = metadata.get('all_models', {})
|
|
|
best_model = metadata.get('best_model', 'Unknown')
|
|
|
|
|
|
if not all_models:
|
|
|
st.warning("No model comparison data available.")
|
|
|
return
|
|
|
|
|
|
|
|
|
st.subheader("Model Performance Metrics")
|
|
|
|
|
|
comparison_data = []
|
|
|
for model_name, metrics in all_models.items():
|
|
|
comparison_data.append({
|
|
|
'Model': model_name,
|
|
|
'Type': 'Time-Series' if model_name in ['ARIMA', 'Prophet'] else 'Machine Learning',
|
|
|
'MAE': metrics.get('mae', 0),
|
|
|
'RMSE': metrics.get('rmse', 0),
|
|
|
'R2 Score': metrics.get('r2', 0)
|
|
|
})
|
|
|
|
|
|
comparison_df = pd.DataFrame(comparison_data)
|
|
|
|
|
|
|
|
|
def highlight_best(row):
|
|
|
if row['Model'] == best_model:
|
|
|
return ['background-color: #90EE90'] * len(row)
|
|
|
return [''] * len(row)
|
|
|
|
|
|
st.dataframe(
|
|
|
comparison_df.style.apply(highlight_best, axis=1),
|
|
|
use_container_width=True
|
|
|
)
|
|
|
|
|
|
|
|
|
st.subheader("Performance Comparison Charts")
|
|
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
|
|
|
with col1:
|
|
|
fig, ax = plt.subplots(figsize=(10, 6))
|
|
|
model_names = comparison_df['Model'].tolist()
|
|
|
mae_scores = comparison_df['MAE'].tolist()
|
|
|
|
|
|
colors = ['coral' if name in ['ARIMA', 'Prophet'] else 'skyblue' for name in model_names]
|
|
|
ax.bar(model_names, mae_scores, color=colors, alpha=0.7)
|
|
|
ax.set_title('MAE Comparison (Lower is Better)', fontsize=14, fontweight='bold')
|
|
|
ax.set_ylabel('MAE', fontsize=12)
|
|
|
ax.tick_params(axis='x', rotation=45)
|
|
|
ax.grid(True, alpha=0.3, axis='y')
|
|
|
plt.tight_layout()
|
|
|
st.pyplot(fig)
|
|
|
|
|
|
with col2:
|
|
|
fig, ax = plt.subplots(figsize=(10, 6))
|
|
|
r2_scores = comparison_df['R2 Score'].tolist()
|
|
|
|
|
|
colors = ['coral' if name in ['ARIMA', 'Prophet'] else 'skyblue' for name in model_names]
|
|
|
ax.bar(model_names, r2_scores, color=colors, alpha=0.7)
|
|
|
ax.set_title('R2 Score Comparison (Higher is Better)', fontsize=14, fontweight='bold')
|
|
|
ax.set_ylabel('R2 Score', fontsize=12)
|
|
|
ax.tick_params(axis='x', rotation=45)
|
|
|
ax.grid(True, alpha=0.3, axis='y')
|
|
|
plt.tight_layout()
|
|
|
st.pyplot(fig)
|
|
|
|
|
|
|
|
|
st.markdown("---")
|
|
|
st.success(f"🏆 **Best Model: {best_model}**")
|
|
|
if best_model in all_models:
|
|
|
best_metrics = all_models[best_model]
|
|
|
col1, col2, col3 = st.columns(3)
|
|
|
with col1:
|
|
|
st.metric("MAE", f"{best_metrics.get('mae', 0):.2f}")
|
|
|
with col2:
|
|
|
st.metric("RMSE", f"{best_metrics.get('rmse', 0):.2f}")
|
|
|
with col3:
|
|
|
st.metric("R2 Score", f"{best_metrics.get('r2', 0):.4f}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|