RAGTechniquesComparisonTool / AdvancedRag.py
DeathBlade020's picture
Upload 8 files
6c044be verified
# advanced_retrieval.py
import os
from typing import List, Dict, Any, Tuple
from dotenv import load_dotenv
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain.schema import Document
from langchain.load import dumps, loads
from bs4.filter import SoupStrainer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from operator import itemgetter
import asyncio
from sentence_transformers import CrossEncoder
load_dotenv()
class AdvancedRetriever:
def __init__(self, link: str):
self.link = link
self.llm = ChatOpenAI(temperature=0)
self.embeddings = OpenAIEmbeddings()
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# Load and process documents
self._load_documents()
self._create_vector_stores()
def _load_documents(self):
"""Load and chunk documents with different strategies"""
loader = WebBaseLoader(
web_path=(self.link,),
bs_kwargs=dict(
parse_only=SoupStrainer(
class_=("post-content", "post-title", "post-header")
)
)
)
docs = loader.load()
# Small chunks for precise retrieval
small_splitter = RecursiveCharacterTextSplitter(
chunk_size=200,
chunk_overlap=50,
)
self.small_chunks = small_splitter.split_documents(docs)
# Large chunks for context
large_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=100,
)
self.large_chunks = large_splitter.split_documents(docs)
# Medium chunks (original)
medium_splitter = RecursiveCharacterTextSplitter(
chunk_size=300,
chunk_overlap=50,
)
self.medium_chunks = medium_splitter.split_documents(docs)
def _create_vector_stores(self):
"""Create vector stores for different chunk sizes"""
self.small_vectorstore = Chroma.from_documents(
documents=self.small_chunks,
embedding=self.embeddings,
collection_name="small_chunks"
)
self.large_vectorstore = Chroma.from_documents(
documents=self.large_chunks,
embedding=self.embeddings,
collection_name="large_chunks"
)
self.medium_vectorstore = Chroma.from_documents(
documents=self.medium_chunks,
embedding=self.embeddings,
collection_name="medium_chunks"
)
class MultiQueryRetrieval(AdvancedRetriever):
"""Generate multiple diverse queries and merge results"""
def retrieve(self, question: str, k: int = 5) -> List[Document]:
# Generate multiple query perspectives
query_generation_prompt = ChatPromptTemplate.from_template("""
You are an AI assistant that generates multiple search queries from different perspectives.
Generate 4 diverse search queries that would help answer this question: {question}
Focus on different aspects and use varied vocabulary.
Each query should be on a separate line.
""")
generate_queries = (
query_generation_prompt
| self.llm
| StrOutputParser()
| (lambda x: x.strip().split('\n'))
)
queries = generate_queries.invoke({"question": question})
queries.append(question) # Include original query
# Retrieve documents for each query
all_docs = []
for query in queries:
docs = self.medium_vectorstore.similarity_search(query, k=k)
all_docs.extend(docs)
# Remove duplicates and return top k
return self._deduplicate_documents(all_docs)[:k]
def _deduplicate_documents(self, docs: List[Document]) -> List[Document]:
"""Remove duplicate documents based on content similarity"""
if not docs:
return docs
unique_docs = [docs[0]]
for doc in docs[1:]:
is_duplicate = False
for unique_doc in unique_docs:
if doc.page_content == unique_doc.page_content:
is_duplicate = True
break
if not is_duplicate:
unique_docs.append(doc)
return unique_docs
class ParentChildRetrieval(AdvancedRetriever):
"""Retrieve small chunks but return larger parent context"""
def retrieve(self, question: str, k: int = 5) -> List[Document]:
# Search with small chunks for precision
small_docs = self.small_vectorstore.similarity_search(question, k=k*2)
# Find corresponding large chunks (parents)
parent_docs = []
for small_doc in small_docs:
# Find the large chunk that contains this small chunk
parent = self._find_parent_chunk(small_doc)
if parent and parent not in parent_docs:
parent_docs.append(parent)
return parent_docs[:k]
def _find_parent_chunk(self, small_doc: Document) -> Document:
"""Find the parent chunk that contains the small chunk"""
small_content = small_doc.page_content
for large_doc in self.large_chunks:
if small_content in large_doc.page_content:
return large_doc
return small_doc # Fallback to small doc if no parent found
class ContextualCompression(AdvancedRetriever):
"""Compress retrieved chunks to focus on relevant information"""
def retrieve(self, question: str, k: int = 5) -> List[Document]:
# Initial retrieval
docs = self.medium_vectorstore.similarity_search(question, k=k*2)
# Compress each document
compression_prompt = ChatPromptTemplate.from_template("""
Given this question: {question}
Extract only the most relevant information from this text that helps answer the question.
Remove any irrelevant details while preserving key facts and context.
Text: {text}
Relevant extract:
""")
compressed_docs = []
for doc in docs:
compressed_content = (
compression_prompt
| self.llm
| StrOutputParser()
).invoke({"question": question, "text": doc.page_content})
# Only keep if compression resulted in meaningful content
if len(compressed_content.strip()) > 50:
compressed_doc = Document(
page_content=compressed_content,
metadata=doc.metadata
)
compressed_docs.append(compressed_doc)
return compressed_docs[:k]
class CrossEncoderReranking(AdvancedRetriever):
"""Use cross-encoder for better relevance scoring"""
def retrieve(self, question: str, k: int = 5) -> List[Document]:
# Initial retrieval with higher k
initial_docs = self.medium_vectorstore.similarity_search(question, k=k*3)
if not initial_docs:
return []
# Prepare query-document pairs for cross-encoder
query_doc_pairs = []
for doc in initial_docs:
query_doc_pairs.append([question, doc.page_content])
# Get relevance scores
scores = self.cross_encoder.predict(query_doc_pairs)
# Sort documents by relevance score
doc_score_pairs = list(zip(initial_docs, scores))
doc_score_pairs.sort(key=lambda x: x[1], reverse=True) # type: ignore
# Return top k documents
return [doc for doc, score in doc_score_pairs[:k]]
class SemanticRouting(AdvancedRetriever):
"""Route queries to specialized retrievers based on query type"""
def __init__(self, link: str):
super().__init__(link)
self.query_classifier_prompt = ChatPromptTemplate.from_template("""
Classify this query into one of these categories:
1. FACTUAL - Asking for specific facts, definitions, or data
2. CONCEPTUAL - Asking for explanations, processes, or how things work
3. COMPARATIVE - Comparing different concepts, methods, or approaches
4. ANALYTICAL - Requiring analysis, reasoning, or synthesis
Query: {question}
Respond with only the category name (FACTUAL, CONCEPTUAL, COMPARATIVE, or ANALYTICAL):
""")
def retrieve(self, question: str, k: int = 5) -> List[Document]:
# Classify the query
query_type = (
self.query_classifier_prompt
| self.llm
| StrOutputParser()
).invoke({"question": question}).strip()
# Route to appropriate retrieval strategy
if query_type == "FACTUAL":
return self._factual_retrieval(question, k)
elif query_type == "CONCEPTUAL":
return self._conceptual_retrieval(question, k)
elif query_type == "COMPARATIVE":
return self._comparative_retrieval(question, k)
else: # ANALYTICAL
return self._analytical_retrieval(question, k)
def _factual_retrieval(self, question: str, k: int) -> List[Document]:
"""Precise retrieval for factual queries"""
return self.small_vectorstore.similarity_search(question, k=k)
def _conceptual_retrieval(self, question: str, k: int) -> List[Document]:
"""Broader context for conceptual queries"""
return self.large_vectorstore.similarity_search(question, k=k)
def _comparative_retrieval(self, question: str, k: int) -> List[Document]:
"""Multi-aspect retrieval for comparative queries"""
# Extract comparison terms
comparison_prompt = ChatPromptTemplate.from_template("""
Extract the main concepts being compared in this question: {question}
List them separated by commas:
""")
concepts = (
comparison_prompt
| self.llm
| StrOutputParser()
).invoke({"question": question})
all_docs = []
for concept in concepts.split(','):
docs = self.medium_vectorstore.similarity_search(concept.strip(), k=k//2)
all_docs.extend(docs)
return self._deduplicate_documents(all_docs)[:k]
def _analytical_retrieval(self, question: str, k: int) -> List[Document]:
"""Comprehensive retrieval for analytical queries"""
# Use multi-query approach for comprehensive coverage
multi_query = MultiQueryRetrieval(self.link)
return multi_query.retrieve(question, k)
def _deduplicate_documents(self, docs: List[Document]) -> List[Document]:
"""Remove duplicate documents"""
unique_docs = []
seen_content = set()
for doc in docs:
if doc.page_content not in seen_content:
unique_docs.append(doc)
seen_content.add(doc.page_content)
return unique_docs
# Integration functions for your main app
def get_answer_using_multi_query(link: str, question: str) -> str:
"""Multi-Query Retrieval implementation"""
retriever = MultiQueryRetrieval(link)
docs = retriever.retrieve(question)
# Generate answer using retrieved docs
template = """Answer the following question based on this context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatOpenAI(temperature=0)
final_chain = (
prompt
| llm
| StrOutputParser()
)
context = "\n\n".join([doc.page_content for doc in docs])
response = final_chain.invoke({"context": context, "question": question})
return response
def get_answer_using_parent_child(link: str, question: str) -> str:
"""Parent-Child Retrieval implementation"""
retriever = ParentChildRetrieval(link)
docs = retriever.retrieve(question)
template = """Answer the following question based on this context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatOpenAI(temperature=0)
final_chain = (
prompt
| llm
| StrOutputParser()
)
context = "\n\n".join([doc.page_content for doc in docs])
response = final_chain.invoke({"context": context, "question": question})
return response
def get_answer_using_contextual_compression(link: str, question: str) -> str:
"""Contextual Compression implementation"""
retriever = ContextualCompression(link)
docs = retriever.retrieve(question)
template = """Answer the following question based on this context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatOpenAI(temperature=0)
final_chain = (
prompt
| llm
| StrOutputParser()
)
context = "\n\n".join([doc.page_content for doc in docs])
response = final_chain.invoke({"context": context, "question": question})
return response
def get_answer_using_cross_encoder(link: str, question: str) -> str:
"""Cross-Encoder Reranking implementation"""
retriever = CrossEncoderReranking(link)
docs = retriever.retrieve(question)
template = """Answer the following question based on this context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatOpenAI(temperature=0)
final_chain = (
prompt
| llm
| StrOutputParser()
)
context = "\n\n".join([doc.page_content for doc in docs])
response = final_chain.invoke({"context": context, "question": question})
return response
def get_answer_using_semantic_routing(link: str, question: str) -> str:
"""Semantic Routing implementation"""
retriever = SemanticRouting(link)
docs = retriever.retrieve(question)
template = """Answer the following question based on this context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatOpenAI(temperature=0)
final_chain = (
prompt
| llm
| StrOutputParser()
)
context = "\n\n".join([doc.page_content for doc in docs])
response = final_chain.invoke({"context": context, "question": question})
return response
# Example usage
# if __name__ == "__main__":
# link = "https://lilianweng.github.io/posts/2023-06-23-agent/"
# question = "What is task decomposition for LLM agents?"
# # Test all advanced retrieval techniques
# techniques = [
# ("Multi-Query Retrieval", get_answer_using_multi_query),
# ("Parent-Child Retrieval", get_answer_using_parent_child),
# ("Contextual Compression", get_answer_using_contextual_compression),
# ("Cross-Encoder Reranking", get_answer_using_cross_encoder),
# ("Semantic Routing", get_answer_using_semantic_routing),
# ]
# for name, func in techniques:
# print(f"\n=== {name} ===")
# try:
# answer = func(link, question)
# print(answer)
# except Exception as e:
# print(f"Error: {e}")
# print("-" * 50)