larrysim commited on
Commit
3e46227
·
verified ·
1 Parent(s): e54fecb

Update app.py

Browse files

fix the model error

Files changed (1) hide show
  1. app.py +143 -118
app.py CHANGED
@@ -2,7 +2,6 @@ import streamlit as st
2
  import pandas as pd
3
  import os
4
  import warnings
5
- import time
6
  import sqlite3
7
  import shutil
8
  import asyncio
@@ -27,9 +26,13 @@ INDEX_PATH = "faiss_index"
27
  REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate Policy.pdf"]
28
 
29
  try:
 
30
  from langchain_groq import ChatGroq
31
- from langchain_google_genai import ChatGoogleGenerativeAI
 
32
  import google.generativeai as genai
 
 
33
  from langchain_huggingface import HuggingFaceEmbeddings
34
  from langchain_community.vectorstores import FAISS
35
  from langchain_community.document_loaders import PyPDFLoader
@@ -37,13 +40,12 @@ try:
37
  from langchain_core.prompts import PromptTemplate
38
  from langchain_core.runnables import RunnablePassthrough
39
  from langchain_core.output_parsers import StrOutputParser
40
- from langchain_core.tools import tool
41
  except ImportError as e:
42
  st.error(f"❌ Import Error: {e}")
43
  st.stop()
44
 
45
  # ==========================================
46
- # 2. DATABASE & TOOLS SETUP
47
  # ==========================================
48
  def init_db():
49
  if os.path.exists(DB_FILE): return
@@ -68,15 +70,13 @@ def run_query(query, params=()):
68
  return cursor.fetchone()
69
  except Exception as e: return f"DB Error: {e}"
70
 
71
- # --- TOOL FUNCTIONS (Pure Python) ---
72
  def tool_get_credit_score(user_id):
73
- """Input: User ID. Returns Credit Score."""
74
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
75
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
76
  return f"Credit Score: {row[0]}" if (row and not isinstance(row, str)) else "User ID not found."
77
 
78
  def tool_get_account_status(user_id):
79
- """Input: User ID. Returns Name, Nationality, Status."""
80
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
81
  row = run_query("SELECT Name, Nationality, Account_Status, Email FROM account_status WHERE ID = ?", (clean_id,))
82
  if row and not isinstance(row, str):
@@ -84,7 +84,6 @@ def tool_get_account_status(user_id):
84
  return "User ID not found."
85
 
86
  def tool_check_pr_status(user_id):
87
- """Input: User ID. Returns PR Status."""
88
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
89
  row = run_query("SELECT PR_Status FROM pr_status WHERE ID = ?", (clean_id,))
90
  if not row or (isinstance(row, str) and "no such column" in row.lower()):
@@ -92,194 +91,220 @@ def tool_check_pr_status(user_id):
92
  return f"PR Status: {row[0]}" if (row and not isinstance(row, str)) else "PR Status: False."
93
 
94
  # ==========================================
95
- # 3. MANUAL AGENT ENGINE (The Fix)
96
  # ==========================================
97
- class ManualReActAgent:
98
- def __init__(self, llm, tools_map, rag_chain):
99
- self.llm = llm
 
100
  self.tools = tools_map
101
  self.rag_chain = rag_chain
102
- self.max_steps = 6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def run(self, query):
105
- """Runs the ReAct loop manually to avoid Library Parsing Errors."""
106
-
107
- # 1. DEFINE PROMPT
108
  tool_desc = "\n".join([f"- {name}: {func.__doc__}" for name, func in self.tools.items()])
109
- system_prompt = f"""You are a Loan Risk Officer. Answer the question using the tools below.
 
110
 
111
- TOOLS:
112
  {tool_desc}
113
- - consult_policy_doc: Consult policy PDF for risk rules. Input: a question string.
114
 
115
- FORMAT:
 
 
116
  Thought: <reasoning>
117
  Action: <tool_name>
118
- Action Input: <input>
119
  Observation: <result>
120
- ... (repeat)
121
- Final Answer: <answer>
122
 
123
  Begin!
124
- Question: {query}
125
  """
126
- history = system_prompt
127
  logs = []
128
 
129
- # 2. LOOP
130
  for i in range(self.max_steps):
131
- # A. Call LLM
132
- response = self.llm.invoke(history).content
133
  history += response + "\n"
134
 
135
- # B. Parse "Action"
 
 
 
 
136
  action_match = re.search(r"Action:\s*(.+)", response)
137
  input_match = re.search(r"Action Input:\s*(.+)", response)
138
-
139
- # C. Check for Final Answer (Stop Condition)
140
- if "Final Answer:" in response:
141
- final_ans = response.split("Final Answer:")[-1].strip()
142
- return final_ans, logs
143
 
144
- # D. Execute Tool
145
  if action_match and input_match:
146
  tool_name = action_match.group(1).strip()
147
- tool_input = input_match.group(1).strip()
148
-
149
- # Strip quotes if present
150
- tool_input = tool_input.strip('"').strip("'")
151
 
152
- logs.append((tool_name, tool_input))
153
 
154
  # Execute
155
- observation = f"Error: Tool {tool_name} not found."
156
  if tool_name in self.tools:
157
- try:
158
- observation = self.tools[tool_name](tool_input)
159
- except Exception as e:
160
- observation = f"Tool Error: {e}"
161
  elif tool_name == "consult_policy_doc":
162
- try:
163
- observation = self.rag_chain.invoke(tool_input)
164
- except Exception as e:
165
- observation = f"RAG Error: {e}"
166
 
167
- obs_str = f"Observation: {observation}\n"
168
- history += obs_str
 
169
  else:
170
- # If LLM didn't output an action but didn't finish, force it
171
- if i == self.max_steps - 1:
172
- return response, logs
173
- history += "Observation: Please continue. If you have the answer, say 'Final Answer:'.\n"
174
 
175
  return "Agent timed out.", logs
176
 
177
  # ==========================================
178
- # 4. UI & SETUP
179
  # ==========================================
180
  st.title("🤖 Multi-Model Loan Assessor")
181
  pdfs_missing = [f for f in REQUIRED_PDFS if not os.path.exists(f)]
182
 
183
  with st.sidebar:
184
  st.header("🔐 Authentication")
185
- provider = st.radio("Model:", ["Groq (Llama-3)", "Google (Gemini)"])
186
 
187
- if 'api_key' not in st.session_state: st.session_state.api_key = None
188
 
189
- key_input = st.text_input("API Key", type="password")
190
- if st.button("Set Key"):
191
- st.session_state.api_key = key_input
192
- st.success("Key Set!")
193
- st.rerun()
194
 
195
- if st.button("♻️ Reset"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  if os.path.exists(INDEX_PATH): shutil.rmtree(INDEX_PATH)
197
  st.cache_resource.clear()
198
  st.rerun()
199
 
200
- if st.session_state.api_key:
201
  # --- RAG SETUP ---
202
  @st.cache_resource
203
  def setup_rag():
204
  if pdfs_missing: return None
 
205
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
 
206
  if os.path.exists(INDEX_PATH):
207
  return FAISS.load_local(INDEX_PATH, embeddings, allow_dangerous_deserialization=True).as_retriever()
208
- documents = []
209
- for f in REQUIRED_PDFS: documents.extend(PyPDFLoader(f).load())
210
- splits = CharacterTextSplitter(chunk_size=600, chunk_overlap=50).split_documents(documents)
 
211
  vectorstore = FAISS.from_documents(splits, embeddings)
212
  vectorstore.save_local(INDEX_PATH)
213
  return vectorstore.as_retriever()
214
 
215
- with st.spinner("Loading AI..."):
216
- retriever = setup_rag()
217
 
218
- # --- LLM SETUP ---
219
- if "Groq" in provider:
220
- llm = ChatGroq(api_key=st.session_state.api_key, temperature=0, model_name="llama-3.3-70b-versatile")
221
- else:
222
- # Using Gemini 1.5 Flash with REST transport
223
- llm = ChatGoogleGenerativeAI(
224
- google_api_key=st.session_state.api_key,
225
- temperature=0,
226
- model="gemini-1.5-flash",
227
- transport="rest"
228
- )
229
-
230
- # --- RAG CHAIN ---
231
- rag_chain = (
232
- {"context": retriever | (lambda d: "\n".join([x.page_content for x in d])), "question": RunnablePassthrough()}
233
- | PromptTemplate.from_template("Info: {context}\nQ: {question}\nA:")
234
- | llm
235
- | StrOutputParser()
236
- )
237
-
238
- # --- AGENT INSTANCE ---
239
- tools_map = {
240
  "get_credit_score": tool_get_credit_score,
241
  "get_account_status": tool_get_account_status,
242
  "check_pr_status": tool_check_pr_status
243
  }
244
- agent = ManualReActAgent(llm, tools_map, rag_chain)
 
 
 
 
 
245
 
246
  # --- UI ---
247
  col1, col2 = st.columns([1, 2])
248
  with col1:
249
  uid = st.text_input("Customer ID", "1111")
250
  use_sim = st.checkbox("Simulation Mode")
251
- sim_score = st.slider("Score", 300, 900, 450) if use_sim else 0
252
- sim_status = st.selectbox("Status", ["good-standing", "closed", "delinquent"]) if use_sim else ""
253
- btn = st.button("Assess Risk", type="primary")
254
 
255
  with col2:
256
  if btn:
257
- query = f"Process Loan for ID {uid}. "
258
- if use_sim: query += f"SIMULATION: Score {sim_score}, Status '{sim_status}'. Do NOT query DB for score/status."
259
- else: query += "Query DB for all info."
260
- query += " Check policies. Report Risk, Rate, and Decision."
261
-
262
- with st.status(f"🤖 {provider} Agent Running...", expanded=True):
263
- st.write("Thinking...")
264
- try:
265
- # Run Manual Loop
266
- final_res, logs = agent.run(query)
267
- st.write("✅ Done!")
268
- except Exception as e:
269
- st.error(f"Error: {e}")
270
- final_res = "Failed."
271
- logs = []
272
-
273
- st.success("### 📋 Report")
274
- st.markdown(final_res)
275
 
276
  with st.expander("Trace"):
277
- for tool_name, tool_in in logs:
278
- st.markdown(f"**Tool:** `{tool_name}` | **Input:** `{tool_in}`")
279
 
280
  if not use_sim:
281
  st.divider()
282
- st.text_area("✉️ Email Draft", value=llm.invoke(f"Draft email for: {final_res}").content)
 
283
 
284
  else:
285
- st.info("👈 Enter API Key")
 
2
  import pandas as pd
3
  import os
4
  import warnings
 
5
  import sqlite3
6
  import shutil
7
  import asyncio
 
26
  REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate Policy.pdf"]
27
 
28
  try:
29
+ # GROQ (Keep LangChain)
30
  from langchain_groq import ChatGroq
31
+
32
+ # GOOGLE (Use Raw SDK - More Stable)
33
  import google.generativeai as genai
34
+
35
+ # SHARED UTILS
36
  from langchain_huggingface import HuggingFaceEmbeddings
37
  from langchain_community.vectorstores import FAISS
38
  from langchain_community.document_loaders import PyPDFLoader
 
40
  from langchain_core.prompts import PromptTemplate
41
  from langchain_core.runnables import RunnablePassthrough
42
  from langchain_core.output_parsers import StrOutputParser
 
43
  except ImportError as e:
44
  st.error(f"❌ Import Error: {e}")
45
  st.stop()
46
 
47
  # ==========================================
48
+ # 2. DATABASE & TOOLS
49
  # ==========================================
50
  def init_db():
51
  if os.path.exists(DB_FILE): return
 
70
  return cursor.fetchone()
71
  except Exception as e: return f"DB Error: {e}"
72
 
73
+ # --- DIRECT TOOL FUNCTIONS ---
74
  def tool_get_credit_score(user_id):
 
75
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
76
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
77
  return f"Credit Score: {row[0]}" if (row and not isinstance(row, str)) else "User ID not found."
78
 
79
  def tool_get_account_status(user_id):
 
80
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
81
  row = run_query("SELECT Name, Nationality, Account_Status, Email FROM account_status WHERE ID = ?", (clean_id,))
82
  if row and not isinstance(row, str):
 
84
  return "User ID not found."
85
 
86
  def tool_check_pr_status(user_id):
 
87
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
88
  row = run_query("SELECT PR_Status FROM pr_status WHERE ID = ?", (clean_id,))
89
  if not row or (isinstance(row, str) and "no such column" in row.lower()):
 
91
  return f"PR Status: {row[0]}" if (row and not isinstance(row, str)) else "PR Status: False."
92
 
93
  # ==========================================
94
+ # 3. HYBRID AGENT ENGINE (The Solution)
95
  # ==========================================
96
+ class HybridAgent:
97
+ def __init__(self, provider, api_key, tools_map, rag_chain):
98
+ self.provider = provider
99
+ self.api_key = api_key
100
  self.tools = tools_map
101
  self.rag_chain = rag_chain
102
+ self.max_steps = 8
103
+
104
+ # Initialize Groq here (Reusable)
105
+ if "Groq" in provider:
106
+ self.groq_chat = ChatGroq(api_key=api_key, model_name="llama-3.3-70b-versatile", temperature=0)
107
+
108
+ # Initialize Gemini Config
109
+ if "Google" in provider:
110
+ genai.configure(api_key=api_key)
111
+ # Use Flash - it's faster and smarter for tools
112
+ self.gemini_model = genai.GenerativeModel('gemini-1.5-flash')
113
+
114
+ def call_llm(self, prompt):
115
+ """Switches between LangChain (Groq) and Raw SDK (Gemini)"""
116
+ if "Groq" in self.provider:
117
+ return self.groq_chat.invoke(prompt).content
118
+ else:
119
+ # Native Google Call - Bypasses LangChain errors
120
+ try:
121
+ response = self.gemini_model.generate_content(prompt)
122
+ return response.text
123
+ except Exception as e:
124
+ return f"Gemini Error: {str(e)}"
125
 
126
  def run(self, query):
 
 
 
127
  tool_desc = "\n".join([f"- {name}: {func.__doc__}" for name, func in self.tools.items()])
128
+
129
+ history = f"""You are a Loan Officer. Solve this request: "{query}"
130
 
131
+ TOOLS AVAILABLE:
132
  {tool_desc}
133
+ - consult_policy_doc: Search PDF policies. Input: question string.
134
 
135
+ RULES:
136
+ 1. You run in a loop. OUTPUT ONLY ONE STEP AT A TIME.
137
+ 2. Format:
138
  Thought: <reasoning>
139
  Action: <tool_name>
140
+ Action Input: <input_string>
141
  Observation: <result>
142
+ ...
143
+ Final Answer: <the full report>
144
 
145
  Begin!
 
146
  """
 
147
  logs = []
148
 
 
149
  for i in range(self.max_steps):
150
+ # 1. Get LLM Response
151
+ response = self.call_llm(history)
152
  history += response + "\n"
153
 
154
+ # 2. Check for Final Answer
155
+ if "Final Answer:" in response:
156
+ return response.split("Final Answer:")[-1].strip(), logs
157
+
158
+ # 3. Parse Tool Call
159
  action_match = re.search(r"Action:\s*(.+)", response)
160
  input_match = re.search(r"Action Input:\s*(.+)", response)
 
 
 
 
 
161
 
 
162
  if action_match and input_match:
163
  tool_name = action_match.group(1).strip()
164
+ val = input_match.group(1).strip().strip('"').strip("'")
 
 
 
165
 
166
+ logs.append((tool_name, val))
167
 
168
  # Execute
169
+ result = "Error: Tool not found"
170
  if tool_name in self.tools:
171
+ try: result = self.tools[tool_name](val)
172
+ except Exception as e: result = f"Error: {e}"
 
 
173
  elif tool_name == "consult_policy_doc":
174
+ try: result = self.rag_chain.invoke(val)
175
+ except Exception as e: result = f"RAG Error: {e}"
 
 
176
 
177
+ # Feed back
178
+ obs = f"Observation: {result}\n"
179
+ history += obs
180
  else:
181
+ # Force agent to continue if it stops early
182
+ if i == self.max_steps - 1: return response, logs
183
+ history += "Observation: Please continue. Use 'Final Answer:' when done.\n"
 
184
 
185
  return "Agent timed out.", logs
186
 
187
  # ==========================================
188
+ # 4. UI & LOGIC
189
  # ==========================================
190
  st.title("🤖 Multi-Model Loan Assessor")
191
  pdfs_missing = [f for f in REQUIRED_PDFS if not os.path.exists(f)]
192
 
193
  with st.sidebar:
194
  st.header("🔐 Authentication")
195
+ provider_opt = st.radio("Model:", ["Groq (Llama-3)", "Google (Gemini)"])
196
 
197
+ if 'auth' not in st.session_state: st.session_state.auth = False
198
 
199
+ # Reset if provider changes
200
+ if st.session_state.get('last_provider') != provider_opt:
201
+ st.session_state.auth = False
202
+ st.session_state.last_provider = provider_opt
 
203
 
204
+ if not st.session_state.auth:
205
+ key_in = st.text_input("API Key", type="password")
206
+ if st.button("Validate"):
207
+ try:
208
+ # Simple Validation
209
+ if "Groq" in provider_opt:
210
+ ChatGroq(api_key=key_in).invoke("Hi")
211
+ else:
212
+ genai.configure(api_key=key_in)
213
+ genai.list_models()
214
+
215
+ st.session_state.auth = True
216
+ st.session_state.key = key_in
217
+ st.success("Valid!")
218
+ st.rerun()
219
+ except Exception as e:
220
+ st.error(f"Invalid: {e}")
221
+ else:
222
+ st.success("Active")
223
+ if st.button("Logout"):
224
+ st.session_state.auth = False
225
+ st.rerun()
226
+
227
+ if st.button("♻️ Reset DB"):
228
  if os.path.exists(INDEX_PATH): shutil.rmtree(INDEX_PATH)
229
  st.cache_resource.clear()
230
  st.rerun()
231
 
232
+ if st.session_state.auth:
233
  # --- RAG SETUP ---
234
  @st.cache_resource
235
  def setup_rag():
236
  if pdfs_missing: return None
237
+ # Always use HuggingFace embeddings (Free, Fast, Compatible)
238
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
239
+
240
  if os.path.exists(INDEX_PATH):
241
  return FAISS.load_local(INDEX_PATH, embeddings, allow_dangerous_deserialization=True).as_retriever()
242
+
243
+ docs = []
244
+ for f in REQUIRED_PDFS: docs.extend(PyPDFLoader(f).load())
245
+ splits = CharacterTextSplitter(chunk_size=600, chunk_overlap=50).split_documents(docs)
246
  vectorstore = FAISS.from_documents(splits, embeddings)
247
  vectorstore.save_local(INDEX_PATH)
248
  return vectorstore.as_retriever()
249
 
250
+ retriever = setup_rag()
 
251
 
252
+ # --- RAG CHAIN FOR TOOLS ---
253
+ # We use a separate Groq LLM for the RAG lookup to ensure it's fast/stable
254
+ # regardless of the main agent choice.
255
+ rag_llm = ChatGroq(api_key=st.session_state.key, model_name="llama-3.3-70b-versatile") if "Groq" in provider_opt else None
256
+
257
+ # Simple RAG Chain
258
+ def query_rag(q):
259
+ if not retriever: return "No PDFs found."
260
+ docs = retriever.invoke(q)
261
+ ctx = "\n".join([d.page_content for d in docs])
262
+ # If using Gemini, we format prompt manually for RAG too
263
+ return f"Context from Policy: {ctx}"
264
+
265
+ # Agent Tools Map
266
+ tools = {
 
 
 
 
 
 
 
267
  "get_credit_score": tool_get_credit_score,
268
  "get_account_status": tool_get_account_status,
269
  "check_pr_status": tool_check_pr_status
270
  }
271
+
272
+ # Initialize Hybrid Agent
273
+ # For RAG, we pass a simple lambda that calls our query_rag function
274
+ rag_lambda = type('RAG', (object,), {"invoke": lambda self, x: query_rag(x)})()
275
+
276
+ agent = HybridAgent(provider_opt, st.session_state.key, tools, rag_lambda)
277
 
278
  # --- UI ---
279
  col1, col2 = st.columns([1, 2])
280
  with col1:
281
  uid = st.text_input("Customer ID", "1111")
282
  use_sim = st.checkbox("Simulation Mode")
283
+ s_score = st.slider("Score", 300, 900, 450) if use_sim else 0
284
+ s_status = st.selectbox("Status", ["good-standing", "closed", "delinquent"]) if use_sim else ""
285
+ btn = st.button("Assess")
286
 
287
  with col2:
288
  if btn:
289
+ q = f"Process Loan ID {uid}. "
290
+ if use_sim: q += f"SIMULATION: Score {s_score}, Status '{s_status}'. Skip DB for those."
291
+ else: q += "Query DB for all data."
292
+ q += " Check Policy. Report Risk, Rate, Decision."
293
+
294
+ with st.status("Agent Working...", expanded=True):
295
+ ans, logs = agent.run(q)
296
+ st.write("Done!")
297
+
298
+ st.success("### Final Report")
299
+ st.markdown(ans)
 
 
 
 
 
 
 
300
 
301
  with st.expander("Trace"):
302
+ for t, i in logs: st.write(f"**{t}**: {i}")
 
303
 
304
  if not use_sim:
305
  st.divider()
306
+ with st.expander("Draft Email"):
307
+ st.text_area("Content", value=agent.call_llm(f"Draft email for: {ans}"))
308
 
309
  else:
310
+ st.info("👈 Login Required")