import streamlit as st import requests import json import pandas as pd import numpy as np # Added for SHAP/model manipulation import joblib # Added for local model loading import shap # Added for SHAP calculation import matplotlib # --- FIX: Set non-interactive Matplotlib backend for headless environments --- matplotlib.use('Agg') import matplotlib.pyplot as plt # Added for SHAP plotting from datetime import datetime import os import traceback from sklearn.pipeline import Pipeline # Added for unpickling the pipeline structure from sklearn.compose import ColumnTransformer # Added for unpickling the ColumnTransformer # ----------------- APP CONFIGURATION (MUST BE FIRST STREAMLIT COMMAND) ----------------- st.set_page_config( page_title="SuperKart Sales Predictor", page_icon="🛒", layout="wide", # Use wide layout for a professional look initial_sidebar_state="auto" ) # --- CRITICAL LOCAL IMPORTS --- # This is necessary for local model loading/explanation as the pipeline expects this class try: # Ensure 'custom_transformers.py' is in the root directory for this import to succeed from custom_transformers import ManualProductTypeMapper except ImportError as e: # This st.error is now safe because st.set_page_config has already been called. st.error(f"Failed to import custom component: {e}. Please ensure 'custom_transformers.py' is in the root directory.") # Exit app if essential custom component is missing st.stop() # ----------------------------- # Define the backend API URL (Using the user's specified URL) BACKEND_URL = "https://Varun6299-SuperKartPredictor.hf.space/predict" MODEL_PATH = "final_xgboost_pipeline.joblib" # --- GLOBAL MODEL & SHAP SETUP (For Local Explanation) --- @st.cache_resource def load_model_and_setup_shap(): """Loads model, extracts components, sets up SHAP, and gets feature names.""" if not os.path.exists(MODEL_PATH): st.warning(f"Model file not found locally: {MODEL_PATH}. Explanation feature is disabled.") return None, None, None try: pipeline = joblib.load(MODEL_PATH) # Extract components using the names defined in your training code # Based on your previous code: xgb_model = pipeline.named_steps['tune_XGBoost_regressor'] preprocessor = pipeline[:-1] # Everything before the regressor explainer = shap.TreeExplainer(xgb_model) # Extract Transformed Feature Names (CRITICAL STEP for plotting) ct = preprocessor.named_steps['col_transform'] transformed_feature_names = list(ct.get_feature_names_out()) # CLEANUP: Remove prefixes 'ohe_cat__' and 'remainder__' for better plot readability (Request 1) def clean_feature_name(name): if name.startswith('ohe_cat__'): return name.replace('ohe_cat__', '') if name.startswith('remainder__'): return name.replace('remainder__', '') return name cleaned_feature_names = [clean_feature_name(name) for name in transformed_feature_names] # Return cleaned names return pipeline, explainer, cleaned_feature_names except Exception as e: st.error(f"Local Model/SHAP Setup Failed. Cannot explain prediction.") st.code(f"Error Details: {e}") return None, None, None # Load all necessary components globally for the "Explain" feature LOCAL_PIPELINE, LOCAL_EXPLAINER, TRANSFORMED_FEATURE_NAMES = load_model_and_setup_shap() # ------------------------------------ # --- SHAP UTILITY FUNCTIONS --- def show_waterfall_plot(explanation): """Generates and displays the SHAP Waterfall plot using st.pyplot.""" plt.style.use('default') fig, ax = plt.subplots(figsize=(10, 6), dpi=100) shap.plots.waterfall(explanation, show=False) st.pyplot(fig, bbox_inches='tight') plt.close(fig) def get_shap_explanation(pipeline, input_df, explainer, feature_names): """Calculates SHAP values for a single raw input.""" # Transform the raw input DataFrame into a NumPy array (shape (1, N)) X_transformed = pipeline[:-1].transform(input_df) # FIX: Pass the 2D array (shape (1, N)) to shap_values to resolve the reshape error. shap_values_2d = explainer.shap_values(X_transformed) # Extract the 1D SHAP values for the single sample (shape (N,)) shap_values = shap_values_2d[0] # Extract the 1D feature data for the plot labeling (shape (N,)) explanation_data = X_transformed[0, :] # Create the SHAP Explanation object explanation = shap.Explanation( values=shap_values, base_values=explainer.expected_value, data=explanation_data, feature_names=feature_names ) return explanation # ----------------- UI SETUP ----------------- # Initialize Streamlit session state for prediction results if 'last_prediction' not in st.session_state: st.session_state.last_prediction = None if 'last_input_df' not in st.session_state: st.session_state.last_input_df = None if 'explanation_shown' not in st.session_state: st.session_state.explanation_shown = False # Custom CSS/Theme & Intro st.markdown("""

🛒 SuperKart Sales Forecasting Tool


""", unsafe_allow_html=True) with st.expander("About SuperKart and this Predictor", expanded=False): st.markdown(""" **SuperKart** is a major retail chain operating supermarkets and food marts across various tier cities, offering a wide range of products. To **optimize inventory management** and make informed decisions around regional sales strategies, this tool accurately forecasts the sales revenue of its outlets for the upcoming quarter. This robust predictive model is built using historical sales data to integrate forecasting directly into SuperKart's decision-making systems. """) st.info("💡 **Instructions:** Adjust the features below and click **1. Predict Sales** first. Then, click **2. Explain Prediction** to see the feature contributions.") # ----------------- INPUT FIELDS (using columns for better layout) ----------------- # --- Store Details (Sidebar) --- st.sidebar.header("Store Configuration 🏢") current_year = datetime.now().year store_establishment_year = st.sidebar.slider( 'Store Establishment Year', min_value=1980, max_value=current_year, value=2000, step=1 ) store_location_city_type = st.sidebar.selectbox( 'Store Location City Type', ['Tier 1', 'Tier 2', 'Tier 3'] ) store_size = st.sidebar.selectbox( 'Store Size', ['Medium', 'Small', 'High'] ) store_type = st.sidebar.selectbox( 'Store Type', ['Supermarket Type1', 'Supermarket Type2', 'Departmental Store', 'Food Mart'] ) # --- Product Details (Main Area - Columns) --- st.header("Product Specifications 📦") col1, col2, col3 = st.columns(3) # Column 1: Numerical Features with col1: product_weight = st.number_input( 'Product Weight (kg)', min_value=0.5, max_value=30.0, value=12.0, step=0.1, help="Weight of the product in kilograms." ) product_mrp = st.slider( 'Product MRP ($)', min_value=10.0, max_value=300.0, value=150.0, step=0.5, help="Maximum Retail Price of the product." ) # Column 2: Other Numerical/Categorical with col2: product_allocated_area = st.number_input( 'Shelf Area (%)', min_value=0.0, max_value=100.0, value=10.0, step=0.1, help="Percentage of store shelf area allocated to this product." ) product_quantity = st.number_input( 'Stock Quantity', min_value=0, value=100, help="Current stock level for the product." ) # Column 3: Categorical Features with col3: product_sugar_content = st.selectbox( 'Sugar/Fat Content', ['Low Sugar', 'Regular', 'No Sugar', 'reg'] ) product_type = st.selectbox( 'Product Type', ['Frozen Foods', 'Dairy', 'Canned', 'Baking Goods', 'Health and Hygiene', 'Food Mart', 'Snack Foods', 'Meat', 'Household', 'Hard Drinks', 'Fruits and Vegetables', 'Breads', 'Soft Drinks', 'Breakfast', 'Others', 'Starchy Foods', 'Seafood'] ) # --- Data Assembly --- feature_cols = [ 'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area', 'Product_Type', 'Product_Quantity', 'Product_MRP', 'Store_Establishment_Year', 'Store_Size', 'Store_Location_City_Type', 'Store_Type' ] input_data_dict = { 'Product_Weight': product_weight, 'Product_Sugar_Content': product_sugar_content, 'Product_Allocated_Area': product_allocated_area, 'Product_Type': product_type, 'Product_Quantity': product_quantity, 'Product_MRP': product_mrp, 'Store_Establishment_Year': store_establishment_year, 'Store_Size': store_size, 'Store_Location_City_Type': store_location_city_type, 'Store_Type': store_type } input_df = pd.DataFrame([input_data_dict], columns=feature_cols) # ----------------- BUTTONS AND LOGIC ----------------- st.markdown("---") # --- 1. PREDICT BUTTON (Calls Remote API) --- # This button is now on its own row if st.button('1. Predict Sales Total (Remote)', use_container_width=True, key='predict_btn'): st.session_state.explanation_shown = False # Reset explanation status st.session_state.last_prediction = None # Clear previous prediction display while loading with st.spinner('Calling remote API...'): try: # Send raw data as dictionary to the API response = requests.post(BACKEND_URL, json=input_df.to_dict(orient='records')[0]) response.raise_for_status() response_data = response.json() prediction = response_data.get('Predicted Sales (in dollars)') if prediction is not None: st.session_state.last_prediction = round(prediction, 2) st.session_state.last_input_df = input_df # Store input for explanation st.session_state.explanation_shown = False st.success("Prediction retrieved successfully.") else: st.error("Prediction failed: Backend response is missing the 'Predicted Sales (in dollars)' key.") st.json(response_data) except requests.exceptions.RequestException as e: st.error(f"Connection Error: Could not connect to the backend API at {BACKEND_URL}. Please ensure the backend is running.") st.code(f"Error Details: {e}") except Exception as e: st.error(f"An unexpected error occurred: {e}") # --- Display Prediction Metric and Explanation Block (Conditional on Prediction) --- if st.session_state.last_prediction is not None: st.subheader("✅ Latest Forecast") st.metric( label="Forecast for Next Period", value=f"${st.session_state.last_prediction:,.2f}" ) # --- 2. EXPLAIN BUTTON (Local SHAP Calculation) - Now placed directly below the metric --- explain_disabled = LOCAL_PIPELINE is None # Place the explain button below the metric if st.button('2. Explain Prediction (Local SHAP)', use_container_width=True, disabled=explain_disabled, key='explain_btn'): if LOCAL_PIPELINE is None: st.warning("Local model is unavailable. Check error messages at the top.") else: with st.spinner('Calculating local SHAP explanation...'): try: # Retrieve the DataFrame from the session state df_to_explain = st.session_state.last_input_df # --- SHAP EXPLANATION --- # 1. Title change for clarity st.subheader("💡 Breakdown of the Forecast (SHAP Analysis)") # Get the SHAP Explanation object shap_explanation = get_shap_explanation( LOCAL_PIPELINE, df_to_explain, LOCAL_EXPLAINER, TRANSFORMED_FEATURE_NAMES ) # 2. Base Value notation simplification st.markdown(f"""

Base Value: ${LOCAL_EXPLAINER.expected_value:,.2f} (Average sales prediction).

""", unsafe_allow_html=True) # Display the Waterfall Plot (This plot will appear below the button) show_waterfall_plot(shap_explanation) # 3. Plot Interpretation Guide notation simplification st.markdown("##### Plot Interpretation Guide") st.markdown(f""" The Waterfall Plot shows how each feature value contributes to pushing the model's output from the **Base Value** (the average prediction: ${LOCAL_EXPLAINER.expected_value:,.2f}) to the final **Forecast** (${st.session_state.last_prediction:,.2f}). * **Red Bars:** Feature values that **increase** the prediction (positive contribution). * **Blue Bars:** Feature values that **decrease** the prediction (negative contribution). * The length of the bar indicates the magnitude (strength) of the feature's influence on the prediction. """) st.session_state.explanation_shown = True st.success("Explanation generated successfully.") except Exception as e: st.error(f"Error during local SHAP calculation.") st.code(f"Error Details: {e}") # --- Footer Guidance (Moved inside the conditional block for prediction) --- if LOCAL_PIPELINE is not None and not st.session_state.explanation_shown: st.markdown("""
If prediction was successful, click the **Explain Prediction** button to understand the feature contributions.
""", unsafe_allow_html=True) # The following block is kept for when a prediction hasn't been made yet, but the local pipeline is ready. elif LOCAL_PIPELINE is not None and st.session_state.last_prediction is None: st.markdown("""
Click **1. Predict Sales Total (Remote)** to get a forecast first.
""", unsafe_allow_html=True)