File size: 1,157 Bytes
e0f19f0
777d446
e0f19f0
777d446
 
e0f19f0
777d446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f19f0
777d446
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import pandas as pd
from sentence_transformers import SentenceTransformer, util

# Load and embed Q&A data
def load_qa_data(file_path="train_data.csv"):
    df = pd.read_csv(file_path)
    questions = df["question"].tolist()
    answers = df["answer"].tolist()
    return questions, answers

# Set up embeddings
model = SentenceTransformer("all-MiniLM-L6-v2")
questions, answers = load_qa_data()
question_embeddings = model.encode(questions, convert_to_tensor=True)

# Search top-k context
def retrieve_context(query, top_k=3):
    query_embedding = model.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(query_embedding, question_embeddings, top_k=top_k)[0]

    context = ""
    for hit in hits:
        idx = hit["corpus_id"]
        context += f"Q: {questions[idx]}\nA: {answers[idx]}\n\n"

    return context.strip()

# Use RAG to check if it's a finance question
def is_financial_text_via_context(query, threshold=0.3):
    query_embedding = model.encode(query, convert_to_tensor=True)
    similarities = util.cos_sim(query_embedding, question_embeddings)[0]
    max_sim = float(similarities.max())
    return max_sim > threshold