DemoChatBot / rag_pipeline.py
OnlyTheTruth03's picture
Initial Commit
721ca73 verified
"""
rag_pipeline.py
───────────────
Orchestrates the full RAG pipeline: query β†’ retrieve β†’ generate β†’ answer.
This module is the single integration point between the vector store and
the LLM. The UI layer (app.py) calls only this module; it knows nothing
about FAISS or Groq directly.
Pipeline steps
──────────────
1. Validate query (non-empty, reasonable length)
2. Retrieve top-k relevant chunks from FAISS
3. Pass chunks + query to the LLM for grounded generation
4. Return the answer (and optionally the source snippets for transparency)
"""
import logging
from dataclasses import dataclass
from groq import Groq
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
import llm as llm_module
import vector_store as vs_module
from config import cfg
logger = logging.getLogger(__name__)
MAX_QUERY_LENGTH = 1000 # characters
# ── Data classes ──────────────────────────────────────────────────────────────
@dataclass
class RAGResponse:
answer: str
sources: list[Document]
query: str
def format_sources(self) -> str:
"""Return a compact source-citation string for display in the UI."""
if not self.sources:
return ""
lines = []
for i, doc in enumerate(self.sources, 1):
src = doc.metadata.get("source", "")
page = doc.metadata.get("page", "")
snippet = doc.page_content[:120].replace("\n", " ") + "…"
label = f"**[{i}]**"
if src:
label += f" {src}"
if page:
label += f" p.{page}"
lines.append(f"{label}: _{snippet}_")
return "\n".join(lines)
# ── Pipeline class ────────────────────────────────────────────────────────────
class RAGPipeline:
"""
Stateful pipeline object. Instantiated once at app startup and reused
for every student query throughout the session.
"""
def __init__(self, index: FAISS, groq_client: Groq) -> None:
self._index = index
self._client = groq_client
logger.info("RAGPipeline ready βœ“")
# ── Public ────────────────────────────────────────────────────────────────
def query(self, user_query: str) -> RAGResponse:
"""
Run the full RAG pipeline for a single student question.
Parameters
----------
user_query : str
Raw question text from the student.
Returns
-------
RAGResponse
Contains the answer string and the source Documents used.
"""
validated = self._validate_query(user_query)
if validated is None:
return RAGResponse(
answer="Please enter a valid question (non-empty, under 1000 characters).",
sources=[],
query=user_query,
)
logger.info("Processing query: '%s'", validated[:80])
# Step 1 β€” Retrieve
context_docs = vs_module.retrieve(self._index, validated, k=cfg.top_k)
# Step 2 β€” Generate
answer = llm_module.generate_answer(self._client, validated, context_docs)
return RAGResponse(answer=answer, sources=context_docs, query=validated)
# ── Internal ──────────────────────────────────────────────────────────────
@staticmethod
def _validate_query(query: str) -> str | None:
"""Return the stripped query if valid, else None."""
stripped = query.strip()
if not stripped or len(stripped) > MAX_QUERY_LENGTH:
return None
return stripped
# ── Factory function ─────────────────────────────────────────────────────────
def build_pipeline() -> RAGPipeline:
"""
Convenience factory: load data, build index, init LLM, return pipeline.
Import and call this once from app.py.
"""
from data_loader import load_documents # local import avoids circular deps
logger.info("=== Building AstroBot RAG Pipeline ===")
docs = load_documents()
index = vs_module.build_index(docs)
client = llm_module.create_client()
pipeline = RAGPipeline(index=index, groq_client=client)
logger.info("=== AstroBot Pipeline Ready βœ“ ===")
return pipeline