| """Find representative members of clusters. |
| |
| Given embeddings and cluster labels, rank each cluster's members by proximity |
| to the cluster centroid. Returns more candidates than strictly requested |
| (oversampled) so callers that render images can skip candidates whose image |
| fails to load and still show the desired number per cluster. |
| """ |
|
|
| from typing import Dict, List |
|
|
| import numpy as np |
|
|
| from shared.utils.logging_config import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def find_cluster_representatives( |
| embeddings: np.ndarray, |
| labels, |
| n_per_cluster: int = 3, |
| oversample: int = 4, |
| ) -> Dict[object, List[int]]: |
| """Rank each cluster's members by closeness to the cluster centroid. |
| |
| Args: |
| embeddings: (N, D) array of embeddings (row i aligns with label i). |
| labels: array-like of length N with cluster labels (int or str). |
| n_per_cluster: how many representatives the caller intends to show. |
| oversample: multiplier for how many candidate indices to return per |
| cluster (n_per_cluster * oversample), so failed image loads can be |
| skipped while still surfacing n_per_cluster images. |
| |
| Returns: |
| Dict mapping each cluster label to a list of global indices into |
| `embeddings`, ordered closest-to-centroid first, capped at |
| n_per_cluster * oversample (or the cluster size, whichever is smaller). |
| """ |
| labels = np.asarray(labels) |
| embeddings = np.asarray(embeddings) |
| n_candidates = max(n_per_cluster * oversample, n_per_cluster) |
|
|
| representatives: Dict[object, List[int]] = {} |
| for cluster_id in np.unique(labels): |
| member_idxs = np.where(labels == cluster_id)[0] |
| if member_idxs.size == 0: |
| continue |
| cluster_embeds = embeddings[member_idxs] |
| centroid = cluster_embeds.mean(axis=0) |
| |
| |
| dists = np.sum((cluster_embeds - centroid) ** 2, axis=1) |
| order = np.argsort(dists)[:n_candidates] |
| |
| key = cluster_id.item() if hasattr(cluster_id, "item") else cluster_id |
| representatives[key] = member_idxs[order].tolist() |
|
|
| logger.debug( |
| f"Found representatives for {len(representatives)} clusters " |
| f"(up to {n_candidates} candidates each)" |
| ) |
| return representatives |
|
|