emb-explorer-demo / shared /utils /representatives.py
netzhang's picture
Deploy merged demo: representative images (#42), t-SNE exact solver (#45), PCA reproducibility (#46), decoupled projection/KMeans + thread pipeline, demo header/footer
269ea1f verified
Raw
History Blame Contribute Delete
2.43 kB
"""Find representative members of clusters.
Given embeddings and cluster labels, rank each cluster's members by proximity
to the cluster centroid. Returns more candidates than strictly requested
(oversampled) so callers that render images can skip candidates whose image
fails to load and still show the desired number per cluster.
"""
from typing import Dict, List
import numpy as np
from shared.utils.logging_config import get_logger
logger = get_logger(__name__)
def find_cluster_representatives(
embeddings: np.ndarray,
labels,
n_per_cluster: int = 3,
oversample: int = 4,
) -> Dict[object, List[int]]:
"""Rank each cluster's members by closeness to the cluster centroid.
Args:
embeddings: (N, D) array of embeddings (row i aligns with label i).
labels: array-like of length N with cluster labels (int or str).
n_per_cluster: how many representatives the caller intends to show.
oversample: multiplier for how many candidate indices to return per
cluster (n_per_cluster * oversample), so failed image loads can be
skipped while still surfacing n_per_cluster images.
Returns:
Dict mapping each cluster label to a list of global indices into
`embeddings`, ordered closest-to-centroid first, capped at
n_per_cluster * oversample (or the cluster size, whichever is smaller).
"""
labels = np.asarray(labels)
embeddings = np.asarray(embeddings)
n_candidates = max(n_per_cluster * oversample, n_per_cluster)
representatives: Dict[object, List[int]] = {}
for cluster_id in np.unique(labels):
member_idxs = np.where(labels == cluster_id)[0]
if member_idxs.size == 0:
continue
cluster_embeds = embeddings[member_idxs]
centroid = cluster_embeds.mean(axis=0)
# Compute squared Euclidean distance to the centroid for each member.
dists = np.sum((cluster_embeds - centroid) ** 2, axis=1)
order = np.argsort(dists)[:n_candidates]
# Keep the label's native Python type for clean dict keys / display.
key = cluster_id.item() if hasattr(cluster_id, "item") else cluster_id
representatives[key] = member_idxs[order].tolist()
logger.debug(
f"Found representatives for {len(representatives)} clusters "
f"(up to {n_candidates} candidates each)"
)
return representatives