balodhi commited on
Commit
3bb3644
·
1 Parent(s): 6718a33

fixed stuff

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -10,8 +10,10 @@ import numpy as np
10
 
11
  embedding_client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2")
12
 
13
- def embed(texts):
14
- return embedding_client.feature_extraction(texts)
 
 
15
 
16
 
17
 
@@ -22,8 +24,6 @@ def embed(texts):
22
  with open("gita.txt", "r", encoding="utf-8") as f:
23
  raw_text = f.read()
24
 
25
- doc_embeddings = np.array(embed(documents))
26
-
27
  def chunk_text(text, chunk_size=500, overlap=50):
28
  chunks = []
29
  start = 0
@@ -33,6 +33,9 @@ def chunk_text(text, chunk_size=500, overlap=50):
33
  start += chunk_size - overlap
34
  return chunks
35
 
 
 
 
36
  # Embedding model (small + free)
37
  #embedder = SentenceTransformer("all-MiniLM-L6-v2")
38
  #doc_embeddings = #embedder.encode(documents)
@@ -40,7 +43,7 @@ def chunk_text(text, chunk_size=500, overlap=50):
40
  #doc_embeddings = embedder.encode(documents)
41
 
42
  def retrieve(query, top_k=4):
43
- query_embedding = embedder.encode([query])[0]
44
  scores = np.dot(doc_embeddings, query_embedding)
45
  top_indices = np.argsort(scores)[-top_k:][::-1]
46
  results = [documents[i] for i in top_indices]
 
10
 
11
  embedding_client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2")
12
 
13
+ def embed_texts(texts):
14
+ if isinstance(texts, str):
15
+ texts = [texts]
16
+ return np.array(embedding_client.feature_extraction(texts))
17
 
18
 
19
 
 
24
  with open("gita.txt", "r", encoding="utf-8") as f:
25
  raw_text = f.read()
26
 
 
 
27
  def chunk_text(text, chunk_size=500, overlap=50):
28
  chunks = []
29
  start = 0
 
33
  start += chunk_size - overlap
34
  return chunks
35
 
36
+ documents = chunk_text(raw_text)
37
+ doc_embeddings = embed_texts(documents)
38
+
39
  # Embedding model (small + free)
40
  #embedder = SentenceTransformer("all-MiniLM-L6-v2")
41
  #doc_embeddings = #embedder.encode(documents)
 
43
  #doc_embeddings = embedder.encode(documents)
44
 
45
  def retrieve(query, top_k=4):
46
+ query_embedding = embed_texts(query)[0]
47
  scores = np.dot(doc_embeddings, query_embedding)
48
  top_indices = np.argsort(scores)[-top_k:][::-1]
49
  results = [documents[i] for i in top_indices]