larrysim commited on
Commit
7993469
Β·
verified Β·
1 Parent(s): 08b3d75

Update app.py

Browse files

add in gemini api keys

Files changed (1) hide show
  1. app.py +78 -62
app.py CHANGED
@@ -9,7 +9,7 @@ import shutil
9
  # ==========================================
10
  # 1. PAGE CONFIG (MUST BE FIRST)
11
  # ==========================================
12
- st.set_page_config(page_title="Bank Loan Agent (SQL)", layout="wide")
13
 
14
  # Suppress warnings
15
  warnings.filterwarnings("ignore")
@@ -22,7 +22,12 @@ INDEX_PATH = "faiss_index"
22
  REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate Policy.pdf"]
23
 
24
  try:
 
25
  from langchain_groq import ChatGroq
 
 
 
 
26
  from langchain_huggingface import HuggingFaceEmbeddings
27
  from langchain_community.vectorstores import FAISS
28
  from langchain_community.callbacks import StreamlitCallbackHandler
@@ -36,13 +41,14 @@ try:
36
 
37
  except ImportError as e:
38
  st.error(f"❌ Critical Import Error: {e}")
 
39
  st.stop()
40
 
41
  # ==========================================
42
  # 3. DATABASE SETUP
43
  # ==========================================
44
  def init_db():
45
- """Converts CSV files to SQLite DB. Handles errors gracefully."""
46
  if os.path.exists(DB_FILE):
47
  return
48
 
@@ -60,7 +66,6 @@ def init_db():
60
  df.columns = [c.strip() for c in df.columns]
61
  if 'ID' in df.columns:
62
  df['ID'] = df['ID'].astype(str)
63
-
64
  try:
65
  df.to_sql(table, conn, if_exists='replace', index=False)
66
  except Exception:
@@ -70,10 +75,8 @@ def init_db():
70
  finally:
71
  conn.close()
72
 
73
- # Initialize DB on startup
74
  init_db()
75
 
76
- # Helper for SQL tools
77
  def run_query(query, params=()):
78
  try:
79
  with sqlite3.connect(DB_FILE) as conn:
@@ -94,7 +97,7 @@ def get_credit_score(user_id: str) -> str:
94
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
95
  if row and not isinstance(row, str):
96
  return f"Credit Score: {row[0]}"
97
- return "User ID not found in Credit DB."
98
 
99
  @tool
100
  def get_account_status(user_id: str) -> str:
@@ -106,28 +109,26 @@ def get_account_status(user_id: str) -> str:
106
  )
107
  if row and not isinstance(row, str):
108
  return f"Customer Name: {row[0]}, Nationality: {row[1]}, Status: {row[2]}, Email: {row[3]}"
109
- return "User ID not found in Account DB."
110
 
111
  @tool
112
  def check_pr_status(user_id: str) -> str:
113
  """Queries SQL DB for PR Status."""
114
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
115
  row = run_query("SELECT PR_Status FROM pr_status WHERE ID = ?", (clean_id,))
116
-
117
  if not row or (isinstance(row, str) and "no such column" in row.lower()):
118
  row = run_query("SELECT Is_PR FROM pr_status WHERE ID = ?", (clean_id,))
119
-
120
  if row and not isinstance(row, str):
121
  return f"PR Status: {row[0]}"
122
- return "PR Status: False (Record not found)"
123
 
124
  # ==========================================
125
  # 5. STREAMLIT APP UI
126
  # ==========================================
127
- st.title("πŸ€– Multi-Policy Loan Assessor (SQL + RAG)")
128
  st.markdown("Agent connects to **SQLite Database** and **Persistent Vector Store**")
129
 
130
- # Calculate missing PDFs globally so everyone can see it
131
  pdfs_missing = [f for f in REQUIRED_PDFS if not os.path.exists(f)]
132
 
133
  # --- METRICS FUNCTION ---
@@ -142,58 +143,74 @@ def update_metrics(placeholder):
142
  col_kpi1.metric("AI Processing", f"{ai_time:.1f}s")
143
  col_kpi2.metric("Time Saved", f"{time_saved/60:.1f} min", delta=f"{saved_pct:.1f}% faster")
144
 
145
- # --- SIDEBAR ---
146
  with st.sidebar:
147
  st.header("πŸ” Authentication")
148
 
149
- # Initialize Session State for Key
150
- if 'is_key_valid' not in st.session_state:
151
- st.session_state['is_key_valid'] = False
152
-
153
- # MANUAL ENTRY ONLY (No Secret Check)
154
- if not st.session_state['is_key_valid']:
155
- api_key_input = st.text_input("Enter Groq API Key", type="password", key="input_key")
156
- if st.button("Validate API Key"):
 
 
 
 
 
 
 
 
 
 
 
 
157
  if not api_key_input:
158
  st.error("⚠️ Please enter a key.")
159
  else:
160
  try:
161
- with st.spinner("Validating..."):
162
- test_llm = ChatGroq(api_key=api_key_input, model_name="llama-3.3-70b-versatile")
163
- test_llm.invoke("Test")
164
- st.session_state['groq_api_key'] = api_key_input
165
- st.session_state['is_key_valid'] = True
 
 
 
 
 
 
 
166
  st.success("βœ… Valid Key!")
167
  time.sleep(0.5)
168
  st.rerun()
169
  except Exception as e:
170
- st.error(f"❌ Invalid Key: {e}")
171
  else:
172
- st.success("βœ… API Key Active")
173
- if st.button("πŸ”΄ Reset Key"):
174
- st.session_state['is_key_valid'] = False
175
- st.session_state['groq_api_key'] = None
176
  st.rerun()
177
 
178
  st.divider()
179
- st.subheader("πŸ› οΈ System Maintenance")
180
 
181
  if st.button("♻️ Rebuild Knowledge Base"):
182
- if os.path.exists(INDEX_PATH):
183
- shutil.rmtree(INDEX_PATH)
184
  st.cache_resource.clear()
185
  st.success("Cache cleared.")
186
  time.sleep(1)
187
  st.rerun()
188
 
189
- if st.button("πŸ’Ύ Reload CSVs to DB"):
190
- if os.path.exists(DB_FILE):
191
- os.remove(DB_FILE)
192
  init_db()
193
  st.success("Database refreshed.")
194
 
195
  st.divider()
196
-
197
  if os.path.exists(DB_FILE) and not pdfs_missing:
198
  st.success("βœ… System Ready")
199
  else:
@@ -204,18 +221,17 @@ with st.sidebar:
204
  update_metrics(metrics_placeholder)
205
 
206
  # --- MAIN LOGIC ---
207
- if st.session_state.get('is_key_valid', False):
208
-
209
- os.environ["GROQ_API_KEY"] = st.session_state['groq_api_key']
210
 
211
- # --- RAG SETUP ---
212
  @st.cache_resource
213
  def setup_rag():
214
- # Check global variable here
215
  if pdfs_missing:
216
  st.error(f"Missing PDFs: {pdfs_missing}")
217
  st.stop()
218
 
 
 
219
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
220
 
221
  if os.path.exists(INDEX_PATH):
@@ -233,11 +249,19 @@ if st.session_state.get('is_key_valid', False):
233
  vectorstore.save_local(INDEX_PATH)
234
  return vectorstore.as_retriever()
235
 
236
- with st.spinner("Initializing AI..."):
237
  retriever = setup_rag()
238
 
239
- llm = ChatGroq(temperature=0, model_name="llama-3.3-70b-versatile")
240
-
 
 
 
 
 
 
 
 
241
  rag_prompt = ChatPromptTemplate.from_template("Answer based on context:\n{context}\nQuestion: {question}")
242
  rag_chain = (
243
  {"context": retriever | (lambda d: "\n".join([x.page_content for x in d])), "question": RunnablePassthrough()}
@@ -260,6 +284,7 @@ if st.session_state.get('is_key_valid', False):
260
  agent = create_tool_calling_agent(llm, tools, prompt)
261
  agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
262
 
 
263
  col1, col2 = st.columns([1, 2])
264
  with col1:
265
  st.subheader("1. Customer Details")
@@ -277,20 +302,16 @@ if st.session_state.get('is_key_valid', False):
277
 
278
  with col2:
279
  if btn:
 
280
  if use_simulation:
281
  query = f"""
282
  Process Loan for Customer ID: {uid}.
283
- *** SIMULATION MODE ***
284
  1. DO NOT query 'get_credit_score' or 'account_status' for Score/Status.
285
  2. USE: Score: {sim_score}, Status: {sim_status}
286
  3. Query 'get_account_status' ONLY for Name/Nationality.
287
  4. Consult Policy Docs for risk/rates.
288
- 5. Provide a Final Recommendation Report that MUST include:
289
- - Customer Name, ID, Email
290
- - Risk Level, Interest Rate
291
- - Final Decision (Approve/Reject)
292
- - Justification for Decision (Cite specific PDF policies)
293
- - Format in a clear markdown table.
294
  """
295
  else:
296
  query = f"""
@@ -298,15 +319,10 @@ if st.session_state.get('is_key_valid', False):
298
  1. Query SQL tools for Name, Email, Nationality, Status, Score.
299
  2. IF Nationality is 'Singaporean', SKIP 'check_pr_status'.
300
  3. Consult Policy Docs for risk/rates.
301
- 4. Provide a Final Recommendation Report that MUST include:
302
- - Customer Name, ID, Email
303
- - Risk Level, Interest Rate
304
- - Final Decision (Approve/Reject)
305
- - Justification for Decision (Cite specific PDF policies)
306
- - Format in a clear markdown table.
307
  """
308
 
309
- with st.status("πŸ€– Agent is processing...", expanded=True) as status:
310
  st_callback = StreamlitCallbackHandler(st.container())
311
  try:
312
  start_time = time.time()
@@ -335,5 +351,5 @@ if st.session_state.get('is_key_valid', False):
335
  email_draft = llm.invoke(email_prompt).content
336
  st.text_area("Email Draft", value=email_draft, height=200)
337
 
338
- elif not st.session_state.get('is_key_valid', False):
339
- st.info("πŸ‘ˆ Please validate your Groq API Key.")
 
9
  # ==========================================
10
  # 1. PAGE CONFIG (MUST BE FIRST)
11
  # ==========================================
12
+ st.set_page_config(page_title="Bank Loan Agent (Multi-Model)", layout="wide")
13
 
14
  # Suppress warnings
15
  warnings.filterwarnings("ignore")
 
22
  REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate Policy.pdf"]
23
 
24
  try:
25
+ # GROQ IMPORTS
26
  from langchain_groq import ChatGroq
27
+ # GOOGLE IMPORTS
28
+ from langchain_google_genai import ChatGoogleGenerativeAI
29
+
30
+ # SHARED IMPORTS
31
  from langchain_huggingface import HuggingFaceEmbeddings
32
  from langchain_community.vectorstores import FAISS
33
  from langchain_community.callbacks import StreamlitCallbackHandler
 
41
 
42
  except ImportError as e:
43
  st.error(f"❌ Critical Import Error: {e}")
44
+ st.info("πŸ’‘ Suggestion: Add 'langchain-google-genai' to requirements.txt")
45
  st.stop()
46
 
47
  # ==========================================
48
  # 3. DATABASE SETUP
49
  # ==========================================
50
  def init_db():
51
+ """Converts CSV files to SQLite DB."""
52
  if os.path.exists(DB_FILE):
53
  return
54
 
 
66
  df.columns = [c.strip() for c in df.columns]
67
  if 'ID' in df.columns:
68
  df['ID'] = df['ID'].astype(str)
 
69
  try:
70
  df.to_sql(table, conn, if_exists='replace', index=False)
71
  except Exception:
 
75
  finally:
76
  conn.close()
77
 
 
78
  init_db()
79
 
 
80
  def run_query(query, params=()):
81
  try:
82
  with sqlite3.connect(DB_FILE) as conn:
 
97
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
98
  if row and not isinstance(row, str):
99
  return f"Credit Score: {row[0]}"
100
+ return "User ID not found."
101
 
102
  @tool
103
  def get_account_status(user_id: str) -> str:
 
109
  )
110
  if row and not isinstance(row, str):
111
  return f"Customer Name: {row[0]}, Nationality: {row[1]}, Status: {row[2]}, Email: {row[3]}"
112
+ return "User ID not found."
113
 
114
  @tool
115
  def check_pr_status(user_id: str) -> str:
116
  """Queries SQL DB for PR Status."""
117
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
118
  row = run_query("SELECT PR_Status FROM pr_status WHERE ID = ?", (clean_id,))
 
119
  if not row or (isinstance(row, str) and "no such column" in row.lower()):
120
  row = run_query("SELECT Is_PR FROM pr_status WHERE ID = ?", (clean_id,))
 
121
  if row and not isinstance(row, str):
122
  return f"PR Status: {row[0]}"
123
+ return "PR Status: False."
124
 
125
  # ==========================================
126
  # 5. STREAMLIT APP UI
127
  # ==========================================
128
+ st.title("πŸ€– Multi-Model Loan Assessor")
129
  st.markdown("Agent connects to **SQLite Database** and **Persistent Vector Store**")
130
 
131
+ # Calculate missing PDFs
132
  pdfs_missing = [f for f in REQUIRED_PDFS if not os.path.exists(f)]
133
 
134
  # --- METRICS FUNCTION ---
 
143
  col_kpi1.metric("AI Processing", f"{ai_time:.1f}s")
144
  col_kpi2.metric("Time Saved", f"{time_saved/60:.1f} min", delta=f"{saved_pct:.1f}% faster")
145
 
146
+ # --- SIDEBAR (MULTI-PROVIDER) ---
147
  with st.sidebar:
148
  st.header("πŸ” Authentication")
149
 
150
+ # 1. Provider Selector
151
+ provider_option = st.radio("Select AI Model:", ["Groq (Llama-3)", "Google (Gemini)"])
152
+
153
+ # Initialize State
154
+ if 'auth_status' not in st.session_state:
155
+ st.session_state['auth_status'] = False
156
+ st.session_state['api_key'] = None
157
+ st.session_state['provider'] = None
158
+
159
+ # Reset if user switches provider
160
+ if st.session_state.get('provider') != provider_option:
161
+ st.session_state['auth_status'] = False
162
+ st.session_state['api_key'] = None
163
+ st.session_state['provider'] = provider_option
164
+
165
+ # 2. Authentication Logic
166
+ if not st.session_state['auth_status']:
167
+ api_key_input = st.text_input(f"Enter {provider_option} API Key", type="password")
168
+
169
+ if st.button("Validate Key"):
170
  if not api_key_input:
171
  st.error("⚠️ Please enter a key.")
172
  else:
173
  try:
174
+ with st.spinner(f"Contacting {provider_option}..."):
175
+ # DYNAMIC VALIDATION
176
+ if "Groq" in provider_option:
177
+ test_llm = ChatGroq(api_key=api_key_input, model_name="llama-3.3-70b-versatile")
178
+ else:
179
+ test_llm = ChatGoogleGenerativeAI(google_api_key=api_key_input, model="gemini-1.5-flash")
180
+
181
+ test_llm.invoke("Test Connection")
182
+
183
+ # Store Success
184
+ st.session_state['auth_status'] = True
185
+ st.session_state['api_key'] = api_key_input
186
  st.success("βœ… Valid Key!")
187
  time.sleep(0.5)
188
  st.rerun()
189
  except Exception as e:
190
+ st.error(f"❌ Connection Failed: {e}")
191
  else:
192
+ st.success(f"βœ… {st.session_state['provider']} Active")
193
+ if st.button("πŸ”΄ Logout / Change Model"):
194
+ st.session_state['auth_status'] = False
195
+ st.session_state['api_key'] = None
196
  st.rerun()
197
 
198
  st.divider()
199
+ st.subheader("πŸ› οΈ Maintenance")
200
 
201
  if st.button("♻️ Rebuild Knowledge Base"):
202
+ if os.path.exists(INDEX_PATH): shutil.rmtree(INDEX_PATH)
 
203
  st.cache_resource.clear()
204
  st.success("Cache cleared.")
205
  time.sleep(1)
206
  st.rerun()
207
 
208
+ if st.button("πŸ’Ύ Reload CSVs"):
209
+ if os.path.exists(DB_FILE): os.remove(DB_FILE)
 
210
  init_db()
211
  st.success("Database refreshed.")
212
 
213
  st.divider()
 
214
  if os.path.exists(DB_FILE) and not pdfs_missing:
215
  st.success("βœ… System Ready")
216
  else:
 
221
  update_metrics(metrics_placeholder)
222
 
223
  # --- MAIN LOGIC ---
224
+ if st.session_state.get('auth_status', False):
 
 
225
 
226
+ # --- RAG SETUP (Shared Embeddings) ---
227
  @st.cache_resource
228
  def setup_rag():
 
229
  if pdfs_missing:
230
  st.error(f"Missing PDFs: {pdfs_missing}")
231
  st.stop()
232
 
233
+ # We use HuggingFace embeddings for BOTH providers to keep the Vector Store compatible
234
+ # This prevents having to rebuild the index every time you switch models.
235
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
236
 
237
  if os.path.exists(INDEX_PATH):
 
249
  vectorstore.save_local(INDEX_PATH)
250
  return vectorstore.as_retriever()
251
 
252
+ with st.spinner("Initializing Knowledge Base..."):
253
  retriever = setup_rag()
254
 
255
+ # --- DYNAMIC LLM INSTANTIATION ---
256
+ current_key = st.session_state['api_key']
257
+ current_provider = st.session_state['provider']
258
+
259
+ if "Groq" in current_provider:
260
+ llm = ChatGroq(api_key=current_key, temperature=0, model_name="llama-3.3-70b-versatile")
261
+ else:
262
+ llm = ChatGoogleGenerativeAI(google_api_key=current_key, temperature=0, model="gemini-1.5-flash")
263
+
264
+ # --- AGENT CHAIN ---
265
  rag_prompt = ChatPromptTemplate.from_template("Answer based on context:\n{context}\nQuestion: {question}")
266
  rag_chain = (
267
  {"context": retriever | (lambda d: "\n".join([x.page_content for x in d])), "question": RunnablePassthrough()}
 
284
  agent = create_tool_calling_agent(llm, tools, prompt)
285
  agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
286
 
287
+ # --- UI INPUT ---
288
  col1, col2 = st.columns([1, 2])
289
  with col1:
290
  st.subheader("1. Customer Details")
 
302
 
303
  with col2:
304
  if btn:
305
+ # Build Prompt
306
  if use_simulation:
307
  query = f"""
308
  Process Loan for Customer ID: {uid}.
309
+ *** SIMULATION MODE ACTIVE ***
310
  1. DO NOT query 'get_credit_score' or 'account_status' for Score/Status.
311
  2. USE: Score: {sim_score}, Status: {sim_status}
312
  3. Query 'get_account_status' ONLY for Name/Nationality.
313
  4. Consult Policy Docs for risk/rates.
314
+ 5. Output Final Report table + Justification.
 
 
 
 
 
315
  """
316
  else:
317
  query = f"""
 
319
  1. Query SQL tools for Name, Email, Nationality, Status, Score.
320
  2. IF Nationality is 'Singaporean', SKIP 'check_pr_status'.
321
  3. Consult Policy Docs for risk/rates.
322
+ 4. Output Final Report table + Justification.
 
 
 
 
 
323
  """
324
 
325
+ with st.status(f"πŸ€– Agent ({current_provider}) is processing...", expanded=True) as status:
326
  st_callback = StreamlitCallbackHandler(st.container())
327
  try:
328
  start_time = time.time()
 
351
  email_draft = llm.invoke(email_prompt).content
352
  st.text_area("Email Draft", value=email_draft, height=200)
353
 
354
+ elif not st.session_state.get('auth_status', False):
355
+ st.info("πŸ‘ˆ Please select a provider (Groq or Gemini) and validate your key.")