vikaswebdev's picture
Upload 17 files
7f90ea0 verified
"""
Demand Prediction System - Streamlit Dashboard
Interactive dashboard for visualizing sales trends and making demand predictions.
"""
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import json
from datetime import datetime, timedelta, date as dt_date
import os
import warnings
warnings.filterwarnings('ignore')
# Page configuration
st.set_page_config(
page_title="Demand Prediction Dashboard",
page_icon="📊",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better styling
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
font-weight: bold;
color: #1f77b4;
text-align: center;
margin-bottom: 2rem;
}
.metric-card {
background-color: #f0f2f6;
padding: 1rem;
border-radius: 0.5rem;
margin: 0.5rem 0;
}
.stButton>button {
width: 100%;
background-color: #1f77b4;
color: white;
}
</style>
""", unsafe_allow_html=True)
# Configuration
DATA_PATH = 'data/sales.csv'
MODEL_DIR = 'models'
MODEL_PATH = f'{MODEL_DIR}/best_model.joblib'
TS_MODEL_PATH = f'{MODEL_DIR}/best_timeseries_model.joblib'
PREPROCESSING_PATH = f'{MODEL_DIR}/preprocessing.joblib'
ALL_MODELS_METADATA_PATH = f'{MODEL_DIR}/all_models_metadata.json'
@st.cache_data
def load_data():
"""Load sales data with caching."""
if os.path.exists(DATA_PATH):
df = pd.read_csv(DATA_PATH)
df['date'] = pd.to_datetime(df['date'])
return df
return None
@st.cache_resource
def load_models():
"""Load trained models with caching."""
models = {
'ml_model': None,
'ts_model': None,
'preprocessing': None,
'model_name': None,
'is_timeseries': False,
'metadata': None
}
# Load metadata
if os.path.exists(ALL_MODELS_METADATA_PATH):
with open(ALL_MODELS_METADATA_PATH, 'r') as f:
models['metadata'] = json.load(f)
models['model_name'] = models['metadata'].get('best_model', 'Unknown')
models['is_timeseries'] = models['model_name'] in ['ARIMA', 'Prophet']
# Load ML model
if os.path.exists(MODEL_PATH):
try:
models['ml_model'] = joblib.load(MODEL_PATH)
except:
pass
# Load time-series model
if os.path.exists(TS_MODEL_PATH):
try:
models['ts_model'] = joblib.load(TS_MODEL_PATH)
except:
pass
# Load preprocessing
if os.path.exists(PREPROCESSING_PATH):
try:
models['preprocessing'] = joblib.load(PREPROCESSING_PATH)
except:
pass
return models
def prepare_features_ml(product_id, date, price, discount, category, preprocessing_data):
"""Prepare features for ML model prediction."""
if preprocessing_data is None:
return None
# Convert date to pandas Timestamp (handles date, datetime, and string)
# Handle datetime.date objects explicitly
if isinstance(date, dt_date):
date = pd.Timestamp(date)
elif not isinstance(date, pd.Timestamp):
date = pd.to_datetime(date)
# Extract date features
day = date.day
month = date.month
day_of_week = date.weekday()
weekend = 1 if day_of_week >= 5 else 0
year = date.year
quarter = date.quarter
# Encode categorical variables
category_encoder = preprocessing_data['encoders']['category']
product_encoder = preprocessing_data['encoders']['product_id']
try:
category_encoded = category_encoder.transform([category])[0]
except ValueError:
category_encoded = 0
try:
product_id_encoded = product_encoder.transform([product_id])[0]
except ValueError:
product_id_encoded = product_encoder.transform([product_encoder.classes_[0]])[0]
# Create feature dictionary
feature_dict = {
'price': price,
'discount': discount,
'day': day,
'month': month,
'day_of_week': day_of_week,
'weekend': weekend,
'year': year,
'quarter': quarter,
'category_encoded': category_encoded,
'product_id_encoded': product_id_encoded
}
# Create feature array in the same order as training
feature_names = preprocessing_data['feature_names']
features = np.array([[feature_dict[name] for name in feature_names]])
# Scale features
scaler = preprocessing_data['scaler']
features_scaled = scaler.transform(features)
return features_scaled
def predict_ml(product_id, date, price, discount, category, model, preprocessing_data):
"""Make prediction using ML model."""
features = prepare_features_ml(product_id, date, price, discount, category, preprocessing_data)
if features is None:
return None
prediction = model.predict(features)[0]
return max(0, prediction)
def predict_timeseries(date, model, model_name):
"""Make prediction using time-series model."""
# Convert date to pandas Timestamp (handles date, datetime, and string)
if isinstance(date, dt_date):
date = pd.Timestamp(date)
elif not isinstance(date, pd.Timestamp):
date = pd.to_datetime(date)
if model_name == 'ARIMA':
try:
forecast = model.forecast(steps=1)
prediction = forecast[0] if hasattr(forecast, '__iter__') else forecast
return max(0, prediction)
except:
return None
elif model_name == 'Prophet':
try:
future = pd.DataFrame({'ds': [date]})
forecast = model.predict(future)
prediction = forecast['yhat'].iloc[0]
return max(0, prediction)
except:
return None
return None
def main():
"""Main dashboard function."""
# Header
st.markdown('<h1 class="main-header">📊 Demand Prediction Dashboard</h1>', unsafe_allow_html=True)
# Load data
df = load_data()
if df is None:
st.error("❌ Sales data not found. Please run generate_dataset.py first.")
return
# Load models
models = load_models()
# Sidebar
with st.sidebar:
st.header("⚙️ Navigation")
page = st.radio(
"Select Page",
["📈 Sales Trends", "🔮 Demand Prediction", "📊 Model Comparison"],
index=0
)
st.markdown("---")
st.header("ℹ️ Information")
if models['metadata']:
best_model = models['metadata'].get('best_model', 'Unknown')
st.info(f"**Best Model:** {best_model}")
if best_model in models['metadata'].get('all_models', {}):
metrics = models['metadata']['all_models'][best_model]
st.metric("R2 Score", f"{metrics.get('r2', 0):.4f}")
# Main content based on selected page
if page == "📈 Sales Trends":
show_sales_trends(df)
elif page == "🔮 Demand Prediction":
show_prediction_interface(df, models)
elif page == "📊 Model Comparison":
show_model_comparison(models)
def show_sales_trends(df):
"""Display sales trends visualizations."""
st.header("📈 Sales Trends Analysis")
# Filters
col1, col2, col3 = st.columns(3)
with col1:
categories = ['All'] + sorted(df['category'].unique().tolist())
selected_category = st.selectbox("Select Category", categories)
with col2:
products = ['All'] + sorted(df['product_id'].unique().tolist())
selected_product = st.selectbox("Select Product", products)
with col3:
date_range = st.date_input(
"Select Date Range",
value=(df['date'].min(), df['date'].max()),
min_value=df['date'].min(),
max_value=df['date'].max()
)
# Filter data
filtered_df = df.copy()
if selected_category != 'All':
filtered_df = filtered_df[filtered_df['category'] == selected_category]
if selected_product != 'All':
filtered_df = filtered_df[filtered_df['product_id'] == int(selected_product)]
if isinstance(date_range, tuple) and len(date_range) == 2:
filtered_df = filtered_df[
(filtered_df['date'] >= pd.to_datetime(date_range[0])) &
(filtered_df['date'] <= pd.to_datetime(date_range[1]))
]
if len(filtered_df) == 0:
st.warning("No data available for selected filters.")
return
# Visualizations
tab1, tab2, tab3, tab4 = st.tabs(["📅 Daily Trends", "📆 Monthly Trends", "📦 Category Analysis", "💰 Price vs Demand"])
with tab1:
st.subheader("Daily Sales Trends")
daily_sales = filtered_df.groupby('date')['sales_quantity'].sum().reset_index()
fig, ax = plt.subplots(figsize=(14, 6))
ax.plot(daily_sales['date'], daily_sales['sales_quantity'], linewidth=2, alpha=0.7)
ax.set_title('Total Daily Sales Quantity', fontsize=16, fontweight='bold')
ax.set_xlabel('Date', fontsize=12)
ax.set_ylabel('Sales Quantity', fontsize=12)
ax.grid(True, alpha=0.3)
plt.xticks(rotation=45)
plt.tight_layout()
st.pyplot(fig)
# Statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Sales", f"{daily_sales['sales_quantity'].sum():,.0f}")
with col2:
st.metric("Average Daily", f"{daily_sales['sales_quantity'].mean():.1f}")
with col3:
st.metric("Max Daily", f"{daily_sales['sales_quantity'].max():,.0f}")
with col4:
st.metric("Min Daily", f"{daily_sales['sales_quantity'].min():,.0f}")
with tab2:
st.subheader("Monthly Sales Trends")
filtered_df['month_year'] = filtered_df['date'].dt.to_period('M')
monthly_sales = filtered_df.groupby('month_year')['sales_quantity'].sum().reset_index()
monthly_sales['month_year'] = monthly_sales['month_year'].astype(str)
fig, ax = plt.subplots(figsize=(14, 6))
ax.bar(range(len(monthly_sales)), monthly_sales['sales_quantity'], alpha=0.7, color='steelblue')
ax.set_title('Monthly Sales Quantity', fontsize=16, fontweight='bold')
ax.set_xlabel('Month', fontsize=12)
ax.set_ylabel('Sales Quantity', fontsize=12)
ax.set_xticks(range(len(monthly_sales)))
ax.set_xticklabels(monthly_sales['month_year'], rotation=45, ha='right')
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
st.pyplot(fig)
with tab3:
st.subheader("Sales by Category")
category_sales = filtered_df.groupby('category')['sales_quantity'].sum().sort_values(ascending=False)
fig, ax = plt.subplots(figsize=(12, 6))
category_sales.plot(kind='barh', ax=ax, color='coral', alpha=0.7)
ax.set_title('Total Sales by Category', fontsize=16, fontweight='bold')
ax.set_xlabel('Total Sales Quantity', fontsize=12)
ax.set_ylabel('Category', fontsize=12)
ax.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
st.pyplot(fig)
# Category statistics table
category_stats = filtered_df.groupby('category').agg({
'sales_quantity': ['sum', 'mean', 'std', 'min', 'max']
}).round(2)
category_stats.columns = ['Total', 'Average', 'Std Dev', 'Min', 'Max']
st.dataframe(category_stats, use_container_width=True)
with tab4:
st.subheader("Price vs Demand Relationship")
# Scatter plot
fig, ax = plt.subplots(figsize=(12, 6))
scatter = ax.scatter(filtered_df['price'], filtered_df['sales_quantity'],
c=filtered_df['discount'], cmap='viridis', alpha=0.6, s=50)
ax.set_title('Price vs Sales Quantity (colored by discount)', fontsize=16, fontweight='bold')
ax.set_xlabel('Price', fontsize=12)
ax.set_ylabel('Sales Quantity', fontsize=12)
ax.grid(True, alpha=0.3)
plt.colorbar(scatter, ax=ax, label='Discount %')
plt.tight_layout()
st.pyplot(fig)
# Correlation
correlation = filtered_df['price'].corr(filtered_df['sales_quantity'])
st.metric("Price-Demand Correlation", f"{correlation:.3f}")
def show_prediction_interface(df, models):
"""Display interactive prediction interface."""
st.header("🔮 Demand Prediction")
# Check if models are available
if models['ml_model'] is None and models['ts_model'] is None:
st.error("❌ No trained models found. Please run train_model.py first.")
return
# Model selection
model_type = st.radio(
"Select Model Type",
["Auto (Best Model)", "Machine Learning", "Time-Series"],
horizontal=True
)
st.markdown("---")
if model_type == "Time-Series" or (model_type == "Auto (Best Model)" and models['is_timeseries']):
# Time-series prediction
st.subheader("Overall Daily Demand Prediction")
col1, col2 = st.columns(2)
with col1:
prediction_date = st.date_input(
"Select Date for Prediction",
value=datetime.now().date() + timedelta(days=30),
min_value=df['date'].max().date() + timedelta(days=1)
)
with col2:
st.write("") # Spacing
st.write("") # Spacing
if st.button("🔮 Predict Demand", type="primary"):
if models['ts_model'] is None:
st.error("Time-series model not available.")
else:
with st.spinner("Making prediction..."):
prediction = predict_timeseries(
prediction_date,
models['ts_model'],
models['model_name']
)
if prediction is not None:
st.success(f"✅ Prediction Complete!")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Predicted Daily Demand", f"{prediction:,.0f} units")
with col2:
day_name = pd.to_datetime(prediction_date).strftime('%A')
st.metric("Day of Week", day_name)
with col3:
is_weekend = "Yes" if pd.to_datetime(prediction_date).weekday() >= 5 else "No"
st.metric("Weekend", is_weekend)
st.info("💡 This prediction represents the total daily demand across all products.")
else:
st.error("Failed to make prediction.")
else:
# ML model prediction
st.subheader("Product-Specific Demand Prediction")
# Get unique values for dropdowns
categories = sorted(df['category'].unique().tolist())
products = sorted(df['product_id'].unique().tolist())
col1, col2 = st.columns(2)
with col1:
selected_category = st.selectbox("Select Category", categories)
selected_product = st.selectbox("Select Product ID", products)
prediction_date = st.date_input(
"Select Date for Prediction",
value=datetime.now().date() + timedelta(days=30),
min_value=df['date'].max().date() + timedelta(days=1)
)
with col2:
price = st.number_input(
"Product Price ($)",
min_value=0.01,
value=100.0,
step=1.0,
format="%.2f"
)
discount = st.slider(
"Discount (%)",
min_value=0,
max_value=100,
value=0,
step=5
)
# Show product statistics
product_data = df[df['product_id'] == selected_product]
if len(product_data) > 0:
with st.expander("📊 Product Statistics"):
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Avg Price", f"${product_data['price'].mean():.2f}")
with col2:
st.metric("Avg Sales", f"{product_data['sales_quantity'].mean():.1f}")
with col3:
st.metric("Total Sales", f"{product_data['sales_quantity'].sum():,.0f}")
with col4:
st.metric("Category", selected_category)
if st.button("🔮 Predict Demand", type="primary"):
if models['ml_model'] is None or models['preprocessing'] is None:
st.error("ML model or preprocessing not available.")
else:
with st.spinner("Making prediction..."):
prediction = predict_ml(
selected_product,
prediction_date,
price,
discount,
selected_category,
models['ml_model'],
models['preprocessing']
)
if prediction is not None:
st.success(f"✅ Prediction Complete!")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Predicted Demand", f"{prediction:,.0f} units")
with col2:
st.metric("Price", f"${price:.2f}")
with col3:
st.metric("Discount", f"{discount}%")
with col4:
day_name = pd.to_datetime(prediction_date).strftime('%A')
st.metric("Day", day_name)
# Additional insights
st.markdown("### 📈 Prediction Insights")
date_obj = pd.to_datetime(prediction_date)
is_weekend = date_obj.weekday() >= 5
month = date_obj.month
insights = []
if is_weekend:
insights.append("📅 Weekend - typically higher demand")
if month in [11, 12]:
insights.append("🎄 Holiday season - peak sales period")
if discount > 0:
insights.append(f"💰 {discount}% discount - may increase demand")
if insights:
for insight in insights:
st.info(insight)
else:
st.error("Failed to make prediction.")
def show_model_comparison(models):
"""Display model comparison."""
st.header("📊 Model Comparison")
if models['metadata'] is None:
st.warning("Model metadata not available. Please run train_model.py first.")
return
metadata = models['metadata']
all_models = metadata.get('all_models', {})
best_model = metadata.get('best_model', 'Unknown')
if not all_models:
st.warning("No model comparison data available.")
return
# Model metrics table
st.subheader("Model Performance Metrics")
comparison_data = []
for model_name, metrics in all_models.items():
comparison_data.append({
'Model': model_name,
'Type': 'Time-Series' if model_name in ['ARIMA', 'Prophet'] else 'Machine Learning',
'MAE': metrics.get('mae', 0),
'RMSE': metrics.get('rmse', 0),
'R2 Score': metrics.get('r2', 0)
})
comparison_df = pd.DataFrame(comparison_data)
# Highlight best model
def highlight_best(row):
if row['Model'] == best_model:
return ['background-color: #90EE90'] * len(row)
return [''] * len(row)
st.dataframe(
comparison_df.style.apply(highlight_best, axis=1),
use_container_width=True
)
# Visualizations
st.subheader("Performance Comparison Charts")
col1, col2 = st.columns(2)
with col1:
fig, ax = plt.subplots(figsize=(10, 6))
model_names = comparison_df['Model'].tolist()
mae_scores = comparison_df['MAE'].tolist()
colors = ['coral' if name in ['ARIMA', 'Prophet'] else 'skyblue' for name in model_names]
ax.bar(model_names, mae_scores, color=colors, alpha=0.7)
ax.set_title('MAE Comparison (Lower is Better)', fontsize=14, fontweight='bold')
ax.set_ylabel('MAE', fontsize=12)
ax.tick_params(axis='x', rotation=45)
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
st.pyplot(fig)
with col2:
fig, ax = plt.subplots(figsize=(10, 6))
r2_scores = comparison_df['R2 Score'].tolist()
colors = ['coral' if name in ['ARIMA', 'Prophet'] else 'skyblue' for name in model_names]
ax.bar(model_names, r2_scores, color=colors, alpha=0.7)
ax.set_title('R2 Score Comparison (Higher is Better)', fontsize=14, fontweight='bold')
ax.set_ylabel('R2 Score', fontsize=12)
ax.tick_params(axis='x', rotation=45)
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
st.pyplot(fig)
# Best model info
st.markdown("---")
st.success(f"🏆 **Best Model: {best_model}**")
if best_model in all_models:
best_metrics = all_models[best_model]
col1, col2, col3 = st.columns(3)
with col1:
st.metric("MAE", f"{best_metrics.get('mae', 0):.2f}")
with col2:
st.metric("RMSE", f"{best_metrics.get('rmse', 0):.2f}")
with col3:
st.metric("R2 Score", f"{best_metrics.get('r2', 0):.4f}")
if __name__ == "__main__":
main()