loan_agent / app.py
larrysim's picture
Update app.py
3ee6129 verified
import streamlit as st
import pandas as pd
import os
import warnings
import time
import sqlite3
import shutil
# ==========================================
# 1. PAGE CONFIG (MUST BE FIRST)
# ==========================================
st.set_page_config(page_title="Bank Loan Agent (SQL)", layout="wide")
# Suppress warnings
warnings.filterwarnings("ignore")
# ==========================================
# 2. GLOBAL CONSTANTS & IMPORTS
# ==========================================
DB_FILE = "bank.db"
INDEX_PATH = "faiss_index"
REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate Policy.pdf"]
try:
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.callbacks import StreamlitCallbackHandler
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import CharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.tools import tool
from langchain.agents import AgentExecutor, create_tool_calling_agent
except ImportError as e:
st.error(f"❌ Critical Import Error: {e}")
st.stop()
# ==========================================
# 3. DATABASE SETUP
# ==========================================
def init_db():
"""Converts CSV files to SQLite DB. Handles errors gracefully."""
if os.path.exists(DB_FILE):
return
conn = sqlite3.connect(DB_FILE)
csv_files = {
"credit_score": "credit_score.csv",
"account_status": "account_status.csv",
"pr_status": "pr_status.csv"
}
try:
for table, file in csv_files.items():
if os.path.exists(file):
df = pd.read_csv(file)
df.columns = [c.strip() for c in df.columns]
if 'ID' in df.columns:
df['ID'] = df['ID'].astype(str)
try:
df.to_sql(table, conn, if_exists='replace', index=False)
except Exception:
pass
except Exception as e:
st.error(f"DB Init Error: {e}")
finally:
conn.close()
# Initialize DB on startup
init_db()
# Helper for SQL tools
def run_query(query, params=()):
try:
with sqlite3.connect(DB_FILE) as conn:
cursor = conn.cursor()
cursor.execute(query, params)
return cursor.fetchone()
except Exception as e:
return f"DB Error: {e}"
# ==========================================
# 4. DEFINE TOOLS
# ==========================================
@tool
def get_credit_score(user_id: str) -> str:
"""Queries SQL DB for Credit Score."""
clean_id = ''.join(filter(str.isdigit, str(user_id)))
row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,))
if row and not isinstance(row, str):
return f"Credit Score: {row[0]}"
return "User ID not found in Credit DB."
@tool
def get_account_status(user_id: str) -> str:
"""Queries SQL DB for Name, Nationality, Status, and Email."""
clean_id = ''.join(filter(str.isdigit, str(user_id)))
row = run_query(
"SELECT Name, Nationality, Account_Status, Email FROM account_status WHERE ID = ?",
(clean_id,)
)
if row and not isinstance(row, str):
return f"Customer Name: {row[0]}, Nationality: {row[1]}, Status: {row[2]}, Email: {row[3]}"
return "User ID not found in Account DB."
@tool
def check_pr_status(user_id: str) -> str:
"""Queries SQL DB for PR Status."""
clean_id = ''.join(filter(str.isdigit, str(user_id)))
row = run_query("SELECT PR_Status FROM pr_status WHERE ID = ?", (clean_id,))
if not row or (isinstance(row, str) and "no such column" in row.lower()):
row = run_query("SELECT Is_PR FROM pr_status WHERE ID = ?", (clean_id,))
if row and not isinstance(row, str):
return f"PR Status: {row[0]}"
return "PR Status: False (Record not found)"
# ==========================================
# 5. STREAMLIT APP UI
# ==========================================
st.title("πŸ€– Multi-Policy Loan Assessor (SQL + RAG)")
st.markdown("Agent connects to **SQLite Database** and **Persistent Vector Store**")
# Calculate missing PDFs globally so everyone can see it
pdfs_missing = [f for f in REQUIRED_PDFS if not os.path.exists(f)]
# --- METRICS FUNCTION ---
def update_metrics(placeholder):
manual_time = 15 * 60
if 'execution_time' in st.session_state:
ai_time = st.session_state.execution_time
time_saved = manual_time - ai_time
saved_pct = (time_saved / manual_time) * 100
with placeholder.container():
col_kpi1, col_kpi2 = st.columns(2)
col_kpi1.metric("AI Processing", f"{ai_time:.1f}s")
col_kpi2.metric("Time Saved", f"{time_saved/60:.1f} min", delta=f"{saved_pct:.1f}% faster")
# --- SIDEBAR ---
with st.sidebar:
st.header("πŸ” Authentication")
if 'is_key_valid' not in st.session_state:
st.session_state['is_key_valid'] = False
if not st.session_state['is_key_valid']:
api_key_input = st.text_input("Enter Groq API Key", type="password", key="input_key")
if st.button("Validate API Key"):
if not api_key_input:
st.error("⚠️ Please enter a key.")
else:
try:
with st.spinner("Validating..."):
test_llm = ChatGroq(api_key=api_key_input, model_name="llama-3.3-70b-versatile")
test_llm.invoke("Test")
st.session_state['groq_api_key'] = api_key_input
st.session_state['is_key_valid'] = True
st.success("βœ… Valid Key!")
time.sleep(0.5)
st.rerun()
except Exception as e:
st.error(f"❌ Invalid Key: {e}")
else:
st.success("βœ… API Key Active")
if st.button("πŸ”΄ Reset Key"):
st.session_state['is_key_valid'] = False
st.session_state['groq_api_key'] = None
st.rerun()
st.divider()
st.subheader("πŸ› οΈ System Maintenance")
if st.button("♻️ Rebuild Knowledge Base"):
if os.path.exists(INDEX_PATH):
shutil.rmtree(INDEX_PATH)
st.cache_resource.clear()
st.success("Cache cleared.")
time.sleep(1)
st.rerun()
if st.button("πŸ’Ύ Reload CSVs to DB"):
if os.path.exists(DB_FILE):
os.remove(DB_FILE)
init_db()
st.success("Database refreshed.")
st.divider()
if os.path.exists(DB_FILE) and not pdfs_missing:
st.success("βœ… System Ready")
else:
st.warning(f"⚠️ Missing: {pdfs_missing}")
st.header("πŸ“Š Metrics")
metrics_placeholder = st.empty()
update_metrics(metrics_placeholder)
# --- MAIN LOGIC ---
if st.session_state.get('is_key_valid', False):
os.environ["GROQ_API_KEY"] = st.session_state['groq_api_key']
# --- RAG SETUP ---
@st.cache_resource
def setup_rag():
if pdfs_missing:
st.error(f"Missing PDFs: {pdfs_missing}")
st.stop()
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
if os.path.exists(INDEX_PATH):
return FAISS.load_local(INDEX_PATH, embeddings, allow_dangerous_deserialization=True).as_retriever()
else:
documents = []
for pdf_file in REQUIRED_PDFS:
loader = PyPDFLoader(pdf_file)
documents.extend(loader.load())
text_splitter = CharacterTextSplitter(chunk_size=600, chunk_overlap=50)
final_docs = text_splitter.split_documents(documents)
vectorstore = FAISS.from_documents(final_docs, embeddings)
vectorstore.save_local(INDEX_PATH)
return vectorstore.as_retriever()
with st.spinner("Initializing AI..."):
retriever = setup_rag()
llm = ChatGroq(temperature=0, model_name="llama-3.3-70b-versatile")
rag_prompt = ChatPromptTemplate.from_template("Answer based on context:\n{context}\nQuestion: {question}")
rag_chain = (
{"context": retriever | (lambda d: "\n".join([x.page_content for x in d])), "question": RunnablePassthrough()}
| rag_prompt | llm | StrOutputParser()
)
@tool
def consult_policy_doc(query: str) -> str:
"""Consults Policy Documents for Risk Rules."""
return rag_chain.invoke(query)
tools = [get_credit_score, get_account_status, check_pr_status, consult_policy_doc]
# ============================================================
# MODIFIED PROMPT: Enforcing the PDF Steps Structure
# ============================================================
system_instruction = """You are a strict Bank Loan Officer.
You MUST execute the loan assessment following strictly these 4 steps and this exact output format.
IMPORTANT FORMATTING RULES:
1. Use '###' (Heading 3) for all Step titles.
2. Use '**' (Bold) for all labels.
3. Do NOT use Heading 1 (#) or Heading 2 (##) to ensure consistent font size.
REQUIRED FLOW:
Step 1.Retrieve information for customer information
Credit Score: [Score] , Account Status: [Status] , Nationality: [Nationality]
Step 2. Check PR Status (For Non-Singapore this extra Step is needed)
**LOGIC RULE:** If 'check_pr_status' returns "False" or "Record not found", you MUST interpret this as "Non-PR" in your final report. Do NOT write "False".
Step 3.Check Overall Risk
Credit Score: [Score] , Account Status: [Status] -> overall risk: Consult policy doc for the risk matrix to decide this)
Step 4.Check interest rate
overall risk: [Level] -> [Rate]%
(Consult policy doc for interest rates)
Step 5. Report
Recommend the loan interest rate [Rate]%
INSTRUCTIONS:
1. Use SQL tools to get Name, ID, Email, Score, Status, Nationality.
2. If Nationality is NOT Singaporean, you MUST check PR status.
3. Use 'consult_policy_doc' to find the Risk Matrix and Interest Rates.
4. Provide a Final Recommendation Report that MUST include:
- Customer Name, ID, Email
- Risk Level, Interest Rate
- Final Decision (Approve/Reject)
- Justification for Decision (Cite specific PDF policies)
5. Format it in a clear markdown table.
"""
prompt = ChatPromptTemplate.from_messages([
("system", system_instruction),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("1. Customer Details")
uid = st.text_input("Customer ID", "1111")
use_simulation = st.checkbox("Simulation Mode")
sim_score = 650
sim_status = "good-standing"
if use_simulation:
sim_score = st.slider("Sim Credit Score", 300, 900, 450, step=10)
sim_status = st.selectbox("Sim Status", ["good-standing", "closed", "delinquent"])
st.divider()
btn = st.button("Assess Loan Risk", type="primary")
with col2:
if btn:
# We simplified the query here because the strict instructions are now in the System Prompt
if use_simulation:
query = f"""
Process Loan for Customer ID: {uid}.
*** SIMULATION MODE ***
1. DO NOT query 'get_credit_score' or 'account_status' for Score/Status.
2. USE: Score: {sim_score}, Status: {sim_status}
3. Query 'get_account_status' ONLY for Name/Nationality/Email.
4. Follow the strict 5-step flow defined in the system instructions.
"""
else:
query = f"""
Process Loan for Customer ID: {uid}.
1. Query SQL tools for Name, Email, Nationality, Status, Score.
2. IF Nationality is 'Singaporean', SKIP 'check_pr_status'.
3. Follow the strict 5-step flow defined in the system instructions.
"""
with st.status("πŸ€– Agent is processing...", expanded=True) as status:
st_callback = StreamlitCallbackHandler(st.container())
try:
start_time = time.time()
res = agent_executor.invoke({"input": query}, {"callbacks": [st_callback]})
end_time = time.time()
st.session_state.execution_time = end_time - start_time
update_metrics(metrics_placeholder)
status.update(label="βœ… Complete!", state="complete", expanded=False)
except Exception as e:
st.error(f"Error: {e}")
st.stop()
st.success("### πŸ“‹ Final Recommendation")
with st.container(border=True):
st.markdown(res['output'])
with st.expander("πŸ” Detailed Trace"):
steps = res.get("intermediate_steps", [])
for i, (action, observation) in enumerate(steps):
st.markdown(f"**Step {i+1}:** Tool `{action.tool}` | Output: `{observation}`")
if not use_simulation:
st.divider()
with st.expander("βœ‰οΈ Draft Email"):
email_prompt = f"Write a formal email based on this decision: {res['output']}"
with st.spinner("Drafting..."):
email_draft = llm.invoke(email_prompt).content
st.text_area("Email Draft", value=email_draft, height=200)
elif not st.session_state.get('is_key_valid', False):
st.info("πŸ‘ˆ Please validate your Groq API Key.")