WillyCodesInit commited on
Commit
9539702
·
verified ·
1 Parent(s): f0a3492

Create model_utils.py

Browse files
Files changed (1) hide show
  1. src/model_utils.py +83 -0
src/model_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_utils.py
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # --- Load LLaMA model ---
6
+ def load_llama_model():
7
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # Ensure you have access
8
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ model_id,
11
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
12
+ device_map="auto"
13
+ )
14
+ return tokenizer, model
15
+
16
+ # --- Helper functions ---
17
+ def is_finance_question(user_query, tokenizer, model):
18
+ check_prompt = (
19
+ f"You are a financial expert. Determine whether the following question is clearly about finance:\n\n"
20
+ f"Question: {user_query}\n\n"
21
+ f"Respond only with 'Yes' or 'No'."
22
+ )
23
+ input_ids = tokenizer(check_prompt, return_tensors="pt").to(model.device)
24
+ output_ids = model.generate(
25
+ **input_ids,
26
+ max_new_tokens=10,
27
+ temperature=0.0,
28
+ top_p=0.9,
29
+ pad_token_id=tokenizer.eos_token_id
30
+ )
31
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
32
+ return response.lower().startswith("yes")
33
+
34
+ def ask_finance_bot(user_query, answers, embedding_model, index, tokenizer, model, top_k=3):
35
+ normalized_query = user_query.lower().strip()
36
+ recent_questions = {}
37
+
38
+ count = recent_questions.get(normalized_query, 0) + 1
39
+ recent_questions[normalized_query] = count
40
+
41
+ query_embedding = embedding_model.encode([user_query])
42
+ D, I = index.search(np.array(query_embedding), top_k)
43
+ retrieved_answers = [answers[i] for i in I[0]]
44
+ context = "\n".join([f"- {text}" for text in retrieved_answers])
45
+
46
+ temperature = min(0.7 + 0.1 * (count - 1), 1.0)
47
+
48
+ instruction = (
49
+ "You are a highly knowledgeable AI assistant specializing strictly in finance.\n"
50
+ "Strictly answer only financially related topics.\n"
51
+ "Never answer questions that are not financially related.\n"
52
+ "Always provide accurate, objective, and concise answers to financial questions.\n"
53
+ "If a valid financial question is asked, always answer.\n"
54
+ "If a question is unrelated to finance, respond: 'I'm specialized in finance and can't help with that. How can I assist you with a finance-related question today?'\n"
55
+ "If a greeting like 'Hi', 'Hello', or 'Hey' is used, respond with: 'Hello! How can I help you with your finance-related question today?'\n"
56
+ )
57
+
58
+ for _ in range(4):
59
+ prompt = f"""{instruction}
60
+
61
+ Background context:
62
+ {context}
63
+
64
+ User question: {user_query}
65
+
66
+ Answer:"""
67
+
68
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
69
+ output_ids = model.generate(
70
+ **input_ids,
71
+ max_new_tokens=256,
72
+ temperature=temperature,
73
+ top_p=0.9,
74
+ pad_token_id=tokenizer.eos_token_id
75
+ )
76
+
77
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
78
+ answer_text = response.split("Answer:")[-1].strip()
79
+
80
+ if any(word.lower() in answer_text.lower() for word in user_query.lower().split()):
81
+ return answer_text
82
+
83
+ return "I'm not confident in the response. Please consult a certified financial expert."