Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import joblib | |
| import shap | |
| import matplotlib.pyplot as plt | |
| # Load your trained pipeline and label encoder | |
| pipeline = joblib.load('xgb_pipeline.joblib') # your Pipeline object | |
| le = joblib.load('label_encoder.joblib') # your LabelEncoder | |
| st.title('Cheese Milk Type Classifier with SHAP Explanation') | |
| # ===== User Input ===== | |
| country = st.selectbox('Country', [ | |
| 'France', 'Italy', 'Switzerland', 'United States', 'Germany', | |
| 'Spain', 'England', 'Australia', 'Canada', 'Greece' | |
| ]) | |
| texture = st.selectbox('Texture', [ | |
| 'creamy', 'firm', 'crumbly', 'smooth', 'hard', 'soft', 'dense', | |
| 'elastic', 'dry', 'buttery', 'mild', 'open' | |
| ]) | |
| rind = st.selectbox('Rind', [ | |
| 'natural', 'washed', 'none', 'ash coated', 'leaf wrapped', 'Unknown' | |
| ]) | |
| vegetarian = st.selectbox('Vegetarian Status', ['True', 'False']) | |
| flavor = st.selectbox('Flavor', [ | |
| 'nutty', 'mild', 'sweet', 'earthy', 'fruity', 'acidic', 'buttery', | |
| 'smokey', 'full-flavored', 'salty', 'creamy', 'sharp', 'bitter', 'Unknown' | |
| ]) | |
| input_df = pd.DataFrame({ | |
| 'country': [country], | |
| 'texture': [texture], | |
| 'rind': [rind], | |
| 'vegetarian': [vegetarian], | |
| 'flavor': [flavor] | |
| }) | |
| # ===== Prediction and SHAP ===== | |
| if st.button('Predict', key='predict_button'): | |
| # 1️⃣ Make prediction | |
| pred_encoded = pipeline.predict(input_df)[0] | |
| pred_label = le.inverse_transform([pred_encoded])[0] | |
| st.write(f'Predicted Milk Type: **{pred_label}**') | |
| # 2️⃣ SHAP Explanation | |
| # Extract XGBoost model from pipeline | |
| model = pipeline.named_steps['clf'] # 'clf' is your XGBClassifier step | |
| # Transform input (excluding the model) | |
| input_transformed = pipeline[:-1].transform(input_df) | |
| # Create SHAP explainer | |
| explainer = shap.Explainer(model) | |
| shap_values = explainer(input_transformed) | |
| # Plot SHAP values (bar plot for top features) | |
| st.subheader('Feature Contribution (SHAP Values)') | |
| fig, ax = plt.subplots() | |
| shap.plots.bar(shap_values[0], max_display=10, show=False) # top 10 features | |
| st.pyplot(fig) | |