Sush commited on
Commit
5c1da1f
·
1 Parent(s): c98c28c

Changes for Citation

Browse files
Files changed (2) hide show
  1. agents/rag_agent.py +42 -32
  2. app.py +11 -13
agents/rag_agent.py CHANGED
@@ -80,58 +80,68 @@ def load_rag_agent(vectorstore_path: str = "vectorstore/"):
80
  # Grounded prompt
81
  prompt_template = """You are a helpful HDFC Bank policy assistant.
82
 
83
- Use ONLY the context below to answer the customer's question.
84
 
85
- IMPORTANT:
86
- - Always include the sources at the end of your answer.
87
- - Also include a short explanation titled "Why this answer?"
88
- - Explain briefly how the answer was derived from context
89
- - The sources are provided in the context.
90
- - Format sources exactly like this:
91
 
92
- Sources:
93
- - file1.pdf
94
- - file2.pdf
95
 
96
- If the answer is not in the context, say:
97
- "I don't have enough information in the policy documents to answer this. Please contact HDFC Bank directly."
98
 
99
- Context:
100
- {context}
101
 
102
- Customer Question: {question}
103
 
104
- Answer:"""
105
 
106
  prompt = PromptTemplate(
107
  template=prompt_template,
108
  input_variables=["context", "question"]
109
  )
110
 
111
- def format_docs(docs):
112
- formatted = []
113
  sources = []
114
 
115
  for doc in docs:
116
  source = doc.metadata.get("source", "Unknown")
117
  filename = os.path.basename(source)
118
-
119
  sources.append(filename)
120
- formatted.append(doc.page_content)
121
 
122
- context = "\n\n".join(formatted)
123
 
124
- unique_sources = list(set(sources))
125
- source_text = "\n\nSources:\n" + "\n".join(f"- {s}" for s in unique_sources)
 
126
 
127
- return context + source_text
 
128
 
129
- # LCEL chain
130
- rag_chain = (
131
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
132
- | prompt
133
- | llm
134
- | StrOutputParser()
135
- )
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- return rag_chain
 
80
  # Grounded prompt
81
  prompt_template = """You are a helpful HDFC Bank policy assistant.
82
 
83
+ Use ONLY the context below to answer the customer's question.
84
 
85
+ IMPORTANT:
86
+ - Always include the sources at the end of your answer.
87
+ - Also include a short explanation titled "Why this answer?"
88
+ - Explain briefly how the answer was derived from context
89
+ - The sources are provided in the context.
90
+ - Format sources exactly like this:
91
 
92
+ Sources:
93
+ - file1.pdf
94
+ - file2.pdf
95
 
96
+ If the answer is not in the context, say:
97
+ "I don't have enough information in the policy documents to answer this. Please contact HDFC Bank directly."
98
 
99
+ Context:
100
+ {context}
101
 
102
+ Customer Question: {question}
103
 
104
+ Answer:"""
105
 
106
  prompt = PromptTemplate(
107
  template=prompt_template,
108
  input_variables=["context", "question"]
109
  )
110
 
111
+ def extract_sources(docs):
 
112
  sources = []
113
 
114
  for doc in docs:
115
  source = doc.metadata.get("source", "Unknown")
116
  filename = os.path.basename(source)
 
117
  sources.append(filename)
 
118
 
119
+ return list(set(sources))
120
 
121
+ def run_rag(question):
122
+ # 1. Retrieve documents
123
+ docs = retriever.get_relevant_documents(question)
124
 
125
+ # 2. Create context for LLM
126
+ context = "\n\n".join(doc.page_content for doc in docs)
127
 
128
+ # 3. Extract sources separately
129
+ sources = []
130
+ for doc in docs:
131
+ source = doc.metadata.get("source", "Unknown")
132
+ filename = os.path.basename(source)
133
+ sources.append(filename)
134
+
135
+ sources = list(set(sources))
136
+
137
+ # 4. Call LLM
138
+ response = llm.invoke(
139
+ prompt.format(context=context, question=question)
140
+ )
141
+
142
+ return {
143
+ "answer": response.content,
144
+ "sources": sources
145
+ }
146
 
147
+ return run_rag
app.py CHANGED
@@ -77,27 +77,25 @@ if query := st.chat_input("Ask your banking question here..."):
77
  "agent_used": "",
78
  "response": ""
79
  })
80
-
81
- response = format_currency(result["response"])
 
82
  agent_used = result["agent_used"].upper()
83
 
84
- # Show which agent handled it
85
  if agent_used == "RAG":
86
- st.caption(" Answered by: Policy Agent (RAG)")
87
  else:
88
- st.caption(" Answered by: Data Agent (SQL)")
89
-
90
- if "Sources:" in response:
91
- answer, sources = response.split("Sources:")
92
 
93
  # Show answer
94
- st.markdown(answer)
95
 
96
- # Show sources nicely
 
97
  st.markdown("### Sources")
98
- for s in sources.strip().split("\n"):
99
- if s.strip():
100
- st.markdown(f"- {s.replace('-', '').strip()}")
101
  else:
102
  st.markdown(response)
103
 
 
77
  "agent_used": "",
78
  "response": ""
79
  })
80
+
81
+ response = result["answer"]
82
+ sources = result["sources"]
83
  agent_used = result["agent_used"].upper()
84
 
85
+ # Show agent
86
  if agent_used == "RAG":
87
+ st.caption("Answered by: Policy Agent (RAG)")
88
  else:
89
+ st.caption("Answered by: Data Agent (SQL)")
 
 
 
90
 
91
  # Show answer
92
+ st.markdown(response)
93
 
94
+ # Show sources
95
+ if sources:
96
  st.markdown("### Sources")
97
+ for s in sources:
98
+ st.markdown(f"- {s}")
 
99
  else:
100
  st.markdown(response)
101