Sush commited on
Commit
c62ce64
·
0 Parent(s):

Initial commit: Banking Intelligence Assistant

Browse files
.env.example ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ GROQ_API_KEY=your_groq_api_key_here
2
+ HUGGINGFACE_TOKEN=your_huggingface_token_here
.github/workflows/ci-cd.yml ADDED
File without changes
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment
2
+ .env
3
+ __pycache__/
4
+ *.pyc
5
+ *.pyo
6
+
7
+ # Vectorstore (too large for git)
8
+ vectorstore/
9
+
10
+ # Jupyter checkpoints
11
+ .ipynb_checkpoints/
12
+
13
+ # Mac system files
14
+ .DS_Store
15
+
16
+ # Virtual environment
17
+ venv/
18
+ .venv/
Dockerfile ADDED
File without changes
README.md ADDED
File without changes
agents/__init__.py ADDED
File without changes
agents/orchestrator.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import StateGraph, END
2
+ from typing import TypedDict
3
+
4
+ class AgentState(TypedDict):
5
+ query: str
6
+ agent_used: str
7
+ response: str
8
+
9
+ def build_orchestrator(rag_chain, sql_agent):
10
+ """Build the LangGraph orchestrator that routes between RAG and SQL agents."""
11
+
12
+ # Router node
13
+ def router(state: AgentState) -> AgentState:
14
+ query = state["query"].lower()
15
+
16
+ sql_keywords = [
17
+ "transaction", "balance", "how many", "outstanding",
18
+ "credit card", "merchant", "customer", "branch",
19
+ "average", "total", "count", "highest", "lowest",
20
+ "failed", "blocked", "overdue", "statement"
21
+ ]
22
+ rag_keywords = [
23
+ "policy", "rule", "guideline", "what is", "how does",
24
+ "eligibility", "penalty", "interest rate", "fee",
25
+ "grievance", "kyc", "document", "complaint", "process",
26
+ "minimum balance", "loan", "terms", "conditions"
27
+ ]
28
+
29
+ sql_score = sum(1 for kw in sql_keywords if kw in query)
30
+ rag_score = sum(1 for kw in rag_keywords if kw in query)
31
+ state["agent_used"] = "sql" if sql_score > rag_score else "rag"
32
+ return state
33
+
34
+ # RAG node
35
+ def run_rag_agent(state: AgentState) -> AgentState:
36
+ state["response"] = rag_chain.invoke(state["query"])
37
+ return state
38
+
39
+ # SQL node
40
+ def run_sql_agent(state: AgentState) -> AgentState:
41
+ result = sql_agent.invoke({"input": state["query"]})
42
+ state["response"] = result["output"]
43
+ return state
44
+
45
+ # Routing function
46
+ def route_to_agent(state: AgentState) -> str:
47
+ return state["agent_used"]
48
+
49
+ # Build graph
50
+ workflow = StateGraph(AgentState)
51
+ workflow.add_node("router", router)
52
+ workflow.add_node("rag", run_rag_agent)
53
+ workflow.add_node("sql", run_sql_agent)
54
+ workflow.set_entry_point("router")
55
+ workflow.add_conditional_edges(
56
+ "router", route_to_agent, {"rag": "rag", "sql": "sql"}
57
+ )
58
+ workflow.add_edge("rag", END)
59
+ workflow.add_edge("sql", END)
60
+
61
+ return workflow.compile()
agents/rag_agent.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_groq import ChatGroq
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_core.prompts import PromptTemplate
6
+ from langchain_core.runnables import RunnablePassthrough
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
+
13
+ def load_rag_agent(vectorstore_path: str = "vectorstore/"):
14
+ """Load the RAG agent from saved FAISS vectorstore."""
15
+
16
+ # Load embeddings
17
+ embeddings = HuggingFaceEmbeddings(
18
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
19
+ )
20
+
21
+ # Load FAISS index
22
+ vectorstore = FAISS.load_local(
23
+ vectorstore_path,
24
+ embeddings,
25
+ allow_dangerous_deserialization=True
26
+ )
27
+
28
+ # MMR retriever
29
+ retriever = vectorstore.as_retriever(
30
+ search_type="mmr",
31
+ search_kwargs={"k": 4, "fetch_k": 20, "lambda_mult": 0.7}
32
+ )
33
+
34
+ # LLM
35
+ llm = ChatGroq(
36
+ api_key=os.getenv("GROQ_API_KEY"),
37
+ model_name="llama-3.1-8b-instant",
38
+ temperature=0
39
+ )
40
+
41
+ # Grounded prompt
42
+ prompt_template = """You are a helpful HDFC Bank policy assistant.
43
+ Use ONLY the context below to answer the customer's question.
44
+ If the answer is not in the context, say "I don't have enough information
45
+ in the policy documents to answer this. Please contact HDFC Bank directly."
46
+
47
+ Context:
48
+ {context}
49
+
50
+ Customer Question: {question}
51
+
52
+ Answer:"""
53
+
54
+ prompt = PromptTemplate(
55
+ template=prompt_template,
56
+ input_variables=["context", "question"]
57
+ )
58
+
59
+ def format_docs(docs):
60
+ return "\n\n".join(doc.page_content for doc in docs)
61
+
62
+ # LCEL chain
63
+ rag_chain = (
64
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
65
+ | prompt
66
+ | llm
67
+ | StrOutputParser()
68
+ )
69
+
70
+ return rag_chain
agents/sql_agent.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_groq import ChatGroq
3
+ from langchain_community.utilities import SQLDatabase
4
+ from langchain_core.prompts import PromptTemplate
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.runnables import RunnableLambda
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+ def load_sql_agent(db_path: str = "data/database/banking.db"):
12
+ """Custom SQL chain using LCEL — more reliable than agent for Llama models."""
13
+
14
+ db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
15
+
16
+ llm = ChatGroq(
17
+ api_key=os.getenv("GROQ_API_KEY"),
18
+ model_name="llama-3.3-70b-versatile",
19
+ temperature=0
20
+ )
21
+
22
+ # Step 1 — Generate SQL from question
23
+ sql_generation_prompt = PromptTemplate(
24
+ template="""You are a SQL expert. Given the database schema and question, write a SQLite SQL query.
25
+ Return ONLY the SQL query, nothing else. No explanation, no markdown, no backticks.
26
+
27
+ Database Schema:
28
+ {schema}
29
+
30
+ Question: {question}
31
+
32
+ SQL Query:""",
33
+ input_variables=["schema", "question"]
34
+ )
35
+
36
+ # Step 2 — Generate final answer from SQL result
37
+ answer_prompt = PromptTemplate(
38
+ template="""Given the question, SQL query, and result, write a clear answer.
39
+
40
+ Question: {question}
41
+ SQL Query: {query}
42
+ SQL Result: {result}
43
+
44
+ Answer:""",
45
+ input_variables=["question", "query", "result"]
46
+ )
47
+
48
+ def run_sql_chain(question: str) -> str:
49
+ try:
50
+ # Get schema
51
+ schema = db.get_table_info()
52
+
53
+ # Generate SQL
54
+ sql_chain = sql_generation_prompt | llm | StrOutputParser()
55
+ sql_query = sql_chain.invoke({
56
+ "schema": schema,
57
+ "question": question
58
+ }).strip()
59
+
60
+ # Clean up query if needed
61
+ sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
62
+
63
+ # Execute SQL
64
+ result = db.run(sql_query)
65
+
66
+ # Generate answer
67
+ answer_chain = answer_prompt | llm | StrOutputParser()
68
+ answer = answer_chain.invoke({
69
+ "question": question,
70
+ "query": sql_query,
71
+ "result": result
72
+ })
73
+
74
+ return answer
75
+
76
+ except Exception as e:
77
+ return f"I encountered an error processing your query: {str(e)}"
78
+
79
+ # Wrap as a dict-compatible interface to match orchestrator
80
+ class SQLChainWrapper:
81
+ def invoke(self, input_dict):
82
+ question = input_dict.get("input", "")
83
+ output = run_sql_chain(question)
84
+ return {"output": output}
85
+
86
+ return SQLChainWrapper()
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
4
+
5
+ import streamlit as st
6
+ from agents.rag_agent import load_rag_agent
7
+ from agents.sql_agent import load_sql_agent
8
+ from agents.orchestrator import build_orchestrator
9
+
10
+ # ── PAGE CONFIG ──
11
+ st.set_page_config(
12
+ page_title="HDFC Banking Intelligence Assistant",
13
+ page_icon="🏦",
14
+ layout="centered"
15
+ )
16
+
17
+ # ── HEADER ──
18
+ st.title(" HDFC Banking Intelligence Assistant")
19
+ st.markdown("""
20
+ Ask me anything about **HDFC Bank policies** or your **account & transaction data**.
21
+ I'll automatically route your question to the right agent.
22
+ """)
23
+ st.divider()
24
+
25
+ # ── LOAD AGENTS (cached so they load only once) ──
26
+ @st.cache_resource
27
+ def load_agents():
28
+ with st.spinner("Loading agents... please wait "):
29
+ rag_chain = load_rag_agent()
30
+ sql_agent = load_sql_agent()
31
+ orchestrator = build_orchestrator(rag_chain, sql_agent)
32
+ return orchestrator
33
+
34
+ orchestrator = load_agents()
35
+
36
+ # ── CHAT HISTORY ──
37
+ if "messages" not in st.session_state:
38
+ st.session_state.messages = []
39
+
40
+ # Display chat history
41
+ for msg in st.session_state.messages:
42
+ with st.chat_message(msg["role"]):
43
+ st.markdown(msg["content"])
44
+
45
+ # ── SAMPLE QUESTIONS ──
46
+ if len(st.session_state.messages) == 0:
47
+ st.markdown("#### Try asking:")
48
+ col1, col2 = st.columns(2)
49
+ with col1:
50
+ st.info(" What is the minimum balance for a savings account?")
51
+ st.info(" How can I raise a grievance against HDFC Bank?")
52
+ st.info(" What are the KYC documents required?")
53
+ with col2:
54
+ st.info(" Which customers have overdue credit cards?")
55
+ st.info(" Which merchant has the highest transactions?")
56
+ st.info(" What is the average balance by account type?")
57
+
58
+ # ── CHAT INPUT ──
59
+ if query := st.chat_input("Ask your banking question here..."):
60
+
61
+ # Add user message
62
+ st.session_state.messages.append({"role": "user", "content": query})
63
+ with st.chat_message("user"):
64
+ st.markdown(query)
65
+
66
+ # Get response
67
+ with st.chat_message("assistant"):
68
+ with st.spinner("Thinking..."):
69
+ result = orchestrator.invoke({
70
+ "query": query,
71
+ "agent_used": "",
72
+ "response": ""
73
+ })
74
+
75
+ response = result["response"]
76
+ agent_used = result["agent_used"].upper()
77
+
78
+ # Show which agent handled it
79
+ if agent_used == "RAG":
80
+ st.caption(" Answered by: Policy Agent (RAG)")
81
+ else:
82
+ st.caption(" Answered by: Data Agent (SQL)")
83
+
84
+ st.markdown(response)
85
+
86
+ # Save assistant message
87
+ st.session_state.messages.append({
88
+ "role": "assistant",
89
+ "content": f"*[{agent_used} Agent]*\n\n{response}"
90
+ })
data/database/banking.db ADDED
Binary file (86 kB). View file
 
ingest.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from langchain_community.document_loaders import PyPDFLoader
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+
8
+ # ── CONFIGURATION ──
9
+ DOCS_PATH = "data/documents/"
10
+ VECTORSTORE_PATH = "vectorstore/"
11
+
12
+ PDF_FILES = [
13
+ "hdfc_credit_card_policy.pdf",
14
+ "hdfc_customer_compensation_policy.pdf",
15
+ "hdfc_grievance_policy.pdf",
16
+ "hdfc_personal_loan_agreement.pdf",
17
+ "hdfc_savings_account_charges.pdf",
18
+ "hdfc_general_terms_conditions.pdf"
19
+ ]
20
+
21
+ def clean_text(text: str) -> str:
22
+ text = re.sub(r'Classification\s*[-–]\s*Internal', '', text)
23
+ text = re.sub(r'\n{3,}', '\n\n', text)
24
+ text = re.sub(r'\s{3,}', ' ', text)
25
+ text = re.sub(r'as on \d{2}\.\d{2}\.\d{4}', '', text)
26
+ return text.strip()
27
+
28
+ def ingest():
29
+ print(" Loading PDFs...")
30
+ all_documents = []
31
+ for pdf in PDF_FILES:
32
+ path = os.path.join(DOCS_PATH, pdf)
33
+ if not os.path.exists(path):
34
+ print(f" Skipping missing file: {pdf}")
35
+ continue
36
+ loader = PyPDFLoader(path)
37
+ docs = loader.load()
38
+ all_documents.extend(docs)
39
+ print(f" {pdf} — {len(docs)} pages")
40
+
41
+ print(f"\n Total pages: {len(all_documents)}")
42
+
43
+ # Split
44
+ print("\n Splitting into chunks...")
45
+ splitter = RecursiveCharacterTextSplitter(
46
+ chunk_size=500,
47
+ chunk_overlap=50,
48
+ separators=["\n\n", "\n", ".", " "]
49
+ )
50
+ chunks = splitter.split_documents(all_documents)
51
+
52
+ # Clean
53
+ for chunk in chunks:
54
+ chunk.page_content = clean_text(chunk.page_content)
55
+ chunks = [c for c in chunks if len(c.page_content) > 50]
56
+ print(f" Chunks after cleaning: {len(chunks)}")
57
+
58
+ # Embed + Save FAISS
59
+ print("\n Building FAISS index...")
60
+ embeddings = HuggingFaceEmbeddings(
61
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
62
+ )
63
+ vectorstore = FAISS.from_documents(chunks, embeddings)
64
+
65
+ os.makedirs(VECTORSTORE_PATH, exist_ok=True)
66
+ vectorstore.save_local(VECTORSTORE_PATH)
67
+ print(f" FAISS index saved to '{VECTORSTORE_PATH}'")
68
+ print(f" Total vectors: {vectorstore.index.ntotal}")
69
+
70
+ if __name__ == "__main__":
71
+ ingest()
notebooks/prototype.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ langchain-groq
4
+ langgraph
5
+ faiss-cpu
6
+ sentence-transformers
7
+ streamlit
8
+ python-dotenv
9
+ sqlalchemy
10
+ pandas
11
+ groq