Embeddings_Chat / advanced_rag.py
Haiss123's picture
Upload 20 files
6b98b09 verified
raw
history blame
10.2 kB
"""
Advanced RAG techniques for improved retrieval and generation
Includes: Query Expansion, Reranking, Contextual Compression, Hybrid Search
"""
from typing import List, Dict, Optional, Tuple
import numpy as np
from dataclasses import dataclass
import re
@dataclass
class RetrievedDocument:
"""Document retrieved from vector database"""
id: str
text: str
confidence: float
metadata: Dict
class AdvancedRAG:
"""Advanced RAG system with modern techniques"""
def __init__(self, embedding_service, qdrant_service):
self.embedding_service = embedding_service
self.qdrant_service = qdrant_service
def expand_query(self, query: str) -> List[str]:
"""
Expand query with related terms and variations
Simple rule-based expansion for Vietnamese queries
"""
queries = [query]
# Add query variations
# Remove question words for alternative search
question_words = ['ai', 'gì', 'nào', 'đâu', 'khi nào', 'như thế nào',
'tại sao', 'có', 'là', 'được', 'không']
query_lower = query.lower()
for qw in question_words:
if qw in query_lower:
variant = query_lower.replace(qw, '').strip()
if variant and variant != query_lower:
queries.append(variant)
# Extract key nouns/phrases (simple approach)
words = query.split()
if len(words) > 3:
# Take important words (skip first question word)
key_phrases = ' '.join(words[1:]) if words[0].lower() in question_words else ' '.join(words[:3])
if key_phrases not in queries:
queries.append(key_phrases)
return queries[:3] # Return top 3 variations
def multi_query_retrieval(
self,
query: str,
top_k: int = 5,
score_threshold: float = 0.5
) -> List[RetrievedDocument]:
"""
Retrieve documents using multiple query variations
Combines results from all query variations
"""
expanded_queries = self.expand_query(query)
all_results = {} # Use dict to deduplicate by doc_id
for q in expanded_queries:
# Generate embedding for each query variant
query_embedding = self.embedding_service.encode_text(q)
# Search in Qdrant
results = self.qdrant_service.search(
query_embedding=query_embedding,
limit=top_k,
score_threshold=score_threshold
)
# Add to results (keep highest score for duplicates)
for result in results:
doc_id = result["id"]
if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence:
all_results[doc_id] = RetrievedDocument(
id=doc_id,
text=result["metadata"].get("text", ""),
confidence=result["confidence"],
metadata=result["metadata"]
)
# Sort by confidence and return top_k
sorted_results = sorted(all_results.values(), key=lambda x: x.confidence, reverse=True)
return sorted_results[:top_k]
def rerank_documents(
self,
query: str,
documents: List[RetrievedDocument],
use_cross_encoder: bool = False
) -> List[RetrievedDocument]:
"""
Rerank documents based on semantic similarity
Simple reranking using embedding similarity (can be upgraded to cross-encoder)
"""
if not documents:
return documents
# Simple reranking: recalculate similarity with original query
query_embedding = self.embedding_service.encode_text(query)
reranked = []
for doc in documents:
# Get document embedding
doc_embedding = self.embedding_service.encode_text(doc.text)
# Calculate cosine similarity
similarity = np.dot(query_embedding.flatten(), doc_embedding.flatten())
# Combine with original confidence (weighted average)
new_score = 0.6 * similarity + 0.4 * doc.confidence
reranked.append(RetrievedDocument(
id=doc.id,
text=doc.text,
confidence=float(new_score),
metadata=doc.metadata
))
# Sort by new score
reranked.sort(key=lambda x: x.confidence, reverse=True)
return reranked
def compress_context(
self,
query: str,
documents: List[RetrievedDocument],
max_tokens: int = 500
) -> List[RetrievedDocument]:
"""
Compress context to most relevant parts
Remove redundant information and keep only relevant sentences
"""
compressed_docs = []
for doc in documents:
# Split into sentences
sentences = self._split_sentences(doc.text)
# Score each sentence based on relevance to query
scored_sentences = []
query_words = set(query.lower().split())
for sent in sentences:
sent_words = set(sent.lower().split())
# Simple relevance: word overlap
overlap = len(query_words & sent_words)
if overlap > 0:
scored_sentences.append((sent, overlap))
# Sort by relevance and take top sentences
scored_sentences.sort(key=lambda x: x[1], reverse=True)
# Reconstruct compressed text (up to max_tokens)
compressed_text = ""
word_count = 0
for sent, score in scored_sentences:
sent_words = len(sent.split())
if word_count + sent_words <= max_tokens:
compressed_text += sent + " "
word_count += sent_words
else:
break
# If nothing selected, take original first part
if not compressed_text.strip():
compressed_text = doc.text[:max_tokens * 5] # Rough estimate
compressed_docs.append(RetrievedDocument(
id=doc.id,
text=compressed_text.strip(),
confidence=doc.confidence,
metadata=doc.metadata
))
return compressed_docs
def _split_sentences(self, text: str) -> List[str]:
"""Split text into sentences (Vietnamese-aware)"""
# Simple sentence splitter
sentences = re.split(r'[.!?]+', text)
return [s.strip() for s in sentences if s.strip()]
def hybrid_rag_pipeline(
self,
query: str,
top_k: int = 5,
score_threshold: float = 0.5,
use_reranking: bool = True,
use_compression: bool = True,
max_context_tokens: int = 500
) -> Tuple[List[RetrievedDocument], Dict]:
"""
Complete advanced RAG pipeline
1. Multi-query retrieval
2. Reranking
3. Contextual compression
"""
stats = {
"original_query": query,
"expanded_queries": [],
"initial_results": 0,
"after_rerank": 0,
"after_compression": 0
}
# Step 1: Multi-query retrieval
expanded_queries = self.expand_query(query)
stats["expanded_queries"] = expanded_queries
documents = self.multi_query_retrieval(
query=query,
top_k=top_k * 2, # Get more candidates for reranking
score_threshold=score_threshold
)
stats["initial_results"] = len(documents)
# Step 2: Reranking (optional)
if use_reranking and documents:
documents = self.rerank_documents(query, documents)
documents = documents[:top_k] # Keep top_k after reranking
stats["after_rerank"] = len(documents)
# Step 3: Contextual compression (optional)
if use_compression and documents:
documents = self.compress_context(
query=query,
documents=documents,
max_tokens=max_context_tokens
)
stats["after_compression"] = len(documents)
return documents, stats
def format_context_for_llm(
self,
documents: List[RetrievedDocument],
include_metadata: bool = True
) -> str:
"""
Format retrieved documents into context string for LLM
Uses better structure for improved LLM understanding
"""
if not documents:
return ""
context_parts = ["RELEVANT CONTEXT:\n"]
for i, doc in enumerate(documents, 1):
context_parts.append(f"\n--- Document {i} (Relevance: {doc.confidence:.2%}) ---")
context_parts.append(doc.text)
if include_metadata and doc.metadata:
# Add useful metadata
meta_str = []
for key, value in doc.metadata.items():
if key not in ['text', 'texts'] and value:
meta_str.append(f"{key}: {value}")
if meta_str:
context_parts.append(f"[Metadata: {', '.join(meta_str)}]")
context_parts.append("\n--- End of Context ---\n")
return "\n".join(context_parts)
def build_rag_prompt(
self,
query: str,
context: str,
system_message: str = "You are a helpful AI assistant."
) -> str:
"""
Build optimized RAG prompt for LLM
Uses best practices for prompt engineering
"""
prompt_template = f"""{system_message}
{context}
INSTRUCTIONS:
1. Answer the user's question using ONLY the information provided in the context above
2. If the context doesn't contain relevant information, say "Tôi không tìm thấy thông tin liên quan trong dữ liệu."
3. Cite relevant parts of the context when answering
4. Be concise and accurate
5. Answer in Vietnamese if the question is in Vietnamese
USER QUESTION: {query}
YOUR ANSWER:"""
return prompt_template