"""End-to-end RAG engine. Orchestrates retrieval → context formatting → LLM generation. This is the main entry point for answering questions about research papers. """ import logging from dataclasses import dataclass, field from src.generation.llm_backend_base import GenerationConfig, GenerationResult, LLMBackend from src.retrieval.pipeline import RetrievalPipeline, RetrievalResult logger = logging.getLogger(__name__) @dataclass class RAGResponse: """Full RAG response with answer, sources, and metadata.""" answer: str sources: list[dict] model: str usage: dict = field(default_factory=dict) def format_context(results: list[RetrievalResult]) -> str: """Format retrieval results into a context block for the LLM prompt. Each chunk is wrapped with its paper metadata so the LLM can cite sources. """ if not results: return "No relevant papers found." blocks = [] for i, r in enumerate(results, 1): header = f"[{i}] \"{r.title}\" ({r.venue or 'unknown venue'}, {r.year})" blocks.append(f"{header}\n{r.chunk_text}") return "\n\n---\n\n".join(blocks) def build_prompt(query: str, context: str) -> str: """Build the user prompt with context and question.""" return ( f"Below are excerpts from relevant research papers.\n\n" f"{context}\n\n" f"---\n\n" f"Based on the above excerpts, answer the following question. " f"Cite papers by their number (e.g., [1], [2]) when referencing specific findings.\n\n" f"Question: {query}" ) class RAGEngine: """Orchestrates retrieval and generation for end-to-end RAG. Usage: engine = RAGEngine(pipeline, llm_backend) response = engine.query("What is LoRA?") print(response.answer) """ def __init__( self, retrieval_pipeline: RetrievalPipeline, llm_backend: LLMBackend, config: GenerationConfig | None = None, ): self.pipeline = retrieval_pipeline self.llm = llm_backend self.config = config or GenerationConfig() def query( self, question: str, top_k: int = 5, source_top_k: int = 20, where: dict | None = None, ) -> RAGResponse: """Answer a question using retrieval-augmented generation. Args: question: The user's natural-language question. top_k: Number of chunks used as LLM generation context. source_top_k: Number of chunks to retrieve for the source list (returns more papers than used for generation). where: Optional metadata filter for retrieval (e.g., year, venue). Returns: RAGResponse with the answer, source papers, and metadata. """ logger.info("RAG query: %r (top_k=%d, source_top_k=%d)", question, top_k, source_top_k) # Step 1: Retrieve relevant chunks (more than needed for generation) results = self.pipeline.search(query=question, top_k=source_top_k, where=where) logger.info("Retrieved %d chunks", len(results)) # Step 2: Format context from top_k chunks only (for LLM prompt) context_results = results[:top_k] context = format_context(context_results) # Track which papers were used for generation context context_paper_ids = {r.paper_id for r in context_results} # Step 3: Build prompt and generate prompt = build_prompt(question, context) try: gen_result: GenerationResult = self.llm.generate( prompt=prompt, system_prompt=self.config.system_prompt, max_tokens=self.config.max_tokens, temperature=self.config.temperature, ) except Exception as exc: logger.error("LLM generation failed: %s", exc) gen_result = GenerationResult( answer=( "⚠ The language model is temporarily unavailable. " "Here are the most relevant papers found for your query — " "browse the sources below for details." ), model="none (generation failed)", usage={}, ) # Step 4: Build source list from ALL results (deduplicated by paper_id) seen_papers: set[str] = set() sources = [] for r in results: if r.paper_id not in seen_papers: seen_papers.add(r.paper_id) sources.append({ "paper_id": r.paper_id, "title": r.title, "year": r.year, "venue": r.venue, "chunk_type": r.chunk_type, "used_in_answer": r.paper_id in context_paper_ids, }) return RAGResponse( answer=gen_result.answer, sources=sources, model=gen_result.model, usage=gen_result.usage, )