Clocksp commited on
Commit
794d1bc
·
verified ·
1 Parent(s): 197d34d

Update src/utils/rag_chain.py

Browse files
Files changed (1) hide show
  1. src/utils/rag_chain.py +304 -305
src/utils/rag_chain.py CHANGED
@@ -1,305 +1,304 @@
1
- from typing import List, Dict, Any, Optional
2
- from langchain_google_genai import ChatGoogleGenerativeAI
3
- from langchain_groq import ChatGroq
4
- from langchain_classic.chains import RetrievalQA
5
- from langchain_classic.prompts import PromptTemplate
6
- from langchain_classic.schema import Document
7
- from langchain_classic.callbacks.base import BaseCallbackHandler
8
- from utils.vector_store import VectorStoreManager
9
- from config import Config
10
- class StreamHandler(BaseCallbackHandler):
11
- """Callback handler for streaming responses"""
12
-
13
- def __init__(self):
14
- self.text = ""
15
-
16
- def on_llm_new_token(self, token: str, **kwargs) -> None:
17
- """Handle new token from LLM"""
18
- self.text += token
19
- print(token, end="", flush=True)
20
-
21
-
22
- class InsuranceRAGChain:
23
- """RAG chain for insurance document Q&A"""
24
-
25
- def __init__(self, vector_store_manager: Optional[VectorStoreManager] = None):
26
- """
27
- Initialize RAG chain
28
-
29
- Args:
30
- vector_store_manager: Optional VectorStoreManager instance
31
- """
32
- # Initialize vector store manager
33
- self.vs_manager = vector_store_manager or VectorStoreManager()
34
-
35
- # Initialize Gemini model
36
- self.llm = ChatGoogleGenerativeAI(
37
- model=Config.GEMINI_MODEL,
38
- google_api_key=Config.GEMINI_API_KEY,
39
- temperature=Config.GEMINI_TEMPERATURE,
40
- max_output_tokens=Config.GEMINI_MAX_OUTPUT_TOKENS,
41
- )
42
-
43
- # Create prompt template
44
- self.prompt_template = PromptTemplate(
45
- template=Config.RAG_PROMPT_TEMPLATE,
46
- input_variables=["context", "question"]
47
- )
48
-
49
- print("RAG chain initialized")
50
-
51
- def create_qa_chain(self, chain_type: str = "stuff") -> RetrievalQA:
52
- """
53
- Create a RetrievalQA chain
54
-
55
- Args:
56
- chain_type: Type of chain ("stuff", "map_reduce", "refine")
57
- "stuff" - puts all docs in context (best for most cases)
58
-
59
- Returns:
60
- RetrievalQA chain
61
- """
62
- retriever = self.vs_manager.get_retriever()
63
-
64
- qa_chain = RetrievalQA.from_chain_type(
65
- llm=self.llm,
66
- chain_type=chain_type,
67
- retriever=retriever,
68
- return_source_documents=True,
69
- chain_type_kwargs={"prompt": self.prompt_template}
70
- )
71
-
72
- return qa_chain
73
-
74
- def query(self, question: str, return_sources: bool = True) -> Dict[str, Any]:
75
- """
76
- Query the RAG system
77
-
78
- Args:
79
- question: User's question
80
- return_sources: Whether to return source documents
81
-
82
- Returns:
83
- Dictionary with answer and optional source documents
84
- """
85
- try:
86
- # Create QA chain
87
- qa_chain = self.create_qa_chain()
88
-
89
- # Run query
90
- result = qa_chain.invoke({"query": question})
91
-
92
- response = {
93
- "answer": result["result"],
94
- "question": question
95
- }
96
-
97
- if return_sources and "source_documents" in result:
98
- response["sources"] = self._format_sources(result["source_documents"])
99
- response["source_documents"] = result["source_documents"]
100
-
101
- return response
102
-
103
- except Exception as e:
104
- print(f" Error during query: {str(e)}")
105
- raise
106
-
107
- def query_with_context(
108
- self,
109
- question: str,
110
- conversation_history: Optional[List[Dict[str, str]]] = None
111
- ) -> Dict[str, Any]:
112
- """
113
- Query with conversation context
114
-
115
- Args:
116
- question: User's question
117
- conversation_history: List of previous Q&A pairs
118
-
119
- Returns:
120
- Dictionary with answer and sources
121
- """
122
- # Build contextualized question if history exists
123
- if conversation_history and len(conversation_history) > 0:
124
- context = "\n".join([
125
- f"Previous Q: {item['question']}\nPrevious A: {item['answer']}"
126
- for item in conversation_history[-3:] # Last 3 turns
127
- ])
128
- contextualized_question = f"Conversation context:\n{context}\n\nCurrent question: {question}"
129
- else:
130
- contextualized_question = question
131
-
132
- return self.query(contextualized_question, return_sources=True)
133
-
134
- def query_specific_section(
135
- self,
136
- question: str,
137
- section_type: str
138
- ) -> Dict[str, Any]:
139
- """
140
- Query a specific section type (exclusions, addons, coverage, etc.)
141
-
142
- Args:
143
- question: User's question
144
- section_type: Section to search in
145
-
146
- Returns:
147
- Dictionary with answer and sources
148
- """
149
- try:
150
- # Get relevant documents from specific section
151
- docs = self.vs_manager.search_by_section_type(
152
- query=question,
153
- section_type=section_type,
154
- k=5
155
- )
156
-
157
- if not docs:
158
- return {
159
- "answer": f"No relevant information found in {section_type} section.",
160
- "question": question,
161
- "sources": []
162
- }
163
-
164
- # Build context from retrieved documents
165
- context = "\n\n".join([doc.page_content for doc in docs])
166
-
167
- # Format prompt
168
- prompt = self.prompt_template.format(
169
- context=context,
170
- question=question
171
- )
172
-
173
- # Get response from LLM
174
- response = self.llm.invoke(prompt)
175
-
176
- return {
177
- "answer": response.content,
178
- "question": question,
179
- "sources": self._format_sources(docs),
180
- "source_documents": docs
181
- }
182
-
183
- except Exception as e:
184
- print(f"Error querying specific section: {str(e)}")
185
- raise
186
-
187
- def compare_addons(self, addon_names: List[str]) -> Dict[str, Any]:
188
- """
189
- Compare multiple add-ons
190
-
191
- Args:
192
- addon_names: List of add-on names to compare
193
-
194
- Returns:
195
- Dictionary with comparison and sources
196
- """
197
- question = f"Compare the following add-ons and explain their key differences, coverage, and benefits: {', '.join(addon_names)}"
198
-
199
- return self.query_specific_section(question, section_type="addons")
200
-
201
- def find_coverage_gaps(self, current_coverage_description: str) -> Dict[str, Any]:
202
- """
203
- Identify potential coverage gaps
204
-
205
- Args:
206
- current_coverage_description: Description of current coverage
207
-
208
- Returns:
209
- Dictionary with gap analysis and recommendations
210
- """
211
- question = f"""Based on this current coverage: {current_coverage_description}
212
-
213
- Please identify:
214
- 1. What scenarios or risks are NOT covered
215
- 2. What add-ons or riders could fill these gaps
216
- 3. Which gaps are most important to address"""
217
-
218
- return self.query(question, return_sources=True)
219
-
220
- def explain_terms(self, terms: List[str]) -> Dict[str, Any]:
221
- """
222
- Explain insurance terms in plain language
223
-
224
- Args:
225
- terms: List of insurance terms to explain
226
-
227
- Returns:
228
- Dictionary with explanations
229
- """
230
- question = f"Explain these insurance terms in simple language: {', '.join(terms)}"
231
-
232
- return self.query(question, return_sources=True)
233
-
234
- def format_sources(self, documents: List[Document]) -> List[Dict[str, Any]]:
235
- """
236
- Format source documents for display
237
-
238
- Args:
239
- documents: List of source documents
240
-
241
- Returns:
242
- List of formatted source information
243
- """
244
- sources = []
245
- for i, doc in enumerate(documents, 1):
246
- source_info = {
247
- "index": i,
248
- "source_file": doc.metadata.get("source_file", "Unknown"),
249
- "page": doc.metadata.get("page", "Unknown"),
250
- "section_type": doc.metadata.get("section_type", "general"),
251
- "content_preview": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
252
- }
253
- sources.append(source_info)
254
-
255
- return sources
256
-
257
- def stream_query(self, question: str) -> tuple[str, List[Dict[str, Any]]]:
258
- """
259
- Query with streaming response
260
-
261
- Args:
262
- question: User's question
263
-
264
- Returns:
265
- Tuple of (answer, sources)
266
- """
267
- try:
268
- # Get relevant documents using invoke method
269
- retriever = self.vs_manager.get_retriever()
270
- docs = retriever.invoke(question)
271
-
272
- if not docs:
273
- return "No relevant information found in the documents.", []
274
-
275
- # Build context
276
- context = "\n\n".join([doc.page_content for doc in docs])
277
-
278
- # Format prompt
279
- prompt = self.prompt_template.format(
280
- context=context,
281
- question=question
282
- )
283
-
284
- # Stream response
285
- print("\n Assistant: ", end="")
286
- stream_handler = StreamHandler()
287
-
288
- streaming_llm = ChatGoogleGenerativeAI(
289
- model=Config.GEMINI_MODEL,
290
- google_api_key=Config.GEMINI_API_KEY,
291
- temperature=Config.GEMINI_TEMPERATURE,
292
- streaming=True,
293
- callbacks=[stream_handler]
294
- )
295
-
296
- streaming_llm.invoke(prompt)
297
- print("\n")
298
-
299
- return stream_handler.text, self._format_sources(docs)
300
-
301
- except Exception as e:
302
- print(f" Error during streaming query: {str(e)}")
303
- raise
304
-
305
-
 
1
+ from typing import List, Dict, Any, Optional
2
+ from langchain_google_genai import ChatGoogleGenerativeAI
3
+ from langchain_classic.chains import RetrievalQA
4
+ from langchain_classic.prompts import PromptTemplate
5
+ from langchain_classic.schema import Document
6
+ from langchain_classic.callbacks.base import BaseCallbackHandler
7
+ from utils.vector_store import VectorStoreManager
8
+ from config import Config
9
+ class StreamHandler(BaseCallbackHandler):
10
+ """Callback handler for streaming responses"""
11
+
12
+ def __init__(self):
13
+ self.text = ""
14
+
15
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
16
+ """Handle new token from LLM"""
17
+ self.text += token
18
+ print(token, end="", flush=True)
19
+
20
+
21
+ class InsuranceRAGChain:
22
+ """RAG chain for insurance document Q&A"""
23
+
24
+ def __init__(self, vector_store_manager: Optional[VectorStoreManager] = None):
25
+ """
26
+ Initialize RAG chain
27
+
28
+ Args:
29
+ vector_store_manager: Optional VectorStoreManager instance
30
+ """
31
+ # Initialize vector store manager
32
+ self.vs_manager = vector_store_manager or VectorStoreManager()
33
+
34
+ # Initialize Gemini model
35
+ self.llm = ChatGoogleGenerativeAI(
36
+ model=Config.GEMINI_MODEL,
37
+ google_api_key=Config.GEMINI_API_KEY,
38
+ temperature=Config.GEMINI_TEMPERATURE,
39
+ max_output_tokens=Config.GEMINI_MAX_OUTPUT_TOKENS,
40
+ )
41
+
42
+ # Create prompt template
43
+ self.prompt_template = PromptTemplate(
44
+ template=Config.RAG_PROMPT_TEMPLATE,
45
+ input_variables=["context", "question"]
46
+ )
47
+
48
+ print("RAG chain initialized")
49
+
50
+ def create_qa_chain(self, chain_type: str = "stuff") -> RetrievalQA:
51
+ """
52
+ Create a RetrievalQA chain
53
+
54
+ Args:
55
+ chain_type: Type of chain ("stuff", "map_reduce", "refine")
56
+ "stuff" - puts all docs in context (best for most cases)
57
+
58
+ Returns:
59
+ RetrievalQA chain
60
+ """
61
+ retriever = self.vs_manager.get_retriever()
62
+
63
+ qa_chain = RetrievalQA.from_chain_type(
64
+ llm=self.llm,
65
+ chain_type=chain_type,
66
+ retriever=retriever,
67
+ return_source_documents=True,
68
+ chain_type_kwargs={"prompt": self.prompt_template}
69
+ )
70
+
71
+ return qa_chain
72
+
73
+ def query(self, question: str, return_sources: bool = True) -> Dict[str, Any]:
74
+ """
75
+ Query the RAG system
76
+
77
+ Args:
78
+ question: User's question
79
+ return_sources: Whether to return source documents
80
+
81
+ Returns:
82
+ Dictionary with answer and optional source documents
83
+ """
84
+ try:
85
+ # Create QA chain
86
+ qa_chain = self.create_qa_chain()
87
+
88
+ # Run query
89
+ result = qa_chain.invoke({"query": question})
90
+
91
+ response = {
92
+ "answer": result["result"],
93
+ "question": question
94
+ }
95
+
96
+ if return_sources and "source_documents" in result:
97
+ response["sources"] = self._format_sources(result["source_documents"])
98
+ response["source_documents"] = result["source_documents"]
99
+
100
+ return response
101
+
102
+ except Exception as e:
103
+ print(f" Error during query: {str(e)}")
104
+ raise
105
+
106
+ def query_with_context(
107
+ self,
108
+ question: str,
109
+ conversation_history: Optional[List[Dict[str, str]]] = None
110
+ ) -> Dict[str, Any]:
111
+ """
112
+ Query with conversation context
113
+
114
+ Args:
115
+ question: User's question
116
+ conversation_history: List of previous Q&A pairs
117
+
118
+ Returns:
119
+ Dictionary with answer and sources
120
+ """
121
+ # Build contextualized question if history exists
122
+ if conversation_history and len(conversation_history) > 0:
123
+ context = "\n".join([
124
+ f"Previous Q: {item['question']}\nPrevious A: {item['answer']}"
125
+ for item in conversation_history[-3:] # Last 3 turns
126
+ ])
127
+ contextualized_question = f"Conversation context:\n{context}\n\nCurrent question: {question}"
128
+ else:
129
+ contextualized_question = question
130
+
131
+ return self.query(contextualized_question, return_sources=True)
132
+
133
+ def query_specific_section(
134
+ self,
135
+ question: str,
136
+ section_type: str
137
+ ) -> Dict[str, Any]:
138
+ """
139
+ Query a specific section type (exclusions, addons, coverage, etc.)
140
+
141
+ Args:
142
+ question: User's question
143
+ section_type: Section to search in
144
+
145
+ Returns:
146
+ Dictionary with answer and sources
147
+ """
148
+ try:
149
+ # Get relevant documents from specific section
150
+ docs = self.vs_manager.search_by_section_type(
151
+ query=question,
152
+ section_type=section_type,
153
+ k=5
154
+ )
155
+
156
+ if not docs:
157
+ return {
158
+ "answer": f"No relevant information found in {section_type} section.",
159
+ "question": question,
160
+ "sources": []
161
+ }
162
+
163
+ # Build context from retrieved documents
164
+ context = "\n\n".join([doc.page_content for doc in docs])
165
+
166
+ # Format prompt
167
+ prompt = self.prompt_template.format(
168
+ context=context,
169
+ question=question
170
+ )
171
+
172
+ # Get response from LLM
173
+ response = self.llm.invoke(prompt)
174
+
175
+ return {
176
+ "answer": response.content,
177
+ "question": question,
178
+ "sources": self._format_sources(docs),
179
+ "source_documents": docs
180
+ }
181
+
182
+ except Exception as e:
183
+ print(f"Error querying specific section: {str(e)}")
184
+ raise
185
+
186
+ def compare_addons(self, addon_names: List[str]) -> Dict[str, Any]:
187
+ """
188
+ Compare multiple add-ons
189
+
190
+ Args:
191
+ addon_names: List of add-on names to compare
192
+
193
+ Returns:
194
+ Dictionary with comparison and sources
195
+ """
196
+ question = f"Compare the following add-ons and explain their key differences, coverage, and benefits: {', '.join(addon_names)}"
197
+
198
+ return self.query_specific_section(question, section_type="addons")
199
+
200
+ def find_coverage_gaps(self, current_coverage_description: str) -> Dict[str, Any]:
201
+ """
202
+ Identify potential coverage gaps
203
+
204
+ Args:
205
+ current_coverage_description: Description of current coverage
206
+
207
+ Returns:
208
+ Dictionary with gap analysis and recommendations
209
+ """
210
+ question = f"""Based on this current coverage: {current_coverage_description}
211
+
212
+ Please identify:
213
+ 1. What scenarios or risks are NOT covered
214
+ 2. What add-ons or riders could fill these gaps
215
+ 3. Which gaps are most important to address"""
216
+
217
+ return self.query(question, return_sources=True)
218
+
219
+ def explain_terms(self, terms: List[str]) -> Dict[str, Any]:
220
+ """
221
+ Explain insurance terms in plain language
222
+
223
+ Args:
224
+ terms: List of insurance terms to explain
225
+
226
+ Returns:
227
+ Dictionary with explanations
228
+ """
229
+ question = f"Explain these insurance terms in simple language: {', '.join(terms)}"
230
+
231
+ return self.query(question, return_sources=True)
232
+
233
+ def format_sources(self, documents: List[Document]) -> List[Dict[str, Any]]:
234
+ """
235
+ Format source documents for display
236
+
237
+ Args:
238
+ documents: List of source documents
239
+
240
+ Returns:
241
+ List of formatted source information
242
+ """
243
+ sources = []
244
+ for i, doc in enumerate(documents, 1):
245
+ source_info = {
246
+ "index": i,
247
+ "source_file": doc.metadata.get("source_file", "Unknown"),
248
+ "page": doc.metadata.get("page", "Unknown"),
249
+ "section_type": doc.metadata.get("section_type", "general"),
250
+ "content_preview": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
251
+ }
252
+ sources.append(source_info)
253
+
254
+ return sources
255
+
256
+ def stream_query(self, question: str) -> tuple[str, List[Dict[str, Any]]]:
257
+ """
258
+ Query with streaming response
259
+
260
+ Args:
261
+ question: User's question
262
+
263
+ Returns:
264
+ Tuple of (answer, sources)
265
+ """
266
+ try:
267
+ # Get relevant documents using invoke method
268
+ retriever = self.vs_manager.get_retriever()
269
+ docs = retriever.invoke(question)
270
+
271
+ if not docs:
272
+ return "No relevant information found in the documents.", []
273
+
274
+ # Build context
275
+ context = "\n\n".join([doc.page_content for doc in docs])
276
+
277
+ # Format prompt
278
+ prompt = self.prompt_template.format(
279
+ context=context,
280
+ question=question
281
+ )
282
+
283
+ # Stream response
284
+ print("\n Assistant: ", end="")
285
+ stream_handler = StreamHandler()
286
+
287
+ streaming_llm = ChatGoogleGenerativeAI(
288
+ model=Config.GEMINI_MODEL,
289
+ google_api_key=Config.GEMINI_API_KEY,
290
+ temperature=Config.GEMINI_TEMPERATURE,
291
+ streaming=True,
292
+ callbacks=[stream_handler]
293
+ )
294
+
295
+ streaming_llm.invoke(prompt)
296
+ print("\n")
297
+
298
+ return stream_handler.text, self._format_sources(docs)
299
+
300
+ except Exception as e:
301
+ print(f" Error during streaming query: {str(e)}")
302
+ raise
303
+
304
+