xxnithicxx's picture
update tab 2
1f4d19b
# -*- coding: utf-8 -*-
"""
Gradio Web Demo for Customer Segmentation Project
A comprehensive interactive interface showcasing:
- Tab 1: Dashboard with KPIs and EDA visualizations
- Tab 2: Clustering Playground with interactive K selection
- Tab 3: Customer DNA analysis with Radar charts
- Tab 4: Segment Prediction for new customers
"""
import sys
import os
import numpy as np
import pandas as pd
import gradio as gr
from datetime import datetime, timedelta
from pathlib import Path
# Add src path for clustering_library
sys.path.insert(0, '../src')
# Import utilities
from utils.data_loader import DataLoader, get_data_loader
from utils.clustering_models import ClusteringModels, init_clustering_models
from utils.visualizations import (
create_kpi_display,
plot_revenue_over_time,
plot_hourly_daily_heatmap,
plot_elbow_silhouette,
plot_clusters_pca_2d,
plot_radar_chart,
create_cluster_stats_table,
)
from sklearn.preprocessing import StandardScaler
# ============================================================================
# INITIALIZATION
# ============================================================================
def initialize_app():
"""Initialize the Gradio app with data and models."""
print("Initializing app...")
# Load data
data_loader = DataLoader("./data/processed")
scaled_features = data_loader.scaled_features
original_features = data_loader.original_features
raw_data = data_loader.raw_data
# Initialize clustering models
models_dir = "./models"
cm = ClusteringModels(scaled_features, original_features, models_dir)
# Try to load existing models, otherwise train them
if Path(models_dir).exists() and any(Path(models_dir).glob("kmeans_k*.pkl")):
print("Loading pre-trained models...")
cm.load_models(k_range=range(2, 11))
else:
print("Models not found. Training models...")
cm.train_models(k_range=range(2, 11))
cm.apply_pca(n_components=None)
cm.save_models()
# If PCA wasn't loaded, apply it
if cm.pca_features is None:
print("Applying PCA...")
cm.apply_pca(n_components=None)
init_clustering_models(scaled_features, original_features, models_dir)
# Pre-compute all PCA plots for Tab 2 (K=2 to K=10)
print("Pre-computing PCA plots for all K values...")
pca_plots_cache = {}
for k in range(2, 11):
if k in cm.cluster_labels:
labels = cm.cluster_labels[k]
pca_plots_cache[k] = plot_clusters_pca_2d(cm.pca_features, labels, k)
print(f" Cached PCA plot for K={k}")
print("All PCA plots cached successfully!")
return data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache
# Global variables (will be initialized at app startup)
data_loader = None
cm = None
raw_data = None
scaled_features = None
original_features = None
pca_plots_cache = None
# ============================================================================
# TAB 1: DASHBOARD OVERVIEW
# ============================================================================
def get_kpi_data():
"""Get KPI metrics."""
return data_loader.get_kpi_metrics()
def get_dashboard_plots():
"""Get dashboard plots (cached)."""
kpi_metrics = get_kpi_data()
kpi_html = create_kpi_display(kpi_metrics)
# Revenue plot
revenue_fig = plot_revenue_over_time(raw_data)
# Heatmap
heatmap_fig = plot_hourly_daily_heatmap(raw_data)
return kpi_html, revenue_fig, heatmap_fig
def create_tab1():
"""Create Tab 1: Dashboard Overview."""
with gr.TabItem("Dashboard - Overview"):
gr.Markdown("# Data Overview Analysis")
# KPI Metrics (HTML display)
kpi_html, revenue_fig, heatmap_fig = get_dashboard_plots()
gr.HTML(kpi_html)
gr.Markdown("## Revenue Over Time")
# Date range picker for revenue
with gr.Row():
date_start = gr.DateTime(
label="From Date",
value=raw_data["InvoiceDate"].min()
)
date_end = gr.DateTime(
label="To Date",
value=raw_data["InvoiceDate"].max()
)
revenue_plot = gr.Plot(
label="Revenue Chart",
value=revenue_fig
)
# Update revenue plot when dates change
def update_revenue_plot(start, end):
if start is None or end is None:
return revenue_fig
return plot_revenue_over_time(raw_data, start, end)
date_start.change(
fn=update_revenue_plot,
inputs=[date_start, date_end],
outputs=revenue_plot
)
date_end.change(
fn=update_revenue_plot,
inputs=[date_start, date_end],
outputs=revenue_plot
)
gr.Markdown("## Shopping Behavior by Hour and Day")
gr.Plot(
label="Shopping Activity Heatmap",
value=heatmap_fig
)
gr.Markdown("""
### Insights:
- **Heatmap** shows shopping patterns by hour (0-23) and day of week
- **Revenue Over Time** shows overall sales trend (12 months)
- Filter by date range to zoom into peak months (Christmas, etc.)
""")
# ============================================================================
# TAB 2: CLUSTERING PLAYGROUND
# ============================================================================
def get_optimal_clusters_data():
"""Get Elbow and Silhouette data (cached)."""
k_list = list(range(2, 11))
inertias = cm.inertias
silhouette_scores = cm.silhouette_scores
return inertias, silhouette_scores, k_list
def create_tab2():
"""Create Tab 2: Clustering Playground."""
with gr.TabItem("Clustering - Playground"):
gr.Markdown("# Explore K-Means Clustering Algorithm")
gr.Markdown("""
Adjust the slider to select different numbers of clusters (K) and see how the algorithm
divides customers into different groups.
""")
# Get optimal data
inertias, silhouette_scores, k_list = get_optimal_clusters_data()
# Elbow + Silhouette plot (static, cached)
gr.Markdown("## Determine Optimal Number of Clusters")
optimal_fig = plot_elbow_silhouette(inertias, silhouette_scores, range(2, 11))
gr.Plot(value=optimal_fig)
gr.Markdown("""
**Explanation:**
- **Elbow Method**: Find the "elbow" point where increasing K doesn't significantly reduce inertia
- **Silhouette Score**: Higher is better. Clusters are more distinct when score is high
- **Recommendation**: K=3 or K=4 are both good choices
""")
# K slider and PCA visualization
gr.Markdown("## Visualize Clusters in PCA Space")
k_slider = gr.Slider(
minimum=2,
maximum=10,
value=4,
step=1,
label="Select number of clusters (K)",
interactive=True
)
def update_pca_plot(k):
"""Update PCA plot based on selected K (from cache)."""
if k in pca_plots_cache:
return pca_plots_cache[k]
return None
pca_plot = gr.Plot(
label="Scatter Plot: PC1 vs PC2",
value=update_pca_plot(4) # Default k=4
)
k_slider.change(
fn=update_pca_plot,
inputs=k_slider,
outputs=pca_plot
)
gr.Markdown("""
**How to Use:**
- Each **point** represents one customer
- **Color** indicates which cluster the customer belongs to
- When changing K, clusters will be instantly updated from cache
""")
# ============================================================================
# TAB 3: CUSTOMER DNA
# ============================================================================
def create_tab3():
"""Create Tab 3: Customer DNA."""
with gr.TabItem("Analysis - Customer DNA"):
gr.Markdown("# Deep Analysis: Characteristics of Each Cluster")
gr.Markdown("""
Select a cluster to see detailed characteristics of customers in it.
The Radar chart shows how this cluster differs from the overall average.
""")
# Get available clusters (K=3 and K=4)
k_choices = [3, 4]
with gr.Row():
k_select = gr.Radio(
choices=k_choices,
value=4,
label="Select Model (K clusters)"
)
cluster_select = gr.Dropdown(
choices=[0, 1, 2, 3],
value=0,
label="Select Cluster",
interactive=True
)
def update_cluster_choices(k):
"""Update cluster choices based on selected K."""
return gr.Dropdown(
choices=list(range(k)),
value=0,
interactive=True
)
k_select.change(
fn=update_cluster_choices,
inputs=k_select,
outputs=cluster_select
)
# Radar chart
gr.Markdown("### Radar Chart - Comparison with Overall Average")
def update_radar_and_stats(k, cluster_idx):
"""Update radar chart and statistics."""
cluster_info = cm.get_cluster_info(k)
cluster_means = cluster_info["means"]
# Create radar chart for selected cluster
radar_fig = plot_radar_chart(cluster_means, k, cluster_idx=cluster_idx)
# Create stats table
stats_df = create_cluster_stats_table(cluster_means, k)
return radar_fig, stats_df
radar_plot = gr.Plot(label="Radar Chart")
stats_table = gr.Dataframe(label="Detailed Statistics")
# Update when K or cluster changes
k_select.change(
fn=update_radar_and_stats,
inputs=[k_select, cluster_select],
outputs=[radar_plot, stats_table]
)
cluster_select.change(
fn=update_radar_and_stats,
inputs=[k_select, cluster_select],
outputs=[radar_plot, stats_table]
)
# Initial load
initial_k = 4
cluster_info = cm.get_cluster_info(initial_k)
initial_radar = plot_radar_chart(cluster_info["means"], initial_k, cluster_idx=0)
initial_stats = create_cluster_stats_table(cluster_info["means"], initial_k)
radar_plot.value = initial_radar
stats_table.value = initial_stats
gr.Markdown("""
### How to Read Radar Chart:
- **Each axis = 1 customer characteristic** (normalized 0-1 scale)
- **Further from center = higher value** for that characteristic
- **Shape of polygon** represents the cluster's profile
- **Compare clusters** by looking at shape and size
""")
# ============================================================================
# MAIN APP
# ============================================================================
def main():
"""Main Gradio app."""
global data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache
print("Starting Gradio app initialization...")
data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache = initialize_app()
print("App initialized successfully!")
# Create interface
with gr.Blocks(
title="Customer Segmentation Demo"
) as demo:
# Header
gr.Markdown("""
# Customer Segmentation - Advanced Analysis
Interactive demo showcasing customer clustering analysis with K-Means.
Explore data stories, clustering patterns, and predict segments for new customers.
""")
# Tabs
create_tab1()
create_tab2()
create_tab3()
# Footer
gr.Markdown("""
---
**Project:** Advanced Customer Segmentation
**Data:** Online Retail (2010-2011) - Customers: 3,920+ - Transactions: 354,000+
**Built from:** Project by Dr.Nguyen Thai Ha
""")
return demo
if __name__ == "__main__":
demo = main()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)