GodsDevProject commited on
Commit
8aeaa64
·
verified ·
1 Parent(s): bb2e03b

Create semantic/faiss_index.py

Browse files
Files changed (1) hide show
  1. semantic/faiss_index.py +7 -12
semantic/faiss_index.py CHANGED
@@ -1,26 +1,21 @@
1
  import faiss
2
- import numpy as np
3
  from sentence_transformers import SentenceTransformer
4
 
5
- MODEL = SentenceTransformer("all-MiniLM-L6-v2")
6
-
7
  class SemanticIndex:
8
  def __init__(self):
 
9
  self.index = None
10
  self.texts = []
11
 
12
  def build(self, texts):
13
  self.texts = texts
14
- embeddings = MODEL.encode(texts, convert_to_numpy=True)
15
- dim = embeddings.shape[1]
16
- self.index = faiss.IndexFlatL2(dim)
17
- self.index.add(embeddings)
18
 
19
  def search(self, query, k=10):
20
  if not self.index:
21
  return []
22
-
23
- q_emb = MODEL.encode([query], convert_to_numpy=True)
24
- distances, indices = self.index.search(q_emb, k)
25
-
26
- return [self.texts[i] for i in indices[0] if i < len(self.texts)]
 
1
  import faiss
 
2
  from sentence_transformers import SentenceTransformer
3
 
 
 
4
  class SemanticIndex:
5
  def __init__(self):
6
+ self.model = SentenceTransformer("all-MiniLM-L6-v2")
7
  self.index = None
8
  self.texts = []
9
 
10
  def build(self, texts):
11
  self.texts = texts
12
+ emb = self.model.encode(texts)
13
+ self.index = faiss.IndexFlatL2(len(emb[0]))
14
+ self.index.add(emb)
 
15
 
16
  def search(self, query, k=10):
17
  if not self.index:
18
  return []
19
+ q = self.model.encode([query])
20
+ _, idx = self.index.search(q, k)
21
+ return [self.texts[i] for i in idx[0]]