Spaces:
Runtime error
Runtime error
Upload app.py with huggingface_hub
Browse files
app.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import joblib
|
| 3 |
+
import json
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from lifetimes import BetaGeoFitter, GammaGammaFitter
|
| 10 |
+
from supervised_explainability import SegmentClassifier
|
| 11 |
+
|
| 12 |
+
# Set dark theme for matplotlib globally
|
| 13 |
+
plt.rcParams.update({
|
| 14 |
+
'grid.color': '#1f2937',
|
| 15 |
+
'text.color': '#f3f4f6',
|
| 16 |
+
'axes.labelcolor': '#9ca3af',
|
| 17 |
+
'axes.edgecolor': '#1f2937',
|
| 18 |
+
'xtick.color': '#9ca3af',
|
| 19 |
+
'ytick.color': '#9ca3af',
|
| 20 |
+
'figure.facecolor': '#070a13',
|
| 21 |
+
'axes.facecolor': '#0d111e',
|
| 22 |
+
})
|
| 23 |
+
|
| 24 |
+
# Load the models, dataset and metadata
|
| 25 |
+
try:
|
| 26 |
+
df_baseline = pd.read_parquet('baseline_training_data.parquet')
|
| 27 |
+
feature_names = json.load(open('feature_names.json'))
|
| 28 |
+
|
| 29 |
+
# Load CLV models
|
| 30 |
+
bgf = BetaGeoFitter(penalizer_coef=0.01)
|
| 31 |
+
bgf.params_ = joblib.load('clv_bgf_params.pkl')
|
| 32 |
+
bgf.predict = bgf.conditional_expected_number_of_purchases_up_to_time
|
| 33 |
+
|
| 34 |
+
ggf = GammaGammaFitter(penalizer_coef=0.01)
|
| 35 |
+
ggf.params_ = joblib.load('clv_ggf_params.pkl')
|
| 36 |
+
|
| 37 |
+
# Load PCA and GMM
|
| 38 |
+
pca_pipeline = joblib.load('pca_pipeline.pkl')
|
| 39 |
+
gmm_model = joblib.load('gmm_model.pkl')
|
| 40 |
+
|
| 41 |
+
# Load XGBoost Segment Classifier
|
| 42 |
+
classifier = SegmentClassifier.load_model('xgb_classifier.pkl')
|
| 43 |
+
|
| 44 |
+
# Calculate dataset averages for default mock enrichment values
|
| 45 |
+
defaults = {
|
| 46 |
+
'CPIAUCSL': df_baseline['CPIAUCSL'].mean() if 'CPIAUCSL' in df_baseline.columns else 218.0,
|
| 47 |
+
'RSXFS': df_baseline['RSXFS'].mean() if 'RSXFS' in df_baseline.columns else 350000.0,
|
| 48 |
+
'PSAVERT': df_baseline['PSAVERT'].mean() if 'PSAVERT' in df_baseline.columns else 6.0,
|
| 49 |
+
}
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Error loading models or dataset: {e}")
|
| 52 |
+
df_baseline = None
|
| 53 |
+
bgf = ggf = pca_pipeline = gmm_model = classifier = None
|
| 54 |
+
defaults = {}
|
| 55 |
+
|
| 56 |
+
# Custom cluster names and details
|
| 57 |
+
CLUSTER_INFO = {
|
| 58 |
+
0: {"name": "VIP / High-Value Loyalists", "desc": "Premium shoppers with high CLV, high repeat transactions, and balanced session activity.", "color": "#a78bfa"},
|
| 59 |
+
1: {"name": "Inactive / Churned Customers", "desc": "Accounts with very high recency and near-zero active probability. Re-engagement needed.", "color": "#f87171"},
|
| 60 |
+
2: {"name": "Instacart Bulk Grocery Buyers", "desc": "Routine buyers with high order counts and low days between orders. Consistent routine shoppers.", "color": "#34d399"},
|
| 61 |
+
3: {"name": "Window Shoppers (High Carts, Low Purchase)", "desc": "High session views and carts, but low purchase counts and high cart abandonment rates.", "color": "#fbbf24"},
|
| 62 |
+
4: {"name": "Weather-Sensitive Impulsive Buyers", "desc": "High weather-precipitation sensitivity flag. Driven by physical rain/precipitation states.", "color": "#f472b6"},
|
| 63 |
+
5: {"name": "Casual One-Time Retail Buyers", "desc": "Low purchase frequency, high order values on single visits, and stable retail shopping profile.", "color": "#818cf8"},
|
| 64 |
+
6: {"name": "Sensible / Standard Spenders", "desc": "Moderate transactional values, highly deliberate shopping paths, and balanced browsing habits.", "color": "#60a5fa"},
|
| 65 |
+
7: {"name": "High-Intent Cart Abandoners", "desc": "High cart and browse history but high abandoned checkout rate. Re-target with promo vouchers.", "color": "#fb7185"},
|
| 66 |
+
8: {"name": "New / Low-Engagement Registrations", "desc": "Recently created profiles with low transaction history, low sessions, and neutral activity.", "color": "#9ca3af"}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
def make_local_shap_plot(importance_df, segment_name):
|
| 70 |
+
"""
|
| 71 |
+
Plots a custom horizontal bar chart of local SHAP feature importances.
|
| 72 |
+
"""
|
| 73 |
+
if importance_df.empty:
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 77 |
+
|
| 78 |
+
# Sort for plotting (largest positive/negative values at the top)
|
| 79 |
+
plot_df = importance_df.head(10).iloc[::-1]
|
| 80 |
+
|
| 81 |
+
colors = ['#f43f5e' if val < 0 else '#8b5cf6' for val in plot_df['importance_value']]
|
| 82 |
+
|
| 83 |
+
bars = ax.barh(plot_df['feature'], plot_df['importance_value'], color=colors, edgecolor='none', height=0.6)
|
| 84 |
+
|
| 85 |
+
# Customize grid and spines
|
| 86 |
+
ax.axvline(x=0, color='#ffffff', alpha=0.2, linestyle='-', linewidth=1)
|
| 87 |
+
ax.grid(True, axis='x', linestyle='--', alpha=0.3)
|
| 88 |
+
ax.spines['top'].set_visible(False)
|
| 89 |
+
ax.spines['right'].set_visible(False)
|
| 90 |
+
ax.spines['left'].set_visible(False)
|
| 91 |
+
ax.spines['bottom'].set_visible(False)
|
| 92 |
+
|
| 93 |
+
# Title and labels
|
| 94 |
+
ax.set_title(f"SHAP Values driving prediction of:\n{segment_name}", fontweight='bold', pad=15, fontsize=12)
|
| 95 |
+
ax.set_xlabel("SHAP Influence (Direction & Strength)", labelpad=10)
|
| 96 |
+
|
| 97 |
+
plt.tight_layout()
|
| 98 |
+
return fig
|
| 99 |
+
|
| 100 |
+
def make_pca_scatter_plot(features_df, predicted_cluster, cluster_color):
|
| 101 |
+
"""
|
| 102 |
+
Reduces the target features to PCA space and overlays it on the baseline population.
|
| 103 |
+
"""
|
| 104 |
+
if df_baseline is None or pca_pipeline is None:
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
# Get baseline PCs
|
| 109 |
+
numeric_cols = df_baseline.select_dtypes(include='number').columns
|
| 110 |
+
id_cols = [c for c in numeric_cols if 'id' in c.lower()]
|
| 111 |
+
X_base = df_baseline[numeric_cols].drop(columns=id_cols + ['Segment_Cluster_ID'], errors='ignore').fillna(0)
|
| 112 |
+
|
| 113 |
+
# Ensure identical column order
|
| 114 |
+
base_features = X_base.columns.tolist()
|
| 115 |
+
features_ordered = features_df[base_features].fillna(0)
|
| 116 |
+
|
| 117 |
+
# Apply fitted scaler + PCA
|
| 118 |
+
base_pca = pca_pipeline.transform(X_base)
|
| 119 |
+
sample_pca = pca_pipeline.transform(features_ordered)
|
| 120 |
+
|
| 121 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 122 |
+
|
| 123 |
+
# Draw background baseline points in subtle dark gray
|
| 124 |
+
ax.scatter(
|
| 125 |
+
base_pca[:, 0],
|
| 126 |
+
base_pca[:, 1],
|
| 127 |
+
color='#ffffff',
|
| 128 |
+
s=12,
|
| 129 |
+
alpha=0.06,
|
| 130 |
+
label='Baseline Customers',
|
| 131 |
+
edgecolor='none'
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Plot the target customer as a glowing star
|
| 135 |
+
ax.scatter(
|
| 136 |
+
sample_pca[0, 0],
|
| 137 |
+
sample_pca[0, 1],
|
| 138 |
+
color=cluster_color,
|
| 139 |
+
marker='*',
|
| 140 |
+
s=450,
|
| 141 |
+
edgecolors='#ffffff',
|
| 142 |
+
linewidths=1.8,
|
| 143 |
+
label='Current Customer',
|
| 144 |
+
zorder=10
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Add glow effect around the star
|
| 148 |
+
ax.scatter(
|
| 149 |
+
sample_pca[0, 0],
|
| 150 |
+
sample_pca[0, 1],
|
| 151 |
+
color=cluster_color,
|
| 152 |
+
marker='o',
|
| 153 |
+
s=1000,
|
| 154 |
+
alpha=0.25,
|
| 155 |
+
edgecolors='none',
|
| 156 |
+
zorder=9
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
ax.set_title("Customer Position in 2D PCA Latent Space", fontweight='bold', pad=15)
|
| 160 |
+
ax.set_xlabel("Principal Component 1 (PC1)")
|
| 161 |
+
ax.set_ylabel("Principal Component 2 (PC2)")
|
| 162 |
+
ax.grid(True, linestyle='--', alpha=0.15)
|
| 163 |
+
ax.legend(loc='upper right', framealpha=0.6, facecolor='#0d111e', edgecolor='#ffffff')
|
| 164 |
+
|
| 165 |
+
plt.tight_layout()
|
| 166 |
+
return fig
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"Error drawing PCA plot: {e}")
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
def process_lookup(customer_id):
|
| 172 |
+
"""
|
| 173 |
+
Looks up a real customer in baseline parquet and returns metrics + predictions + charts.
|
| 174 |
+
"""
|
| 175 |
+
if df_baseline is None or classifier is None:
|
| 176 |
+
return "System error: Models or baseline parquet not loaded.", "", "", "", "", None, None
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
row = df_baseline[df_baseline['customer_id'] == str(customer_id)]
|
| 180 |
+
if row.empty:
|
| 181 |
+
return f"Customer ID '{customer_id}' not found in the baseline dataset.", "", "", "", "", None, None
|
| 182 |
+
|
| 183 |
+
# Extract features for XGBoost prediction
|
| 184 |
+
# XGBoost expects the raw numeric features in exactly the same columns
|
| 185 |
+
X_sample = row[feature_names].copy()
|
| 186 |
+
|
| 187 |
+
# Perform Segment Prediction
|
| 188 |
+
pred_cluster = int(classifier.predict(X_sample)[0])
|
| 189 |
+
cluster_data = CLUSTER_INFO.get(pred_cluster, {"name": f"Segment {pred_cluster}", "desc": "Unknown segment characteristics.", "color": "#9ca3af"})
|
| 190 |
+
|
| 191 |
+
# Extract specific metrics
|
| 192 |
+
rfm_str = f"๐
Recency: {int(row['Recency'].iloc[0])} days | ๐ Frequency: {int(row['Frequency'].iloc[0])} purchases | ๐ฐ Monetary: ${row['MonetaryValue'].iloc[0]:,.2f}"
|
| 193 |
+
|
| 194 |
+
ecom_str = f"๐๏ธ Views: {int(row['total_views'].iloc[0])} | ๐ Carts: {int(row['total_carts'].iloc[0])} | ๐๏ธ Purchases: {int(row['total_purchases'].iloc[0])} | ๐ช Sessions: {int(row['total_sessions'].iloc[0])}"
|
| 195 |
+
|
| 196 |
+
clv_str = f"๐ฏ Expected 12m CLV: ${row['expected_12m_clv'].iloc[0]:,.2f} | ๐ Active Prob: {row['expected_active_probability'].iloc[0]*100:.1f}% | ๐ Group: {row['ab_group'].iloc[0]}"
|
| 197 |
+
|
| 198 |
+
desc_html = f"""
|
| 199 |
+
<div style="padding: 15px; border-left: 5px solid {cluster_data['color']}; background: rgba(255,255,255,0.02); border-radius: 4px;">
|
| 200 |
+
<h3 style="margin: 0; color: {cluster_data['color']}; font-size: 1.3em;">{cluster_data['name']}</h3>
|
| 201 |
+
<p style="margin: 8px 0 0 0; color: #d1d5db; line-height: 1.5;">{cluster_data['desc']}</p>
|
| 202 |
+
</div>
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
# Generate explanations & PCA scatter
|
| 206 |
+
importance_df = classifier.explain_local(X_sample, pred_cluster)
|
| 207 |
+
fig_shap = make_local_shap_plot(importance_df, cluster_data['name'])
|
| 208 |
+
fig_pca = make_pca_scatter_plot(X_sample, pred_cluster, cluster_data['color'])
|
| 209 |
+
|
| 210 |
+
return rfm_str, ecom_str, clv_str, desc_html, fig_shap, fig_pca
|
| 211 |
+
except Exception as e:
|
| 212 |
+
import traceback
|
| 213 |
+
traceback.print_exc()
|
| 214 |
+
return f"Error running lookup inference: {e}", "", "", "", None, None
|
| 215 |
+
|
| 216 |
+
def process_simulation(recency, frequency, monetary, views, carts, purchases, sessions, abandon, instacart_orders, days_between_orders, precip):
|
| 217 |
+
"""
|
| 218 |
+
Computes lifetimes CLV predictions, constructs features, fits XGBoost segment classifier, and generates SHAP + PCA plots.
|
| 219 |
+
"""
|
| 220 |
+
if classifier is None or bgf is None or ggf is None:
|
| 221 |
+
return "System error: Predictors are not loaded.", "", "", "", None, None
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
# Construct raw inputs
|
| 225 |
+
precip_flag = 1 if precip == "Yes" else 0
|
| 226 |
+
abandon_flag = 1 if abandon == "Yes" else 0
|
| 227 |
+
T = recency + 30
|
| 228 |
+
|
| 229 |
+
# Calculate lifetimes CLV predictions
|
| 230 |
+
if frequency == 0:
|
| 231 |
+
active_prob = 0.0
|
| 232 |
+
expected_clv = 0.0
|
| 233 |
+
else:
|
| 234 |
+
freq_s = pd.Series([frequency])
|
| 235 |
+
rec_s = pd.Series([recency])
|
| 236 |
+
T_s = pd.Series([T])
|
| 237 |
+
mon_s = pd.Series([monetary])
|
| 238 |
+
|
| 239 |
+
active_prob_raw = bgf.conditional_probability_alive(freq_s, rec_s, T_s)
|
| 240 |
+
if hasattr(active_prob_raw, "iloc"):
|
| 241 |
+
active_prob = float(active_prob_raw.iloc[0])
|
| 242 |
+
elif hasattr(active_prob_raw, "item"):
|
| 243 |
+
active_prob = float(active_prob_raw.item())
|
| 244 |
+
else:
|
| 245 |
+
active_prob = float(active_prob_raw[0])
|
| 246 |
+
|
| 247 |
+
clv_raw = ggf.customer_lifetime_value(bgf, freq_s, rec_s, T_s, mon_s, time=12, discount_rate=0.01)
|
| 248 |
+
if hasattr(clv_raw, "iloc"):
|
| 249 |
+
expected_clv = float(clv_raw.iloc[0])
|
| 250 |
+
elif hasattr(clv_raw, "item"):
|
| 251 |
+
expected_clv = float(clv_raw.item())
|
| 252 |
+
else:
|
| 253 |
+
expected_clv = float(clv_raw[0])
|
| 254 |
+
|
| 255 |
+
# Build features dict using defaults for macro factors
|
| 256 |
+
input_data = {
|
| 257 |
+
'Recency': recency,
|
| 258 |
+
'Frequency': frequency,
|
| 259 |
+
'MonetaryValue': monetary,
|
| 260 |
+
'total_views': views,
|
| 261 |
+
'total_carts': carts,
|
| 262 |
+
'total_purchases': purchases,
|
| 263 |
+
'total_ecommerce_spend': purchases * monetary,
|
| 264 |
+
'total_sessions': sessions,
|
| 265 |
+
'checkout_abandonment_flag': abandon_flag,
|
| 266 |
+
'instacart_total_orders': instacart_orders,
|
| 267 |
+
'avg_days_between_orders': days_between_orders,
|
| 268 |
+
'purchased_during_precipitation': precip_flag,
|
| 269 |
+
'T': T,
|
| 270 |
+
'expected_active_probability': active_prob,
|
| 271 |
+
'expected_12m_clv': expected_clv,
|
| 272 |
+
'CPIAUCSL': defaults.get('CPIAUCSL', 218.0),
|
| 273 |
+
'RSXFS': defaults.get('RSXFS', 350000.0),
|
| 274 |
+
'PSAVERT': defaults.get('PSAVERT', 6.0)
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
X_sample = pd.DataFrame([input_data])
|
| 278 |
+
X_sample = X_sample[feature_names] # Align columns perfectly
|
| 279 |
+
|
| 280 |
+
# Predict Segment
|
| 281 |
+
pred_cluster = int(classifier.predict(X_sample)[0])
|
| 282 |
+
cluster_data = CLUSTER_INFO.get(pred_cluster, {"name": f"Segment {pred_cluster}", "desc": "Unknown segment characteristics.", "color": "#9ca3af"})
|
| 283 |
+
|
| 284 |
+
# Outputs
|
| 285 |
+
rfm_str = f"๐
Recency: {int(recency)} days | ๐ Frequency: {int(frequency)} purchases | ๐ฐ Monetary: ${monetary:,.2f}"
|
| 286 |
+
ecom_str = f"๐๏ธ Views: {int(views)} | ๐ Carts: {int(carts)} | ๐๏ธ Purchases: {int(purchases)} | ๐ช Sessions: {int(sessions)}"
|
| 287 |
+
clv_str = f"๐ฏ Simulated 12m CLV: ${expected_clv:,.2f} | ๐ Active Prob: {active_prob*100:.1f}% | ๐ Group: Simulation"
|
| 288 |
+
|
| 289 |
+
desc_html = f"""
|
| 290 |
+
<div style="padding: 15px; border-left: 5px solid {cluster_data['color']}; background: rgba(255,255,255,0.02); border-radius: 4px;">
|
| 291 |
+
<h3 style="margin: 0; color: {cluster_data['color']}; font-size: 1.3em;">{cluster_data['name']}</h3>
|
| 292 |
+
<p style="margin: 8px 0 0 0; color: #d1d5db; line-height: 1.5;">{cluster_data['desc']}</p>
|
| 293 |
+
</div>
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
# Plots
|
| 297 |
+
importance_df = classifier.explain_local(X_sample, pred_cluster)
|
| 298 |
+
fig_shap = make_local_shap_plot(importance_df, cluster_data['name'])
|
| 299 |
+
fig_pca = make_pca_scatter_plot(X_sample, pred_cluster, cluster_data['color'])
|
| 300 |
+
|
| 301 |
+
return rfm_str, ecom_str, clv_str, desc_html, fig_shap, fig_pca
|
| 302 |
+
except Exception as e:
|
| 303 |
+
import traceback
|
| 304 |
+
traceback.print_exc()
|
| 305 |
+
return f"Error running simulation: {e}", "", "", "", None, None
|
| 306 |
+
|
| 307 |
+
# Get a sample of real Customer IDs for the dropdown selection
|
| 308 |
+
if df_baseline is not None:
|
| 309 |
+
sample_ids = df_baseline['customer_id'].head(100).tolist()
|
| 310 |
+
else:
|
| 311 |
+
sample_ids = []
|
| 312 |
+
|
| 313 |
+
# Custom CSS for gorgeous aesthetics and dark-mode styling
|
| 314 |
+
custom_css = """
|
| 315 |
+
@import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700&display=swap');
|
| 316 |
+
|
| 317 |
+
body, .gradio-container {
|
| 318 |
+
font-family: 'Outfit', sans-serif !important;
|
| 319 |
+
background: radial-gradient(circle at 10% 20%, hsla(230, 45%, 11%, 1) 0%, hsla(240, 60%, 4%, 1) 90.1%) !important;
|
| 320 |
+
color: #f3f4f6 !important;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
.prose h1 {
|
| 324 |
+
font-weight: 700 !important;
|
| 325 |
+
background: linear-gradient(135deg, #c084fc, #6366f1) !important;
|
| 326 |
+
-webkit-background-clip: text !important;
|
| 327 |
+
-webkit-text-fill-color: transparent !important;
|
| 328 |
+
text-shadow: 0 4px 20px rgba(99, 102, 241, 0.15) !important;
|
| 329 |
+
text-align: center !important;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
.gradio-container button.primary {
|
| 333 |
+
background: linear-gradient(135deg, #8b5cf6, #4f46e5) !important;
|
| 334 |
+
border: none !important;
|
| 335 |
+
border-radius: 8px !important;
|
| 336 |
+
font-weight: 600 !important;
|
| 337 |
+
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1) !important;
|
| 338 |
+
box-shadow: 0 4px 15px rgba(99, 102, 241, 0.3) !important;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
.gradio-container button.primary:hover {
|
| 342 |
+
transform: translateY(-2px) !important;
|
| 343 |
+
box-shadow: 0 8px 25px rgba(99, 102, 241, 0.5) !important;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
/* Custom cards for stats */
|
| 347 |
+
.stats-box {
|
| 348 |
+
background: rgba(255, 255, 255, 0.02) !important;
|
| 349 |
+
border: 1px solid rgba(255, 255, 255, 0.06) !important;
|
| 350 |
+
border-radius: 12px !important;
|
| 351 |
+
padding: 15px !important;
|
| 352 |
+
transition: all 0.3s ease !important;
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
.stats-box:hover {
|
| 356 |
+
border-color: rgba(167, 139, 250, 0.25) !important;
|
| 357 |
+
background: rgba(255, 255, 255, 0.03) !important;
|
| 358 |
+
}
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
with gr.Blocks() as demo:
|
| 362 |
+
gr.HTML("<h1 style='font-size: 2.8em; margin-bottom: 5px; text-align: center;'>๐๏ธ Real-Time Customer Segmentation Dashboard</h1>")
|
| 363 |
+
gr.HTML("<p style='text-align: center; color: #9ca3af; font-size: 1.1em; margin-bottom: 30px;'>Powered by PCA, BIC-Gaussian Mixture Soft Clustering, and Random Forest Classifiers.</p>")
|
| 364 |
+
|
| 365 |
+
with gr.Tabs():
|
| 366 |
+
# TAB 1: Real Customer Batch Explorer
|
| 367 |
+
with gr.Tab("๐ Customer Batch Explorer & Lookup"):
|
| 368 |
+
gr.HTML("<p style='color: #a78bfa; font-weight: 500; margin-bottom: 15px;'>Select an active Customer ID from the real ingested cohort (~5.5M customers baseline) to view behavioral metrics and predictions.</p>")
|
| 369 |
+
|
| 370 |
+
with gr.Row():
|
| 371 |
+
with gr.Column(scale=1):
|
| 372 |
+
id_dropdown = gr.Dropdown(
|
| 373 |
+
choices=sample_ids,
|
| 374 |
+
value=sample_ids[0] if sample_ids else None,
|
| 375 |
+
label="Select Customer ID",
|
| 376 |
+
interactive=True
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
gr.HTML("<div style='margin-top: 20px;'><h4 style='color: #9ca3af; margin-bottom: 10px;'>๐ Predictor Outputs</h4></div>")
|
| 380 |
+
lookup_rfm = gr.Textbox(label="Transactional RFM Profile", interactive=False)
|
| 381 |
+
lookup_ecom = gr.Textbox(label="eCommerce Activity Metrics", interactive=False)
|
| 382 |
+
lookup_clv = gr.Textbox(label="CLV & A/B Testing Routing", interactive=False)
|
| 383 |
+
|
| 384 |
+
gr.HTML("<div style='margin-top: 20px;'><h4 style='color: #9ca3af; margin-bottom: 10px;'>๐ท๏ธ Segment Classification</h4></div>")
|
| 385 |
+
segment_desc = gr.HTML(label="Predicted Segment Description")
|
| 386 |
+
|
| 387 |
+
lookup_btn = gr.Button("Query Customer", variant="primary")
|
| 388 |
+
|
| 389 |
+
with gr.Column(scale=2):
|
| 390 |
+
with gr.Row():
|
| 391 |
+
plot_shap = gr.Plot(label="Local Feature Contribution (Random Forest Gini)")
|
| 392 |
+
plot_pca = gr.Plot(label="Placement in PCA Latent Space")
|
| 393 |
+
|
| 394 |
+
lookup_btn.click(
|
| 395 |
+
fn=process_lookup,
|
| 396 |
+
inputs=[id_dropdown],
|
| 397 |
+
outputs=[lookup_rfm, lookup_ecom, lookup_clv, segment_desc, plot_shap, plot_pca]
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# TAB 2: Interactive Simulator
|
| 401 |
+
with gr.Tab("๐งช Segment Simulator & Scenario Sandbox"):
|
| 402 |
+
gr.HTML("<p style='color: #a78bfa; font-weight: 500; margin-bottom: 15px;'>Simulate a custom consumer's behaviour by moving sliders across Transactional, eCommerce, and routine grocery indicators.</p>")
|
| 403 |
+
|
| 404 |
+
with gr.Row():
|
| 405 |
+
with gr.Column(scale=1, variant="panel"):
|
| 406 |
+
gr.HTML("<h3 style='color: #c084fc; margin-top: 0;'>1. Transactional Metrics</h3>")
|
| 407 |
+
sim_recency = gr.Slider(minimum=1, maximum=365, value=45, label="Recency (Days since last purchase)")
|
| 408 |
+
sim_frequency = gr.Slider(minimum=0, maximum=50, value=5, label="Frequency (Number of repeat transactions)")
|
| 409 |
+
sim_monetary = gr.Slider(minimum=5, maximum=1000, value=75, label="Monetary Value ($ average per order)")
|
| 410 |
+
|
| 411 |
+
gr.HTML("<h3 style='color: #c084fc; margin-top: 15px;'>2. eCommerce Activity</h3>")
|
| 412 |
+
sim_views = gr.Slider(minimum=0, maximum=1000, value=120, label="Total Session Views")
|
| 413 |
+
sim_carts = gr.Slider(minimum=0, maximum=100, value=12, label="Total Items Added to Cart")
|
| 414 |
+
sim_purchases = gr.Slider(minimum=0, maximum=50, value=6, label="Total Completed eCommerce Purchases")
|
| 415 |
+
sim_sessions = gr.Slider(minimum=1, maximum=200, value=15, label="Total Session Volume")
|
| 416 |
+
sim_abandon = gr.Dropdown(choices=["No", "Yes"], value="No", label="Abandoned Checkout Flag")
|
| 417 |
+
|
| 418 |
+
gr.HTML("<h3 style='color: #c084fc; margin-top: 15px;'>3. Instacart / Grocery Habits</h3>")
|
| 419 |
+
sim_instacart = gr.Slider(minimum=0, maximum=100, value=10, label="Instacart Total Order Volume")
|
| 420 |
+
sim_days = gr.Slider(minimum=0, maximum=30, value=7, label="Average Days Between Grocery Orders")
|
| 421 |
+
|
| 422 |
+
gr.HTML("<h3 style='color: #c084fc; margin-top: 15px;'>4. Environmental Context</h3>")
|
| 423 |
+
sim_precip = gr.Dropdown(choices=["No", "Yes"], value="No", label="Purchased During Precipitation (Rain/Snow)")
|
| 424 |
+
|
| 425 |
+
sim_btn = gr.Button("Simulate Customer Segment", variant="primary")
|
| 426 |
+
|
| 427 |
+
with gr.Column(scale=2):
|
| 428 |
+
gr.HTML("<h3 style='color: #a78bfa; margin-top: 0;'>๐ฎ Simulation Predictions</h3>")
|
| 429 |
+
sim_out_rfm = gr.Textbox(label="Transactional RFM Profile", interactive=False)
|
| 430 |
+
sim_out_ecom = gr.Textbox(label="eCommerce Activity Metrics", interactive=False)
|
| 431 |
+
sim_out_clv = gr.Textbox(label="CLV & A/B Testing Routing", interactive=False)
|
| 432 |
+
|
| 433 |
+
sim_out_desc = gr.HTML(label="Predicted Segment Description")
|
| 434 |
+
|
| 435 |
+
with gr.Row():
|
| 436 |
+
sim_plot_shap = gr.Plot(label="Local Feature Contribution (Random Forest Gini)")
|
| 437 |
+
with gr.Row():
|
| 438 |
+
sim_plot_pca = gr.Plot(label="Placement in PCA Latent Space")
|
| 439 |
+
|
| 440 |
+
sim_btn.click(
|
| 441 |
+
fn=process_simulation,
|
| 442 |
+
inputs=[
|
| 443 |
+
sim_recency, sim_frequency, sim_monetary,
|
| 444 |
+
sim_views, sim_carts, sim_purchases, sim_sessions, sim_abandon,
|
| 445 |
+
sim_instacart, sim_days, sim_precip
|
| 446 |
+
],
|
| 447 |
+
outputs=[sim_out_rfm, sim_out_ecom, sim_out_clv, sim_out_desc, sim_plot_shap, sim_plot_pca]
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# TAB 3: About & Metrics
|
| 451 |
+
with gr.Tab("โน๏ธ About & Metrics Explanation"):
|
| 452 |
+
gr.Markdown("""
|
| 453 |
+
### ๐ Project Overview
|
| 454 |
+
This application predicts Customer Segments and their Lifetime Value (CLV) using several machine learning and statistical models on real-world grocery, eCommerce, and mall purchase data.
|
| 455 |
+
|
| 456 |
+
We use **PCA** for Dimensionality Reduction, **Gaussian Mixture Models (GMM)** with BIC optimization for soft clustering, **BG/NBD** and **Gamma-Gamma** models for CLV, and **Random Forest** for segment prediction and feature importance explainability.
|
| 457 |
+
|
| 458 |
+
---
|
| 459 |
+
### ๐ Explanation of Metrics
|
| 460 |
+
|
| 461 |
+
#### 1. Transactional RFM Metrics
|
| 462 |
+
* **Recency:** The number of days since the customer's last purchase.
|
| 463 |
+
* **Frequency:** The total number of repeat transactions made by the customer.
|
| 464 |
+
* **Monetary Value:** The average amount spent per order.
|
| 465 |
+
|
| 466 |
+
#### 2. eCommerce Activity
|
| 467 |
+
* **Views:** Total number of product pages visited.
|
| 468 |
+
* **Carts:** Total number of items added to the shopping cart.
|
| 469 |
+
* **Purchases:** Successful eCommerce transactions completed.
|
| 470 |
+
* **Sessions:** Number of independent browsing sessions.
|
| 471 |
+
* **Abandoned Checkout Flag:** Whether the user has an abnormally high ratio of adding items to the cart without completing the purchase.
|
| 472 |
+
|
| 473 |
+
#### 3. Instacart / Grocery Habits
|
| 474 |
+
* **Total Orders:** Number of grocery orders placed on Instacart.
|
| 475 |
+
* **Days Between Orders:** The average gap (in days) between successive grocery runs.
|
| 476 |
+
|
| 477 |
+
#### 4. Macro-Economic & Environmental Context
|
| 478 |
+
* **Precipitation Flag:** Did the customer make purchases specifically during rain or snow? (Sourced from Open-Meteo API).
|
| 479 |
+
* **CPIAUCSL (Inflation):** Consumer Price Index for All Urban Consumers (Sourced from FRED API).
|
| 480 |
+
* **RSXFS (Retail Sales):** Advance Retail Sales: Retail Trade (Sourced from FRED API).
|
| 481 |
+
* **PSAVERT (Savings Rate):** Personal Saving Rate (Sourced from FRED API).
|
| 482 |
+
|
| 483 |
+
#### 5. Output Predictions
|
| 484 |
+
* **Expected 12m CLV:** The predicted monetary value the customer will generate for the business over the next 12 months.
|
| 485 |
+
* **Active Probability:** The mathematical likelihood that the customer hasn't permanently churned given their recency and frequency trends (from the BG/NBD model).
|
| 486 |
+
* **A/B Testing Routing:** Shows whether this user should be in a control or variant group for targeted marketing campaigns based on their risk profile.
|
| 487 |
+
""")
|
| 488 |
+
|
| 489 |
+
# Load initial default values on app launch
|
| 490 |
+
if sample_ids:
|
| 491 |
+
demo.load(
|
| 492 |
+
fn=process_lookup,
|
| 493 |
+
inputs=[id_dropdown],
|
| 494 |
+
outputs=[lookup_rfm, lookup_ecom, lookup_clv, segment_desc, plot_shap, plot_pca]
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
if __name__ == "__main__":
|
| 498 |
+
demo.launch(theme=gr.themes.Soft(), css=custom_css)
|