GodsDevProject commited on
Commit
96eb5a4
·
verified ·
1 Parent(s): a86976f

Create ingest/cluster.py

Browse files
Files changed (1) hide show
  1. ingest/cluster.py +16 -24
ingest/cluster.py CHANGED
@@ -1,29 +1,21 @@
1
- from sentence_transformers import SentenceTransformer
2
- import faiss, numpy as np
3
- import plotly.graph_objects as go
4
 
5
- model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
 
 
6
 
7
- def semantic_cluster_plot(results):
8
- texts = [r["title"] + " " + r["snippet"] for r in results]
9
- if not texts:
10
- return go.Figure()
11
 
12
- embeddings = model.encode(texts)
13
- dim = embeddings.shape[1]
14
  index = faiss.IndexFlatL2(dim)
15
- index.add(np.array(embeddings))
16
 
17
- # simple 2D projection (first 2 dims for HF safety)
18
- x, y = embeddings[:,0], embeddings[:,1]
19
-
20
- fig = go.Figure(
21
- data=go.Scatter(
22
- x=x, y=y,
23
- mode="markers",
24
- text=[r["title"] for r in results],
25
- marker=dict(size=8)
26
- )
27
- )
28
- fig.update_layout(title="Semantic Document Clusters")
29
- return fig
 
1
+ import faiss
2
+ import numpy as np
3
+ from sklearn.feature_extraction.text import TfidfVectorizer
4
 
5
+ def semantic_clusters(documents, k=5):
6
+ """
7
+ documents: list[str]
8
+ returns: list[int] cluster ids
9
+ """
10
+ if len(documents) < 2:
11
+ return [0] * len(documents)
12
 
13
+ vectorizer = TfidfVectorizer(max_features=512, stop_words="english")
14
+ vectors = vectorizer.fit_transform(documents).toarray().astype("float32")
 
 
15
 
16
+ dim = vectors.shape[1]
 
17
  index = faiss.IndexFlatL2(dim)
18
+ index.add(vectors)
19
 
20
+ _, labels = index.search(vectors, 1)
21
+ return labels.flatten().tolist()