| | |
| | import os, asyncio |
| | from huggingface_hub import InferenceClient |
| | from sklearn.cluster import KMeans |
| |
|
| | |
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| | EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2" |
| | client = InferenceClient(token=HF_TOKEN) |
| |
|
| | async def embed_texts(texts: list[str]) -> list[list[float]]: |
| | """ |
| | Compute embeddings for a list of texts via HF Inference API. |
| | """ |
| | def _embed(t): |
| | return client.embed(model=EMBED_MODEL, inputs=t) |
| | |
| | tasks = [asyncio.to_thread(_embed, t) for t in texts] |
| | return await asyncio.gather(*tasks) |
| |
|
| | async def cluster_embeddings(embs: list[list[float]], n_clusters: int = 5) -> list[int]: |
| | """ |
| | Cluster embeddings into n_clusters, return list of cluster labels. |
| | """ |
| | kmeans = KMeans(n_clusters=n_clusters, random_state=0) |
| | return kmeans.fit_predict(embs).tolist() |
| |
|