File size: 5,460 Bytes
8521f60
 
 
 
 
 
 
 
 
 
bdb49ae
8521f60
 
 
 
 
 
 
 
 
 
 
 
 
79bdbbe
 
bdb49ae
 
79bdbbe
bdb49ae
79bdbbe
 
8521f60
 
 
 
 
 
12409b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8521f60
 
12409b1
 
 
8521f60
 
 
 
 
79bdbbe
 
 
 
 
 
 
 
 
8521f60
 
 
 
cdf4160
79bdbbe
8521f60
4dc151e
8521f60
79bdbbe
 
 
 
 
 
 
8521f60
 
79bdbbe
 
 
cdf4160
79bdbbe
 
 
8521f60
 
 
 
 
bdb49ae
 
 
 
 
 
 
 
 
8521f60
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""High‑level RAG pipeline orchestration."""

from __future__ import annotations
import logging
from typing import Dict, Any, List

from .config import PipelineConfig
from .retrievers import bm25, dense, hybrid
from .generators.hf_generator import HFGenerator
from .retrievers.base import Retriever, Context
from .rerankers.cross_encoder import CrossEncoderReranker
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class RAGPipeline:
    """Run retrieval → generation → scoring in a single object."""

    def __init__(self, cfg: PipelineConfig):
        self.cfg = cfg
        self.retriever: Retriever = self._build_retriever(cfg)
        self.generator = HFGenerator(
            model_name=cfg.generator.model_name, device=cfg.generator.device
        )
        if cfg.reranker.enable:
            self.reranker = CrossEncoderReranker(
                cfg.reranker.model_name,
                device=cfg.reranker.device,
                max_length=cfg.reranker.max_length,
            )
        else:
            self.reranker = None

    # ---------------------------------------------------------------------
    # Public API
    # ---------------------------------------------------------------------
    def run(self, question: str) -> Dict[str, Any]:
        logger.info("Question: %s", question)

        # 1. raw retrieval
        k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k
        initial: List[Context] = self.retriever.retrieve(question, top_k=k_first)

        raw_hits = [
            {"text": c.text, "id": c.id, "score": getattr(c, "retrieval_score", None)}
            for c in initial
        ]

        # 2. reranking (if enabled)
        if self.reranker:
            final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k
            reranked: List[Context] = self.reranker.rerank(question, initial, k=final_k)

            reranked_hits = [
                {
                    "text": c.text,
                    "id": c.id,
                    "score": getattr(c, "cross_encoder_score", None),
                }
                for c in reranked
            ]
            contexts_for_gen = reranked
        else:
            reranked_hits = [] 
            contexts_for_gen = initial

        # 3. generation
        answer = self.generator.generate(
            question,
            [c.text for c in contexts_for_gen],
            max_new_tokens=self.cfg.generator.max_new_tokens,
            temperature=self.cfg.generator.temperature,
        )

        return {
            "question": question,
            "raw_retrieval": raw_hits,
            "reranked": reranked_hits,
            "contexts": [c.text for c in contexts_for_gen],
            "answer": answer,
        }

    __call__ = run  # alias

    def run_queries(self, queries: list[dict[str, Any]]) -> list[dict[str, Any]]:
        """Accepts a list of {'question': str, 'id': Any}, returns list of result dicts."""
        results: list[dict[str, Any]] = []
        for entry in queries:
            q = entry.get("question", "")
            doc_id = entry.get("id")
            answer = self.run(q)
            results.append({"id": doc_id, "question": q, "answer": answer})
        return results
    # ---------------------------------------------------------------------
    # Private helpers
    # ---------------------------------------------------------------------
    def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
        r=cfg.retriever
        name = r.name
        if name == "bm25":
            return bm25.BM25Retriever(bm25_idx=str(r.bm25_idx), doc_store=str(r.doc_store))
        if name == "dense":
            return dense.DenseRetriever(
                faiss_index=str(r.faiss_index),
                doc_store=str(r.doc_store),
                model_name=r.model_name,
                embedder_cache=str(r.embedder_cache) if r.embedder_cache else None,
                device=r.device,
            )
        if name == "hybrid":
            return hybrid.HybridRetriever(
                str(r.bm25_index),
                str(r.faiss_index),
                doc_store=str(r.doc_store),
                alpha=r.alpha,
                model_name=r.model_name,
                embedder_cache=str(r.embedder_cache) if r.embedder_cache else None,
                device=r.device,
            )
        raise ValueError(f"Unsupported retriever '{name}'")

    def _retrieve(self, question: str) -> List[Context]:
        logger.info("Retrieving top‑%d passages", self.cfg.retriever.top_k)
        k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k
        initial = self.retriever.retrieve(question, top_k=k_first)

        if self.reranker:
            final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k
            logger.info("Re-ranking %d docs with cross-encoder ...", len(initial))
            initial = self.reranker.rerank(question, initial, k=final_k)

        return initial

    def _generate(self, question: str, contexts: List[Context]) -> str:
        texts = [c.text for c in contexts]
        logger.info("Generating answer with %d context passages", len(texts))
        return self.generator.generate(
            question,
            texts,
            max_new_tokens=self.cfg.generator.max_new_tokens,
            temperature=self.cfg.generator.temperature,
        )