""" Advanced RAG Engine for easyResearch. Implements production-grade retrieval with: - Hybrid Search: Dense vectors + BM25 sparse retrieval - Cross-Encoder Re-ranking optimized for RTX 3050 (4GB VRAM) - Reciprocal Rank Fusion for score combination - Full observability integration """ from __future__ import annotations import gc import os import re from dataclasses import dataclass from typing import Any, Literal import torch from dotenv import load_dotenv from langchain_core.documents import Document from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate from langchain_groq import ChatGroq from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder from core.embedder import get_vector_store, embedding_model from core.observability import RAGTracer, rag_logger, log_execution_time, log_gpu_memory from config import ( DEVICE, MAX_HISTORY_MESSAGES, RERANKER_MODEL, HYBRID_WEIGHT_RERANK, HYBRID_WEIGHT_BM25, MIN_SCORE_THRESHOLD, LLM_MODEL_GROQ, LLM_TEMPERATURE, LLM_MAX_TOKENS, ) load_dotenv() # ────────────────────────────────────────────────────────────────────────────── # Configuration # ────────────────────────────────────────────────────────────────────────────── @dataclass class RetrievalConfig: """Configuration for hybrid retrieval.""" dense_k: int = 20 # Initial dense retrieval count sparse_k: int = 20 # BM25 retrieval count rerank_top_k: int = 10 # Final docs after re-ranking # Score fusion weights dense_weight: float = 0.5 sparse_weight: float = 0.2 rerank_weight: float = 0.3 # Thresholds min_score_threshold: float = MIN_SCORE_THRESHOLD # Re-ranker settings (optimized for 4GB VRAM) reranker_batch_size: int = 8 # Small batch for memory efficiency use_fp16: bool = True # Half precision for memory savings # Global re-ranker (lazy loaded) _reranker: CrossEncoder | None = None def get_reranker() -> CrossEncoder: """Lazy load cross-encoder with memory optimization.""" global _reranker if _reranker is None: rag_logger.info(f"Loading cross-encoder: {RERANKER_MODEL} on {DEVICE}") # Configure for low VRAM usage _reranker = CrossEncoder( RERANKER_MODEL, device=DEVICE, max_length=512, # Limit input length for memory ) # Use FP16 if on CUDA for memory efficiency if DEVICE == "cuda" and torch.cuda.is_available(): _reranker.model.half() rag_logger.info("Cross-encoder using FP16 for memory efficiency") return _reranker # ────────────────────────────────────────────────────────────────────────────── # Prompts # ────────────────────────────────────────────────────────────────────────────── CONTEXTUALIZE_PROMPT = ChatPromptTemplate.from_messages([ ( "system", "You are a question reformulation expert. Reformulate the user's latest " "question into a standalone question that can be understood WITHOUT the " "chat history.\n\n" "RULES:\n" "1. Replace pronouns (it, this, they, he, she…) with actual terms from history.\n" "2. Incorporate previous topic for follow-ups.\n" "3. If already self-contained, return AS-IS.\n" "4. NEVER answer the question.\n" "5. Keep the same language.\n" "6. Be concise but complete." ), ("placeholder", "{chat_history}"), ("human", "Reformulate this question: {input}"), ]) RAG_PROMPT_WITH_HISTORY = ChatPromptTemplate.from_messages([ ( "system", "You are a helpful AI research assistant with access to a document database.\n" "Answer the user's question based ONLY on the provided context below.\n\n" "GUIDELINES:\n" "1. If the answer is not in the context, say you don't know.\n" "2. Use the conversation summary for flow.\n" "3. Be concise. Use bullet points when appropriate.\n" "4. Cite source document names when possible.\n" "5. Answer in the SAME language as the question.\n\n" "CONVERSATION SUMMARY:\n{conversation_summary}", ), ("human", "RETRIEVED DOCUMENTS:\n{context}\n\nQUESTION:\n{question}"), ]) RAG_PROMPT_NO_HISTORY = ChatPromptTemplate.from_messages([ ( "system", "You are a helpful AI research assistant with access to a document database.\n" "Answer the user's question based ONLY on the provided context below.\n\n" "GUIDELINES:\n" "1. If the answer is not in the context, say you don't know.\n" "2. Be concise. Use bullet points when appropriate.\n" "3. Cite source document names when possible.\n" "4. Answer in the SAME language as the question.", ), ("human", "RETRIEVED DOCUMENTS:\n{context}\n\nQUESTION:\n{question}"), ]) # ────────────────────────────────────────────────────────────────────────────── # Hybrid Search Implementation # ────────────────────────────────────────────────────────────────────────────── def _tokenize(text: str) -> list[str]: """Simple tokenizer for BM25.""" return re.findall(r"\w+", text.lower()) @log_execution_time def bm25_search( documents: list[Document], query: str, top_k: int = 20, ) -> list[tuple[Document, float]]: """ Perform BM25 sparse retrieval on documents. Returns documents with normalized BM25 scores. """ if not documents: return [] corpus = [_tokenize(doc.page_content) for doc in documents] bm25 = BM25Okapi(corpus) query_tokens = _tokenize(query) scores = bm25.get_scores(query_tokens) # Normalize scores to [0, 1] max_score = float(max(scores)) if max(scores) > 0 else 1.0 normalized_scores = [float(s) / max_score for s in scores] # Pair documents with scores and sort doc_scores = list(zip(documents, normalized_scores)) doc_scores.sort(key=lambda x: x[1], reverse=True) return doc_scores[:top_k] @log_execution_time def dense_search( collection_name: str, query: str, k: int = 20, filter_dict: dict | None = None, ) -> list[tuple[Document, float]]: """ Perform dense vector similarity search. Returns documents with similarity scores. """ db = get_vector_store(collection_name) if filter_dict: from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchText conditions = [] if "format" in filter_dict: conditions.append( FieldCondition(key="metadata.format", match=MatchValue(value=filter_dict["format"])) ) if "source" in filter_dict: conditions.append( FieldCondition(key="metadata.source", match=MatchText(text=filter_dict["source"])) ) qdrant_filter = Filter(must=conditions) if conditions else None results = db.similarity_search_with_score( query, k=k, filter=qdrant_filter ) else: results = db.similarity_search_with_score(query, k=k) # Convert to (doc, score) tuples - note: Qdrant returns distance, lower is better # We invert to higher is better for consistency doc_scores = [(doc, 1.0 - min(score, 1.0)) for doc, score in results] return doc_scores @log_gpu_memory @log_execution_time def rerank_documents( query: str, documents: list[Document], config: RetrievalConfig, ) -> list[tuple[Document, float]]: """ Re-rank documents using cross-encoder. Optimized for RTX 3050 4GB VRAM with batch processing. """ if not documents: return [] reranker = get_reranker() # Create query-document pairs pairs = [[query, doc.page_content[:512]] for doc in documents] # Truncate for memory # Batch prediction for memory efficiency all_scores = [] batch_size = config.reranker_batch_size for i in range(0, len(pairs), batch_size): batch = pairs[i:i + batch_size] with torch.no_grad(): if DEVICE == "cuda": with torch.cuda.amp.autocast(): # Mixed precision scores = reranker.predict(batch, show_progress_bar=False) else: scores = reranker.predict(batch, show_progress_bar=False) all_scores.extend(scores.tolist() if hasattr(scores, 'tolist') else list(scores)) # Clear GPU cache after each batch if DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() # Normalize scores to [0, 1] using sigmoid-like transformation min_score, max_score = min(all_scores), max(all_scores) score_range = max_score - min_score if max_score != min_score else 1.0 normalized_scores = [(s - min_score) / score_range for s in all_scores] # Pair and sort doc_scores = list(zip(documents, normalized_scores)) doc_scores.sort(key=lambda x: x[1], reverse=True) return doc_scores def reciprocal_rank_fusion( rankings: list[list[tuple[Document, float]]], k: int = 60, ) -> list[tuple[Document, float]]: """ Combine multiple rankings using Reciprocal Rank Fusion. RRF score = Σ 1 / (k + rank_i) for each ranking """ doc_scores: dict[int, float] = {} doc_map: dict[int, Document] = {} for ranking in rankings: for rank, (doc, _) in enumerate(ranking, start=1): # Use content hash as doc identifier doc_id = hash(doc.page_content[:200]) if doc_id not in doc_scores: doc_scores[doc_id] = 0.0 doc_map[doc_id] = doc doc_scores[doc_id] += 1.0 / (k + rank) # Sort by RRF score sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True) return [(doc_map[doc_id], score) for doc_id, score in sorted_docs] @log_execution_time def hybrid_search( collection_name: str, query: str, config: RetrievalConfig | None = None, filter_dict: dict | None = None, ) -> list[tuple[Document, float]]: """ Perform hybrid search combining dense and sparse retrieval. Pipeline: 1. Dense vector search (semantic similarity) 2. BM25 sparse search (keyword matching) 3. Reciprocal Rank Fusion 4. Cross-encoder re-ranking """ config = config or RetrievalConfig() # Step 1: Dense retrieval rag_logger.debug(f"Dense search: k={config.dense_k}") dense_results = dense_search(collection_name, query, k=config.dense_k, filter_dict=filter_dict) dense_docs = [doc for doc, _ in dense_results] if not dense_docs: rag_logger.warning("No documents found in dense search") return [] # Step 2: BM25 on retrieved documents (not full corpus for efficiency) rag_logger.debug(f"BM25 search on {len(dense_docs)} docs") bm25_results = bm25_search(dense_docs, query, top_k=config.sparse_k) # Step 3: Reciprocal Rank Fusion rag_logger.debug("Combining with RRF") fused_results = reciprocal_rank_fusion([dense_results, bm25_results]) # Take top candidates for re-ranking (limit for memory) candidates = [doc for doc, _ in fused_results[:min(len(fused_results), 30)]] # Step 4: Cross-encoder re-ranking rag_logger.debug(f"Re-ranking {len(candidates)} candidates") reranked = rerank_documents(query, candidates, config) # Apply score threshold and limit final_results = [ (doc, score) for doc, score in reranked if score >= config.min_score_threshold ][:config.rerank_top_k] rag_logger.info(f"Hybrid search returned {len(final_results)} docs") return final_results # ────────────────────────────────────────────────────────────────────────────── # Query Processing # ────────────────────────────────────────────────────────────────────────────── def _needs_contextualization(question: str) -> bool: """Check if question needs context from chat history.""" patterns = [ r"\b(it|its|this|that|these|those|they|them|their|he|she|him|her)\b", r"\b(the same|above|previous|mentioned|said|such)\b", r"\b(what about|how about|and the|also the|another)\b", # Vietnamese r"\b(nó|này|đó|ở trên|như vậy|còn|thế thì|vậy thì)\b", r"\b(cái này|cái đó|điều đó|vấn đề này|chúng)\b", ] q_lower = question.lower() return any(re.search(p, q_lower) for p in patterns) def _summarize_conversation(chat_history: list[dict], max_messages: int = 5) -> str: """Create a conversation summary for context.""" if not chat_history or len(chat_history) <= 1: return "This is the beginning of the conversation." recent = chat_history[-max_messages:] parts = [] for msg in recent: role = "User" if msg["role"] == "user" else "Assistant" content = msg["content"][:200] + "…" if len(msg["content"]) > 200 else msg["content"] parts.append(f"- {role}: {content}") return "\n".join(parts) def _get_llm( api_key: str | None = None, ) -> ChatGroq: """Get configured LLM instance with error handling.""" key = api_key or os.getenv("GROQ_API_KEY") if not key: raise ValueError("Groq API key required") from pydantic import SecretStr return ChatGroq( model=LLM_MODEL_GROQ, temperature=LLM_TEMPERATURE, max_tokens=LLM_MAX_TOKENS, api_key=SecretStr(key), ) def contextualize_query( question: str, chat_history: list[dict], llm: ChatGroq, ) -> str: """Reformulate question to be standalone using chat history.""" if not chat_history or not _needs_contextualization(question): return question try: # Convert to LangChain message format recent = chat_history[-MAX_HISTORY_MESSAGES:-1] if len(chat_history) > MAX_HISTORY_MESSAGES else chat_history[:-1] hist_lc = [] for msg in recent: if msg["role"] == "user": hist_lc.append(HumanMessage(content=msg["content"])) else: hist_lc.append(AIMessage(content=msg["content"])) chain = CONTEXTUALIZE_PROMPT | llm result = chain.invoke({ "chat_history": hist_lc, "input": question, }) content = result.content standalone = content.strip() if isinstance(content, str) else str(content) rag_logger.debug(f"Contextualized: '{question[:50]}' -> '{standalone[:50]}'") return standalone except Exception as e: rag_logger.warning(f"Contextualization failed: {e}") return question # ────────────────────────────────────────────────────────────────────────────── # Main RAG Query Function # ────────────────────────────────────────────────────────────────────────────── def query_rag( question: str, collection_name: str, chat_history: list[dict] | None = None, k_target: int = 10, user_api_key: str | None = None, format_filter: str | None = None, source_filter: str | None = None, retrieval_config: RetrievalConfig | None = None, **kwargs, ) -> dict[str, Any]: """ Execute RAG query with hybrid search and full observability. Args: question: User's question collection_name: Qdrant collection/workspace name chat_history: Previous conversation messages k_target: Target number of documents to retrieve user_api_key: Optional user-provided API key format_filter: Filter by document format (e.g., "pdf") source_filter: Filter by source filename retrieval_config: Custom retrieval configuration Returns: Dictionary with answer, sources, debug info, and trace data """ chat_history = chat_history or [] config = retrieval_config or RetrievalConfig(rerank_top_k=k_target) # Build filter dict filter_dict = {} if format_filter: filter_dict["format"] = format_filter if source_filter: filter_dict["source"] = source_filter with RAGTracer(workspace=collection_name, query=question) as tracer: try: # Initialize LLM llm = _get_llm(user_api_key) model_name = LLM_MODEL_GROQ # Contextualize query if needed with tracer.stage("contextualization"): standalone_question = contextualize_query( question, chat_history, llm ) tracer.set_standalone_query(standalone_question) # Hybrid retrieval with tracer.stage("retrieval"): results = hybrid_search( collection_name, standalone_question, config=config, filter_dict=filter_dict if filter_dict else None, ) tracer.log_retrieval([doc for doc, _ in results]) if not results: return { "answer": "I could not find relevant information in the documents.", "sources": [], "standalone_question": standalone_question, "pipeline_info": {"status": "no_docs_found"}, } # Log re-ranking scores tracer.log_reranking( [doc for doc, _ in results], [score for _, score in results], ) # Build context final_docs = [doc for doc, _ in results] context_text = "\n\n---\n\n".join([ f"[Source: {d.metadata.get('source', 'Unknown')}]\n" f"{d.metadata.get('parent_content', d.page_content)}" for d in final_docs ]) # Generate response with tracer.stage("generation"): has_history = len(chat_history) > 1 if has_history: summary = _summarize_conversation(chat_history) messages = RAG_PROMPT_WITH_HISTORY.format_messages( context=context_text, question=question, conversation_summary=summary, ) else: messages = RAG_PROMPT_NO_HISTORY.format_messages( context=context_text, question=question, ) response = llm.invoke(messages) content = response.content answer_text = content.strip() if isinstance(content, str) else str(content) # Collect sources source_names = list({d.metadata.get("source", "Unknown") for d in final_docs}) # Log generation details tracer.log_generation( context_length=len(context_text), llm_provider="groq", llm_model=model_name, response_length=len(answer_text), sources=source_names, ) return { "answer": answer_text, "sources": source_names, "standalone_question": standalone_question, "raw_docs": [ { "source": d.metadata.get("source", "Unknown"), "score": results[i][1] if i < len(results) else 0.0, "preview": d.page_content[:200], } for i, d in enumerate(final_docs[:5]) ], "pipeline_info": { "trace_id": tracer.trace.trace_id, "retrieval_count": len(results), "model": model_name, "contextualized": tracer.trace.contextualized, }, } except ValueError as e: # API key or config errors return { "answer": f"❌ Configuration error: {e}", "sources": [], "pipeline_info": {"error": str(e)}, } except Exception as e: rag_logger.exception("RAG query failed") return { "answer": f"❌ Query failed: {e}", "sources": [], "pipeline_info": {"error": str(e)}, } # Alias for backward compatibility query_rag_system = query_rag