Spaces:
Sleeping
Sleeping
File size: 5,460 Bytes
8521f60 bdb49ae 8521f60 79bdbbe bdb49ae 79bdbbe bdb49ae 79bdbbe 8521f60 12409b1 8521f60 12409b1 8521f60 79bdbbe 8521f60 cdf4160 79bdbbe 8521f60 4dc151e 8521f60 79bdbbe 8521f60 79bdbbe cdf4160 79bdbbe 8521f60 bdb49ae 8521f60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""High‑level RAG pipeline orchestration."""
from __future__ import annotations
import logging
from typing import Dict, Any, List
from .config import PipelineConfig
from .retrievers import bm25, dense, hybrid
from .generators.hf_generator import HFGenerator
from .retrievers.base import Retriever, Context
from .rerankers.cross_encoder import CrossEncoderReranker
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class RAGPipeline:
"""Run retrieval → generation → scoring in a single object."""
def __init__(self, cfg: PipelineConfig):
self.cfg = cfg
self.retriever: Retriever = self._build_retriever(cfg)
self.generator = HFGenerator(
model_name=cfg.generator.model_name, device=cfg.generator.device
)
if cfg.reranker.enable:
self.reranker = CrossEncoderReranker(
cfg.reranker.model_name,
device=cfg.reranker.device,
max_length=cfg.reranker.max_length,
)
else:
self.reranker = None
# ---------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------
def run(self, question: str) -> Dict[str, Any]:
logger.info("Question: %s", question)
# 1. raw retrieval
k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k
initial: List[Context] = self.retriever.retrieve(question, top_k=k_first)
raw_hits = [
{"text": c.text, "id": c.id, "score": getattr(c, "retrieval_score", None)}
for c in initial
]
# 2. reranking (if enabled)
if self.reranker:
final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k
reranked: List[Context] = self.reranker.rerank(question, initial, k=final_k)
reranked_hits = [
{
"text": c.text,
"id": c.id,
"score": getattr(c, "cross_encoder_score", None),
}
for c in reranked
]
contexts_for_gen = reranked
else:
reranked_hits = []
contexts_for_gen = initial
# 3. generation
answer = self.generator.generate(
question,
[c.text for c in contexts_for_gen],
max_new_tokens=self.cfg.generator.max_new_tokens,
temperature=self.cfg.generator.temperature,
)
return {
"question": question,
"raw_retrieval": raw_hits,
"reranked": reranked_hits,
"contexts": [c.text for c in contexts_for_gen],
"answer": answer,
}
__call__ = run # alias
def run_queries(self, queries: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Accepts a list of {'question': str, 'id': Any}, returns list of result dicts."""
results: list[dict[str, Any]] = []
for entry in queries:
q = entry.get("question", "")
doc_id = entry.get("id")
answer = self.run(q)
results.append({"id": doc_id, "question": q, "answer": answer})
return results
# ---------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------
def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
r=cfg.retriever
name = r.name
if name == "bm25":
return bm25.BM25Retriever(bm25_idx=str(r.bm25_idx), doc_store=str(r.doc_store))
if name == "dense":
return dense.DenseRetriever(
faiss_index=str(r.faiss_index),
doc_store=str(r.doc_store),
model_name=r.model_name,
embedder_cache=str(r.embedder_cache) if r.embedder_cache else None,
device=r.device,
)
if name == "hybrid":
return hybrid.HybridRetriever(
str(r.bm25_index),
str(r.faiss_index),
doc_store=str(r.doc_store),
alpha=r.alpha,
model_name=r.model_name,
embedder_cache=str(r.embedder_cache) if r.embedder_cache else None,
device=r.device,
)
raise ValueError(f"Unsupported retriever '{name}'")
def _retrieve(self, question: str) -> List[Context]:
logger.info("Retrieving top‑%d passages", self.cfg.retriever.top_k)
k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k
initial = self.retriever.retrieve(question, top_k=k_first)
if self.reranker:
final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k
logger.info("Re-ranking %d docs with cross-encoder ...", len(initial))
initial = self.reranker.rerank(question, initial, k=final_k)
return initial
def _generate(self, question: str, contexts: List[Context]) -> str:
texts = [c.text for c in contexts]
logger.info("Generating answer with %d context passages", len(texts))
return self.generator.generate(
question,
texts,
max_new_tokens=self.cfg.generator.max_new_tokens,
temperature=self.cfg.generator.temperature,
)
|