Amit-kr26's picture
HF Spaces deployment
c9f187d
Raw
History Blame Contribute Delete
2.4 kB
"""KMeans customer segmentation model."""
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
SEGMENT_NAMES = {
0: "High Value",
1: "Loyal",
2: "Occasional",
3: "At Risk",
4: "Lost",
}
FEATURE_COLS = [
"recency_days",
"frequency",
"monetary",
"avg_order_value",
"tenure_days",
"avg_days_between_orders",
]
class CustomerSegmentationModel:
def __init__(self, n_clusters: int = 5, random_state: int = 42) -> None:
self.n_clusters = n_clusters
self.random_state = random_state
self.pipeline: Pipeline | None = None
def build_feature_matrix(
self, rfm_df: pd.DataFrame, clv_df: pd.DataFrame
) -> tuple[pd.DataFrame, np.ndarray]:
"""Merge RFM + CLV features and return (index_df, feature_matrix)."""
merged = rfm_df[["customer_unique_id"] + FEATURE_COLS[:3]].merge(
clv_df[["customer_unique_id", "avg_order_value", "tenure_days", "avg_days_between_orders"]],
on="customer_unique_id",
how="inner",
)
merged = merged.dropna(subset=FEATURE_COLS)
X = merged[FEATURE_COLS].values
return merged[["customer_unique_id"]], X
def fit(self, X: np.ndarray) -> "CustomerSegmentationModel":
self.pipeline = Pipeline(
[
("scaler", StandardScaler()),
("kmeans", KMeans(n_clusters=self.n_clusters, random_state=self.random_state, n_init=10)),
]
)
self.pipeline.fit(X)
return self
def predict(self, X: np.ndarray) -> np.ndarray:
if self.pipeline is None:
raise RuntimeError("Model not fitted. Call fit() first.")
return self.pipeline.predict(X)
def evaluate(self, X: np.ndarray) -> dict:
labels = self.predict(X)
X_scaled = self.pipeline.named_steps["scaler"].transform(X)
return {
"silhouette_score": float(silhouette_score(X_scaled, labels, sample_size=10000)),
"inertia": float(self.pipeline.named_steps["kmeans"].inertia_),
"n_clusters": self.n_clusters,
}
def get_segment_name(self, cluster_id: int) -> str:
return SEGMENT_NAMES.get(cluster_id, f"Segment {cluster_id}")