import streamlit as st
import requests
import json
import pandas as pd
import numpy as np # Added for SHAP/model manipulation
import joblib # Added for local model loading
import shap # Added for SHAP calculation
import matplotlib
# --- FIX: Set non-interactive Matplotlib backend for headless environments ---
matplotlib.use('Agg')
import matplotlib.pyplot as plt # Added for SHAP plotting
from datetime import datetime
import os
import traceback
from sklearn.pipeline import Pipeline # Added for unpickling the pipeline structure
from sklearn.compose import ColumnTransformer # Added for unpickling the ColumnTransformer
# ----------------- APP CONFIGURATION (MUST BE FIRST STREAMLIT COMMAND) -----------------
st.set_page_config(
page_title="SuperKart Sales Predictor",
page_icon="🛒",
layout="wide", # Use wide layout for a professional look
initial_sidebar_state="auto"
)
# --- CRITICAL LOCAL IMPORTS ---
# This is necessary for local model loading/explanation as the pipeline expects this class
try:
# Ensure 'custom_transformers.py' is in the root directory for this import to succeed
from custom_transformers import ManualProductTypeMapper
except ImportError as e:
# This st.error is now safe because st.set_page_config has already been called.
st.error(f"Failed to import custom component: {e}. Please ensure 'custom_transformers.py' is in the root directory.")
# Exit app if essential custom component is missing
st.stop()
# -----------------------------
# Define the backend API URL (Using the user's specified URL)
BACKEND_URL = "https://Varun6299-SuperKartPredictor.hf.space/predict"
MODEL_PATH = "final_xgboost_pipeline.joblib"
# --- GLOBAL MODEL & SHAP SETUP (For Local Explanation) ---
@st.cache_resource
def load_model_and_setup_shap():
"""Loads model, extracts components, sets up SHAP, and gets feature names."""
if not os.path.exists(MODEL_PATH):
st.warning(f"Model file not found locally: {MODEL_PATH}. Explanation feature is disabled.")
return None, None, None
try:
pipeline = joblib.load(MODEL_PATH)
# Extract components using the names defined in your training code
# Based on your previous code:
xgb_model = pipeline.named_steps['tune_XGBoost_regressor']
preprocessor = pipeline[:-1] # Everything before the regressor
explainer = shap.TreeExplainer(xgb_model)
# Extract Transformed Feature Names (CRITICAL STEP for plotting)
ct = preprocessor.named_steps['col_transform']
transformed_feature_names = list(ct.get_feature_names_out())
# CLEANUP: Remove prefixes 'ohe_cat__' and 'remainder__' for better plot readability (Request 1)
def clean_feature_name(name):
if name.startswith('ohe_cat__'):
return name.replace('ohe_cat__', '')
if name.startswith('remainder__'):
return name.replace('remainder__', '')
return name
cleaned_feature_names = [clean_feature_name(name) for name in transformed_feature_names]
# Return cleaned names
return pipeline, explainer, cleaned_feature_names
except Exception as e:
st.error(f"Local Model/SHAP Setup Failed. Cannot explain prediction.")
st.code(f"Error Details: {e}")
return None, None, None
# Load all necessary components globally for the "Explain" feature
LOCAL_PIPELINE, LOCAL_EXPLAINER, TRANSFORMED_FEATURE_NAMES = load_model_and_setup_shap()
# ------------------------------------
# --- SHAP UTILITY FUNCTIONS ---
def show_waterfall_plot(explanation):
"""Generates and displays the SHAP Waterfall plot using st.pyplot."""
plt.style.use('default')
fig, ax = plt.subplots(figsize=(10, 6), dpi=100)
shap.plots.waterfall(explanation, show=False)
st.pyplot(fig, bbox_inches='tight')
plt.close(fig)
def get_shap_explanation(pipeline, input_df, explainer, feature_names):
"""Calculates SHAP values for a single raw input."""
# Transform the raw input DataFrame into a NumPy array (shape (1, N))
X_transformed = pipeline[:-1].transform(input_df)
# FIX: Pass the 2D array (shape (1, N)) to shap_values to resolve the reshape error.
shap_values_2d = explainer.shap_values(X_transformed)
# Extract the 1D SHAP values for the single sample (shape (N,))
shap_values = shap_values_2d[0]
# Extract the 1D feature data for the plot labeling (shape (N,))
explanation_data = X_transformed[0, :]
# Create the SHAP Explanation object
explanation = shap.Explanation(
values=shap_values,
base_values=explainer.expected_value,
data=explanation_data,
feature_names=feature_names
)
return explanation
# ----------------- UI SETUP -----------------
# Initialize Streamlit session state for prediction results
if 'last_prediction' not in st.session_state:
st.session_state.last_prediction = None
if 'last_input_df' not in st.session_state:
st.session_state.last_input_df = None
if 'explanation_shown' not in st.session_state:
st.session_state.explanation_shown = False
# Custom CSS/Theme & Intro
st.markdown("""
🛒 SuperKart Sales Forecasting Tool
""", unsafe_allow_html=True)
with st.expander("About SuperKart and this Predictor", expanded=False):
st.markdown("""
**SuperKart** is a major retail chain operating supermarkets and food marts across various tier cities, offering a wide range of products.
To **optimize inventory management** and make informed decisions around regional sales strategies, this tool accurately forecasts the sales revenue of its outlets for the upcoming quarter. This robust predictive model is built using historical sales data to integrate forecasting directly into SuperKart's decision-making systems.
""")
st.info("💡 **Instructions:** Adjust the features below and click **1. Predict Sales** first. Then, click **2. Explain Prediction** to see the feature contributions.")
# ----------------- INPUT FIELDS (using columns for better layout) -----------------
# --- Store Details (Sidebar) ---
st.sidebar.header("Store Configuration 🏢")
current_year = datetime.now().year
store_establishment_year = st.sidebar.slider(
'Store Establishment Year',
min_value=1980,
max_value=current_year,
value=2000,
step=1
)
store_location_city_type = st.sidebar.selectbox(
'Store Location City Type',
['Tier 1', 'Tier 2', 'Tier 3']
)
store_size = st.sidebar.selectbox(
'Store Size',
['Medium', 'Small', 'High']
)
store_type = st.sidebar.selectbox(
'Store Type',
['Supermarket Type1', 'Supermarket Type2', 'Departmental Store', 'Food Mart']
)
# --- Product Details (Main Area - Columns) ---
st.header("Product Specifications 📦")
col1, col2, col3 = st.columns(3)
# Column 1: Numerical Features
with col1:
product_weight = st.number_input(
'Product Weight (kg)',
min_value=0.5,
max_value=30.0,
value=12.0,
step=0.1,
help="Weight of the product in kilograms."
)
product_mrp = st.slider(
'Product MRP ($)',
min_value=10.0,
max_value=300.0,
value=150.0,
step=0.5,
help="Maximum Retail Price of the product."
)
# Column 2: Other Numerical/Categorical
with col2:
product_allocated_area = st.number_input(
'Shelf Area (%)',
min_value=0.0,
max_value=100.0,
value=10.0,
step=0.1,
help="Percentage of store shelf area allocated to this product."
)
product_quantity = st.number_input(
'Stock Quantity',
min_value=0,
value=100,
help="Current stock level for the product."
)
# Column 3: Categorical Features
with col3:
product_sugar_content = st.selectbox(
'Sugar/Fat Content',
['Low Sugar', 'Regular', 'No Sugar', 'reg']
)
product_type = st.selectbox(
'Product Type',
['Frozen Foods', 'Dairy', 'Canned', 'Baking Goods', 'Health and Hygiene', 'Food Mart', 'Snack Foods', 'Meat', 'Household', 'Hard Drinks', 'Fruits and Vegetables', 'Breads', 'Soft Drinks', 'Breakfast', 'Others', 'Starchy Foods', 'Seafood']
)
# --- Data Assembly ---
feature_cols = [
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
'Product_Type', 'Product_Quantity', 'Product_MRP',
'Store_Establishment_Year', 'Store_Size', 'Store_Location_City_Type',
'Store_Type'
]
input_data_dict = {
'Product_Weight': product_weight,
'Product_Sugar_Content': product_sugar_content,
'Product_Allocated_Area': product_allocated_area,
'Product_Type': product_type,
'Product_Quantity': product_quantity,
'Product_MRP': product_mrp,
'Store_Establishment_Year': store_establishment_year,
'Store_Size': store_size,
'Store_Location_City_Type': store_location_city_type,
'Store_Type': store_type
}
input_df = pd.DataFrame([input_data_dict], columns=feature_cols)
# ----------------- BUTTONS AND LOGIC -----------------
st.markdown("---")
# --- 1. PREDICT BUTTON (Calls Remote API) ---
# This button is now on its own row
if st.button('1. Predict Sales Total (Remote)', use_container_width=True, key='predict_btn'):
st.session_state.explanation_shown = False # Reset explanation status
st.session_state.last_prediction = None # Clear previous prediction display while loading
with st.spinner('Calling remote API...'):
try:
# Send raw data as dictionary to the API
response = requests.post(BACKEND_URL, json=input_df.to_dict(orient='records')[0])
response.raise_for_status()
response_data = response.json()
prediction = response_data.get('Predicted Sales (in dollars)')
if prediction is not None:
st.session_state.last_prediction = round(prediction, 2)
st.session_state.last_input_df = input_df # Store input for explanation
st.session_state.explanation_shown = False
st.success("Prediction retrieved successfully.")
else:
st.error("Prediction failed: Backend response is missing the 'Predicted Sales (in dollars)' key.")
st.json(response_data)
except requests.exceptions.RequestException as e:
st.error(f"Connection Error: Could not connect to the backend API at {BACKEND_URL}. Please ensure the backend is running.")
st.code(f"Error Details: {e}")
except Exception as e:
st.error(f"An unexpected error occurred: {e}")
# --- Display Prediction Metric and Explanation Block (Conditional on Prediction) ---
if st.session_state.last_prediction is not None:
st.subheader("✅ Latest Forecast")
st.metric(
label="Forecast for Next Period",
value=f"${st.session_state.last_prediction:,.2f}"
)
# --- 2. EXPLAIN BUTTON (Local SHAP Calculation) - Now placed directly below the metric ---
explain_disabled = LOCAL_PIPELINE is None
# Place the explain button below the metric
if st.button('2. Explain Prediction (Local SHAP)', use_container_width=True, disabled=explain_disabled, key='explain_btn'):
if LOCAL_PIPELINE is None:
st.warning("Local model is unavailable. Check error messages at the top.")
else:
with st.spinner('Calculating local SHAP explanation...'):
try:
# Retrieve the DataFrame from the session state
df_to_explain = st.session_state.last_input_df
# --- SHAP EXPLANATION ---
# 1. Title change for clarity
st.subheader("💡 Breakdown of the Forecast (SHAP Analysis)")
# Get the SHAP Explanation object
shap_explanation = get_shap_explanation(
LOCAL_PIPELINE,
df_to_explain,
LOCAL_EXPLAINER,
TRANSFORMED_FEATURE_NAMES
)
# 2. Base Value notation simplification
st.markdown(f"""
Base Value: ${LOCAL_EXPLAINER.expected_value:,.2f}
(Average sales prediction).
""", unsafe_allow_html=True)
# Display the Waterfall Plot (This plot will appear below the button)
show_waterfall_plot(shap_explanation)
# 3. Plot Interpretation Guide notation simplification
st.markdown("##### Plot Interpretation Guide")
st.markdown(f"""
The Waterfall Plot shows how each feature value contributes to pushing the model's output from the **Base Value** (the average prediction: ${LOCAL_EXPLAINER.expected_value:,.2f}) to the final **Forecast** (${st.session_state.last_prediction:,.2f}).
* **Red Bars:** Feature values that **increase** the prediction (positive contribution).
* **Blue Bars:** Feature values that **decrease** the prediction (negative contribution).
* The length of the bar indicates the magnitude (strength) of the feature's influence on the prediction.
""")
st.session_state.explanation_shown = True
st.success("Explanation generated successfully.")
except Exception as e:
st.error(f"Error during local SHAP calculation.")
st.code(f"Error Details: {e}")
# --- Footer Guidance (Moved inside the conditional block for prediction) ---
if LOCAL_PIPELINE is not None and not st.session_state.explanation_shown:
st.markdown("""
If prediction was successful, click the **Explain Prediction** button to understand the feature contributions.
""", unsafe_allow_html=True)
# The following block is kept for when a prediction hasn't been made yet, but the local pipeline is ready.
elif LOCAL_PIPELINE is not None and st.session_state.last_prediction is None:
st.markdown("""
Click **1. Predict Sales Total (Remote)** to get a forecast first.
""", unsafe_allow_html=True)