WillyCodesInit commited on
Commit
89da9cc
·
verified ·
1 Parent(s): b02256f

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +73 -22
utils.py CHANGED
@@ -1,32 +1,83 @@
 
1
  import pandas as pd
2
  import numpy as np
3
- import faiss
4
  from sentence_transformers import SentenceTransformer
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
6
 
7
- # Load your CSV with 'question' and 'answer' columns
8
- df = pd.read_csv("train_data.csv")
9
- qa_pairs = df["question"] + " | " + df["answer"]
 
10
 
11
- # Sentence Transformer for embeddings
12
- embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
13
- embeddings = embedding_model.encode(qa_pairs.tolist(), convert_to_numpy=True)
14
 
15
- # FAISS index
16
- dimension = embeddings.shape[1]
17
- index = faiss.IndexFlatL2(dimension)
18
- index.add(embeddings)
 
 
 
 
 
19
 
20
- # FLAN-T5
21
- tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
22
- model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
 
 
 
 
 
 
 
23
 
24
- def ask_finance_bot(user_query, top_k=3):
 
 
 
 
 
 
 
 
 
 
25
  query_embedding = embedding_model.encode([user_query])
26
- D, I = index.search(np.array(query_embedding), top_k)
27
- context = "\n".join([qa_pairs[i] for i in I[0]])
28
 
29
- prompt = f"Context:\n{context}\n\nQuestion: {user_query}\nAnswer:"
30
- inputs = tokenizer(prompt, return_tensors="pt")
31
- outputs = model.generate(**inputs, max_new_tokens=2045)
32
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import pandas as pd
3
  import numpy as np
 
4
  from sentence_transformers import SentenceTransformer
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
+ # Initialize model and tokenizer
9
+ model_name = "google/flan-t5-base" # You can use a different model if needed
10
+ model = AutoModelForCausalLM.from_pretrained(model_name)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
13
+ # Sentence transformer model to encode questions for similarity
14
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
 
15
 
16
+ # Load question-answer data from CSV
17
+ def load_qa_data_from_csv(file_path):
18
+ """
19
+ Reads a CSV file containing question-answer pairs.
20
+ Assumes the CSV file has columns 'question' and 'answer'.
21
+ """
22
+ data = pd.read_csv(file_path)
23
+ qa_pairs = list(zip(data['question'], data['answer']))
24
+ return qa_pairs
25
 
26
+ # Load question-answer data from JSON
27
+ def load_qa_data_from_json(file_path):
28
+ """
29
+ Reads a JSON file containing question-answer pairs.
30
+ """
31
+ with open(file_path, 'r') as file:
32
+ data = json.load(file)
33
+
34
+ qa_pairs = [(item['question'], item['answer']) for item in data]
35
+ return qa_pairs
36
 
37
+ # Check if the question is related to finance
38
+ def is_valid_finance_question(question):
39
+ # Here you can refine the check to use model verification as well
40
+ # For now, we are doing a simple check based on keywords
41
+ finance_keywords = ['finance', 'investment', 'bank', 'insurance', 'credit', 'budget', 'economy', 'inflation',
42
+ 'debt', 'interest', 'mortgage', 'pension', 'retirement', 'savings']
43
+ return any(keyword in question.lower() for keyword in finance_keywords)
44
+
45
+ # Generate the response for a valid financial question
46
+ def ask_finance_bot(user_query, qa_pairs):
47
+ # Embed the user query
48
  query_embedding = embedding_model.encode([user_query])
 
 
49
 
50
+ # Assuming 'index' here is a pre-built FAISS index or similar structure
51
+ # For this example, using a basic search from qa_pairs
52
+ retrieved_qa_pairs = qa_pairs[:3] # Take top 3 for now, or improve with vector search
53
+
54
+ # Temperature control to avoid repetition if same question is asked frequently
55
+ temperature = 0.7
56
+
57
+ instruction = (
58
+ "You are a highly knowledgeable AI assistant specializing strictly in finance.\n"
59
+ "Strictly answer only financially related topics.\n"
60
+ "Do not answer anything outside finance.\n"
61
+ "Always provide accurate, objective, and concise answers to financial questions.\n"
62
+ )
63
+
64
+ # Create the prompt for the model
65
+ prompt = f"{instruction}\n\nUser query: {user_query}\nAnswer:"
66
+
67
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
68
+ output_ids = model.generate(
69
+ **input_ids,
70
+ max_new_tokens=256,
71
+ temperature=temperature,
72
+ top_p=0.9,
73
+ pad_token_id=tokenizer.eos_token_id
74
+ )
75
+
76
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
77
+ answer_text = response.split("Answer:")[-1].strip()
78
+
79
+ if is_valid_finance_question(answer_text):
80
+ return answer_text
81
+ else:
82
+ return "I'm specialized in finance and can't help with that."
83
+