Spaces:
Runtime error
Runtime error
File size: 1,194 Bytes
8775e4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# evaluate.py
# Purpose: small evaluation and visualization for clusters
import numpy as np
import pandas as pd
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt
import seaborn as sns
def silhouette(embs, labels):
mask = labels >= 0
if mask.sum() <= 1:
return None
score = silhouette_score(embs[mask], labels[mask])
return score
def cluster_stats(df_original, labels):
df = df_original.copy()
df['cluster'] = labels
stats = df.groupby('cluster').agg({'customer_id':'count', 'annual_income':'median', 'spend_score':'median'})
return stats
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--features', default='data/features.parquet')
parser.add_argument('--emb', default='data/embeddings.npy')
parser.add_argument('--labels', default='data/cluster_labels.npy')
args = parser.parse_args()
df = pd.read_parquet(args.features)
embs = np.load(args.emb)
labels = np.load(args.labels)
s = silhouette(embs, labels)
print('Silhouette score (ignoring noise labels -1):', s)
try:
stats = cluster_stats(df, labels)
print(stats)
except Exception:
print('Could not compute descriptive stats (missing columns).') |