finsmart / utils.py
WillyCodesInit's picture
Update utils.py
777d446 verified
import pandas as pd
from sentence_transformers import SentenceTransformer, util
# Load and embed Q&A data
def load_qa_data(file_path="train_data.csv"):
df = pd.read_csv(file_path)
questions = df["question"].tolist()
answers = df["answer"].tolist()
return questions, answers
# Set up embeddings
model = SentenceTransformer("all-MiniLM-L6-v2")
questions, answers = load_qa_data()
question_embeddings = model.encode(questions, convert_to_tensor=True)
# Search top-k context
def retrieve_context(query, top_k=3):
query_embedding = model.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, question_embeddings, top_k=top_k)[0]
context = ""
for hit in hits:
idx = hit["corpus_id"]
context += f"Q: {questions[idx]}\nA: {answers[idx]}\n\n"
return context.strip()
# Use RAG to check if it's a finance question
def is_financial_text_via_context(query, threshold=0.3):
query_embedding = model.encode(query, convert_to_tensor=True)
similarities = util.cos_sim(query_embedding, question_embeddings)[0]
max_sim = float(similarities.max())
return max_sim > threshold