"""Generate clustered prototype databases from embeddings.""" from __future__ import annotations import argparse from pathlib import Path from typing import Iterable, Optional import faiss import numpy as np import torch class GPUKMeansClusterer: def __init__(self, dim: int, n_clusters: int = 500, n_iter: int = 20, n_gpu: int = 1): self.clus = faiss.Clustering(dim, n_clusters) self.clus.verbose = True self.clus.niter = n_iter self.dim = dim self.n_clusters = n_clusters self.clus.update_index = True res = [faiss.StandardGpuResources() for _ in range(n_gpu)] flat_config = [] for i in range(n_gpu): cfg = faiss.GpuIndexFlatConfig() cfg.useFloat16 = False cfg.device = i flat_config.append(cfg) if n_gpu == 1: self.index = faiss.GpuIndexFlatL2(res[0], self.dim, flat_config[0]) else: indexes = [faiss.GpuIndexFlatL2(res[i], self.dim, flat_config[i]) for i in range(n_gpu)] self.index = faiss.IndexReplicas() for sub_index in indexes: self.index.addIndex(sub_index) def fit(self, embeddings_np: np.ndarray) -> np.ndarray: self.index.reset() self.clus.train(embeddings_np, self.index) centroids = faiss.vector_float_to_array(self.clus.centroids) centroids = centroids.reshape(self.n_clusters, self.dim) return centroids def gen_data(dict_data): embeddings = dict_data["embeddings"] labels = dict_data["labels"] ids = dict_data["ids"] classes = dict_data["classes"] return embeddings, labels, ids, classes def build_argument_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Cluster embeddings into prototype databases using GPU K-Means.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--database", type=Path, required=True, help="Input embedding database (.pt).") parser.add_argument("--output", type=Path, required=True, help="Output path for the clustered database.") parser.add_argument("--clusters", type=int, default=10000) parser.add_argument("--dimension", type=int, default=1024) parser.add_argument("--iterations", type=int, default=100) parser.add_argument("--gpus", type=int, default=1) parser.add_argument("--human-class-name", type=str, default="human", help="Label representing humans in the class list.") return parser def cluster_database(args: argparse.Namespace) -> None: data_emb, data_labels, data_ids, data_classes = gen_data(torch.load(args.database)) human_idx = data_classes.index(args.human_class_name) datapos = (data_labels == human_idx).long() pos2cnt = {0: args.clusters, 1: args.clusters} pos2name = {0: ["llm"], 1: ["human"]} datapos_np = datapos.cpu().numpy() kmeans = GPUKMeansClusterer(args.dimension, n_clusters=args.clusters, n_iter=args.iterations, n_gpu=args.gpus) all_centers = {} save_labels = None for key in data_emb: now_emb = data_emb[key].float().cpu().numpy() all_center = [] all_labels = [] for pos in pos2cnt: pos_emb = now_emb[datapos_np == pos] pos_center = kmeans.fit(pos_emb) all_center.append(pos_center) all_labels.append(np.full((pos_center.shape[0],), pos)) all_center = np.concatenate(all_center, axis=0) all_labels = np.concatenate(all_labels, axis=0) all_center = torch.from_numpy(all_center).to(dtype=torch.bfloat16) all_labels = torch.from_numpy(all_labels).to(dtype=torch.long) all_centers[key] = all_center save_labels = all_labels save_ids = torch.arange(save_labels.shape[0], dtype=torch.long) classes = [None] * len(pos2name.keys()) for pos in pos2name: classes[pos] = ','.join(pos2name[pos]) emb_dict = {"embeddings": all_centers, "labels": save_labels, "ids": save_ids, "classes": classes} args.output.parent.mkdir(parents=True, exist_ok=True) torch.save(emb_dict, args.output) print(f"All centers saved to: {args.output}") def main(argv: Optional[Iterable[str]] = None) -> None: parser = build_argument_parser() args = parser.parse_args(argv) cluster_database(args) if __name__ == "__main__": main() __all__ = ["build_argument_parser", "cluster_database", "main"]