larrysim commited on
Commit
fa537a2
·
verified ·
1 Parent(s): 01a9a12

Update app.py

Browse files

fix model error

Files changed (1) hide show
  1. app.py +43 -36
app.py CHANGED
@@ -26,14 +26,14 @@ INDEX_PATH = "faiss_index"
26
  REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate Policy.pdf"]
27
 
28
  try:
29
- # GROQ (Keep LangChain)
30
  from langchain_groq import ChatGroq
31
 
32
- # GOOGLE (Use Raw SDK - Stable)
33
  import google.generativeai as genai
34
  from google.generativeai.types import HarmCategory, HarmBlockThreshold
35
 
36
- # SHARED UTILS
37
  from langchain_huggingface import HuggingFaceEmbeddings
38
  from langchain_community.vectorstores import FAISS
39
  from langchain_community.document_loaders import PyPDFLoader
@@ -71,7 +71,7 @@ def run_query(query, params=()):
71
  return cursor.fetchone()
72
  except Exception as e: return f"DB Error: {e}"
73
 
74
- # --- DIRECT TOOL FUNCTIONS ---
75
  def tool_get_credit_score(user_id):
76
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
77
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
@@ -92,7 +92,7 @@ def tool_check_pr_status(user_id):
92
  return f"PR Status: {row[0]}" if (row and not isinstance(row, str)) else "PR Status: False."
93
 
94
  # ==========================================
95
- # 3. HYBRID AGENT ENGINE (The Solution)
96
  # ==========================================
97
  class HybridAgent:
98
  def __init__(self, provider, api_key, tools_map, rag_chain):
@@ -101,25 +101,47 @@ class HybridAgent:
101
  self.tools = tools_map
102
  self.rag_chain = rag_chain
103
  self.max_steps = 8
 
104
 
105
  # Initialize Groq
106
  if "Groq" in provider:
107
  self.groq_chat = ChatGroq(api_key=api_key, model_name="llama-3.3-70b-versatile", temperature=0)
108
 
109
- # Initialize Gemini (Native SDK)
110
  if "Google" in provider:
111
  genai.configure(api_key=api_key)
112
- # FIX: Use 'gemini-pro' (Stable) instead of 'flash' (404 Error)
113
- self.gemini_model = genai.GenerativeModel('gemini-pro')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  def call_llm(self, prompt):
116
- """Switches between LangChain (Groq) and Raw SDK (Gemini)"""
117
  if "Groq" in self.provider:
118
  return self.groq_chat.invoke(prompt).content
119
  else:
120
- # Native Google Call
121
  try:
122
- # Safety Settings to prevent blocking
123
  safety = {
124
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
125
  HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
@@ -155,15 +177,12 @@ Begin!
155
  logs = []
156
 
157
  for i in range(self.max_steps):
158
- # 1. Get LLM Response
159
  response = self.call_llm(history)
160
  history += response + "\n"
161
 
162
- # 2. Check for Final Answer
163
  if "Final Answer:" in response:
164
  return response.split("Final Answer:")[-1].strip(), logs
165
 
166
- # 3. Parse Tool Call
167
  action_match = re.search(r"Action:\s*(.+)", response)
168
  input_match = re.search(r"Action Input:\s*(.+)", response)
169
 
@@ -173,7 +192,6 @@ Begin!
173
 
174
  logs.append((tool_name, val))
175
 
176
- # Execute
177
  result = "Error: Tool not found"
178
  if tool_name in self.tools:
179
  try: result = self.tools[tool_name](val)
@@ -182,11 +200,8 @@ Begin!
182
  try: result = self.rag_chain.invoke(val)
183
  except Exception as e: result = f"RAG Error: {e}"
184
 
185
- # Feed back
186
- obs = f"Observation: {result}\n"
187
- history += obs
188
  else:
189
- # Force agent to continue if it stops early
190
  if i == self.max_steps - 1: return response, logs
191
  history += "Observation: Please continue. Use 'Final Answer:' when done.\n"
192
 
@@ -204,7 +219,6 @@ with st.sidebar:
204
 
205
  if 'auth' not in st.session_state: st.session_state.auth = False
206
 
207
- # Reset if provider changes
208
  if st.session_state.get('last_provider') != provider_opt:
209
  st.session_state.auth = False
210
  st.session_state.last_provider = provider_opt
@@ -213,12 +227,12 @@ with st.sidebar:
213
  key_in = st.text_input("API Key", type="password")
214
  if st.button("Validate"):
215
  try:
216
- # Validation Logic
217
  if "Groq" in provider_opt:
218
  ChatGroq(api_key=key_in).invoke("Hi")
219
  else:
220
  genai.configure(api_key=key_in)
221
- genai.list_models()
 
222
 
223
  st.session_state.auth = True
224
  st.session_state.key = key_in
@@ -238,13 +252,10 @@ with st.sidebar:
238
  st.rerun()
239
 
240
  if st.session_state.auth:
241
- # --- RAG SETUP ---
242
  @st.cache_resource
243
  def setup_rag():
244
  if pdfs_missing: return None
245
- # Always use HuggingFace embeddings (Free, Fast, Compatible)
246
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
247
-
248
  if os.path.exists(INDEX_PATH):
249
  return FAISS.load_local(INDEX_PATH, embeddings, allow_dangerous_deserialization=True).as_retriever()
250
 
@@ -257,28 +268,20 @@ if st.session_state.auth:
257
 
258
  retriever = setup_rag()
259
 
260
- # --- RAG CHAIN FOR TOOLS ---
261
- # Use Groq for RAG processing if available (faster), otherwise skip or use simplified
262
  def query_rag(q):
263
  if not retriever: return "No PDFs found."
264
  docs = retriever.invoke(q)
265
- ctx = "\n".join([d.page_content for d in docs])
266
- return f"Context from Policy: {ctx}"
267
 
268
- # Agent Tools Map
269
  tools = {
270
  "get_credit_score": tool_get_credit_score,
271
  "get_account_status": tool_get_account_status,
272
  "check_pr_status": tool_check_pr_status
273
  }
274
 
275
- # Initialize Hybrid Agent
276
- # For RAG, we pass a simple lambda that calls our query_rag function
277
  rag_lambda = type('RAG', (object,), {"invoke": lambda self, x: query_rag(x)})()
278
-
279
  agent = HybridAgent(provider_opt, st.session_state.key, tools, rag_lambda)
280
 
281
- # --- UI ---
282
  col1, col2 = st.columns([1, 2])
283
  with col1:
284
  uid = st.text_input("Customer ID", "1111")
@@ -295,8 +298,12 @@ if st.session_state.auth:
295
  q += " Check Policy. Report Risk, Rate, Decision."
296
 
297
  with st.status("Agent Working...", expanded=True):
298
- ans, logs = agent.run(q)
299
- st.write("Done!")
 
 
 
 
300
 
301
  st.success("### Final Report")
302
  st.markdown(ans)
 
26
  REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate Policy.pdf"]
27
 
28
  try:
29
+ # GROQ
30
  from langchain_groq import ChatGroq
31
 
32
+ # GOOGLE (Native SDK)
33
  import google.generativeai as genai
34
  from google.generativeai.types import HarmCategory, HarmBlockThreshold
35
 
36
+ # SHARED
37
  from langchain_huggingface import HuggingFaceEmbeddings
38
  from langchain_community.vectorstores import FAISS
39
  from langchain_community.document_loaders import PyPDFLoader
 
71
  return cursor.fetchone()
72
  except Exception as e: return f"DB Error: {e}"
73
 
74
+ # --- TOOLS ---
75
  def tool_get_credit_score(user_id):
76
  clean_id = ''.join(filter(str.isdigit, str(user_id)))
77
  row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
 
92
  return f"PR Status: {row[0]}" if (row and not isinstance(row, str)) else "PR Status: False."
93
 
94
  # ==========================================
95
+ # 3. HYBRID AGENT (Dynamic Model Loader)
96
  # ==========================================
97
  class HybridAgent:
98
  def __init__(self, provider, api_key, tools_map, rag_chain):
 
101
  self.tools = tools_map
102
  self.rag_chain = rag_chain
103
  self.max_steps = 8
104
+ self.gemini_model = None
105
 
106
  # Initialize Groq
107
  if "Groq" in provider:
108
  self.groq_chat = ChatGroq(api_key=api_key, model_name="llama-3.3-70b-versatile", temperature=0)
109
 
110
+ # Initialize Gemini with DYNAMIC DISCOVERY
111
  if "Google" in provider:
112
  genai.configure(api_key=api_key)
113
+ self.gemini_model = self._find_best_gemini_model()
114
+
115
+ def _find_best_gemini_model(self):
116
+ """Auto-detects which Gemini model is actually available to avoid 404s."""
117
+ try:
118
+ available_models = [m.name for m in genai.list_models() if 'generateContent' in m.supported_generation_methods]
119
+
120
+ # Priority 1: Flash (Fastest)
121
+ for m in available_models:
122
+ if "flash" in m and "1.5" in m: return genai.GenerativeModel(m)
123
+
124
+ # Priority 2: Pro 1.5
125
+ for m in available_models:
126
+ if "pro" in m and "1.5" in m: return genai.GenerativeModel(m)
127
+
128
+ # Priority 3: Pro 1.0 / Standard
129
+ for m in available_models:
130
+ if "gemini-pro" in m: return genai.GenerativeModel(m)
131
+
132
+ # Fallback: Just take the first one
133
+ if available_models: return genai.GenerativeModel(available_models[0])
134
+
135
+ return genai.GenerativeModel('gemini-pro') # Blind hope
136
+ except:
137
+ return genai.GenerativeModel('gemini-1.5-flash') # Default
138
 
139
  def call_llm(self, prompt):
 
140
  if "Groq" in self.provider:
141
  return self.groq_chat.invoke(prompt).content
142
  else:
 
143
  try:
144
+ # Disable safety to prevent "list index out of range" errors
145
  safety = {
146
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
147
  HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
 
177
  logs = []
178
 
179
  for i in range(self.max_steps):
 
180
  response = self.call_llm(history)
181
  history += response + "\n"
182
 
 
183
  if "Final Answer:" in response:
184
  return response.split("Final Answer:")[-1].strip(), logs
185
 
 
186
  action_match = re.search(r"Action:\s*(.+)", response)
187
  input_match = re.search(r"Action Input:\s*(.+)", response)
188
 
 
192
 
193
  logs.append((tool_name, val))
194
 
 
195
  result = "Error: Tool not found"
196
  if tool_name in self.tools:
197
  try: result = self.tools[tool_name](val)
 
200
  try: result = self.rag_chain.invoke(val)
201
  except Exception as e: result = f"RAG Error: {e}"
202
 
203
+ history += f"Observation: {result}\n"
 
 
204
  else:
 
205
  if i == self.max_steps - 1: return response, logs
206
  history += "Observation: Please continue. Use 'Final Answer:' when done.\n"
207
 
 
219
 
220
  if 'auth' not in st.session_state: st.session_state.auth = False
221
 
 
222
  if st.session_state.get('last_provider') != provider_opt:
223
  st.session_state.auth = False
224
  st.session_state.last_provider = provider_opt
 
227
  key_in = st.text_input("API Key", type="password")
228
  if st.button("Validate"):
229
  try:
 
230
  if "Groq" in provider_opt:
231
  ChatGroq(api_key=key_in).invoke("Hi")
232
  else:
233
  genai.configure(api_key=key_in)
234
+ # Quick list check to validate key
235
+ [m.name for m in genai.list_models()]
236
 
237
  st.session_state.auth = True
238
  st.session_state.key = key_in
 
252
  st.rerun()
253
 
254
  if st.session_state.auth:
 
255
  @st.cache_resource
256
  def setup_rag():
257
  if pdfs_missing: return None
 
258
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
 
259
  if os.path.exists(INDEX_PATH):
260
  return FAISS.load_local(INDEX_PATH, embeddings, allow_dangerous_deserialization=True).as_retriever()
261
 
 
268
 
269
  retriever = setup_rag()
270
 
 
 
271
  def query_rag(q):
272
  if not retriever: return "No PDFs found."
273
  docs = retriever.invoke(q)
274
+ return "Context: " + "\n".join([d.page_content for d in docs])
 
275
 
 
276
  tools = {
277
  "get_credit_score": tool_get_credit_score,
278
  "get_account_status": tool_get_account_status,
279
  "check_pr_status": tool_check_pr_status
280
  }
281
 
 
 
282
  rag_lambda = type('RAG', (object,), {"invoke": lambda self, x: query_rag(x)})()
 
283
  agent = HybridAgent(provider_opt, st.session_state.key, tools, rag_lambda)
284
 
 
285
  col1, col2 = st.columns([1, 2])
286
  with col1:
287
  uid = st.text_input("Customer ID", "1111")
 
298
  q += " Check Policy. Report Risk, Rate, Decision."
299
 
300
  with st.status("Agent Working...", expanded=True):
301
+ try:
302
+ ans, logs = agent.run(q)
303
+ st.write("Done!")
304
+ except Exception as e:
305
+ st.error(f"Execution Error: {e}")
306
+ ans, logs = "Error", []
307
 
308
  st.success("### Final Report")
309
  st.markdown(ans)