WillyCodesInit commited on
Commit
f143c26
·
verified ·
1 Parent(s): 308e42a

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +33 -20
utils.py CHANGED
@@ -1,25 +1,38 @@
1
- import pandas as pd
2
- from sentence_transformers import SentenceTransformer
3
- import faiss
4
  import numpy as np
 
 
 
5
 
6
- def load_dataset(path):
7
- df = pd.read_csv(path)
8
- df = df.dropna(subset=['question', 'answer'])
 
 
 
 
9
  return df
10
 
11
- def embed_questions(df):
12
- embed_model = SentenceTransformer("all-MiniLM-L6-v2")
13
- embeddings = embed_model.encode(df["question"].tolist(), convert_to_tensor=False)
14
- index = faiss.IndexFlatL2(embeddings[0].shape[0])
15
- index.add(np.array(embeddings))
16
- return embed_model, index
 
 
 
 
 
 
 
 
 
17
 
18
- def retrieve_context(query, embed_model, index, df, k=3):
19
- query_embedding = embed_model.encode([query])[0]
20
- distances, indices = index.search(np.array([query_embedding]), k)
21
- results = []
22
- for i in indices[0]:
23
- if i < len(df):
24
- results.append(f"Q: {df.iloc[i]['question']}\nA: {df.iloc[i]['answer']}")
25
- return "\n\n".join(results)
 
 
 
 
1
  import numpy as np
2
+ import faiss
3
+ import json
4
+ from sentence_transformers import SentenceTransformer
5
 
6
+ def load_dataset(file_path):
7
+ """
8
+ Loads the dataset (CSV file) and returns a list of Q&A pairs.
9
+ """
10
+ import pandas as pd
11
+ df = pd.read_csv(file_path)
12
+ df.dropna(subset=["question", "answer"], inplace=True) # Remove any rows with missing questions/answers
13
  return df
14
 
15
+ def embed_questions(df, model_name='all-MiniLM-L6-v2'):
16
+ """
17
+ Embeds the questions and answers using the sentence transformer model.
18
+ """
19
+ model = SentenceTransformer(model_name)
20
+ qa_pairs = [f"Q: {q.strip()} A: {a.strip()}" for q, a in zip(df["question"], df["answer"])]
21
+ embeddings = model.encode(qa_pairs, show_progress_bar=True)
22
+ embeddings = np.array(embeddings).astype("float32")
23
+
24
+ # Create FAISS index
25
+ index = faiss.IndexFlatL2(embeddings.shape[1]) # Create the index for cosine similarity search
26
+ index.add(embeddings)
27
+
28
+ # Return QA pairs and the index
29
+ return qa_pairs, index
30
 
31
+ def retrieve_context(query, embed_model, index, qa_pairs, top_k=3):
32
+ """
33
+ Retrieves the most relevant context from the dataset for a given query.
34
+ """
35
+ query_embedding = embed_model.encode([query])
36
+ D, I = index.search(np.array(query_embedding).astype("float32"), top_k)
37
+ retrieved_qa_pairs = [qa_pairs[i] for i in I[0]]
38
+ return "\n".join([f"- {pair}" for pair in retrieved_qa_pairs])