balodhi commited on
Commit
c3ac79c
·
1 Parent(s): 9e8b3eb

added gita.txt, removed faiss

Browse files
Files changed (2) hide show
  1. app.py +19 -7
  2. gita.txt +0 -0
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from sentence_transformers import SentenceTransformer
4
- import faiss
5
  import numpy as np
6
 
 
 
7
  # =========================
8
  # Load and Prepare Gita Text
9
  # =========================
@@ -28,17 +30,27 @@ embedder = SentenceTransformer("all-MiniLM-L6-v2")
28
  doc_embeddings = embedder.encode(documents)
29
  dimension = doc_embeddings.shape[1]
30
 
31
- index = faiss.IndexFlatL2(dimension)
32
- index.add(np.array(doc_embeddings))
33
-
34
 
35
  def retrieve(query, top_k=4):
36
- query_embedding = embedder.encode([query])
37
- distances, indices = index.search(np.array(query_embedding), top_k)
38
- results = [documents[i] for i in indices[0]]
 
39
  return "\n\n".join(results)
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
42
  # =========================
43
  # RAG Chat Function
44
  # =========================
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from sentence_transformers import SentenceTransformer
4
+ #import faiss
5
  import numpy as np
6
 
7
+
8
+
9
  # =========================
10
  # Load and Prepare Gita Text
11
  # =========================
 
30
  doc_embeddings = embedder.encode(documents)
31
  dimension = doc_embeddings.shape[1]
32
 
33
+ doc_embeddings = embedder.encode(documents)
 
 
34
 
35
  def retrieve(query, top_k=4):
36
+ query_embedding = embedder.encode([query])[0]
37
+ scores = np.dot(doc_embeddings, query_embedding)
38
+ top_indices = np.argsort(scores)[-top_k:][::-1]
39
+ results = [documents[i] for i in top_indices]
40
  return "\n\n".join(results)
41
 
42
 
43
+ # index = faiss.IndexFlatL2(dimension)
44
+ # index.add(np.array(doc_embeddings))
45
+
46
+
47
+ # def retrieve(query, top_k=4):
48
+ # query_embedding = embedder.encode([query])
49
+ # distances, indices = index.search(np.array(query_embedding), top_k)
50
+ # results = [documents[i] for i in indices[0]]
51
+ # return "\n\n".join(results)
52
+
53
+
54
  # =========================
55
  # RAG Chat Function
56
  # =========================
gita.txt ADDED
The diff for this file is too large to render. See raw diff