larrysim commited on
Commit
b76bcbc
·
verified ·
1 Parent(s): b3f77bc

Update app.py

Browse files

fix list error

Files changed (1) hide show
  1. app.py +20 -11
app.py CHANGED
@@ -8,7 +8,7 @@ import shutil
8
  import asyncio
9
 
10
  # ==========================================
11
- # 0. ASYNC FIX (CRITICAL)
12
  # ==========================================
13
  try:
14
  asyncio.get_running_loop()
@@ -30,7 +30,7 @@ REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate P
30
 
31
  try:
32
  from langchain_groq import ChatGroq
33
- from langchain_google_genai import ChatGoogleGenerativeAI
34
  import google.generativeai as genai
35
 
36
  from langchain_huggingface import HuggingFaceEmbeddings
@@ -38,11 +38,11 @@ try:
38
  from langchain_community.callbacks import StreamlitCallbackHandler
39
  from langchain_community.document_loaders import PyPDFLoader
40
  from langchain_text_splitters import CharacterTextSplitter
41
- from langchain_core.prompts import PromptTemplate # Changed for ReAct
42
  from langchain_core.runnables import RunnablePassthrough
43
  from langchain_core.output_parsers import StrOutputParser
44
  from langchain_core.tools import tool
45
- from langchain.agents import AgentExecutor, create_react_agent # Changed Agent Type
46
  except ImportError as e:
47
  st.error(f"❌ Import Error: {e}")
48
  st.stop()
@@ -78,7 +78,7 @@ def run_query(query, params=()):
78
  # ==========================================
79
  @tool
80
  def get_credit_score(user_id: str) -> str:
81
- """Queries SQL DB for Credit Score. Input is just the numeric ID string."""
82
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
83
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
84
  return f"Credit Score: {row[0]}" if (row and not isinstance(row, str)) else "User ID not found."
@@ -201,12 +201,21 @@ if st.session_state.get('auth_status', False):
201
  model_name="llama-3.3-70b-versatile"
202
  )
203
  else:
204
- # Use Gemini 1.5 Flash
 
 
 
 
 
 
 
 
205
  llm = ChatGoogleGenerativeAI(
206
  google_api_key=current_key,
207
  temperature=0,
208
  model="gemini-1.5-flash",
209
- transport="rest"
 
210
  )
211
 
212
  # --- RAG CHAIN ---
@@ -224,7 +233,7 @@ if st.session_state.get('auth_status', False):
224
 
225
  tools = [get_credit_score, get_account_status, check_pr_status, consult_policy_doc]
226
 
227
- # --- REACT AGENT TEMPLATE (Universal Compatibility) ---
228
  template = '''Answer the following questions as best you can. You have access to the following tools:
229
 
230
  {tools}
@@ -247,7 +256,7 @@ Thought:{agent_scratchpad}'''
247
 
248
  prompt = PromptTemplate.from_template(template)
249
 
250
- # Switch to create_react_agent (More Robust than tool_calling_agent)
251
  agent = create_react_agent(llm, tools, prompt)
252
 
253
  agent_executor = AgentExecutor(
@@ -255,9 +264,10 @@ Thought:{agent_scratchpad}'''
255
  tools=tools,
256
  verbose=True,
257
  return_intermediate_steps=True,
258
- handle_parsing_errors=True
259
  )
260
 
 
261
  col1, col2 = st.columns([1, 2])
262
  with col1:
263
  uid = st.text_input("Customer ID", "1111")
@@ -293,7 +303,6 @@ Thought:{agent_scratchpad}'''
293
  st.markdown(final_output)
294
 
295
  with st.expander("Trace"):
296
- # Handle ReAct agent steps format
297
  steps = res.get("intermediate_steps", [])
298
  for action, obs in steps:
299
  st.markdown(f"**Tool:** `{action.tool}`\n**Result:** `{obs}`")
 
8
  import asyncio
9
 
10
  # ==========================================
11
+ # 0. ASYNC FIX
12
  # ==========================================
13
  try:
14
  asyncio.get_running_loop()
 
30
 
31
  try:
32
  from langchain_groq import ChatGroq
33
+ from langchain_google_genai import ChatGoogleGenerativeAI, HarmBlockThreshold, HarmCategory
34
  import google.generativeai as genai
35
 
36
  from langchain_huggingface import HuggingFaceEmbeddings
 
38
  from langchain_community.callbacks import StreamlitCallbackHandler
39
  from langchain_community.document_loaders import PyPDFLoader
40
  from langchain_text_splitters import CharacterTextSplitter
41
+ from langchain_core.prompts import PromptTemplate
42
  from langchain_core.runnables import RunnablePassthrough
43
  from langchain_core.output_parsers import StrOutputParser
44
  from langchain_core.tools import tool
45
+ from langchain.agents import AgentExecutor, create_react_agent
46
  except ImportError as e:
47
  st.error(f"❌ Import Error: {e}")
48
  st.stop()
 
78
  # ==========================================
79
  @tool
80
  def get_credit_score(user_id: str) -> str:
81
+ """Queries SQL DB for Credit Score. Input is ID string."""
82
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
83
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
84
  return f"Credit Score: {row[0]}" if (row and not isinstance(row, str)) else "User ID not found."
 
201
  model_name="llama-3.3-70b-versatile"
202
  )
203
  else:
204
+ # CRITICAL FIX: DISABLE SAFETY FILTERS
205
+ # This prevents Gemini from returning an empty list or error when it gets confused
206
+ safety = {
207
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
208
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
209
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
210
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
211
+ }
212
+
213
  llm = ChatGoogleGenerativeAI(
214
  google_api_key=current_key,
215
  temperature=0,
216
  model="gemini-1.5-flash",
217
+ transport="rest",
218
+ safety_settings=safety # <--- APPLY FIX
219
  )
220
 
221
  # --- RAG CHAIN ---
 
233
 
234
  tools = [get_credit_score, get_account_status, check_pr_status, consult_policy_doc]
235
 
236
+ # --- REACT PROMPT ---
237
  template = '''Answer the following questions as best you can. You have access to the following tools:
238
 
239
  {tools}
 
256
 
257
  prompt = PromptTemplate.from_template(template)
258
 
259
+ # --- AGENT CREATION ---
260
  agent = create_react_agent(llm, tools, prompt)
261
 
262
  agent_executor = AgentExecutor(
 
264
  tools=tools,
265
  verbose=True,
266
  return_intermediate_steps=True,
267
+ handle_parsing_errors=True # Auto-fix formatting errors
268
  )
269
 
270
+ # --- UI ---
271
  col1, col2 = st.columns([1, 2])
272
  with col1:
273
  uid = st.text_input("Customer ID", "1111")
 
303
  st.markdown(final_output)
304
 
305
  with st.expander("Trace"):
 
306
  steps = res.get("intermediate_steps", [])
307
  for action, obs in steps:
308
  st.markdown(f"**Tool:** `{action.tool}`\n**Result:** `{obs}`")