Spaces:
Sleeping
Sleeping
File size: 5,066 Bytes
721ca73 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
rag_pipeline.py
βββββββββββββββ
Orchestrates the full RAG pipeline: query β retrieve β generate β answer.
This module is the single integration point between the vector store and
the LLM. The UI layer (app.py) calls only this module; it knows nothing
about FAISS or Groq directly.
Pipeline steps
ββββββββββββββ
1. Validate query (non-empty, reasonable length)
2. Retrieve top-k relevant chunks from FAISS
3. Pass chunks + query to the LLM for grounded generation
4. Return the answer (and optionally the source snippets for transparency)
"""
import logging
from dataclasses import dataclass
from groq import Groq
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
import llm as llm_module
import vector_store as vs_module
from config import cfg
logger = logging.getLogger(__name__)
MAX_QUERY_LENGTH = 1000 # characters
# ββ Data classes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@dataclass
class RAGResponse:
answer: str
sources: list[Document]
query: str
def format_sources(self) -> str:
"""Return a compact source-citation string for display in the UI."""
if not self.sources:
return ""
lines = []
for i, doc in enumerate(self.sources, 1):
src = doc.metadata.get("source", "")
page = doc.metadata.get("page", "")
snippet = doc.page_content[:120].replace("\n", " ") + "β¦"
label = f"**[{i}]**"
if src:
label += f" {src}"
if page:
label += f" p.{page}"
lines.append(f"{label}: _{snippet}_")
return "\n".join(lines)
# ββ Pipeline class ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class RAGPipeline:
"""
Stateful pipeline object. Instantiated once at app startup and reused
for every student query throughout the session.
"""
def __init__(self, index: FAISS, groq_client: Groq) -> None:
self._index = index
self._client = groq_client
logger.info("RAGPipeline ready β")
# ββ Public ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def query(self, user_query: str) -> RAGResponse:
"""
Run the full RAG pipeline for a single student question.
Parameters
----------
user_query : str
Raw question text from the student.
Returns
-------
RAGResponse
Contains the answer string and the source Documents used.
"""
validated = self._validate_query(user_query)
if validated is None:
return RAGResponse(
answer="Please enter a valid question (non-empty, under 1000 characters).",
sources=[],
query=user_query,
)
logger.info("Processing query: '%s'", validated[:80])
# Step 1 β Retrieve
context_docs = vs_module.retrieve(self._index, validated, k=cfg.top_k)
# Step 2 β Generate
answer = llm_module.generate_answer(self._client, validated, context_docs)
return RAGResponse(answer=answer, sources=context_docs, query=validated)
# ββ Internal ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@staticmethod
def _validate_query(query: str) -> str | None:
"""Return the stripped query if valid, else None."""
stripped = query.strip()
if not stripped or len(stripped) > MAX_QUERY_LENGTH:
return None
return stripped
# ββ Factory function βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def build_pipeline() -> RAGPipeline:
"""
Convenience factory: load data, build index, init LLM, return pipeline.
Import and call this once from app.py.
"""
from data_loader import load_documents # local import avoids circular deps
logger.info("=== Building AstroBot RAG Pipeline ===")
docs = load_documents()
index = vs_module.build_index(docs)
client = llm_module.create_client()
pipeline = RAGPipeline(index=index, groq_client=client)
logger.info("=== AstroBot Pipeline Ready β ===")
return pipeline
|