larrysim commited on
Commit
b3f77bc
Β·
verified Β·
1 Parent(s): 6382199

Update app.py

Browse files

fix list error

Files changed (1) hide show
  1. app.py +48 -25
app.py CHANGED
@@ -38,11 +38,11 @@ try:
38
  from langchain_community.callbacks import StreamlitCallbackHandler
39
  from langchain_community.document_loaders import PyPDFLoader
40
  from langchain_text_splitters import CharacterTextSplitter
41
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
42
  from langchain_core.runnables import RunnablePassthrough
43
  from langchain_core.output_parsers import StrOutputParser
44
  from langchain_core.tools import tool
45
- from langchain.agents import AgentExecutor, create_tool_calling_agent
46
  except ImportError as e:
47
  st.error(f"❌ Import Error: {e}")
48
  st.stop()
@@ -78,14 +78,14 @@ def run_query(query, params=()):
78
  # ==========================================
79
  @tool
80
  def get_credit_score(user_id: str) -> str:
81
- """Queries SQL DB for Credit Score."""
82
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
83
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
84
  return f"Credit Score: {row[0]}" if (row and not isinstance(row, str)) else "User ID not found."
85
 
86
  @tool
87
  def get_account_status(user_id: str) -> str:
88
- """Queries SQL DB for Name, Nationality, Status, and Email."""
89
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
90
  row = run_query("SELECT Name, Nationality, Account_Status, Email FROM account_status WHERE ID = ?", (clean_id,))
91
  if row and not isinstance(row, str):
@@ -94,7 +94,7 @@ def get_account_status(user_id: str) -> str:
94
 
95
  @tool
96
  def check_pr_status(user_id: str) -> str:
97
- """Queries SQL DB for PR Status."""
98
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
99
  row = run_query("SELECT PR_Status FROM pr_status WHERE ID = ?", (clean_id,))
100
  if not row or (isinstance(row, str) and "no such column" in row.lower()):
@@ -201,7 +201,7 @@ if st.session_state.get('auth_status', False):
201
  model_name="llama-3.3-70b-versatile"
202
  )
203
  else:
204
- # Use Gemini 1.5 Flash (Better for Tools)
205
  llm = ChatGoogleGenerativeAI(
206
  google_api_key=current_key,
207
  temperature=0,
@@ -212,29 +212,49 @@ if st.session_state.get('auth_status', False):
212
  # --- RAG CHAIN ---
213
  rag_chain = (
214
  {"context": retriever | (lambda d: "\n".join([x.page_content for x in d])), "question": RunnablePassthrough()}
215
- | ChatPromptTemplate.from_template("{context}\nQ:{question}") | llm | StrOutputParser()
 
 
216
  )
217
 
218
  @tool
219
  def consult_policy_doc(query: str) -> str:
220
- """Consults Policy Documents."""
221
  return rag_chain.invoke(query)
222
 
223
  tools = [get_credit_score, get_account_status, check_pr_status, consult_policy_doc]
224
 
225
- prompt = ChatPromptTemplate.from_messages([
226
- ("system", "Act as a Loan Officer. Query SQL DB for info. Check Policies via tool. Output Markdown report."),
227
- ("human", "{input}"),
228
- MessagesPlaceholder(variable_name="agent_scratchpad"),
229
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- # --- AGENT EXECUTOR WITH ERROR HANDLING ---
232
  agent_executor = AgentExecutor(
233
- agent=create_tool_calling_agent(llm, tools, prompt),
234
  tools=tools,
235
  verbose=True,
236
  return_intermediate_steps=True,
237
- # CRITICAL FIX: This handles the 'list object no attribute get' error
238
  handle_parsing_errors=True
239
  )
240
 
@@ -248,10 +268,13 @@ if st.session_state.get('auth_status', False):
248
 
249
  with col2:
250
  if btn:
251
- query = f"Process Loan {uid}. "
252
- if use_sim: query += f"SIMULATION: Use Score {sim_score}, Status {sim_status}. Only query Name from DB."
253
- else: query += "Query SQL for all data."
254
- query += " Check Policies. Output Final Report."
 
 
 
255
 
256
  with st.status(f"πŸ€– Agent ({current_provider}) Working...", expanded=True) as status:
257
  st_callback = StreamlitCallbackHandler(st.container())
@@ -262,17 +285,17 @@ if st.session_state.get('auth_status', False):
262
  update_metrics(metrics_placeholder)
263
  status.update(label="βœ… Done", state="complete", expanded=False)
264
  except Exception as e:
265
- # Fallback Logic: If Agent fails, print error clearly
266
- st.error(f"Agent Logic Error: {e}")
267
  st.stop()
268
 
269
  st.success("### πŸ“‹ Final Report")
270
- # Handle case where output might be in a different key due to error
271
  final_output = res.get('output', "Error generating report.")
272
  st.markdown(final_output)
273
 
274
  with st.expander("Trace"):
275
- for action, obs in res.get("intermediate_steps", []):
 
 
276
  st.markdown(f"**Tool:** `{action.tool}`\n**Result:** `{obs}`")
277
 
278
  if not use_sim:
@@ -282,7 +305,7 @@ if st.session_state.get('auth_status', False):
282
  email = llm.invoke(f"Draft email for: {final_output}").content
283
  st.text_area("Draft", value=email, height=200)
284
  except:
285
- st.warning("Could not draft email due to generation error.")
286
 
287
  elif not st.session_state.get('auth_status', False):
288
  st.info("πŸ‘ˆ Select Provider & Validate Key in Sidebar")
 
38
  from langchain_community.callbacks import StreamlitCallbackHandler
39
  from langchain_community.document_loaders import PyPDFLoader
40
  from langchain_text_splitters import CharacterTextSplitter
41
+ from langchain_core.prompts import PromptTemplate # Changed for ReAct
42
  from langchain_core.runnables import RunnablePassthrough
43
  from langchain_core.output_parsers import StrOutputParser
44
  from langchain_core.tools import tool
45
+ from langchain.agents import AgentExecutor, create_react_agent # Changed Agent Type
46
  except ImportError as e:
47
  st.error(f"❌ Import Error: {e}")
48
  st.stop()
 
78
  # ==========================================
79
  @tool
80
  def get_credit_score(user_id: str) -> str:
81
+ """Queries SQL DB for Credit Score. Input is just the numeric ID string."""
82
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
83
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
84
  return f"Credit Score: {row[0]}" if (row and not isinstance(row, str)) else "User ID not found."
85
 
86
  @tool
87
  def get_account_status(user_id: str) -> str:
88
+ """Queries SQL DB for Name, Nationality, Status, and Email. Input is ID string."""
89
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
90
  row = run_query("SELECT Name, Nationality, Account_Status, Email FROM account_status WHERE ID = ?", (clean_id,))
91
  if row and not isinstance(row, str):
 
94
 
95
  @tool
96
  def check_pr_status(user_id: str) -> str:
97
+ """Queries SQL DB for PR Status. Input is ID string."""
98
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
99
  row = run_query("SELECT PR_Status FROM pr_status WHERE ID = ?", (clean_id,))
100
  if not row or (isinstance(row, str) and "no such column" in row.lower()):
 
201
  model_name="llama-3.3-70b-versatile"
202
  )
203
  else:
204
+ # Use Gemini 1.5 Flash
205
  llm = ChatGoogleGenerativeAI(
206
  google_api_key=current_key,
207
  temperature=0,
 
212
  # --- RAG CHAIN ---
213
  rag_chain = (
214
  {"context": retriever | (lambda d: "\n".join([x.page_content for x in d])), "question": RunnablePassthrough()}
215
+ | PromptTemplate.from_template("Answer based on context:\n{context}\nQuestion: {question}")
216
+ | llm
217
+ | StrOutputParser()
218
  )
219
 
220
  @tool
221
  def consult_policy_doc(query: str) -> str:
222
+ """Consults Policy Documents. Input should be a question string."""
223
  return rag_chain.invoke(query)
224
 
225
  tools = [get_credit_score, get_account_status, check_pr_status, consult_policy_doc]
226
 
227
+ # --- REACT AGENT TEMPLATE (Universal Compatibility) ---
228
+ template = '''Answer the following questions as best you can. You have access to the following tools:
229
+
230
+ {tools}
231
+
232
+ Use the following format:
233
+
234
+ Question: the input question you must answer
235
+ Thought: you should always think about what to do
236
+ Action: the action to take, should be one of [{tool_names}]
237
+ Action Input: the input to the action
238
+ Observation: the result of the action
239
+ ... (this Thought/Action/Action Input/Observation can repeat N times)
240
+ Thought: I now know the final answer
241
+ Final Answer: the final answer to the original input question
242
+
243
+ Begin!
244
+
245
+ Question: {input}
246
+ Thought:{agent_scratchpad}'''
247
+
248
+ prompt = PromptTemplate.from_template(template)
249
+
250
+ # Switch to create_react_agent (More Robust than tool_calling_agent)
251
+ agent = create_react_agent(llm, tools, prompt)
252
 
 
253
  agent_executor = AgentExecutor(
254
+ agent=agent,
255
  tools=tools,
256
  verbose=True,
257
  return_intermediate_steps=True,
 
258
  handle_parsing_errors=True
259
  )
260
 
 
268
 
269
  with col2:
270
  if btn:
271
+ query = f"Process Loan for Customer ID: {uid}. "
272
+ if use_sim:
273
+ query += f"SIMULATION ACTIVE. Use Score {sim_score} and Status '{sim_status}'. Do NOT query credit/status tools. Only query Name."
274
+ else:
275
+ query += "Query SQL tools for Name, Email, Nationality, Status, Score."
276
+
277
+ query += " Check Policies using 'consult_policy_doc'. Output a Final Report Table."
278
 
279
  with st.status(f"πŸ€– Agent ({current_provider}) Working...", expanded=True) as status:
280
  st_callback = StreamlitCallbackHandler(st.container())
 
285
  update_metrics(metrics_placeholder)
286
  status.update(label="βœ… Done", state="complete", expanded=False)
287
  except Exception as e:
288
+ st.error(f"Agent Error: {e}")
 
289
  st.stop()
290
 
291
  st.success("### πŸ“‹ Final Report")
 
292
  final_output = res.get('output', "Error generating report.")
293
  st.markdown(final_output)
294
 
295
  with st.expander("Trace"):
296
+ # Handle ReAct agent steps format
297
+ steps = res.get("intermediate_steps", [])
298
+ for action, obs in steps:
299
  st.markdown(f"**Tool:** `{action.tool}`\n**Result:** `{obs}`")
300
 
301
  if not use_sim:
 
305
  email = llm.invoke(f"Draft email for: {final_output}").content
306
  st.text_area("Draft", value=email, height=200)
307
  except:
308
+ pass
309
 
310
  elif not st.session_state.get('auth_status', False):
311
  st.info("πŸ‘ˆ Select Provider & Validate Key in Sidebar")