Insurance-AI / src /utils /rag_chain.py
Clocksp's picture
Update src/utils/rag_chain.py
e4ef6ab verified
from typing import List, Dict, Any, Optional
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_classic.chains import RetrievalQA
from langchain_classic.prompts import PromptTemplate
from langchain_classic.schema import Document
from langchain_classic.callbacks.base import BaseCallbackHandler
from utils.vector_store import VectorStoreManager
from config import Config
class StreamHandler(BaseCallbackHandler):
"""Callback handler for streaming responses"""
def __init__(self):
self.text = ""
def on_llm_new_token(self, token: str, **kwargs) -> None:
"""Handle new token from LLM"""
self.text += token
print(token, end="", flush=True)
class InsuranceRAGChain:
"""RAG chain for insurance document Q&A"""
def __init__(self, vector_store_manager: Optional[VectorStoreManager] = None):
"""
Initialize RAG chain
Args:
vector_store_manager: Optional VectorStoreManager instance
"""
# Initialize vector store manager
self.vs_manager = vector_store_manager or VectorStoreManager()
# Initialize Gemini model
self.llm = ChatGoogleGenerativeAI(
model=Config.GEMINI_MODEL,
google_api_key=Config.GEMINI_API_KEY,
temperature=Config.GEMINI_TEMPERATURE,
max_output_tokens=Config.GEMINI_MAX_OUTPUT_TOKENS,
)
# Create prompt template
self.prompt_template = PromptTemplate(
template=Config.RAG_PROMPT_TEMPLATE,
input_variables=["context", "question"]
)
print("RAG chain initialized")
def create_qa_chain(self, chain_type: str = "stuff") -> RetrievalQA:
"""
Create a RetrievalQA chain
Args:
chain_type: Type of chain ("stuff", "map_reduce", "refine")
"stuff" - puts all docs in context (best for most cases)
Returns:
RetrievalQA chain
"""
retriever = self.vs_manager.get_retriever()
qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type=chain_type,
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": self.prompt_template}
)
return qa_chain
def query(self, question: str, return_sources: bool = True) -> Dict[str, Any]:
"""
Query the RAG system
Args:
question: User's question
return_sources: Whether to return source documents
Returns:
Dictionary with answer and optional source documents
"""
try:
# Create QA chain
qa_chain = self.create_qa_chain()
# Run query
result = qa_chain.invoke({"query": question})
response = {
"answer": result["result"],
"question": question
}
if return_sources and "source_documents" in result:
response["sources"] = self._format_sources(result["source_documents"])
response["source_documents"] = result["source_documents"]
return response
except Exception as e:
print(f" Error during query: {str(e)}")
raise
def query_with_context(
self,
question: str,
conversation_history: Optional[List[Dict[str, str]]] = None
) -> Dict[str, Any]:
"""
Query with conversation context
Args:
question: User's question
conversation_history: List of previous Q&A pairs
Returns:
Dictionary with answer and sources
"""
# Build contextualized question if history exists
if conversation_history and len(conversation_history) > 0:
context = "\n".join([
f"Previous Q: {item['question']}\nPrevious A: {item['answer']}"
for item in conversation_history[-3:] # Last 3 turns
])
contextualized_question = f"Conversation context:\n{context}\n\nCurrent question: {question}"
else:
contextualized_question = question
return self.query(contextualized_question, return_sources=True)
def query_specific_section(
self,
question: str,
section_type: str
) -> Dict[str, Any]:
"""
Query a specific section type (exclusions, addons, coverage, etc.)
Args:
question: User's question
section_type: Section to search in
Returns:
Dictionary with answer and sources
"""
try:
# Get relevant documents from specific section
docs = self.vs_manager.search_by_section_type(
query=question,
section_type=section_type,
k=5
)
if not docs:
return {
"answer": f"No relevant information found in {section_type} section.",
"question": question,
"sources": []
}
# Build context from retrieved documents
context = "\n\n".join([doc.page_content for doc in docs])
# Format prompt
prompt = self.prompt_template.format(
context=context,
question=question
)
# Get response from LLM
response = self.llm.invoke(prompt)
return {
"answer": response.content,
"question": question,
"sources": self._format_sources(docs),
"source_documents": docs
}
except Exception as e:
print(f"Error querying specific section: {str(e)}")
raise
def compare_addons(self, addon_names: List[str]) -> Dict[str, Any]:
"""
Compare multiple add-ons
Args:
addon_names: List of add-on names to compare
Returns:
Dictionary with comparison and sources
"""
question = f"Compare the following add-ons and explain their key differences, coverage, and benefits: {', '.join(addon_names)}"
return self.query_specific_section(question, section_type="addons")
def find_coverage_gaps(self, current_coverage_description: str) -> Dict[str, Any]:
"""
Identify potential coverage gaps
Args:
current_coverage_description: Description of current coverage
Returns:
Dictionary with gap analysis and recommendations
"""
question = f"""Based on this current coverage: {current_coverage_description}
Please identify:
1. What scenarios or risks are NOT covered
2. What add-ons or riders could fill these gaps
3. Which gaps are most important to address"""
return self.query(question, return_sources=True)
def explain_terms(self, terms: List[str]) -> Dict[str, Any]:
"""
Explain insurance terms in plain language
Args:
terms: List of insurance terms to explain
Returns:
Dictionary with explanations
"""
question = f"Explain these insurance terms in simple language: {', '.join(terms)}"
return self.query(question, return_sources=True)
def _format_sources(self, documents: List[Document]) -> List[Dict[str, Any]]:
"""
Format source documents for display
Args:
documents: List of source documents
Returns:
List of formatted source information
"""
sources = []
for i, doc in enumerate(documents, 1):
source_info = {
"index": i,
"source_file": doc.metadata.get("source_file", "Unknown"),
"page": doc.metadata.get("page", "Unknown"),
"section_type": doc.metadata.get("section_type", "general"),
"content_preview": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
}
sources.append(source_info)
return sources
def stream_query(self, question: str) -> tuple[str, List[Dict[str, Any]]]:
"""
Query with streaming response
Args:
question: User's question
Returns:
Tuple of (answer, sources)
"""
try:
# Get relevant documents using invoke method
retriever = self.vs_manager.get_retriever()
docs = retriever.invoke(question)
if not docs:
return "No relevant information found in the documents.", []
# Build context
context = "\n\n".join([doc.page_content for doc in docs])
# Format prompt
prompt = self.prompt_template.format(
context=context,
question=question
)
# Stream response
print("\n Assistant: ", end="")
stream_handler = StreamHandler()
streaming_llm = ChatGoogleGenerativeAI(
model=Config.GEMINI_MODEL,
google_api_key=Config.GEMINI_API_KEY,
temperature=Config.GEMINI_TEMPERATURE,
streaming=True,
callbacks=[stream_handler]
)
streaming_llm.invoke(prompt)
print("\n")
return stream_handler.text, self._format_sources(docs)
except Exception as e:
print(f" Error during streaming query: {str(e)}")
raise