Sush commited on
Commit ·
5c1da1f
1
Parent(s): c98c28c
Changes for Citation
Browse files- agents/rag_agent.py +42 -32
- app.py +11 -13
agents/rag_agent.py
CHANGED
|
@@ -80,58 +80,68 @@ def load_rag_agent(vectorstore_path: str = "vectorstore/"):
|
|
| 80 |
# Grounded prompt
|
| 81 |
prompt_template = """You are a helpful HDFC Bank policy assistant.
|
| 82 |
|
| 83 |
-
Use ONLY the context below to answer the customer's question.
|
| 84 |
|
| 85 |
-
IMPORTANT:
|
| 86 |
-
- Always include the sources at the end of your answer.
|
| 87 |
-
- Also include a short explanation titled "Why this answer?"
|
| 88 |
-
- Explain briefly how the answer was derived from context
|
| 89 |
-
- The sources are provided in the context.
|
| 90 |
-
- Format sources exactly like this:
|
| 91 |
|
| 92 |
-
Sources:
|
| 93 |
-
- file1.pdf
|
| 94 |
-
- file2.pdf
|
| 95 |
|
| 96 |
-
If the answer is not in the context, say:
|
| 97 |
-
"I don't have enough information in the policy documents to answer this. Please contact HDFC Bank directly."
|
| 98 |
|
| 99 |
-
Context:
|
| 100 |
-
{context}
|
| 101 |
|
| 102 |
-
Customer Question: {question}
|
| 103 |
|
| 104 |
-
Answer:"""
|
| 105 |
|
| 106 |
prompt = PromptTemplate(
|
| 107 |
template=prompt_template,
|
| 108 |
input_variables=["context", "question"]
|
| 109 |
)
|
| 110 |
|
| 111 |
-
def
|
| 112 |
-
formatted = []
|
| 113 |
sources = []
|
| 114 |
|
| 115 |
for doc in docs:
|
| 116 |
source = doc.metadata.get("source", "Unknown")
|
| 117 |
filename = os.path.basename(source)
|
| 118 |
-
|
| 119 |
sources.append(filename)
|
| 120 |
-
formatted.append(doc.page_content)
|
| 121 |
|
| 122 |
-
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
|
|
|
| 126 |
|
| 127 |
-
|
|
|
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
return
|
|
|
|
| 80 |
# Grounded prompt
|
| 81 |
prompt_template = """You are a helpful HDFC Bank policy assistant.
|
| 82 |
|
| 83 |
+
Use ONLY the context below to answer the customer's question.
|
| 84 |
|
| 85 |
+
IMPORTANT:
|
| 86 |
+
- Always include the sources at the end of your answer.
|
| 87 |
+
- Also include a short explanation titled "Why this answer?"
|
| 88 |
+
- Explain briefly how the answer was derived from context
|
| 89 |
+
- The sources are provided in the context.
|
| 90 |
+
- Format sources exactly like this:
|
| 91 |
|
| 92 |
+
Sources:
|
| 93 |
+
- file1.pdf
|
| 94 |
+
- file2.pdf
|
| 95 |
|
| 96 |
+
If the answer is not in the context, say:
|
| 97 |
+
"I don't have enough information in the policy documents to answer this. Please contact HDFC Bank directly."
|
| 98 |
|
| 99 |
+
Context:
|
| 100 |
+
{context}
|
| 101 |
|
| 102 |
+
Customer Question: {question}
|
| 103 |
|
| 104 |
+
Answer:"""
|
| 105 |
|
| 106 |
prompt = PromptTemplate(
|
| 107 |
template=prompt_template,
|
| 108 |
input_variables=["context", "question"]
|
| 109 |
)
|
| 110 |
|
| 111 |
+
def extract_sources(docs):
|
|
|
|
| 112 |
sources = []
|
| 113 |
|
| 114 |
for doc in docs:
|
| 115 |
source = doc.metadata.get("source", "Unknown")
|
| 116 |
filename = os.path.basename(source)
|
|
|
|
| 117 |
sources.append(filename)
|
|
|
|
| 118 |
|
| 119 |
+
return list(set(sources))
|
| 120 |
|
| 121 |
+
def run_rag(question):
|
| 122 |
+
# 1. Retrieve documents
|
| 123 |
+
docs = retriever.get_relevant_documents(question)
|
| 124 |
|
| 125 |
+
# 2. Create context for LLM
|
| 126 |
+
context = "\n\n".join(doc.page_content for doc in docs)
|
| 127 |
|
| 128 |
+
# 3. Extract sources separately
|
| 129 |
+
sources = []
|
| 130 |
+
for doc in docs:
|
| 131 |
+
source = doc.metadata.get("source", "Unknown")
|
| 132 |
+
filename = os.path.basename(source)
|
| 133 |
+
sources.append(filename)
|
| 134 |
+
|
| 135 |
+
sources = list(set(sources))
|
| 136 |
+
|
| 137 |
+
# 4. Call LLM
|
| 138 |
+
response = llm.invoke(
|
| 139 |
+
prompt.format(context=context, question=question)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return {
|
| 143 |
+
"answer": response.content,
|
| 144 |
+
"sources": sources
|
| 145 |
+
}
|
| 146 |
|
| 147 |
+
return run_rag
|
app.py
CHANGED
|
@@ -77,27 +77,25 @@ if query := st.chat_input("Ask your banking question here..."):
|
|
| 77 |
"agent_used": "",
|
| 78 |
"response": ""
|
| 79 |
})
|
| 80 |
-
|
| 81 |
-
response =
|
|
|
|
| 82 |
agent_used = result["agent_used"].upper()
|
| 83 |
|
| 84 |
-
# Show
|
| 85 |
if agent_used == "RAG":
|
| 86 |
-
st.caption("
|
| 87 |
else:
|
| 88 |
-
st.caption("
|
| 89 |
-
|
| 90 |
-
if "Sources:" in response:
|
| 91 |
-
answer, sources = response.split("Sources:")
|
| 92 |
|
| 93 |
# Show answer
|
| 94 |
-
|
| 95 |
|
| 96 |
-
# Show sources
|
|
|
|
| 97 |
st.markdown("### Sources")
|
| 98 |
-
for s in sources
|
| 99 |
-
|
| 100 |
-
st.markdown(f"- {s.replace('-', '').strip()}")
|
| 101 |
else:
|
| 102 |
st.markdown(response)
|
| 103 |
|
|
|
|
| 77 |
"agent_used": "",
|
| 78 |
"response": ""
|
| 79 |
})
|
| 80 |
+
|
| 81 |
+
response = result["answer"]
|
| 82 |
+
sources = result["sources"]
|
| 83 |
agent_used = result["agent_used"].upper()
|
| 84 |
|
| 85 |
+
# Show agent
|
| 86 |
if agent_used == "RAG":
|
| 87 |
+
st.caption("Answered by: Policy Agent (RAG)")
|
| 88 |
else:
|
| 89 |
+
st.caption("Answered by: Data Agent (SQL)")
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# Show answer
|
| 92 |
+
st.markdown(response)
|
| 93 |
|
| 94 |
+
# Show sources
|
| 95 |
+
if sources:
|
| 96 |
st.markdown("### Sources")
|
| 97 |
+
for s in sources:
|
| 98 |
+
st.markdown(f"- {s}")
|
|
|
|
| 99 |
else:
|
| 100 |
st.markdown(response)
|
| 101 |
|