stellar-search / clustering_service.py
X1ng1's picture
fixed imports
95a501f
"""
Clustering service using hierarchical clustering
"""
import numpy as np
from typing import List, Dict, Tuple, Optional
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import pdist, squareform
import logging
from config import config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ClusteringService:
"""Service for hierarchical clustering of message embeddings"""
def __init__(
self,
distance_threshold: Optional[float] = None,
min_cluster_size: Optional[int] = None,
random_seed: Optional[int] = None
):
"""
Initialize clustering service
Args:
distance_threshold: Distance threshold for hierarchical clustering
min_cluster_size: Minimum number of messages per cluster
random_seed: Random seed for reproducibility
"""
self.distance_threshold = distance_threshold or config.DISTANCE_THRESHOLD
self.min_cluster_size = min_cluster_size or config.MIN_CLUSTER_SIZE
self.random_seed = random_seed or config.RANDOM_SEED
# Set random seed for reproducibility
np.random.seed(self.random_seed)
logger.info(f"Clustering service initialized with distance_threshold={self.distance_threshold}, "
f"min_cluster_size={self.min_cluster_size}")
def cluster_messages(
self,
embeddings: np.ndarray,
message_ids: List[str]
) -> Tuple[List[int], Dict[int, List[int]], np.ndarray]:
"""
Perform hierarchical clustering on embeddings
Args:
embeddings: Matrix of embeddings (n_messages, embedding_dim)
message_ids: List of message IDs corresponding to embeddings
Returns:
Tuple of (cluster_labels, cluster_to_messages, linkage_matrix)
"""
n_messages = len(embeddings)
if n_messages < 2:
logger.warning("Too few messages for clustering")
return [0] * n_messages, {0: list(range(n_messages))}, None
logger.info(f"Clustering {n_messages} messages")
# Compute distance matrix (using cosine distance)
# Since embeddings are normalized, cosine distance = 1 - dot product
distances = 1 - np.dot(embeddings, embeddings.T)
np.fill_diagonal(distances, 0) # Ensure diagonal is 0
# Convert to condensed distance matrix for scipy
condensed_distances = squareform(distances, checks=False)
# Perform hierarchical clustering using ward linkage
# Ward minimizes variance within clusters - good for text clustering
linkage_matrix = linkage(
condensed_distances,
method='ward',
optimal_ordering=False # Disable for determinism
)
# Form flat clusters using distance threshold
cluster_labels = fcluster(
linkage_matrix,
t=self.distance_threshold,
criterion='distance'
)
# Convert to 0-indexed
cluster_labels = cluster_labels - 1
# Group messages by cluster
cluster_to_messages = {}
for idx, cluster_id in enumerate(cluster_labels):
if cluster_id not in cluster_to_messages:
cluster_to_messages[cluster_id] = []
cluster_to_messages[cluster_id].append(idx)
# Filter small clusters
cluster_labels, cluster_to_messages = self._filter_small_clusters(
cluster_labels,
cluster_to_messages,
n_messages
)
logger.info(f"Created {len(cluster_to_messages)} clusters")
return cluster_labels.tolist(), cluster_to_messages, linkage_matrix
def _filter_small_clusters(
self,
cluster_labels: np.ndarray,
cluster_to_messages: Dict[int, List[int]],
n_messages: int
) -> Tuple[np.ndarray, Dict[int, List[int]]]:
"""
Filter out clusters that are too small
Args:
cluster_labels: Array of cluster labels
cluster_to_messages: Mapping of cluster IDs to message indices
n_messages: Total number of messages
Returns:
Filtered cluster_labels and cluster_to_messages
"""
# Find clusters that meet minimum size
valid_clusters = {
cluster_id: messages
for cluster_id, messages in cluster_to_messages.items()
if len(messages) >= self.min_cluster_size
}
if not valid_clusters:
# If no valid clusters, create one cluster with all messages
logger.warning("No clusters meet minimum size requirement, creating single cluster")
return np.zeros(n_messages, dtype=int), {0: list(range(n_messages))}
# Create a mapping from old cluster IDs to new ones
old_to_new = {old_id: new_id for new_id, old_id in enumerate(sorted(valid_clusters.keys()))}
# Reassign labels
new_labels = np.full(n_messages, -1, dtype=int)
new_cluster_to_messages = {}
for old_id, new_id in old_to_new.items():
messages = cluster_to_messages[old_id]
new_cluster_to_messages[new_id] = messages
for msg_idx in messages:
new_labels[msg_idx] = new_id
# Assign unclustered messages (from small clusters) to nearest valid cluster
unclustered_mask = new_labels == -1
if np.any(unclustered_mask):
logger.info(f"Reassigning {np.sum(unclustered_mask)} messages from small clusters")
# For simplicity, assign to cluster 0 (could be improved with nearest neighbor)
new_labels[unclustered_mask] = 0
unclustered_indices = np.where(unclustered_mask)[0].tolist()
new_cluster_to_messages[0].extend(unclustered_indices)
return new_labels, new_cluster_to_messages
def compute_cluster_centroids(
self,
embeddings: np.ndarray,
cluster_to_messages: Dict[int, List[int]]
) -> Dict[int, np.ndarray]:
"""
Compute centroid for each cluster
Args:
embeddings: Matrix of embeddings
cluster_to_messages: Mapping of cluster IDs to message indices
Returns:
Dictionary mapping cluster IDs to centroid vectors
"""
centroids = {}
for cluster_id, message_indices in cluster_to_messages.items():
cluster_embeddings = embeddings[message_indices]
centroid = np.mean(cluster_embeddings, axis=0)
# Normalize centroid
centroid = centroid / np.linalg.norm(centroid)
centroids[cluster_id] = centroid
return centroids
def build_hierarchy(
self,
linkage_matrix: np.ndarray,
cluster_to_messages: Dict[int, List[int]],
max_depth: int = 3
) -> Dict[int, Optional[int]]:
"""
Build hierarchical structure from linkage matrix
Args:
linkage_matrix: Linkage matrix from hierarchical clustering
cluster_to_messages: Mapping of cluster IDs to message indices
max_depth: Maximum depth of hierarchy
Returns:
Dictionary mapping cluster IDs to parent cluster IDs
"""
# For now, return flat hierarchy (no parents)
# Can be extended to build multi-level hierarchy
return {cluster_id: None for cluster_id in cluster_to_messages.keys()}
def get_representative_messages(
self,
embeddings: np.ndarray,
cluster_messages: List[int],
centroid: np.ndarray,
top_k: int = 5
) -> List[int]:
"""
Get the most representative messages for a cluster
Args:
embeddings: Matrix of all embeddings
cluster_messages: Indices of messages in this cluster
centroid: Cluster centroid
top_k: Number of representative messages to return
Returns:
List of message indices closest to centroid
"""
if len(cluster_messages) <= top_k:
return cluster_messages
cluster_embeddings = embeddings[cluster_messages]
similarities = np.dot(cluster_embeddings, centroid)
top_k_local = np.argsort(similarities)[-top_k:][::-1]
return [cluster_messages[i] for i in top_k_local]
# Global instance
_clustering_service: Optional[ClusteringService] = None
def get_clustering_service() -> ClusteringService:
"""Get or create the global clustering service instance"""
global _clustering_service
if _clustering_service is None:
_clustering_service = ClusteringService()
return _clustering_service