WillyCodesInit commited on
Commit
777d446
·
verified ·
1 Parent(s): ca9e3c9

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +29 -14
utils.py CHANGED
@@ -1,18 +1,33 @@
1
  import pandas as pd
2
- from sentence_transformers import SentenceTransformer
3
 
4
- # Load embedding model
5
- embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
6
-
7
- # Load CSV data
8
- def load_qa_data(file_path="qa_data.csv"):
9
  df = pd.read_csv(file_path)
10
- return list(zip(df["question"], df["answer"]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Basic semantic check (optional if you want to embed later)
13
- def is_financial_text(text: str) -> bool:
14
- finance_keywords = [
15
- "finance", "investment", "bank", "insurance", "credit", "budget", "economy",
16
- "inflation", "debt", "interest", "mortgage", "pension", "retirement", "savings"
17
- ]
18
- return any(keyword in text.lower() for keyword in finance_keywords)
 
1
  import pandas as pd
2
+ from sentence_transformers import SentenceTransformer, util
3
 
4
+ # Load and embed Q&A data
5
+ def load_qa_data(file_path="train_data.csv"):
 
 
 
6
  df = pd.read_csv(file_path)
7
+ questions = df["question"].tolist()
8
+ answers = df["answer"].tolist()
9
+ return questions, answers
10
+
11
+ # Set up embeddings
12
+ model = SentenceTransformer("all-MiniLM-L6-v2")
13
+ questions, answers = load_qa_data()
14
+ question_embeddings = model.encode(questions, convert_to_tensor=True)
15
+
16
+ # Search top-k context
17
+ def retrieve_context(query, top_k=3):
18
+ query_embedding = model.encode(query, convert_to_tensor=True)
19
+ hits = util.semantic_search(query_embedding, question_embeddings, top_k=top_k)[0]
20
+
21
+ context = ""
22
+ for hit in hits:
23
+ idx = hit["corpus_id"]
24
+ context += f"Q: {questions[idx]}\nA: {answers[idx]}\n\n"
25
+
26
+ return context.strip()
27
 
28
+ # Use RAG to check if it's a finance question
29
+ def is_financial_text_via_context(query, threshold=0.3):
30
+ query_embedding = model.encode(query, convert_to_tensor=True)
31
+ similarities = util.cos_sim(query_embedding, question_embeddings)[0]
32
+ max_sim = float(similarities.max())
33
+ return max_sim > threshold