Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import os | |
| import warnings | |
| import time | |
| import sqlite3 | |
| import shutil | |
| # ========================================== | |
| # 1. PAGE CONFIG (MUST BE FIRST) | |
| # ========================================== | |
| st.set_page_config(page_title="Bank Loan Agent (SQL)", layout="wide") | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| # ========================================== | |
| # 2. GLOBAL CONSTANTS & IMPORTS | |
| # ========================================== | |
| DB_FILE = "bank.db" | |
| INDEX_PATH = "faiss_index" | |
| REQUIRED_PDFS = ["Bank Loan Overall Risk Policy.pdf", "Bank Loan Interest Rate Policy.pdf"] | |
| try: | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.callbacks import StreamlitCallbackHandler | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import CharacterTextSplitter | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.tools import tool | |
| from langchain.agents import AgentExecutor, create_tool_calling_agent | |
| except ImportError as e: | |
| st.error(f"β Critical Import Error: {e}") | |
| st.stop() | |
| # ========================================== | |
| # 3. DATABASE SETUP | |
| # ========================================== | |
| def init_db(): | |
| """Converts CSV files to SQLite DB. Handles errors gracefully.""" | |
| if os.path.exists(DB_FILE): | |
| return | |
| conn = sqlite3.connect(DB_FILE) | |
| csv_files = { | |
| "credit_score": "credit_score.csv", | |
| "account_status": "account_status.csv", | |
| "pr_status": "pr_status.csv" | |
| } | |
| try: | |
| for table, file in csv_files.items(): | |
| if os.path.exists(file): | |
| df = pd.read_csv(file) | |
| df.columns = [c.strip() for c in df.columns] | |
| if 'ID' in df.columns: | |
| df['ID'] = df['ID'].astype(str) | |
| try: | |
| df.to_sql(table, conn, if_exists='replace', index=False) | |
| except Exception: | |
| pass | |
| except Exception as e: | |
| st.error(f"DB Init Error: {e}") | |
| finally: | |
| conn.close() | |
| # Initialize DB on startup | |
| init_db() | |
| # Helper for SQL tools | |
| def run_query(query, params=()): | |
| try: | |
| with sqlite3.connect(DB_FILE) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(query, params) | |
| return cursor.fetchone() | |
| except Exception as e: | |
| return f"DB Error: {e}" | |
| # ========================================== | |
| # 4. DEFINE TOOLS | |
| # ========================================== | |
| def get_credit_score(user_id: str) -> str: | |
| """Queries SQL DB for Credit Score.""" | |
| clean_id = ''.join(filter(str.isdigit, str(user_id))) | |
| row = run_query("SELECT Credit_Score FROM credit_score WHERE ID = ?", (clean_id,)) | |
| if row and not isinstance(row, str): | |
| return f"Credit Score: {row[0]}" | |
| return "User ID not found in Credit DB." | |
| def get_account_status(user_id: str) -> str: | |
| """Queries SQL DB for Name, Nationality, Status, and Email.""" | |
| clean_id = ''.join(filter(str.isdigit, str(user_id))) | |
| row = run_query( | |
| "SELECT Name, Nationality, Account_Status, Email FROM account_status WHERE ID = ?", | |
| (clean_id,) | |
| ) | |
| if row and not isinstance(row, str): | |
| return f"Customer Name: {row[0]}, Nationality: {row[1]}, Status: {row[2]}, Email: {row[3]}" | |
| return "User ID not found in Account DB." | |
| def check_pr_status(user_id: str) -> str: | |
| """Queries SQL DB for PR Status.""" | |
| clean_id = ''.join(filter(str.isdigit, str(user_id))) | |
| row = run_query("SELECT PR_Status FROM pr_status WHERE ID = ?", (clean_id,)) | |
| if not row or (isinstance(row, str) and "no such column" in row.lower()): | |
| row = run_query("SELECT Is_PR FROM pr_status WHERE ID = ?", (clean_id,)) | |
| if row and not isinstance(row, str): | |
| return f"PR Status: {row[0]}" | |
| return "PR Status: False (Record not found)" | |
| # ========================================== | |
| # 5. STREAMLIT APP UI | |
| # ========================================== | |
| st.title("π€ Multi-Policy Loan Assessor (SQL + RAG)") | |
| st.markdown("Agent connects to **SQLite Database** and **Persistent Vector Store**") | |
| # Calculate missing PDFs globally so everyone can see it | |
| pdfs_missing = [f for f in REQUIRED_PDFS if not os.path.exists(f)] | |
| # --- METRICS FUNCTION --- | |
| def update_metrics(placeholder): | |
| manual_time = 15 * 60 | |
| if 'execution_time' in st.session_state: | |
| ai_time = st.session_state.execution_time | |
| time_saved = manual_time - ai_time | |
| saved_pct = (time_saved / manual_time) * 100 | |
| with placeholder.container(): | |
| col_kpi1, col_kpi2 = st.columns(2) | |
| col_kpi1.metric("AI Processing", f"{ai_time:.1f}s") | |
| col_kpi2.metric("Time Saved", f"{time_saved/60:.1f} min", delta=f"{saved_pct:.1f}% faster") | |
| # --- SIDEBAR --- | |
| with st.sidebar: | |
| st.header("π Authentication") | |
| if 'is_key_valid' not in st.session_state: | |
| st.session_state['is_key_valid'] = False | |
| if not st.session_state['is_key_valid']: | |
| api_key_input = st.text_input("Enter Groq API Key", type="password", key="input_key") | |
| if st.button("Validate API Key"): | |
| if not api_key_input: | |
| st.error("β οΈ Please enter a key.") | |
| else: | |
| try: | |
| with st.spinner("Validating..."): | |
| test_llm = ChatGroq(api_key=api_key_input, model_name="llama-3.3-70b-versatile") | |
| test_llm.invoke("Test") | |
| st.session_state['groq_api_key'] = api_key_input | |
| st.session_state['is_key_valid'] = True | |
| st.success("β Valid Key!") | |
| time.sleep(0.5) | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"β Invalid Key: {e}") | |
| else: | |
| st.success("β API Key Active") | |
| if st.button("π΄ Reset Key"): | |
| st.session_state['is_key_valid'] = False | |
| st.session_state['groq_api_key'] = None | |
| st.rerun() | |
| st.divider() | |
| st.subheader("π οΈ System Maintenance") | |
| if st.button("β»οΈ Rebuild Knowledge Base"): | |
| if os.path.exists(INDEX_PATH): | |
| shutil.rmtree(INDEX_PATH) | |
| st.cache_resource.clear() | |
| st.success("Cache cleared.") | |
| time.sleep(1) | |
| st.rerun() | |
| if st.button("πΎ Reload CSVs to DB"): | |
| if os.path.exists(DB_FILE): | |
| os.remove(DB_FILE) | |
| init_db() | |
| st.success("Database refreshed.") | |
| st.divider() | |
| if os.path.exists(DB_FILE) and not pdfs_missing: | |
| st.success("β System Ready") | |
| else: | |
| st.warning(f"β οΈ Missing: {pdfs_missing}") | |
| st.header("π Metrics") | |
| metrics_placeholder = st.empty() | |
| update_metrics(metrics_placeholder) | |
| # --- MAIN LOGIC --- | |
| if st.session_state.get('is_key_valid', False): | |
| os.environ["GROQ_API_KEY"] = st.session_state['groq_api_key'] | |
| # --- RAG SETUP --- | |
| def setup_rag(): | |
| if pdfs_missing: | |
| st.error(f"Missing PDFs: {pdfs_missing}") | |
| st.stop() | |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| if os.path.exists(INDEX_PATH): | |
| return FAISS.load_local(INDEX_PATH, embeddings, allow_dangerous_deserialization=True).as_retriever() | |
| else: | |
| documents = [] | |
| for pdf_file in REQUIRED_PDFS: | |
| loader = PyPDFLoader(pdf_file) | |
| documents.extend(loader.load()) | |
| text_splitter = CharacterTextSplitter(chunk_size=600, chunk_overlap=50) | |
| final_docs = text_splitter.split_documents(documents) | |
| vectorstore = FAISS.from_documents(final_docs, embeddings) | |
| vectorstore.save_local(INDEX_PATH) | |
| return vectorstore.as_retriever() | |
| with st.spinner("Initializing AI..."): | |
| retriever = setup_rag() | |
| llm = ChatGroq(temperature=0, model_name="llama-3.3-70b-versatile") | |
| rag_prompt = ChatPromptTemplate.from_template("Answer based on context:\n{context}\nQuestion: {question}") | |
| rag_chain = ( | |
| {"context": retriever | (lambda d: "\n".join([x.page_content for x in d])), "question": RunnablePassthrough()} | |
| | rag_prompt | llm | StrOutputParser() | |
| ) | |
| def consult_policy_doc(query: str) -> str: | |
| """Consults Policy Documents for Risk Rules.""" | |
| return rag_chain.invoke(query) | |
| tools = [get_credit_score, get_account_status, check_pr_status, consult_policy_doc] | |
| # ============================================================ | |
| # MODIFIED PROMPT: Enforcing the PDF Steps Structure | |
| # ============================================================ | |
| system_instruction = """You are a strict Bank Loan Officer. | |
| You MUST execute the loan assessment following strictly these 4 steps and this exact output format. | |
| IMPORTANT FORMATTING RULES: | |
| 1. Use '###' (Heading 3) for all Step titles. | |
| 2. Use '**' (Bold) for all labels. | |
| 3. Do NOT use Heading 1 (#) or Heading 2 (##) to ensure consistent font size. | |
| REQUIRED FLOW: | |
| Step 1.Retrieve information for customer information | |
| Credit Score: [Score] , Account Status: [Status] , Nationality: [Nationality] | |
| Step 2. Check PR Status (For Non-Singapore this extra Step is needed) | |
| **LOGIC RULE:** If 'check_pr_status' returns "False" or "Record not found", you MUST interpret this as "Non-PR" in your final report. Do NOT write "False". | |
| Step 3.Check Overall Risk | |
| Credit Score: [Score] , Account Status: [Status] -> overall risk: Consult policy doc for the risk matrix to decide this) | |
| Step 4.Check interest rate | |
| overall risk: [Level] -> [Rate]% | |
| (Consult policy doc for interest rates) | |
| Step 5. Report | |
| Recommend the loan interest rate [Rate]% | |
| INSTRUCTIONS: | |
| 1. Use SQL tools to get Name, ID, Email, Score, Status, Nationality. | |
| 2. If Nationality is NOT Singaporean, you MUST check PR status. | |
| 3. Use 'consult_policy_doc' to find the Risk Matrix and Interest Rates. | |
| 4. Provide a Final Recommendation Report that MUST include: | |
| - Customer Name, ID, Email | |
| - Risk Level, Interest Rate | |
| - Final Decision (Approve/Reject) | |
| - Justification for Decision (Cite specific PDF policies) | |
| 5. Format it in a clear markdown table. | |
| """ | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system_instruction), | |
| ("human", "{input}"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad"), | |
| ]) | |
| agent = create_tool_calling_agent(llm, tools, prompt) | |
| agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("1. Customer Details") | |
| uid = st.text_input("Customer ID", "1111") | |
| use_simulation = st.checkbox("Simulation Mode") | |
| sim_score = 650 | |
| sim_status = "good-standing" | |
| if use_simulation: | |
| sim_score = st.slider("Sim Credit Score", 300, 900, 450, step=10) | |
| sim_status = st.selectbox("Sim Status", ["good-standing", "closed", "delinquent"]) | |
| st.divider() | |
| btn = st.button("Assess Loan Risk", type="primary") | |
| with col2: | |
| if btn: | |
| # We simplified the query here because the strict instructions are now in the System Prompt | |
| if use_simulation: | |
| query = f""" | |
| Process Loan for Customer ID: {uid}. | |
| *** SIMULATION MODE *** | |
| 1. DO NOT query 'get_credit_score' or 'account_status' for Score/Status. | |
| 2. USE: Score: {sim_score}, Status: {sim_status} | |
| 3. Query 'get_account_status' ONLY for Name/Nationality/Email. | |
| 4. Follow the strict 5-step flow defined in the system instructions. | |
| """ | |
| else: | |
| query = f""" | |
| Process Loan for Customer ID: {uid}. | |
| 1. Query SQL tools for Name, Email, Nationality, Status, Score. | |
| 2. IF Nationality is 'Singaporean', SKIP 'check_pr_status'. | |
| 3. Follow the strict 5-step flow defined in the system instructions. | |
| """ | |
| with st.status("π€ Agent is processing...", expanded=True) as status: | |
| st_callback = StreamlitCallbackHandler(st.container()) | |
| try: | |
| start_time = time.time() | |
| res = agent_executor.invoke({"input": query}, {"callbacks": [st_callback]}) | |
| end_time = time.time() | |
| st.session_state.execution_time = end_time - start_time | |
| update_metrics(metrics_placeholder) | |
| status.update(label="β Complete!", state="complete", expanded=False) | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| st.stop() | |
| st.success("### π Final Recommendation") | |
| with st.container(border=True): | |
| st.markdown(res['output']) | |
| with st.expander("π Detailed Trace"): | |
| steps = res.get("intermediate_steps", []) | |
| for i, (action, observation) in enumerate(steps): | |
| st.markdown(f"**Step {i+1}:** Tool `{action.tool}` | Output: `{observation}`") | |
| if not use_simulation: | |
| st.divider() | |
| with st.expander("βοΈ Draft Email"): | |
| email_prompt = f"Write a formal email based on this decision: {res['output']}" | |
| with st.spinner("Drafting..."): | |
| email_draft = llm.invoke(email_prompt).content | |
| st.text_area("Email Draft", value=email_draft, height=200) | |
| elif not st.session_state.get('is_key_valid', False): | |
| st.info("π Please validate your Groq API Key.") |