|
|
|
|
|
import streamlit as st |
|
|
import requests |
|
|
import json |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import joblib |
|
|
import shap |
|
|
import matplotlib |
|
|
|
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
from datetime import datetime |
|
|
import os |
|
|
import traceback |
|
|
from sklearn.pipeline import Pipeline |
|
|
from sklearn.compose import ColumnTransformer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="SuperKart Sales Predictor", |
|
|
page_icon="๐", |
|
|
layout="wide", |
|
|
initial_sidebar_state="auto" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from custom_transformers import ManualProductTypeMapper |
|
|
except ImportError as e: |
|
|
|
|
|
st.error(f"Failed to import custom component: {e}. Please ensure 'custom_transformers.py' is in the root directory.") |
|
|
|
|
|
st.stop() |
|
|
|
|
|
|
|
|
|
|
|
BACKEND_URL = "https://Varun6299-SuperKartPredictor.hf.space/predict" |
|
|
MODEL_PATH = "final_xgboost_pipeline.joblib" |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model_and_setup_shap(): |
|
|
"""Loads model, extracts components, sets up SHAP, and gets feature names.""" |
|
|
if not os.path.exists(MODEL_PATH): |
|
|
st.warning(f"Model file not found locally: {MODEL_PATH}. Explanation feature is disabled.") |
|
|
return None, None, None |
|
|
try: |
|
|
pipeline = joblib.load(MODEL_PATH) |
|
|
|
|
|
|
|
|
|
|
|
xgb_model = pipeline.named_steps['tune_XGBoost_regressor'] |
|
|
preprocessor = pipeline[:-1] |
|
|
|
|
|
explainer = shap.TreeExplainer(xgb_model) |
|
|
|
|
|
|
|
|
ct = preprocessor.named_steps['col_transform'] |
|
|
transformed_feature_names = list(ct.get_feature_names_out()) |
|
|
|
|
|
|
|
|
def clean_feature_name(name): |
|
|
if name.startswith('ohe_cat__'): |
|
|
return name.replace('ohe_cat__', '') |
|
|
if name.startswith('remainder__'): |
|
|
return name.replace('remainder__', '') |
|
|
return name |
|
|
|
|
|
cleaned_feature_names = [clean_feature_name(name) for name in transformed_feature_names] |
|
|
|
|
|
|
|
|
return pipeline, explainer, cleaned_feature_names |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Local Model/SHAP Setup Failed. Cannot explain prediction.") |
|
|
st.code(f"Error Details: {e}") |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
LOCAL_PIPELINE, LOCAL_EXPLAINER, TRANSFORMED_FEATURE_NAMES = load_model_and_setup_shap() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def show_waterfall_plot(explanation): |
|
|
"""Generates and displays the SHAP Waterfall plot using st.pyplot.""" |
|
|
plt.style.use('default') |
|
|
fig, ax = plt.subplots(figsize=(10, 6), dpi=100) |
|
|
shap.plots.waterfall(explanation, show=False) |
|
|
st.pyplot(fig, bbox_inches='tight') |
|
|
plt.close(fig) |
|
|
|
|
|
def get_shap_explanation(pipeline, input_df, explainer, feature_names): |
|
|
"""Calculates SHAP values for a single raw input.""" |
|
|
|
|
|
X_transformed = pipeline[:-1].transform(input_df) |
|
|
|
|
|
|
|
|
shap_values_2d = explainer.shap_values(X_transformed) |
|
|
|
|
|
|
|
|
shap_values = shap_values_2d[0] |
|
|
|
|
|
|
|
|
explanation_data = X_transformed[0, :] |
|
|
|
|
|
|
|
|
explanation = shap.Explanation( |
|
|
values=shap_values, |
|
|
base_values=explainer.expected_value, |
|
|
data=explanation_data, |
|
|
feature_names=feature_names |
|
|
) |
|
|
return explanation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'last_prediction' not in st.session_state: |
|
|
st.session_state.last_prediction = None |
|
|
if 'last_input_df' not in st.session_state: |
|
|
st.session_state.last_input_df = None |
|
|
if 'explanation_shown' not in st.session_state: |
|
|
st.session_state.explanation_shown = False |
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
/* 1. REMOVE WHITE PANEL BACKGROUNDS */ |
|
|
.stBlock { |
|
|
background-color: transparent !important; |
|
|
border-radius: 0px !important; |
|
|
} |
|
|
.stAlert { |
|
|
border-radius: 8px !important; |
|
|
} |
|
|
.big-font { |
|
|
font-size:30px !important; |
|
|
font-weight: bold; |
|
|
color: #0E8388; |
|
|
} |
|
|
.stButton>button { |
|
|
font-size: 18px; |
|
|
font-weight: bold; |
|
|
color: white; |
|
|
background-color: #007bff; |
|
|
border-radius: 8px; |
|
|
padding: 10px 20px; |
|
|
border: none; |
|
|
} |
|
|
.stApp { |
|
|
background-color: #f0f2f6; |
|
|
} |
|
|
[data-testid="stMetricValue"] { |
|
|
font-size: 3rem; |
|
|
color: #007bff; |
|
|
} |
|
|
</style> |
|
|
<p class="big-font">๐ SuperKart Sales Forecasting Tool</p> |
|
|
<hr style="border: 2px solid #007bff; border-radius: 5px;"> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
with st.expander("About SuperKart and this Predictor", expanded=False): |
|
|
st.markdown(""" |
|
|
**SuperKart** is a major retail chain operating supermarkets and food marts across various tier cities, offering a wide range of products. |
|
|
|
|
|
To **optimize inventory management** and make informed decisions around regional sales strategies, this tool accurately forecasts the sales revenue of its outlets for the upcoming quarter. This robust predictive model is built using historical sales data to integrate forecasting directly into SuperKart's decision-making systems. |
|
|
""") |
|
|
|
|
|
st.info("๐ก **Instructions:** Adjust the features below and click **1. Predict Sales** first. Then, click **2. Explain Prediction** to see the feature contributions.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.sidebar.header("Store Configuration ๐ข") |
|
|
current_year = datetime.now().year |
|
|
store_establishment_year = st.sidebar.slider( |
|
|
'Store Establishment Year', |
|
|
min_value=1980, |
|
|
max_value=current_year, |
|
|
value=2000, |
|
|
step=1 |
|
|
) |
|
|
store_location_city_type = st.sidebar.selectbox( |
|
|
'Store Location City Type', |
|
|
['Tier 1', 'Tier 2', 'Tier 3'] |
|
|
) |
|
|
store_size = st.sidebar.selectbox( |
|
|
'Store Size', |
|
|
['Medium', 'Small', 'High'] |
|
|
) |
|
|
store_type = st.sidebar.selectbox( |
|
|
'Store Type', |
|
|
['Supermarket Type1', 'Supermarket Type2', 'Departmental Store', 'Food Mart'] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
st.header("Product Specifications ๐ฆ") |
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
|
|
|
|
|
with col1: |
|
|
product_weight = st.number_input( |
|
|
'Product Weight (kg)', |
|
|
min_value=0.5, |
|
|
max_value=30.0, |
|
|
value=12.0, |
|
|
step=0.1, |
|
|
help="Weight of the product in kilograms." |
|
|
) |
|
|
product_mrp = st.slider( |
|
|
'Product MRP ($)', |
|
|
min_value=10.0, |
|
|
max_value=300.0, |
|
|
value=150.0, |
|
|
step=0.5, |
|
|
help="Maximum Retail Price of the product." |
|
|
) |
|
|
|
|
|
|
|
|
with col2: |
|
|
product_allocated_area = st.number_input( |
|
|
'Shelf Area (%)', |
|
|
min_value=0.0, |
|
|
max_value=100.0, |
|
|
value=10.0, |
|
|
step=0.1, |
|
|
help="Percentage of store shelf area allocated to this product." |
|
|
) |
|
|
product_quantity = st.number_input( |
|
|
'Stock Quantity', |
|
|
min_value=0, |
|
|
value=100, |
|
|
help="Current stock level for the product." |
|
|
) |
|
|
|
|
|
|
|
|
with col3: |
|
|
product_sugar_content = st.selectbox( |
|
|
'Sugar/Fat Content', |
|
|
['Low Sugar', 'Regular', 'No Sugar', 'reg'] |
|
|
) |
|
|
product_type = st.selectbox( |
|
|
'Product Type', |
|
|
['Frozen Foods', 'Dairy', 'Canned', 'Baking Goods', 'Health and Hygiene', 'Food Mart', 'Snack Foods', 'Meat', 'Household', 'Hard Drinks', 'Fruits and Vegetables', 'Breads', 'Soft Drinks', 'Breakfast', 'Others', 'Starchy Foods', 'Seafood'] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
feature_cols = [ |
|
|
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area', |
|
|
'Product_Type', 'Product_Quantity', 'Product_MRP', |
|
|
'Store_Establishment_Year', 'Store_Size', 'Store_Location_City_Type', |
|
|
'Store_Type' |
|
|
] |
|
|
|
|
|
input_data_dict = { |
|
|
'Product_Weight': product_weight, |
|
|
'Product_Sugar_Content': product_sugar_content, |
|
|
'Product_Allocated_Area': product_allocated_area, |
|
|
'Product_Type': product_type, |
|
|
'Product_Quantity': product_quantity, |
|
|
'Product_MRP': product_mrp, |
|
|
'Store_Establishment_Year': store_establishment_year, |
|
|
'Store_Size': store_size, |
|
|
'Store_Location_City_Type': store_location_city_type, |
|
|
'Store_Type': store_type |
|
|
} |
|
|
|
|
|
input_df = pd.DataFrame([input_data_dict], columns=feature_cols) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
if st.button('1. Predict Sales Total (Remote)', use_container_width=True, key='predict_btn'): |
|
|
st.session_state.explanation_shown = False |
|
|
st.session_state.last_prediction = None |
|
|
with st.spinner('Calling remote API...'): |
|
|
try: |
|
|
|
|
|
response = requests.post(BACKEND_URL, json=input_df.to_dict(orient='records')[0]) |
|
|
response.raise_for_status() |
|
|
|
|
|
response_data = response.json() |
|
|
prediction = response_data.get('Predicted Sales (in dollars)') |
|
|
|
|
|
if prediction is not None: |
|
|
st.session_state.last_prediction = round(prediction, 2) |
|
|
st.session_state.last_input_df = input_df |
|
|
st.session_state.explanation_shown = False |
|
|
st.success("Prediction retrieved successfully.") |
|
|
else: |
|
|
st.error("Prediction failed: Backend response is missing the 'Predicted Sales (in dollars)' key.") |
|
|
st.json(response_data) |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
st.error(f"Connection Error: Could not connect to the backend API at {BACKEND_URL}. Please ensure the backend is running.") |
|
|
st.code(f"Error Details: {e}") |
|
|
except Exception as e: |
|
|
st.error(f"An unexpected error occurred: {e}") |
|
|
|
|
|
|
|
|
if st.session_state.last_prediction is not None: |
|
|
st.subheader("โ
Latest Forecast") |
|
|
st.metric( |
|
|
label="Forecast for Next Period", |
|
|
value=f"${st.session_state.last_prediction:,.2f}" |
|
|
) |
|
|
|
|
|
|
|
|
explain_disabled = LOCAL_PIPELINE is None |
|
|
|
|
|
|
|
|
if st.button('2. Explain Prediction (Local SHAP)', use_container_width=True, disabled=explain_disabled, key='explain_btn'): |
|
|
if LOCAL_PIPELINE is None: |
|
|
st.warning("Local model is unavailable. Check error messages at the top.") |
|
|
else: |
|
|
with st.spinner('Calculating local SHAP explanation...'): |
|
|
try: |
|
|
|
|
|
df_to_explain = st.session_state.last_input_df |
|
|
|
|
|
|
|
|
|
|
|
st.subheader("๐ก Breakdown of the Forecast (SHAP Analysis)") |
|
|
|
|
|
|
|
|
shap_explanation = get_shap_explanation( |
|
|
LOCAL_PIPELINE, |
|
|
df_to_explain, |
|
|
LOCAL_EXPLAINER, |
|
|
TRANSFORMED_FEATURE_NAMES |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown(f""" |
|
|
<p style='font-size:16px;'> |
|
|
Base Value: ${LOCAL_EXPLAINER.expected_value:,.2f} |
|
|
(Average sales prediction). |
|
|
</p> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
show_waterfall_plot(shap_explanation) |
|
|
|
|
|
|
|
|
st.markdown("##### Plot Interpretation Guide") |
|
|
st.markdown(f""" |
|
|
The Waterfall Plot shows how each feature value contributes to pushing the model's output from the **Base Value** (the average prediction: ${LOCAL_EXPLAINER.expected_value:,.2f}) to the final **Forecast** (${st.session_state.last_prediction:,.2f}). |
|
|
|
|
|
* **Red Bars:** Feature values that **increase** the prediction (positive contribution). |
|
|
* **Blue Bars:** Feature values that **decrease** the prediction (negative contribution). |
|
|
* The length of the bar indicates the magnitude (strength) of the feature's influence on the prediction. |
|
|
|
|
|
""") |
|
|
st.session_state.explanation_shown = True |
|
|
st.success("Explanation generated successfully.") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error during local SHAP calculation.") |
|
|
st.code(f"Error Details: {e}") |
|
|
|
|
|
|
|
|
if LOCAL_PIPELINE is not None and not st.session_state.explanation_shown: |
|
|
st.markdown(""" |
|
|
<div style='margin-top: 20px; padding: 10px; background-color: #f0f8ff; border-radius: 8px; border-left: 5px solid #007bff;'> |
|
|
If prediction was successful, click the **Explain Prediction** button to understand the feature contributions. |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
elif LOCAL_PIPELINE is not None and st.session_state.last_prediction is None: |
|
|
st.markdown(""" |
|
|
<div style='margin-top: 20px; padding: 10px; background-color: #fff3cd; border-radius: 8px; border-left: 5px solid #ffc107;'> |
|
|
Click **1. Predict Sales Total (Remote)** to get a forecast first. |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|