Spaces:
Running
Running
| """ | |
| Advanced RAG techniques for improved retrieval and generation | |
| Includes: Query Expansion, Reranking, Contextual Compression, Hybrid Search | |
| """ | |
| from typing import List, Dict, Optional, Tuple | |
| import numpy as np | |
| from dataclasses import dataclass | |
| import re | |
| class RetrievedDocument: | |
| """Document retrieved from vector database""" | |
| id: str | |
| text: str | |
| confidence: float | |
| metadata: Dict | |
| class AdvancedRAG: | |
| """Advanced RAG system with modern techniques""" | |
| def __init__(self, embedding_service, qdrant_service): | |
| self.embedding_service = embedding_service | |
| self.qdrant_service = qdrant_service | |
| def expand_query(self, query: str) -> List[str]: | |
| """ | |
| Expand query with related terms and variations | |
| Simple rule-based expansion for Vietnamese queries | |
| """ | |
| queries = [query] | |
| # Add query variations | |
| # Remove question words for alternative search | |
| question_words = ['ai', 'gì', 'nào', 'đâu', 'khi nào', 'như thế nào', | |
| 'tại sao', 'có', 'là', 'được', 'không'] | |
| query_lower = query.lower() | |
| for qw in question_words: | |
| if qw in query_lower: | |
| variant = query_lower.replace(qw, '').strip() | |
| if variant and variant != query_lower: | |
| queries.append(variant) | |
| # Extract key nouns/phrases (simple approach) | |
| words = query.split() | |
| if len(words) > 3: | |
| # Take important words (skip first question word) | |
| key_phrases = ' '.join(words[1:]) if words[0].lower() in question_words else ' '.join(words[:3]) | |
| if key_phrases not in queries: | |
| queries.append(key_phrases) | |
| return queries[:3] # Return top 3 variations | |
| def multi_query_retrieval( | |
| self, | |
| query: str, | |
| top_k: int = 5, | |
| score_threshold: float = 0.5 | |
| ) -> List[RetrievedDocument]: | |
| """ | |
| Retrieve documents using multiple query variations | |
| Combines results from all query variations | |
| """ | |
| expanded_queries = self.expand_query(query) | |
| all_results = {} # Use dict to deduplicate by doc_id | |
| for q in expanded_queries: | |
| # Generate embedding for each query variant | |
| query_embedding = self.embedding_service.encode_text(q) | |
| # Search in Qdrant | |
| results = self.qdrant_service.search( | |
| query_embedding=query_embedding, | |
| limit=top_k, | |
| score_threshold=score_threshold | |
| ) | |
| # Add to results (keep highest score for duplicates) | |
| for result in results: | |
| doc_id = result["id"] | |
| if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence: | |
| all_results[doc_id] = RetrievedDocument( | |
| id=doc_id, | |
| text=result["metadata"].get("text", ""), | |
| confidence=result["confidence"], | |
| metadata=result["metadata"] | |
| ) | |
| # Sort by confidence and return top_k | |
| sorted_results = sorted(all_results.values(), key=lambda x: x.confidence, reverse=True) | |
| return sorted_results[:top_k] | |
| def rerank_documents( | |
| self, | |
| query: str, | |
| documents: List[RetrievedDocument], | |
| use_cross_encoder: bool = False | |
| ) -> List[RetrievedDocument]: | |
| """ | |
| Rerank documents based on semantic similarity | |
| Simple reranking using embedding similarity (can be upgraded to cross-encoder) | |
| """ | |
| if not documents: | |
| return documents | |
| # Simple reranking: recalculate similarity with original query | |
| query_embedding = self.embedding_service.encode_text(query) | |
| reranked = [] | |
| for doc in documents: | |
| # Get document embedding | |
| doc_embedding = self.embedding_service.encode_text(doc.text) | |
| # Calculate cosine similarity | |
| similarity = np.dot(query_embedding.flatten(), doc_embedding.flatten()) | |
| # Combine with original confidence (weighted average) | |
| new_score = 0.6 * similarity + 0.4 * doc.confidence | |
| reranked.append(RetrievedDocument( | |
| id=doc.id, | |
| text=doc.text, | |
| confidence=float(new_score), | |
| metadata=doc.metadata | |
| )) | |
| # Sort by new score | |
| reranked.sort(key=lambda x: x.confidence, reverse=True) | |
| return reranked | |
| def compress_context( | |
| self, | |
| query: str, | |
| documents: List[RetrievedDocument], | |
| max_tokens: int = 500 | |
| ) -> List[RetrievedDocument]: | |
| """ | |
| Compress context to most relevant parts | |
| Remove redundant information and keep only relevant sentences | |
| """ | |
| compressed_docs = [] | |
| for doc in documents: | |
| # Split into sentences | |
| sentences = self._split_sentences(doc.text) | |
| # Score each sentence based on relevance to query | |
| scored_sentences = [] | |
| query_words = set(query.lower().split()) | |
| for sent in sentences: | |
| sent_words = set(sent.lower().split()) | |
| # Simple relevance: word overlap | |
| overlap = len(query_words & sent_words) | |
| if overlap > 0: | |
| scored_sentences.append((sent, overlap)) | |
| # Sort by relevance and take top sentences | |
| scored_sentences.sort(key=lambda x: x[1], reverse=True) | |
| # Reconstruct compressed text (up to max_tokens) | |
| compressed_text = "" | |
| word_count = 0 | |
| for sent, score in scored_sentences: | |
| sent_words = len(sent.split()) | |
| if word_count + sent_words <= max_tokens: | |
| compressed_text += sent + " " | |
| word_count += sent_words | |
| else: | |
| break | |
| # If nothing selected, take original first part | |
| if not compressed_text.strip(): | |
| compressed_text = doc.text[:max_tokens * 5] # Rough estimate | |
| compressed_docs.append(RetrievedDocument( | |
| id=doc.id, | |
| text=compressed_text.strip(), | |
| confidence=doc.confidence, | |
| metadata=doc.metadata | |
| )) | |
| return compressed_docs | |
| def _split_sentences(self, text: str) -> List[str]: | |
| """Split text into sentences (Vietnamese-aware)""" | |
| # Simple sentence splitter | |
| sentences = re.split(r'[.!?]+', text) | |
| return [s.strip() for s in sentences if s.strip()] | |
| def hybrid_rag_pipeline( | |
| self, | |
| query: str, | |
| top_k: int = 5, | |
| score_threshold: float = 0.5, | |
| use_reranking: bool = True, | |
| use_compression: bool = True, | |
| max_context_tokens: int = 500 | |
| ) -> Tuple[List[RetrievedDocument], Dict]: | |
| """ | |
| Complete advanced RAG pipeline | |
| 1. Multi-query retrieval | |
| 2. Reranking | |
| 3. Contextual compression | |
| """ | |
| stats = { | |
| "original_query": query, | |
| "expanded_queries": [], | |
| "initial_results": 0, | |
| "after_rerank": 0, | |
| "after_compression": 0 | |
| } | |
| # Step 1: Multi-query retrieval | |
| expanded_queries = self.expand_query(query) | |
| stats["expanded_queries"] = expanded_queries | |
| documents = self.multi_query_retrieval( | |
| query=query, | |
| top_k=top_k * 2, # Get more candidates for reranking | |
| score_threshold=score_threshold | |
| ) | |
| stats["initial_results"] = len(documents) | |
| # Step 2: Reranking (optional) | |
| if use_reranking and documents: | |
| documents = self.rerank_documents(query, documents) | |
| documents = documents[:top_k] # Keep top_k after reranking | |
| stats["after_rerank"] = len(documents) | |
| # Step 3: Contextual compression (optional) | |
| if use_compression and documents: | |
| documents = self.compress_context( | |
| query=query, | |
| documents=documents, | |
| max_tokens=max_context_tokens | |
| ) | |
| stats["after_compression"] = len(documents) | |
| return documents, stats | |
| def format_context_for_llm( | |
| self, | |
| documents: List[RetrievedDocument], | |
| include_metadata: bool = True | |
| ) -> str: | |
| """ | |
| Format retrieved documents into context string for LLM | |
| Uses better structure for improved LLM understanding | |
| """ | |
| if not documents: | |
| return "" | |
| context_parts = ["RELEVANT CONTEXT:\n"] | |
| for i, doc in enumerate(documents, 1): | |
| context_parts.append(f"\n--- Document {i} (Relevance: {doc.confidence:.2%}) ---") | |
| context_parts.append(doc.text) | |
| if include_metadata and doc.metadata: | |
| # Add useful metadata | |
| meta_str = [] | |
| for key, value in doc.metadata.items(): | |
| if key not in ['text', 'texts'] and value: | |
| meta_str.append(f"{key}: {value}") | |
| if meta_str: | |
| context_parts.append(f"[Metadata: {', '.join(meta_str)}]") | |
| context_parts.append("\n--- End of Context ---\n") | |
| return "\n".join(context_parts) | |
| def build_rag_prompt( | |
| self, | |
| query: str, | |
| context: str, | |
| system_message: str = "You are a helpful AI assistant." | |
| ) -> str: | |
| """ | |
| Build optimized RAG prompt for LLM | |
| Uses best practices for prompt engineering | |
| """ | |
| prompt_template = f"""{system_message} | |
| {context} | |
| INSTRUCTIONS: | |
| 1. Answer the user's question using ONLY the information provided in the context above | |
| 2. If the context doesn't contain relevant information, say "Tôi không tìm thấy thông tin liên quan trong dữ liệu." | |
| 3. Cite relevant parts of the context when answering | |
| 4. Be concise and accurate | |
| 5. Answer in Vietnamese if the question is in Vietnamese | |
| USER QUESTION: {query} | |
| YOUR ANSWER:""" | |
| return prompt_template | |