customer / app.py
entropy25's picture
Update app.py
e26889f verified
# -*- coding: utf-8 -*-
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, roc_auc_score, precision_recall_curve
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
from datetime import datetime, timedelta
import io
import base64
import warnings
from typing import Optional, Tuple, Dict, Any
warnings.filterwarnings('ignore')
# Try importing optional dependencies
try:
import xgboost as xgb
XGBOOST_AVAILABLE = True
except ImportError:
XGBOOST_AVAILABLE = False
try:
from reportlab.lib.pagesizes import letter, A4
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import inch
from reportlab.lib import colors
REPORTLAB_AVAILABLE = True
except ImportError:
REPORTLAB_AVAILABLE = False
# Business configuration
BUSINESS_CONFIG = {
'churn_threshold_days': 90,
'high_risk_probability': 0.7,
'rfm_quantiles': 5,
'min_customers_for_model': 10
}
# UI color scheme
COLORS = {
'primary': '#6366f1',
'success': '#10b981',
'warning': '#f59e0b',
'danger': '#ef4444',
'purple': '#8b5cf6',
'pink': '#ec4899',
'blue': '#3b82f6',
'indigo': '#6366f1'
}
class DataProcessor:
"""Handles data loading, validation, and preprocessing"""
@staticmethod
def load_and_validate(file) -> Tuple[Optional[pd.DataFrame], str]:
"""Load and validate CSV file"""
if file is None:
return None, "Please upload a CSV file"
try:
df = pd.read_csv(file.name)
# Flexible column mapping
column_mapping = DataProcessor._map_columns(df.columns)
if not column_mapping:
return None, f"Required columns not found. Available: {list(df.columns)}"
df = df.rename(columns=column_mapping)
# Clean and validate data
initial_rows = len(df)
df = DataProcessor._clean_data(df)
final_rows = len(df)
if final_rows == 0:
return None, "No valid data after cleaning"
status = f"Data loaded successfully! {final_rows} records from {df['customer_id'].nunique()} customers"
if initial_rows != final_rows:
status += f" ({initial_rows - final_rows} invalid rows removed)"
return df, status
except Exception as e:
return None, f"Error loading data: {str(e)}"
@staticmethod
def _map_columns(columns) -> Dict[str, str]:
"""Map CSV columns to standard names"""
required = ['customer_id', 'order_date', 'amount']
mapping = {}
column_variations = {
'customer_id': ['customer', 'cust_id', 'id', 'customerid', 'client_id', 'customer_id'],
'order_date': ['date', 'order_date', 'orderdate', 'purchase_date', 'transaction_date'],
'amount': ['revenue', 'value', 'price', 'total', 'sales', 'order_value', 'amount']
}
for req_col in required:
found = False
for col in columns:
col_lower = col.lower().strip()
if col_lower == req_col or any(var in col_lower for var in column_variations[req_col]):
mapping[col] = req_col
found = True
break
if not found:
return {}
return mapping
@staticmethod
def _clean_data(df: pd.DataFrame) -> pd.DataFrame:
"""Clean and prepare data"""
df = df.copy()
df['customer_id'] = df['customer_id'].astype(str)
df['order_date'] = pd.to_datetime(df['order_date'], errors='coerce')
df['amount'] = pd.to_numeric(df['amount'], errors='coerce')
# Remove invalid rows
df = df.dropna(subset=['customer_id', 'order_date', 'amount'])
df = df[df['amount'] > 0] # Remove negative amounts
return df
class RFMAnalyzer:
"""Handles RFM analysis and customer metrics calculation"""
@staticmethod
def calculate_rfm_metrics(df: pd.DataFrame) -> pd.DataFrame:
"""Calculate RFM metrics for customers"""
current_date = df['order_date'].max() + timedelta(days=1)
customer_metrics = df.groupby('customer_id').agg({
'order_date': ['max', 'count', 'min'],
'amount': ['sum', 'mean', 'std', 'min', 'max']
})
# Flatten column names
customer_metrics.columns = [
'last_order_date', 'frequency', 'first_order_date',
'monetary', 'avg_order_value', 'std_amount', 'min_amount', 'max_amount'
]
# Calculate additional features
customer_metrics['recency_days'] = (current_date - customer_metrics['last_order_date']).dt.days
customer_metrics['customer_lifetime_days'] = (
customer_metrics['last_order_date'] - customer_metrics['first_order_date']
).dt.days
customer_metrics['std_amount'] = customer_metrics['std_amount'].fillna(0)
customer_metrics['customer_lifetime_days'] = customer_metrics['customer_lifetime_days'].fillna(0)
return customer_metrics.reset_index()
class CustomerSegmenter:
"""Handles customer segmentation based on RFM analysis"""
@staticmethod
def perform_segmentation(customer_metrics: pd.DataFrame) -> pd.DataFrame:
"""Segment customers using RFM scores"""
df = customer_metrics.copy()
# Calculate RFM scores
if len(df) >= BUSINESS_CONFIG['rfm_quantiles']:
try:
df['R_Score'] = pd.qcut(df['recency_days'], BUSINESS_CONFIG['rfm_quantiles'],
labels=[5,4,3,2,1], duplicates='drop')
df['F_Score'] = pd.qcut(df['frequency'], BUSINESS_CONFIG['rfm_quantiles'],
labels=[1,2,3,4,5], duplicates='drop')
df['M_Score'] = pd.qcut(df['monetary'], BUSINESS_CONFIG['rfm_quantiles'],
labels=[1,2,3,4,5], duplicates='drop')
except ValueError:
# Fallback for small datasets
df['R_Score'] = pd.cut(df['recency_days'], bins=BUSINESS_CONFIG['rfm_quantiles'],
labels=[5,4,3,2,1], include_lowest=True)
df['F_Score'] = pd.cut(df['frequency'], bins=BUSINESS_CONFIG['rfm_quantiles'],
labels=[1,2,3,4,5], include_lowest=True)
df['M_Score'] = pd.cut(df['monetary'], bins=BUSINESS_CONFIG['rfm_quantiles'],
labels=[1,2,3,4,5], include_lowest=True)
else:
df['R_Score'] = 3
df['F_Score'] = 3
df['M_Score'] = 3
# Convert to numeric and handle NaN
for col in ['R_Score', 'F_Score', 'M_Score']:
df[col] = pd.to_numeric(df[col], errors='coerce').fillna(3).astype(int)
# Apply segmentation logic
df['Segment'] = df.apply(CustomerSegmenter._assign_segment, axis=1)
df['Churn_Risk'] = df.apply(CustomerSegmenter._assign_risk_level, axis=1)
return df
@staticmethod
def _assign_segment(row) -> str:
"""Assign customer segment based on RFM scores"""
r, f, m = row['R_Score'], row['F_Score'], row['M_Score']
if r >= 4 and f >= 4 and m >= 4:
return 'Champions'
elif r >= 3 and f >= 3 and m >= 3:
return 'Loyal Customers'
elif r >= 3 and f >= 2:
return 'Potential Loyalists'
elif r >= 4 and f <= 2:
return 'New Customers'
elif r <= 2 and f >= 3:
return 'At Risk'
elif r <= 2 and f <= 2 and m >= 3:
return 'Cannot Lose Them'
elif r <= 2 and f <= 2 and m <= 2:
return 'Lost Customers'
else:
return 'Others'
@staticmethod
def _assign_risk_level(row) -> str:
"""Assign churn risk level"""
segment = CustomerSegmenter._assign_segment(row)
if segment in ['Lost Customers', 'At Risk']:
return 'High'
elif segment in ['Others', 'Cannot Lose Them']:
return 'Medium'
else:
return 'Low'
class ChurnPredictor:
"""Handles churn prediction model training and inference"""
def __init__(self):
self.model = None
self.feature_importance = None
self.model_metrics = {}
def train_model(self, customer_metrics: pd.DataFrame) -> Tuple[bool, str, Dict]:
"""Train churn prediction model"""
if len(customer_metrics) < BUSINESS_CONFIG['min_customers_for_model']:
return False, f"Insufficient data for training (minimum {BUSINESS_CONFIG['min_customers_for_model']} customers required)", {}
# Prepare features
feature_cols = [
'recency_days', 'frequency', 'monetary', 'avg_order_value',
'std_amount', 'min_amount', 'max_amount', 'customer_lifetime_days'
]
X = customer_metrics[feature_cols]
y = (customer_metrics['recency_days'] > BUSINESS_CONFIG['churn_threshold_days']).astype(int)
# Check for sufficient class diversity
if y.nunique() < 2:
return False, "Cannot train model: all customers have the same churn status", {}
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Select and train model
if XGBOOST_AVAILABLE:
try:
self.model = xgb.XGBClassifier(random_state=42, eval_metric='logloss')
model_name = "XGBoost Classifier"
except:
self.model = RandomForestClassifier(random_state=42, n_estimators=100)
model_name = "Random Forest Classifier"
else:
self.model = RandomForestClassifier(random_state=42, n_estimators=100)
model_name = "Random Forest Classifier"
self.model.fit(X_train, y_train)
# Evaluate model
y_pred = self.model.predict(X_test)
y_pred_proba = self.model.predict_proba(X_test)[:, 1]
accuracy = accuracy_score(y_test, y_pred)
auc_score = roc_auc_score(y_test, y_pred_proba)
# Cross-validation
cv_scores = cross_val_score(self.model, X, y, cv=5, scoring='roc_auc')
# Feature importance
self.feature_importance = pd.DataFrame({
'feature': feature_cols,
'importance': self.model.feature_importances_
}).sort_values('importance', ascending=False)
self.model_metrics = {
'accuracy': accuracy,
'auc_score': auc_score,
'cv_mean': cv_scores.mean(),
'cv_std': cv_scores.std(),
'model_name': model_name,
'n_features': len(feature_cols),
'n_samples': len(X_train)
}
return True, "Model trained successfully", self.model_metrics
def predict(self, customer_metrics: pd.DataFrame) -> pd.DataFrame:
"""Make churn predictions"""
if self.model is None:
return customer_metrics
feature_cols = [
'recency_days', 'frequency', 'monetary', 'avg_order_value',
'std_amount', 'min_amount', 'max_amount', 'customer_lifetime_days'
]
X = customer_metrics[feature_cols]
predictions = self.model.predict_proba(X)[:, 1]
result = customer_metrics.copy()
result['churn_probability'] = predictions
result['predicted_churn'] = (predictions > BUSINESS_CONFIG['high_risk_probability']).astype(int)
return result
class VisualizationEngine:
"""Handles all chart creation and visualization"""
@staticmethod
def create_segment_chart(customer_data: pd.DataFrame):
"""Create customer segment distribution chart"""
segment_counts = customer_data['Segment'].value_counts().reset_index()
segment_counts.columns = ['Segment', 'Count']
fig = px.pie(
segment_counts,
values='Count',
names='Segment',
title='Customer Segment Distribution',
hole=0.4,
color_discrete_sequence=list(COLORS.values())
)
fig.update_traces(textposition='inside', textinfo='percent+label')
fig.update_layout(height=400, title={'x': 0.5, 'xanchor': 'center'})
return fig
@staticmethod
def create_rfm_scatter(customer_data: pd.DataFrame):
"""Create RFM analysis scatter plot"""
fig = px.scatter(
customer_data,
x='recency_days',
y='frequency',
size='monetary',
color='Segment',
title='RFM Customer Behavior Matrix',
labels={
'recency_days': 'Days Since Last Purchase',
'frequency': 'Purchase Frequency',
'monetary': 'Total Revenue'
},
color_discrete_sequence=list(COLORS.values())
)
fig.update_layout(height=400, title={'x': 0.5, 'xanchor': 'center'})
return fig
@staticmethod
def create_churn_chart(customer_data: pd.DataFrame, has_predictions: bool = False):
"""Create churn risk visualization"""
if has_predictions and 'churn_probability' in customer_data.columns:
fig = px.histogram(
customer_data,
x='churn_probability',
nbins=20,
title='Churn Probability Distribution',
labels={'churn_probability': 'Churn Probability', 'count': 'Number of Customers'},
color_discrete_sequence=[COLORS['primary']]
)
fig.add_vline(x=BUSINESS_CONFIG['high_risk_probability'], line_dash="dash",
line_color=COLORS['danger'], annotation_text="High Risk Threshold")
else:
risk_counts = customer_data['Churn_Risk'].value_counts().reset_index()
risk_counts.columns = ['Risk_Level', 'Count']
colors_map = {'High': COLORS['danger'], 'Medium': COLORS['warning'], 'Low': COLORS['success']}
fig = px.bar(
risk_counts,
x='Risk_Level',
y='Count',
title='Customer Churn Risk Distribution',
color='Risk_Level',
color_discrete_map=colors_map
)
fig.update_layout(showlegend=False)
fig.update_layout(height=400, title={'x': 0.5, 'xanchor': 'center'})
return fig
@staticmethod
def create_revenue_trend(df: pd.DataFrame, time_granularity='month', customer_filter='all'):
"""Create revenue trend visualization with filters"""
df_copy = df.copy()
# Filter by customer if specified
if customer_filter != 'all' and customer_filter:
df_copy = df_copy[df_copy['customer_id'] == customer_filter]
if df_copy.empty:
# Return empty chart with message
fig = go.Figure()
fig.add_annotation(text=f"No data found for customer {customer_filter}",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
fig.update_layout(title=f"Revenue Trend - {customer_filter}", height=400)
return fig
# Group by time granularity
if time_granularity == 'day':
df_copy['time_period'] = df_copy['order_date'].dt.date
title_suffix = "Daily"
elif time_granularity == 'week':
df_copy['time_period'] = df_copy['order_date'].dt.to_period('W')
title_suffix = "Weekly"
elif time_granularity == 'year':
df_copy['time_period'] = df_copy['order_date'].dt.to_period('Y')
title_suffix = "Yearly"
else: # default to month
df_copy['time_period'] = df_copy['order_date'].dt.to_period('M')
title_suffix = "Monthly"
revenue_data = df_copy.groupby('time_period')['amount'].sum().reset_index()
revenue_data['time_period'] = revenue_data['time_period'].astype(str)
# Create title
if customer_filter == 'all':
title = f"{title_suffix} Revenue Trends - All Customers"
else:
title = f"{title_suffix} Revenue Trends - {customer_filter}"
fig = px.line(
revenue_data,
x='time_period',
y='amount',
title=title,
labels={'amount': 'Revenue ($)', 'time_period': 'Time Period'}
)
fig.update_traces(line_color=COLORS['primary'], line_width=3)
fig.update_layout(height=400, title={'x': 0.5, 'xanchor': 'center'})
return fig
@staticmethod
def create_feature_importance_chart(feature_importance: pd.DataFrame):
"""Create feature importance chart"""
fig = px.bar(
feature_importance.head(8),
x='importance',
y='feature',
orientation='h',
title='Feature Importance Analysis',
labels={'importance': 'Importance Score', 'feature': 'Features'},
color='importance',
color_continuous_scale='viridis'
)
fig.update_layout(
height=500,
showlegend=False,
plot_bgcolor='white',
paper_bgcolor='white',
title={'x': 0.5, 'xanchor': 'center'},
yaxis={'categoryorder': 'total ascending'}
)
return fig
class ReportGenerator:
"""Handles report generation"""
@staticmethod
def generate_pdf_report(customer_data: pd.DataFrame, model_metrics: Dict) -> bytes:
"""Generate PDF report"""
if not REPORTLAB_AVAILABLE:
raise ImportError("PDF generation requires ReportLab library")
buffer = io.BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=A4,
rightMargin=72, leftMargin=72,
topMargin=72, bottomMargin=18)
styles = getSampleStyleSheet()
story = []
# Title
title_style = ParagraphStyle('CustomTitle', parent=styles['Title'],
fontSize=24, spaceAfter=30, alignment=1)
story.append(Paragraph("B2B Customer Analytics Report", title_style))
story.append(Spacer(1, 12))
# Executive summary
story.append(Paragraph("Executive Summary", styles['Heading2']))
total_customers = len(customer_data)
total_revenue = customer_data['monetary'].sum()
avg_revenue = customer_data['monetary'].mean()
summary_text = f"""
This comprehensive analysis covers {total_customers:,} customers with combined revenue of ${total_revenue:,.2f}.
The average customer value is ${avg_revenue:,.2f}. Customer segmentation and churn risk assessment
have been performed using advanced RFM analysis and machine learning techniques.
"""
story.append(Paragraph(summary_text, styles['Normal']))
story.append(Spacer(1, 20))
# Segment distribution
story.append(Paragraph("Customer Segmentation Overview", styles['Heading2']))
segment_dist = customer_data['Segment'].value_counts()
segment_data = []
segment_data.append(['Segment', 'Count', 'Percentage'])
for segment, count in segment_dist.items():
percentage = (count / total_customers) * 100
segment_data.append([segment, str(count), f"{percentage:.1f}%"])
segment_table = Table(segment_data)
segment_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, 0), 14),
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
('BACKGROUND', (0, 1), (-1, -1), colors.beige),
('GRID', (0, 0), (-1, -1), 1, colors.black)
]))
story.append(segment_table)
story.append(Spacer(1, 20))
# Model performance (if available)
if model_metrics:
story.append(Paragraph("Churn Prediction Model Performance", styles['Heading2']))
model_text = f"""
Model Type: {model_metrics['model_name']}<br/>
Accuracy: {model_metrics['accuracy']:.1%}<br/>
AUC Score: {model_metrics['auc_score']:.3f}<br/>
Cross-validation Score: {model_metrics['cv_mean']:.3f} ± {model_metrics['cv_std']:.3f}<br/>
Features Used: {model_metrics['n_features']}<br/>
Training Samples: {model_metrics['n_samples']}
"""
story.append(Paragraph(model_text, styles['Normal']))
# Build and return PDF
doc.build(story)
pdf_bytes = buffer.getvalue()
buffer.close()
return pdf_bytes
class B2BCustomerAnalytics:
"""Main analytics orchestrator"""
def __init__(self):
self.raw_data = None
self.customer_metrics = None
self.churn_predictor = ChurnPredictor()
self.has_trained_model = False
def load_data(self, file) -> Tuple[str, str, Optional[pd.DataFrame]]:
"""Load and process data"""
self.raw_data, status = DataProcessor.load_and_validate(file)
if self.raw_data is not None:
# Calculate RFM metrics
self.customer_metrics = RFMAnalyzer.calculate_rfm_metrics(self.raw_data)
# Perform segmentation
self.customer_metrics = CustomerSegmenter.perform_segmentation(self.customer_metrics)
# Generate dashboard
dashboard_html = self._generate_dashboard()
preview_data = self._prepare_preview_data()
return status, dashboard_html, preview_data
return status, "", None
def train_churn_model(self) -> Tuple[str, Optional[Any]]:
"""Train churn prediction model"""
if self.customer_metrics is None:
return "No data available. Please upload data first.", None
success, message, metrics = self.churn_predictor.train_model(self.customer_metrics)
if success:
self.has_trained_model = True
# Update predictions
self.customer_metrics = self.churn_predictor.predict(self.customer_metrics)
results_html = self._format_model_results(metrics)
chart = VisualizationEngine.create_feature_importance_chart(
self.churn_predictor.feature_importance
)
return results_html, chart
return f"Model training failed: {message}", None
def get_visualizations(self) -> Tuple[Any, Any, Any, Any]:
"""Get all visualizations"""
if self.customer_metrics is None:
return None, None, None, None
segment_chart = VisualizationEngine.create_segment_chart(self.customer_metrics)
rfm_chart = VisualizationEngine.create_rfm_scatter(self.customer_metrics)
churn_chart = VisualizationEngine.create_churn_chart(
self.customer_metrics, self.has_trained_model
)
revenue_chart = VisualizationEngine.create_revenue_trend(self.raw_data)
return segment_chart, rfm_chart, churn_chart, revenue_chart
def get_revenue_chart_with_filters(self, time_granularity='month', customer_filter='all'):
"""Get revenue chart with time and customer filters"""
if self.raw_data is None:
return None
return VisualizationEngine.create_revenue_trend(
self.raw_data, time_granularity, customer_filter
)
def get_customer_list(self):
"""Get list of customer IDs for dropdown"""
if self.raw_data is None:
return []
return ['all'] + sorted(self.raw_data['customer_id'].unique().tolist())
def get_customer_table(self) -> Optional[pd.DataFrame]:
"""Get formatted customer table"""
if self.customer_metrics is None:
return None
columns = ['customer_id', 'Segment', 'Churn_Risk', 'recency_days',
'frequency', 'monetary', 'avg_order_value']
if 'churn_probability' in self.customer_metrics.columns:
columns.append('churn_probability')
self.customer_metrics['churn_probability'] = (
self.customer_metrics['churn_probability'] * 100
).round(1)
table_data = self.customer_metrics[columns].copy()
table_data['monetary'] = table_data['monetary'].round(2)
table_data['avg_order_value'] = table_data['avg_order_value'].round(2)
# Rename columns for display
display_names = {
'customer_id': 'Customer ID',
'Segment': 'Segment',
'Churn_Risk': 'Risk Level',
'recency_days': 'Recency (Days)',
'frequency': 'Frequency',
'monetary': 'Total Spent ($)',
'avg_order_value': 'Avg Order ($)',
'churn_probability': 'Churn Probability (%)'
}
table_data = table_data.rename(columns=display_names)
return table_data.head(50)
def get_customer_insights(self, customer_id: str) -> str:
"""Get detailed customer insights"""
if self.customer_metrics is None or not customer_id:
return "Please enter a valid customer ID"
customer_data = self.customer_metrics[
self.customer_metrics['customer_id'] == customer_id
]
if customer_data.empty:
return f"Customer {customer_id} not found"
customer = customer_data.iloc[0]
return self._format_customer_profile(customer)
def generate_report(self) -> bytes:
"""Generate PDF report"""
if self.customer_metrics is None:
raise ValueError("No data available for report generation")
return ReportGenerator.generate_pdf_report(
self.customer_metrics,
self.churn_predictor.model_metrics
)
def _generate_dashboard(self) -> str:
"""Generate dashboard HTML"""
total_customers = len(self.customer_metrics)
total_revenue = self.customer_metrics['monetary'].sum()
avg_order_value = self.customer_metrics['avg_order_value'].mean()
high_risk_customers = (self.customer_metrics['Churn_Risk'] == 'High').sum()
segment_dist = self.customer_metrics['Segment'].value_counts()
return f"""
<div style="display: flex; flex-wrap: wrap; gap: 1rem; margin-bottom: 2rem;">
<div style="flex: 1; min-width: 200px; background: linear-gradient(135deg, #3b82f6, #1d4ed8); padding: 1.5rem; border-radius: 12px; color: white; text-align: center;">
<h3 style="margin: 0 0 0.5rem 0; font-size: 0.9rem; opacity: 0.9;">Total Customers</h3>
<div style="font-size: 2.5rem; font-weight: bold;">{total_customers:,}</div>
</div>
<div style="flex: 1; min-width: 200px; background: linear-gradient(135deg, #10b981, #047857); padding: 1.5rem; border-radius: 12px; color: white; text-align: center;">
<h3 style="margin: 0 0 0.5rem 0; font-size: 0.9rem; opacity: 0.9;">Total Revenue</h3>
<div style="font-size: 2.5rem; font-weight: bold;">${total_revenue/1000000:.1f}M</div>
</div>
<div style="flex: 1; min-width: 200px; background: linear-gradient(135deg, #8b5cf6, #6d28d9); padding: 1.5rem; border-radius: 12px; color: white; text-align: center;">
<h3 style="margin: 0 0 0.5rem 0; font-size: 0.9rem; opacity: 0.9;">Avg Order Value</h3>
<div style="font-size: 2.5rem; font-weight: bold;">${avg_order_value:.0f}</div>
</div>
<div style="flex: 1; min-width: 200px; background: linear-gradient(135deg, #ef4444, #dc2626); padding: 1.5rem; border-radius: 12px; color: white; text-align: center;">
<h3 style="margin: 0 0 0.5rem 0; font-size: 0.9rem; opacity: 0.9;">High Risk Customers</h3>
<div style="font-size: 2.5rem; font-weight: bold;">{high_risk_customers}</div>
</div>
</div>
<div style="background: #f8fafc; padding: 1.5rem; border-radius: 12px; border-left: 4px solid #6366f1;">
<h4 style="margin: 0 0 1rem 0; color: #374151;">Customer Segments Overview</h4>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem;">
{' '.join([f'<div><strong>{segment}:</strong> {count}</div>' for segment, count in segment_dist.items()])}
</div>
</div>
"""
def _prepare_preview_data(self) -> pd.DataFrame:
"""Prepare data preview"""
if self.raw_data is None:
return pd.DataFrame()
preview = self.raw_data.merge(
self.customer_metrics[['customer_id', 'Segment', 'Churn_Risk']],
on='customer_id',
how='left'
)
return preview.head(20)
def _format_model_results(self, metrics: Dict) -> str:
"""Format model training results"""
return f"""
<div style="background: white; padding: 2rem; border-radius: 1rem; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); margin-bottom: 2rem;">
<div style="text-align: center; margin-bottom: 2rem;">
<h3 style="color: #1f2937; font-size: 1.5rem; font-weight: bold; margin-bottom: 0.5rem;">
Model Training Completed Successfully
</h3>
<p style="color: #6b7280;">{metrics['model_name']} with Advanced Feature Engineering</p>
</div>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem; margin-bottom: 2rem;">
<div style="background: linear-gradient(135deg, #6366f1, #4f46e5); padding: 1rem; border-radius: 8px; text-align: center; color: white;">
<div style="font-size: 2rem; font-weight: bold;">{metrics['accuracy']:.1%}</div>
<div style="font-size: 0.9rem;">Accuracy</div>
</div>
<div style="background: linear-gradient(135deg, #10b981, #059669); padding: 1rem; border-radius: 8px; text-align: center; color: white;">
<div style="font-size: 2rem; font-weight: bold;">{metrics['auc_score']:.3f}</div>
<div style="font-size: 0.9rem;">AUC Score</div>
</div>
<div style="background: linear-gradient(135deg, #f59e0b, #d97706); padding: 1rem; border-radius: 8px; text-align: center; color: white;">
<div style="font-size: 2rem; font-weight: bold;">{metrics['n_features']}</div>
<div style="font-size: 0.9rem;">Features Used</div>
</div>
<div style="background: linear-gradient(135deg, #8b5cf6, #7c3aed); padding: 1rem; border-radius: 8px; text-align: center; color: white;">
<div style="font-size: 2rem; font-weight: bold;">{metrics['cv_mean']:.3f}</div>
<div style="font-size: 0.9rem;">CV Score</div>
</div>
</div>
</div>
"""
def _format_customer_profile(self, customer) -> str:
"""Format individual customer profile"""
churn_prob = customer.get('churn_probability', 0.5)
recommendations = self._get_customer_recommendations(
customer['Segment'], customer['Churn_Risk'], churn_prob, customer['recency_days']
)
return f"""
<div style="background: white; padding: 2rem; border-radius: 1rem; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); margin-bottom: 1rem;">
<h3 style="text-align: center; color: #1f2937; margin-bottom: 1.5rem;">Customer Profile: {customer['customer_id']}</h3>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 1rem; margin-bottom: 2rem;">
<div style="background: linear-gradient(135deg, #6366f1, #4f46e5); padding: 1rem; border-radius: 8px; color: white; text-align: center;">
<h4 style="margin: 0 0 0.5rem 0; font-size: 0.9rem; opacity: 0.9;">Segment</h4>
<div style="font-size: 1.2rem; font-weight: bold;">{customer['Segment']}</div>
</div>
<div style="background: linear-gradient(135deg, #ef4444, #dc2626); padding: 1rem; border-radius: 8px; color: white; text-align: center;">
<h4 style="margin: 0 0 0.5rem 0; font-size: 0.9rem; opacity: 0.9;">Churn Risk</h4>
<div style="font-size: 1.2rem; font-weight: bold;">{customer['Churn_Risk']}</div>
</div>
<div style="background: linear-gradient(135deg, #8b5cf6, #6d28d9); padding: 1rem; border-radius: 8px; color: white; text-align: center;">
<h4 style="margin: 0 0 0.5rem 0; font-size: 0.9rem; opacity: 0.9;">Churn Probability</h4>
<div style="font-size: 1.2rem; font-weight: bold;">{churn_prob:.1%}</div>
</div>
</div>
<div style="background: #f8fafc; padding: 1.5rem; border-radius: 8px; margin-bottom: 1rem;">
<h4 style="color: #374151; margin-bottom: 1rem;">Transaction Analytics</h4>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem;">
<div>
<div style="font-size: 0.8rem; color: #6b7280; margin-bottom: 0.2rem;">Purchase Frequency</div>
<div style="font-size: 1.5rem; font-weight: bold; color: #1f2937;">{customer['frequency']}</div>
</div>
<div>
<div style="font-size: 0.8rem; color: #6b7280; margin-bottom: 0.2rem;">Total Spent</div>
<div style="font-size: 1.5rem; font-weight: bold; color: #1f2937;">${customer['monetary']:,.0f}</div>
</div>
<div>
<div style="font-size: 0.8rem; color: #6b7280; margin-bottom: 0.2rem;">Avg Order Value</div>
<div style="font-size: 1.5rem; font-weight: bold; color: #1f2937;">${customer['avg_order_value']:.0f}</div>
</div>
<div>
<div style="font-size: 0.8rem; color: #6b7280; margin-bottom: 0.2rem;">Days Since Last Order</div>
<div style="font-size: 1.5rem; font-weight: bold; color: #1f2937;">{customer['recency_days']}</div>
</div>
</div>
</div>
<div style="background: linear-gradient(135deg, #f0f9ff, #e0f2fe); border-left: 4px solid #3b82f6; padding: 1rem; border-radius: 4px;">
<h4 style="color: #1e40af; margin-bottom: 0.5rem;">Recommendations</h4>
<p style="color: #1f2937; margin: 0;">{recommendations}</p>
</div>
</div>
"""
def _get_customer_recommendations(self, segment: str, risk_level: str,
churn_prob: float, recency: int) -> str:
"""Generate personalized recommendations"""
recommendations = []
if risk_level == 'High' or churn_prob > BUSINESS_CONFIG['high_risk_probability']:
recommendations.append("URGENT: Personal outreach required within 24 hours")
recommendations.append("Offer retention incentive or loyalty program")
elif risk_level == 'Medium':
recommendations.append("Send personalized re-engagement campaign")
if segment == 'Champions':
recommendations.append("Invite to VIP program or advisory board")
elif segment == 'At Risk':
recommendations.append("Proactive customer success intervention needed")
elif segment == 'New Customers':
recommendations.append("Deploy onboarding campaign sequence")
elif segment == 'Lost Customers':
recommendations.append("Win-back campaign with deep discount offer")
if recency > 60:
recommendations.append("Re-engagement campaign with special offer recommended")
return " • ".join(recommendations) if recommendations else "Continue monitoring customer engagement patterns."
def update_revenue_chart(analytics_instance, time_gran, customer_id):
"""Update revenue chart based on filters"""
try:
chart = analytics_instance.get_revenue_chart_with_filters(time_gran, customer_id)
return chart
except Exception as e:
return None
def update_customer_dropdown(analytics_instance):
"""Update customer dropdown options"""
try:
customers = analytics_instance.get_customer_list()
return gr.Dropdown(choices=customers, value='all')
except:
return gr.Dropdown(choices=['all'], value='all')
def create_gradio_interface():
"""Create the enhanced Gradio interface"""
# Custom CSS for modern styling
custom_css = """
.gradio-container {
font-family: 'Inter', system-ui, sans-serif !important;
max-width: 1200px !important;
}
.tab-nav {
background: #f8fafc !important;
border-radius: 8px !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), title="B2B Customer Analytics", css=custom_css) as demo:
# Initialize analytics instance per session
analytics = gr.State(B2BCustomerAnalytics())
gr.HTML("""
<div style="background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%); padding: 2rem; border-radius: 1rem; color: white; text-align: center; margin-bottom: 2rem;">
<h1 style="font-size: 2.5rem; font-weight: bold; margin-bottom: 0.5rem;">B2B Customer Analytics Platform</h1>
<p style="font-size: 1.1rem; opacity: 0.9;">Advanced Customer Segmentation & Churn Prediction</p>
<div style="font-size: 0.9rem; opacity: 0.8; margin-top: 1rem;">
Upload your customer data CSV with columns: customer_id, order_date, amount (or similar)
</div>
</div>
""")
with gr.Tabs():
with gr.Tab("Data Upload & Dashboard"):
with gr.Row():
with gr.Column():
file_input = gr.File(
label="Upload Customer Data CSV",
file_types=[".csv"],
type="filepath"
)
load_btn = gr.Button(
"Load & Process Data",
variant="primary",
size="lg"
)
load_status = gr.Textbox(
label="Status",
interactive=False,
max_lines=2
)
summary_display = gr.HTML()
data_preview = gr.DataFrame(label="Data Preview (First 20 Rows)")
with gr.Tab("Customer Segmentation"):
with gr.Row():
with gr.Column():
segment_chart = gr.Plot(label="Customer Segments Distribution")
with gr.Column():
rfm_chart = gr.Plot(label="RFM Behavior Analysis")
customer_table = gr.DataFrame(label="Customer Segmentation Details")
gr.HTML("""
<div style="background: #f0f9ff; padding: 1rem; border-radius: 8px; border-left: 4px solid #3b82f6; margin-top: 1rem;">
<h4 style="color: #1e40af; margin: 0 0 0.5rem 0;">Segment Definitions</h4>
<p style="margin: 0; color: #1f2937; font-size: 0.9rem;">
<strong>Champions:</strong> High value, frequent customers •
<strong>Loyal Customers:</strong> Regular, valuable customers •
<strong>At Risk:</strong> Previously valuable but declining activity •
<strong>Lost Customers:</strong> Haven't purchased recently
</p>
</div>
""")
with gr.Tab("Churn Prediction"):
train_btn = gr.Button(
"Train Churn Prediction Model",
variant="primary",
size="lg"
)
model_results = gr.HTML()
with gr.Row():
with gr.Column():
feature_importance_chart = gr.Plot(label="Feature Importance Analysis")
with gr.Column():
churn_distribution_chart = gr.Plot(label="Churn Risk Distribution")
gr.HTML("""
<div style="background: #fef3c7; padding: 1rem; border-radius: 8px; border-left: 4px solid #f59e0b; margin-top: 1rem;">
<h4 style="color: #92400e; margin: 0 0 0.5rem 0;">Model Information</h4>
<p style="margin: 0; color: #1f2937; font-size: 0.9rem;">
The model uses advanced features including customer lifetime, purchase patterns, and RFM metrics.
Customers with >90 days since last purchase are considered churned for training purposes.
</p>
</div>
""")
with gr.Tab("Revenue Analytics"):
with gr.Row():
with gr.Column(scale=1):
time_granularity = gr.Radio(
choices=['day', 'week', 'month', 'year'],
value='month',
label="Time Granularity"
)
customer_filter = gr.Dropdown(
choices=[], # Will be populated dynamically
value='all',
label="Customer Filter"
)
update_chart_btn = gr.Button("Update Chart", variant="primary")
with gr.Column(scale=3):
revenue_chart = gr.Plot(label="Revenue Trends")
# Add the info box
gr.HTML("""
<div style="background: #ecfdf5; padding: 1rem; border-radius: 8px; border-left: 4px solid #10b981; margin-top: 1rem;">
<h4 style="color: #065f46; margin: 0 0 0.5rem 0;">Interactive Revenue Analysis</h4>
<p style="margin: 0; color: #1f2937; font-size: 0.9rem;">
Select time granularity (day/week/month/year) and specific customers to analyze revenue patterns.
Use "all" to view aggregate trends across all customers.
</p>
</div>
""")
with gr.Tab("Customer Insights"):
with gr.Row():
customer_id_input = gr.Textbox(
label="Customer ID",
placeholder="Enter customer ID for detailed analysis",
scale=3
)
insights_btn = gr.Button(
"Get Customer Profile",
variant="primary",
scale=1
)
customer_insights = gr.HTML()
with gr.Tab("Reports"):
with gr.Row():
with gr.Column():
gr.HTML("""
<div style="background: white; padding: 2rem; border-radius: 1rem; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
<h3 style="color: #1f2937; margin-bottom: 1rem;">Generate Comprehensive Report</h3>
<p style="color: #6b7280; margin-bottom: 1.5rem;">
Create a detailed PDF report including customer segmentation analysis,
churn predictions, and actionable business insights.
</p>
</div>
""")
report_btn = gr.Button(
"Generate PDF Report",
variant="primary",
size="lg"
)
with gr.Column():
report_file = gr.File(
label="Download Report",
interactive=False
)
# Event handlers with proper error handling
def safe_load_data(analytics_instance, file):
try:
if file is None:
return analytics_instance, "Please upload a CSV file", "", None, None, None, None, None, None, gr.Dropdown(choices=['all'], value='all')
status, dashboard, preview = analytics_instance.load_data(file)
if "successfully" in status:
charts = analytics_instance.get_visualizations()
table = analytics_instance.get_customer_table()
# Update customer dropdown
customers = analytics_instance.get_customer_list()
customer_dropdown = gr.Dropdown(choices=customers, value='all')
return analytics_instance, status, dashboard, preview, *charts, table, customer_dropdown
else:
return analytics_instance, status, "", None, None, None, None, None, None, gr.Dropdown(choices=['all'], value='all')
except Exception as e:
error_msg = f"Error loading data: {str(e)}"
return analytics_instance, error_msg, "", None, None, None, None, None, None, gr.Dropdown(choices=['all'], value='all')
def safe_train_model(analytics_instance):
try:
result_html, chart = analytics_instance.train_churn_model()
# Update churn chart after training
updated_charts = analytics_instance.get_visualizations()
return analytics_instance, result_html, chart, updated_charts[2]
except Exception as e:
error_msg = f"Error training model: {str(e)}"
return analytics_instance, error_msg, None, None
def safe_get_insights(analytics_instance, customer_id):
try:
return analytics_instance.get_customer_insights(customer_id)
except Exception as e:
return f"Error getting insights: {str(e)}"
def safe_generate_report(analytics_instance):
try:
if analytics_instance.customer_metrics is None:
return None
pdf_bytes = analytics_instance.generate_report()
# Save to temporary file
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(pdf_bytes)
return tmp.name
except Exception as e:
gr.Warning(f"Error generating report: {str(e)}")
return None
# Wire up events
load_btn.click(
fn=safe_load_data,
inputs=[analytics, file_input],
outputs=[analytics, load_status, summary_display, data_preview,
segment_chart, rfm_chart, churn_distribution_chart, revenue_chart, customer_table, customer_filter]
)
# Update chart when filters change
update_chart_btn.click(
fn=update_revenue_chart,
inputs=[analytics, time_granularity, customer_filter],
outputs=[revenue_chart]
)
# Auto-update on filter change
time_granularity.change(
fn=update_revenue_chart,
inputs=[analytics, time_granularity, customer_filter],
outputs=[revenue_chart]
)
customer_filter.change(
fn=update_revenue_chart,
inputs=[analytics, time_granularity, customer_filter],
outputs=[revenue_chart]
)
train_btn.click(
fn=safe_train_model,
inputs=[analytics],
outputs=[analytics, model_results, feature_importance_chart, churn_distribution_chart]
)
insights_btn.click(
fn=safe_get_insights,
inputs=[analytics, customer_id_input],
outputs=[customer_insights]
)
report_btn.click(
fn=safe_generate_report,
inputs=[analytics],
outputs=[report_file]
)
# Auto-update customer insights on Enter key
customer_id_input.submit(
fn=safe_get_insights,
inputs=[analytics, customer_id_input],
outputs=[customer_insights]
)
return demo
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)