WillyCodesInit commited on
Commit
644cc5a
·
verified ·
1 Parent(s): 0c571b0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +46 -47
src/streamlit_app.py CHANGED
@@ -8,45 +8,50 @@ from sentence_transformers import SentenceTransformer
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from huggingface_hub import login
10
 
11
- # --- HuggingFace login ---
12
- HF_TOKEN = os.getenv("HF_TOKEN")
13
  if HF_TOKEN:
14
  login(token=HF_TOKEN)
15
-
16
- # Path to the data file within the 'src' folder
17
- data_path = os.path.join(os.path.dirname(__file__), 'train_data.csv')
18
-
19
-
20
-
21
- # Load data
22
- train = pd.read_csv(data_path)
23
-
24
 
25
  # --- Load data ---
 
 
 
 
 
26
 
27
- questions = train['question'].tolist()
28
- answers = train['answer'].tolist()
29
-
30
  qa_pairs = [f"Q: {q} A: {a}" for q, a in zip(questions, answers)]
31
 
32
- # --- Embedding model ---
33
- embedding_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
34
- answer_embeddings = embedding_model.encode(answers)
35
-
36
- # --- FAISS index ---
37
- index = faiss.IndexFlatL2(answer_embeddings.shape[1])
38
- index.add(np.array(answer_embeddings))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # --- LLaMA model setup ---
41
- model_name = "meta-llama/Llama-3-8B-Instruct" # Update to a valid space-available model if needed
42
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
43
- model = AutoModelForCausalLM.from_pretrained(
44
- model_name,
45
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
46
- device_map="auto"
47
- )
48
 
49
- # --- Helper Functions ---
50
  def sanitize_answer(question, answer):
51
  return any(word.lower() in answer.lower() for word in question.lower().split())
52
 
@@ -74,7 +79,6 @@ def ask_finance_bot(user_query, top_k=3):
74
  count = recent_questions.get(normalized_query, 0) + 1
75
  recent_questions[normalized_query] = count
76
 
77
- # Embed user query
78
  query_embedding = embedding_model.encode([user_query])
79
  D, I = index.search(np.array(query_embedding), top_k)
80
  retrieved_answers = [answers[i] for i in I[0]]
@@ -86,18 +90,13 @@ def ask_finance_bot(user_query, top_k=3):
86
  "You are a highly knowledgeable AI assistant specializing strictly in finance.\n"
87
  "Strictly answer only financially related topics.\n"
88
  "Never answer questions that are not financially related.\n"
89
- "Do not answer anything outside finance.\n"
90
  "Always provide accurate, objective, and concise answers to financial questions.\n"
91
- "Avoid unnecessary elaboration and focus directly on answering the user's query.\n"
92
- "Use the background context only if it is accurate, clear, and relevant. If the context is unclear, incomplete, low-quality, or irrelevant, ignore it and generate your own correct, concise financial answer.\n"
93
- "Do not copy or repeat the context verbatim — instead, synthesize your own response based on it.\n"
94
- "Do not speculate or use personal phrases like 'I think' or 'In my opinion'.\n"
95
- "If a valid financial question is asked, always answer — never refuse or say 'I can't help with that.'\n"
96
  "If a question is unrelated to finance, respond: 'I'm specialized in finance and can't help with that. How can I assist you with a finance-related question today?'\n"
97
  "If a greeting like 'Hi', 'Hello', or 'Hey' is used, respond with: 'Hello! How can I help you with your finance-related question today?'\n"
98
  )
99
 
100
- for _ in range(6):
101
  prompt = f"""{instruction}
102
 
103
  Background context:
@@ -124,17 +123,17 @@ Answer:"""
124
 
125
  return "I'm not confident in the response. Please consult a certified financial expert."
126
 
127
- # --- Streamlit App UI ---
128
  st.set_page_config(page_title="DiMowkayBot - Finance Assistant", layout="centered")
129
- st.title("DiMowkayBot - Your Finance Q&A Assistant")
130
 
131
- user_query = st.text_input("Enter your finance-related question:")
132
 
133
  if user_query:
134
- if not is_finance_question(user_query):
135
- st.warning("I'm specialized in finance and can't help with that. How can I assist you with a finance-related question today?")
136
- else:
137
- with st.spinner("Thinking..."):
138
  answer = ask_finance_bot(user_query)
139
- st.success("Response:")
140
- st.write(answer)
 
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from huggingface_hub import login
10
 
11
+ # --- Hugging Face login ---
12
+ HF_TOKEN = st.secrets.get("HF_TOKEN", os.getenv("HF_TOKEN"))
13
  if HF_TOKEN:
14
  login(token=HF_TOKEN)
15
+ else:
16
+ st.error("Hugging Face token not found. Please set it in secrets.toml or environment.")
17
+ st.stop()
 
 
 
 
 
 
18
 
19
  # --- Load data ---
20
+ @st.cache_data
21
+ def load_data():
22
+ data_path = os.path.join(os.path.dirname(__file__), 'train_data.csv')
23
+ df = pd.read_csv(data_path)
24
+ return df['question'].tolist(), df['answer'].tolist()
25
 
26
+ questions, answers = load_data()
 
 
27
  qa_pairs = [f"Q: {q} A: {a}" for q, a in zip(questions, answers)]
28
 
29
+ # --- Embedding model and FAISS index ---
30
+ @st.cache_resource
31
+ def setup_embeddings():
32
+ embedder = SentenceTransformer('paraphrase-MiniLM-L6-v2')
33
+ answer_embeddings = embedder.encode(answers, show_progress_bar=True)
34
+ index = faiss.IndexFlatL2(answer_embeddings.shape[1])
35
+ index.add(np.array(answer_embeddings))
36
+ return embedder, index
37
+
38
+ embedding_model, index = setup_embeddings()
39
+
40
+ # --- Load LLaMA model ---
41
+ @st.cache_resource
42
+ def load_llama_model():
43
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # Ensure you have access
44
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_id,
47
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
48
+ device_map="auto"
49
+ )
50
+ return tokenizer, model
51
 
52
+ tokenizer, model = load_llama_model()
 
 
 
 
 
 
 
53
 
54
+ # --- Helper functions ---
55
  def sanitize_answer(question, answer):
56
  return any(word.lower() in answer.lower() for word in question.lower().split())
57
 
 
79
  count = recent_questions.get(normalized_query, 0) + 1
80
  recent_questions[normalized_query] = count
81
 
 
82
  query_embedding = embedding_model.encode([user_query])
83
  D, I = index.search(np.array(query_embedding), top_k)
84
  retrieved_answers = [answers[i] for i in I[0]]
 
90
  "You are a highly knowledgeable AI assistant specializing strictly in finance.\n"
91
  "Strictly answer only financially related topics.\n"
92
  "Never answer questions that are not financially related.\n"
 
93
  "Always provide accurate, objective, and concise answers to financial questions.\n"
94
+ "If a valid financial question is asked, always answer.\n"
 
 
 
 
95
  "If a question is unrelated to finance, respond: 'I'm specialized in finance and can't help with that. How can I assist you with a finance-related question today?'\n"
96
  "If a greeting like 'Hi', 'Hello', or 'Hey' is used, respond with: 'Hello! How can I help you with your finance-related question today?'\n"
97
  )
98
 
99
+ for _ in range(4):
100
  prompt = f"""{instruction}
101
 
102
  Background context:
 
123
 
124
  return "I'm not confident in the response. Please consult a certified financial expert."
125
 
126
+ # --- Streamlit UI ---
127
  st.set_page_config(page_title="DiMowkayBot - Finance Assistant", layout="centered")
128
+ st.title("🤖 DiMowkayBot - Your Finance Q&A Assistant")
129
 
130
+ user_query = st.text_input("Ask a finance-related question:")
131
 
132
  if user_query:
133
+ with st.spinner("Thinking..."):
134
+ if not is_finance_question(user_query):
135
+ st.warning("I'm specialized in finance and can't help with that. How can I assist you with a finance-related question today?")
136
+ else:
137
  answer = ask_finance_bot(user_query)
138
+ st.success("Response:")
139
+ st.write(answer)