larrysim commited on
Commit
e54fecb
·
verified ·
1 Parent(s): b76bcbc

Update app.py

Browse files

fix list error

Files changed (1) hide show
  1. app.py +150 -185
app.py CHANGED
@@ -6,6 +6,7 @@ import time
6
  import sqlite3
7
  import shutil
8
  import asyncio
 
9
 
10
  # ==========================================
11
  # 0. ASYNC FIX
@@ -16,39 +17,33 @@ except RuntimeError:
16
  asyncio.set_event_loop(asyncio.new_event_loop())
17
 
18
  # ==========================================
19
- # 1. PAGE CONFIG
20
  # ==========================================
21
  st.set_page_config(page_title="Bank Loan Agent", layout="wide")
22
  warnings.filterwarnings("ignore")
23
 
24
- # ==========================================
25
- # 2. IMPORTS & CONSTANTS
26
- # ==========================================
27
  DB_FILE = "bank.db"
28
  INDEX_PATH = "faiss_index"
29
  REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate Policy.pdf"]
30
 
31
  try:
32
  from langchain_groq import ChatGroq
33
- from langchain_google_genai import ChatGoogleGenerativeAI, HarmBlockThreshold, HarmCategory
34
  import google.generativeai as genai
35
-
36
  from langchain_huggingface import HuggingFaceEmbeddings
37
  from langchain_community.vectorstores import FAISS
38
- from langchain_community.callbacks import StreamlitCallbackHandler
39
  from langchain_community.document_loaders import PyPDFLoader
40
  from langchain_text_splitters import CharacterTextSplitter
41
  from langchain_core.prompts import PromptTemplate
42
  from langchain_core.runnables import RunnablePassthrough
43
  from langchain_core.output_parsers import StrOutputParser
44
  from langchain_core.tools import tool
45
- from langchain.agents import AgentExecutor, create_react_agent
46
  except ImportError as e:
47
  st.error(f"❌ Import Error: {e}")
48
  st.stop()
49
 
50
  # ==========================================
51
- # 3. DATABASE SETUP
52
  # ==========================================
53
  def init_db():
54
  if os.path.exists(DB_FILE): return
@@ -73,28 +68,23 @@ def run_query(query, params=()):
73
  return cursor.fetchone()
74
  except Exception as e: return f"DB Error: {e}"
75
 
76
- # ==========================================
77
- # 4. TOOLS
78
- # ==========================================
79
- @tool
80
- def get_credit_score(user_id: str) -> str:
81
- """Queries SQL DB for Credit Score. Input is ID string."""
82
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
83
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
84
  return f"Credit Score: {row[0]}" if (row and not isinstance(row, str)) else "User ID not found."
85
 
86
- @tool
87
- def get_account_status(user_id: str) -> str:
88
- """Queries SQL DB for Name, Nationality, Status, and Email. Input is ID string."""
89
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
90
  row = run_query("SELECT Name, Nationality, Account_Status, Email FROM account_status WHERE ID = ?", (clean_id,))
91
  if row and not isinstance(row, str):
92
  return f"Customer Name: {row[0]}, Nationality: {row[1]}, Status: {row[2]}, Email: {row[3]}"
93
  return "User ID not found."
94
 
95
- @tool
96
- def check_pr_status(user_id: str) -> str:
97
- """Queries SQL DB for PR Status. Input is ID string."""
98
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
99
  row = run_query("SELECT PR_Status FROM pr_status WHERE ID = ?", (clean_id,))
100
  if not row or (isinstance(row, str) and "no such column" in row.lower()):
@@ -102,170 +92,156 @@ def check_pr_status(user_id: str) -> str:
102
  return f"PR Status: {row[0]}" if (row and not isinstance(row, str)) else "PR Status: False."
103
 
104
  # ==========================================
105
- # 5. UI & AUTH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # ==========================================
107
  st.title("🤖 Multi-Model Loan Assessor")
108
  pdfs_missing = [f for f in REQUIRED_PDFS if not os.path.exists(f)]
109
 
110
- def update_metrics(placeholder):
111
- if 'execution_time' in st.session_state:
112
- col1, col2 = placeholder.columns(2)
113
- col1.metric("Processing Time", f"{st.session_state.execution_time:.2f}s")
114
- col2.metric("Status", "Success")
115
-
116
  with st.sidebar:
117
  st.header("🔐 Authentication")
118
- provider_option = st.radio("Select Model:", ["Groq (Llama-3)", "Google (Gemini)"])
119
 
120
- if 'auth_status' not in st.session_state:
121
- st.session_state['auth_status'] = False
122
- st.session_state['api_key'] = None
123
- st.session_state['provider'] = None
124
-
125
- if st.session_state.get('provider') != provider_option:
126
- st.session_state['auth_status'] = False
127
- st.session_state['api_key'] = None
128
- st.session_state['provider'] = provider_option
129
-
130
- if not st.session_state['auth_status']:
131
- api_key_input = st.text_input(f"Enter {provider_option} API Key", type="password")
132
- if st.button("Validate Key"):
133
- if not api_key_input:
134
- st.error("⚠️ Enter a key.")
135
- else:
136
- try:
137
- with st.spinner(f"Verifying {provider_option}..."):
138
- if "Groq" in provider_option:
139
- ChatGroq(api_key=api_key_input).invoke("Hi")
140
- else:
141
- genai.configure(api_key=api_key_input)
142
- list(genai.list_models())
143
-
144
- st.session_state['auth_status'] = True
145
- st.session_state['api_key'] = api_key_input
146
- st.success("✅ Valid!")
147
- time.sleep(0.5)
148
- st.rerun()
149
- except Exception as e:
150
- st.error(f"❌ Error: {e}")
151
- else:
152
- st.success(f"✅ {st.session_state['provider']} Active")
153
- if st.button("🔴 Logout"):
154
- st.session_state['auth_status'] = False
155
- st.rerun()
156
 
157
- st.divider()
158
- if st.button("♻️ Rebuild Database"):
159
  if os.path.exists(INDEX_PATH): shutil.rmtree(INDEX_PATH)
160
  st.cache_resource.clear()
161
  st.rerun()
162
 
163
- if os.path.exists(DB_FILE) and not pdfs_missing:
164
- st.success("✅ System Ready")
165
- else:
166
- st.warning(f"⚠️ Missing: {pdfs_missing}")
167
- update_metrics(st.empty())
168
-
169
- # ==========================================
170
- # 6. MAIN LOGIC
171
- # ==========================================
172
- if st.session_state.get('auth_status', False):
173
-
174
- current_key = st.session_state['api_key']
175
- current_provider = st.session_state['provider']
176
-
177
  @st.cache_resource
178
- def setup_rag(_provider, _key):
179
- if pdfs_missing: st.stop()
180
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
181
-
182
  if os.path.exists(INDEX_PATH):
183
  return FAISS.load_local(INDEX_PATH, embeddings, allow_dangerous_deserialization=True).as_retriever()
184
- else:
185
- documents = []
186
- for pdf_file in REQUIRED_PDFS:
187
- documents.extend(PyPDFLoader(pdf_file).load())
188
- splits = CharacterTextSplitter(chunk_size=600, chunk_overlap=50).split_documents(documents)
189
- vectorstore = FAISS.from_documents(splits, embeddings)
190
- vectorstore.save_local(INDEX_PATH)
191
- return vectorstore.as_retriever()
192
 
193
  with st.spinner("Loading AI..."):
194
- retriever = setup_rag(current_provider, current_key)
195
 
196
- # --- LLM SELECTION ---
197
- if "Groq" in current_provider:
198
- llm = ChatGroq(
199
- api_key=current_key,
200
- temperature=0,
201
- model_name="llama-3.3-70b-versatile"
202
- )
203
  else:
204
- # CRITICAL FIX: DISABLE SAFETY FILTERS
205
- # This prevents Gemini from returning an empty list or error when it gets confused
206
- safety = {
207
- HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
208
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
209
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
210
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
211
- }
212
-
213
  llm = ChatGoogleGenerativeAI(
214
- google_api_key=current_key,
215
  temperature=0,
216
  model="gemini-1.5-flash",
217
- transport="rest",
218
- safety_settings=safety # <--- APPLY FIX
219
  )
220
 
221
  # --- RAG CHAIN ---
222
  rag_chain = (
223
  {"context": retriever | (lambda d: "\n".join([x.page_content for x in d])), "question": RunnablePassthrough()}
224
- | PromptTemplate.from_template("Answer based on context:\n{context}\nQuestion: {question}")
225
  | llm
226
  | StrOutputParser()
227
  )
228
-
229
- @tool
230
- def consult_policy_doc(query: str) -> str:
231
- """Consults Policy Documents. Input should be a question string."""
232
- return rag_chain.invoke(query)
233
-
234
- tools = [get_credit_score, get_account_status, check_pr_status, consult_policy_doc]
235
-
236
- # --- REACT PROMPT ---
237
- template = '''Answer the following questions as best you can. You have access to the following tools:
238
-
239
- {tools}
240
-
241
- Use the following format:
242
 
243
- Question: the input question you must answer
244
- Thought: you should always think about what to do
245
- Action: the action to take, should be one of [{tool_names}]
246
- Action Input: the input to the action
247
- Observation: the result of the action
248
- ... (this Thought/Action/Action Input/Observation can repeat N times)
249
- Thought: I now know the final answer
250
- Final Answer: the final answer to the original input question
251
-
252
- Begin!
253
-
254
- Question: {input}
255
- Thought:{agent_scratchpad}'''
256
-
257
- prompt = PromptTemplate.from_template(template)
258
-
259
- # --- AGENT CREATION ---
260
- agent = create_react_agent(llm, tools, prompt)
261
-
262
- agent_executor = AgentExecutor(
263
- agent=agent,
264
- tools=tools,
265
- verbose=True,
266
- return_intermediate_steps=True,
267
- handle_parsing_errors=True # Auto-fix formatting errors
268
- )
269
 
270
  # --- UI ---
271
  col1, col2 = st.columns([1, 2])
@@ -278,43 +254,32 @@ Thought:{agent_scratchpad}'''
278
 
279
  with col2:
280
  if btn:
281
- query = f"Process Loan for Customer ID: {uid}. "
282
- if use_sim:
283
- query += f"SIMULATION ACTIVE. Use Score {sim_score} and Status '{sim_status}'. Do NOT query credit/status tools. Only query Name."
284
- else:
285
- query += "Query SQL tools for Name, Email, Nationality, Status, Score."
286
-
287
- query += " Check Policies using 'consult_policy_doc'. Output a Final Report Table."
288
-
289
- with st.status(f"🤖 Agent ({current_provider}) Working...", expanded=True) as status:
290
- st_callback = StreamlitCallbackHandler(st.container())
291
  try:
292
- start_time = time.time()
293
- res = agent_executor.invoke({"input": query}, {"callbacks": [st_callback]})
294
- st.session_state.execution_time = time.time() - start_time
295
- update_metrics(metrics_placeholder)
296
- status.update(label="✅ Done", state="complete", expanded=False)
297
  except Exception as e:
298
- st.error(f"Agent Error: {e}")
299
- st.stop()
 
300
 
301
- st.success("### 📋 Final Report")
302
- final_output = res.get('output', "Error generating report.")
303
- st.markdown(final_output)
304
 
305
  with st.expander("Trace"):
306
- steps = res.get("intermediate_steps", [])
307
- for action, obs in steps:
308
- st.markdown(f"**Tool:** `{action.tool}`\n**Result:** `{obs}`")
309
-
310
  if not use_sim:
311
  st.divider()
312
- with st.expander("✉️ Email Draft"):
313
- try:
314
- email = llm.invoke(f"Draft email for: {final_output}").content
315
- st.text_area("Draft", value=email, height=200)
316
- except:
317
- pass
318
 
319
- elif not st.session_state.get('auth_status', False):
320
- st.info("👈 Select Provider & Validate Key in Sidebar")
 
6
  import sqlite3
7
  import shutil
8
  import asyncio
9
+ import re
10
 
11
  # ==========================================
12
  # 0. ASYNC FIX
 
17
  asyncio.set_event_loop(asyncio.new_event_loop())
18
 
19
  # ==========================================
20
+ # 1. CONFIG & IMPORTS
21
  # ==========================================
22
  st.set_page_config(page_title="Bank Loan Agent", layout="wide")
23
  warnings.filterwarnings("ignore")
24
 
 
 
 
25
  DB_FILE = "bank.db"
26
  INDEX_PATH = "faiss_index"
27
  REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate Policy.pdf"]
28
 
29
  try:
30
  from langchain_groq import ChatGroq
31
+ from langchain_google_genai import ChatGoogleGenerativeAI
32
  import google.generativeai as genai
 
33
  from langchain_huggingface import HuggingFaceEmbeddings
34
  from langchain_community.vectorstores import FAISS
 
35
  from langchain_community.document_loaders import PyPDFLoader
36
  from langchain_text_splitters import CharacterTextSplitter
37
  from langchain_core.prompts import PromptTemplate
38
  from langchain_core.runnables import RunnablePassthrough
39
  from langchain_core.output_parsers import StrOutputParser
40
  from langchain_core.tools import tool
 
41
  except ImportError as e:
42
  st.error(f"❌ Import Error: {e}")
43
  st.stop()
44
 
45
  # ==========================================
46
+ # 2. DATABASE & TOOLS SETUP
47
  # ==========================================
48
  def init_db():
49
  if os.path.exists(DB_FILE): return
 
68
  return cursor.fetchone()
69
  except Exception as e: return f"DB Error: {e}"
70
 
71
+ # --- TOOL FUNCTIONS (Pure Python) ---
72
+ def tool_get_credit_score(user_id):
73
+ """Input: User ID. Returns Credit Score."""
 
 
 
74
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
75
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
76
  return f"Credit Score: {row[0]}" if (row and not isinstance(row, str)) else "User ID not found."
77
 
78
+ def tool_get_account_status(user_id):
79
+ """Input: User ID. Returns Name, Nationality, Status."""
 
80
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
81
  row = run_query("SELECT Name, Nationality, Account_Status, Email FROM account_status WHERE ID = ?", (clean_id,))
82
  if row and not isinstance(row, str):
83
  return f"Customer Name: {row[0]}, Nationality: {row[1]}, Status: {row[2]}, Email: {row[3]}"
84
  return "User ID not found."
85
 
86
+ def tool_check_pr_status(user_id):
87
+ """Input: User ID. Returns PR Status."""
 
88
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
89
  row = run_query("SELECT PR_Status FROM pr_status WHERE ID = ?", (clean_id,))
90
  if not row or (isinstance(row, str) and "no such column" in row.lower()):
 
92
  return f"PR Status: {row[0]}" if (row and not isinstance(row, str)) else "PR Status: False."
93
 
94
  # ==========================================
95
+ # 3. MANUAL AGENT ENGINE (The Fix)
96
+ # ==========================================
97
+ class ManualReActAgent:
98
+ def __init__(self, llm, tools_map, rag_chain):
99
+ self.llm = llm
100
+ self.tools = tools_map
101
+ self.rag_chain = rag_chain
102
+ self.max_steps = 6
103
+
104
+ def run(self, query):
105
+ """Runs the ReAct loop manually to avoid Library Parsing Errors."""
106
+
107
+ # 1. DEFINE PROMPT
108
+ tool_desc = "\n".join([f"- {name}: {func.__doc__}" for name, func in self.tools.items()])
109
+ system_prompt = f"""You are a Loan Risk Officer. Answer the question using the tools below.
110
+
111
+ TOOLS:
112
+ {tool_desc}
113
+ - consult_policy_doc: Consult policy PDF for risk rules. Input: a question string.
114
+
115
+ FORMAT:
116
+ Thought: <reasoning>
117
+ Action: <tool_name>
118
+ Action Input: <input>
119
+ Observation: <result>
120
+ ... (repeat)
121
+ Final Answer: <answer>
122
+
123
+ Begin!
124
+ Question: {query}
125
+ """
126
+ history = system_prompt
127
+ logs = []
128
+
129
+ # 2. LOOP
130
+ for i in range(self.max_steps):
131
+ # A. Call LLM
132
+ response = self.llm.invoke(history).content
133
+ history += response + "\n"
134
+
135
+ # B. Parse "Action"
136
+ action_match = re.search(r"Action:\s*(.+)", response)
137
+ input_match = re.search(r"Action Input:\s*(.+)", response)
138
+
139
+ # C. Check for Final Answer (Stop Condition)
140
+ if "Final Answer:" in response:
141
+ final_ans = response.split("Final Answer:")[-1].strip()
142
+ return final_ans, logs
143
+
144
+ # D. Execute Tool
145
+ if action_match and input_match:
146
+ tool_name = action_match.group(1).strip()
147
+ tool_input = input_match.group(1).strip()
148
+
149
+ # Strip quotes if present
150
+ tool_input = tool_input.strip('"').strip("'")
151
+
152
+ logs.append((tool_name, tool_input))
153
+
154
+ # Execute
155
+ observation = f"Error: Tool {tool_name} not found."
156
+ if tool_name in self.tools:
157
+ try:
158
+ observation = self.tools[tool_name](tool_input)
159
+ except Exception as e:
160
+ observation = f"Tool Error: {e}"
161
+ elif tool_name == "consult_policy_doc":
162
+ try:
163
+ observation = self.rag_chain.invoke(tool_input)
164
+ except Exception as e:
165
+ observation = f"RAG Error: {e}"
166
+
167
+ obs_str = f"Observation: {observation}\n"
168
+ history += obs_str
169
+ else:
170
+ # If LLM didn't output an action but didn't finish, force it
171
+ if i == self.max_steps - 1:
172
+ return response, logs
173
+ history += "Observation: Please continue. If you have the answer, say 'Final Answer:'.\n"
174
+
175
+ return "Agent timed out.", logs
176
+
177
+ # ==========================================
178
+ # 4. UI & SETUP
179
  # ==========================================
180
  st.title("🤖 Multi-Model Loan Assessor")
181
  pdfs_missing = [f for f in REQUIRED_PDFS if not os.path.exists(f)]
182
 
 
 
 
 
 
 
183
  with st.sidebar:
184
  st.header("🔐 Authentication")
185
+ provider = st.radio("Model:", ["Groq (Llama-3)", "Google (Gemini)"])
186
 
187
+ if 'api_key' not in st.session_state: st.session_state.api_key = None
188
+
189
+ key_input = st.text_input("API Key", type="password")
190
+ if st.button("Set Key"):
191
+ st.session_state.api_key = key_input
192
+ st.success("Key Set!")
193
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ if st.button("♻️ Reset"):
 
196
  if os.path.exists(INDEX_PATH): shutil.rmtree(INDEX_PATH)
197
  st.cache_resource.clear()
198
  st.rerun()
199
 
200
+ if st.session_state.api_key:
201
+ # --- RAG SETUP ---
 
 
 
 
 
 
 
 
 
 
 
 
202
  @st.cache_resource
203
+ def setup_rag():
204
+ if pdfs_missing: return None
205
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
 
206
  if os.path.exists(INDEX_PATH):
207
  return FAISS.load_local(INDEX_PATH, embeddings, allow_dangerous_deserialization=True).as_retriever()
208
+ documents = []
209
+ for f in REQUIRED_PDFS: documents.extend(PyPDFLoader(f).load())
210
+ splits = CharacterTextSplitter(chunk_size=600, chunk_overlap=50).split_documents(documents)
211
+ vectorstore = FAISS.from_documents(splits, embeddings)
212
+ vectorstore.save_local(INDEX_PATH)
213
+ return vectorstore.as_retriever()
 
 
214
 
215
  with st.spinner("Loading AI..."):
216
+ retriever = setup_rag()
217
 
218
+ # --- LLM SETUP ---
219
+ if "Groq" in provider:
220
+ llm = ChatGroq(api_key=st.session_state.api_key, temperature=0, model_name="llama-3.3-70b-versatile")
 
 
 
 
221
  else:
222
+ # Using Gemini 1.5 Flash with REST transport
 
 
 
 
 
 
 
 
223
  llm = ChatGoogleGenerativeAI(
224
+ google_api_key=st.session_state.api_key,
225
  temperature=0,
226
  model="gemini-1.5-flash",
227
+ transport="rest"
 
228
  )
229
 
230
  # --- RAG CHAIN ---
231
  rag_chain = (
232
  {"context": retriever | (lambda d: "\n".join([x.page_content for x in d])), "question": RunnablePassthrough()}
233
+ | PromptTemplate.from_template("Info: {context}\nQ: {question}\nA:")
234
  | llm
235
  | StrOutputParser()
236
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ # --- AGENT INSTANCE ---
239
+ tools_map = {
240
+ "get_credit_score": tool_get_credit_score,
241
+ "get_account_status": tool_get_account_status,
242
+ "check_pr_status": tool_check_pr_status
243
+ }
244
+ agent = ManualReActAgent(llm, tools_map, rag_chain)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  # --- UI ---
247
  col1, col2 = st.columns([1, 2])
 
254
 
255
  with col2:
256
  if btn:
257
+ query = f"Process Loan for ID {uid}. "
258
+ if use_sim: query += f"SIMULATION: Score {sim_score}, Status '{sim_status}'. Do NOT query DB for score/status."
259
+ else: query += "Query DB for all info."
260
+ query += " Check policies. Report Risk, Rate, and Decision."
261
+
262
+ with st.status(f"🤖 {provider} Agent Running...", expanded=True):
263
+ st.write("Thinking...")
 
 
 
264
  try:
265
+ # Run Manual Loop
266
+ final_res, logs = agent.run(query)
267
+ st.write("✅ Done!")
 
 
268
  except Exception as e:
269
+ st.error(f"Error: {e}")
270
+ final_res = "Failed."
271
+ logs = []
272
 
273
+ st.success("### 📋 Report")
274
+ st.markdown(final_res)
 
275
 
276
  with st.expander("Trace"):
277
+ for tool_name, tool_in in logs:
278
+ st.markdown(f"**Tool:** `{tool_name}` | **Input:** `{tool_in}`")
279
+
 
280
  if not use_sim:
281
  st.divider()
282
+ st.text_area("✉️ Email Draft", value=llm.invoke(f"Draft email for: {final_res}").content)
 
 
 
 
 
283
 
284
+ else:
285
+ st.info("👈 Enter API Key")