Spaces:
Sleeping
Sleeping
| """High‑level RAG pipeline orchestration.""" | |
| from __future__ import annotations | |
| import logging | |
| from typing import Dict, Any, List | |
| from .config import PipelineConfig | |
| from .retrievers import bm25, dense, hybrid | |
| from .generators.hf_generator import HFGenerator | |
| from .retrievers.base import Retriever, Context | |
| from .rerankers.cross_encoder import CrossEncoderReranker | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| class RAGPipeline: | |
| """Run retrieval → generation → scoring in a single object.""" | |
| def __init__(self, cfg: PipelineConfig): | |
| self.cfg = cfg | |
| self.retriever: Retriever = self._build_retriever(cfg) | |
| self.generator = HFGenerator( | |
| model_name=cfg.generator.model_name, device=cfg.generator.device | |
| ) | |
| if cfg.reranker.enable: | |
| self.reranker = CrossEncoderReranker( | |
| cfg.reranker.model_name, | |
| device=cfg.reranker.device, | |
| max_length=cfg.reranker.max_length, | |
| ) | |
| else: | |
| self.reranker = None | |
| # --------------------------------------------------------------------- | |
| # Public API | |
| # --------------------------------------------------------------------- | |
| def run(self, question: str) -> Dict[str, Any]: | |
| logger.info("Question: %s", question) | |
| # 1. raw retrieval | |
| k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k | |
| initial: List[Context] = self.retriever.retrieve(question, top_k=k_first) | |
| raw_hits = [ | |
| {"text": c.text, "id": c.id, "score": getattr(c, "retrieval_score", None)} | |
| for c in initial | |
| ] | |
| # 2. reranking (if enabled) | |
| if self.reranker: | |
| final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k | |
| reranked: List[Context] = self.reranker.rerank(question, initial, k=final_k) | |
| reranked_hits = [ | |
| { | |
| "text": c.text, | |
| "id": c.id, | |
| "score": getattr(c, "cross_encoder_score", None), | |
| } | |
| for c in reranked | |
| ] | |
| contexts_for_gen = reranked | |
| else: | |
| reranked_hits = [] | |
| contexts_for_gen = initial | |
| # 3. generation | |
| answer = self.generator.generate( | |
| question, | |
| [c.text for c in contexts_for_gen], | |
| max_new_tokens=self.cfg.generator.max_new_tokens, | |
| temperature=self.cfg.generator.temperature, | |
| ) | |
| return { | |
| "question": question, | |
| "raw_retrieval": raw_hits, | |
| "reranked": reranked_hits, | |
| "contexts": [c.text for c in contexts_for_gen], | |
| "answer": answer, | |
| } | |
| __call__ = run # alias | |
| def run_queries(self, queries: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| """Accepts a list of {'question': str, 'id': Any}, returns list of result dicts.""" | |
| results: list[dict[str, Any]] = [] | |
| for entry in queries: | |
| q = entry.get("question", "") | |
| doc_id = entry.get("id") | |
| answer = self.run(q) | |
| results.append({"id": doc_id, "question": q, "answer": answer}) | |
| return results | |
| # --------------------------------------------------------------------- | |
| # Private helpers | |
| # --------------------------------------------------------------------- | |
| def _build_retriever(self, cfg: PipelineConfig) -> Retriever: | |
| r=cfg.retriever | |
| name = r.name | |
| if name == "bm25": | |
| return bm25.BM25Retriever(bm25_idx=str(r.bm25_idx), doc_store=str(r.doc_store)) | |
| if name == "dense": | |
| return dense.DenseRetriever( | |
| faiss_index=str(r.faiss_index), | |
| doc_store=str(r.doc_store), | |
| model_name=r.model_name, | |
| embedder_cache=str(r.embedder_cache) if r.embedder_cache else None, | |
| device=r.device, | |
| ) | |
| if name == "hybrid": | |
| return hybrid.HybridRetriever( | |
| str(r.bm25_index), | |
| str(r.faiss_index), | |
| doc_store=str(r.doc_store), | |
| alpha=r.alpha, | |
| model_name=r.model_name, | |
| embedder_cache=str(r.embedder_cache) if r.embedder_cache else None, | |
| device=r.device, | |
| ) | |
| raise ValueError(f"Unsupported retriever '{name}'") | |
| def _retrieve(self, question: str) -> List[Context]: | |
| logger.info("Retrieving top‑%d passages", self.cfg.retriever.top_k) | |
| k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k | |
| initial = self.retriever.retrieve(question, top_k=k_first) | |
| if self.reranker: | |
| final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k | |
| logger.info("Re-ranking %d docs with cross-encoder ...", len(initial)) | |
| initial = self.reranker.rerank(question, initial, k=final_k) | |
| return initial | |
| def _generate(self, question: str, contexts: List[Context]) -> str: | |
| texts = [c.text for c in contexts] | |
| logger.info("Generating answer with %d context passages", len(texts)) | |
| return self.generator.generate( | |
| question, | |
| texts, | |
| max_new_tokens=self.cfg.generator.max_new_tokens, | |
| temperature=self.cfg.generator.temperature, | |
| ) | |