File size: 5,066 Bytes
721ca73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""

rag_pipeline.py

───────────────

Orchestrates the full RAG pipeline: query β†’ retrieve β†’ generate β†’ answer.



This module is the single integration point between the vector store and

the LLM. The UI layer (app.py) calls only this module; it knows nothing

about FAISS or Groq directly.



Pipeline steps

──────────────

  1. Validate query (non-empty, reasonable length)

  2. Retrieve top-k relevant chunks from FAISS

  3. Pass chunks + query to the LLM for grounded generation

  4. Return the answer (and optionally the source snippets for transparency)

"""

import logging
from dataclasses import dataclass

from groq import Groq
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document

import llm as llm_module
import vector_store as vs_module
from config import cfg

logger = logging.getLogger(__name__)

MAX_QUERY_LENGTH = 1000  # characters


# ── Data classes ──────────────────────────────────────────────────────────────

@dataclass
class RAGResponse:
    answer: str
    sources: list[Document]
    query: str

    def format_sources(self) -> str:
        """Return a compact source-citation string for display in the UI."""
        if not self.sources:
            return ""
        lines = []
        for i, doc in enumerate(self.sources, 1):
            src  = doc.metadata.get("source", "")
            page = doc.metadata.get("page",   "")
            snippet = doc.page_content[:120].replace("\n", " ") + "…"
            label = f"**[{i}]**"
            if src:
                label += f" {src}"
            if page:
                label += f" p.{page}"
            lines.append(f"{label}: _{snippet}_")
        return "\n".join(lines)


# ── Pipeline class ────────────────────────────────────────────────────────────

class RAGPipeline:
    """

    Stateful pipeline object. Instantiated once at app startup and reused

    for every student query throughout the session.

    """

    def __init__(self, index: FAISS, groq_client: Groq) -> None:
        self._index  = index
        self._client = groq_client
        logger.info("RAGPipeline ready βœ“")

    # ── Public ────────────────────────────────────────────────────────────────

    def query(self, user_query: str) -> RAGResponse:
        """

        Run the full RAG pipeline for a single student question.



        Parameters

        ----------

        user_query : str

            Raw question text from the student.



        Returns

        -------

        RAGResponse

            Contains the answer string and the source Documents used.

        """
        validated = self._validate_query(user_query)
        if validated is None:
            return RAGResponse(
                answer="Please enter a valid question (non-empty, under 1000 characters).",
                sources=[],
                query=user_query,
            )

        logger.info("Processing query: '%s'", validated[:80])

        # Step 1 β€” Retrieve
        context_docs = vs_module.retrieve(self._index, validated, k=cfg.top_k)

        # Step 2 β€” Generate
        answer = llm_module.generate_answer(self._client, validated, context_docs)

        return RAGResponse(answer=answer, sources=context_docs, query=validated)

    # ── Internal ──────────────────────────────────────────────────────────────

    @staticmethod
    def _validate_query(query: str) -> str | None:
        """Return the stripped query if valid, else None."""
        stripped = query.strip()
        if not stripped or len(stripped) > MAX_QUERY_LENGTH:
            return None
        return stripped


# ── Factory function ─────────────────────────────────────────────────────────

def build_pipeline() -> RAGPipeline:
    """

    Convenience factory: load data, build index, init LLM, return pipeline.

    Import and call this once from app.py.

    """
    from data_loader import load_documents  # local import avoids circular deps

    logger.info("=== Building AstroBot RAG Pipeline ===")

    docs       = load_documents()
    index      = vs_module.build_index(docs)
    client     = llm_module.create_client()
    pipeline   = RAGPipeline(index=index, groq_client=client)

    logger.info("=== AstroBot Pipeline Ready βœ“ ===")
    return pipeline