github-actions[bot]
Deploy from GitHub Actions
dfa6a46
from typing import List, Dict, Any
from langchain.schema import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from project.source.data_preparation import DataPreparation
from project.model.retriever import DocumentRetriever
from project.model.reranking import DocumentReranker
from project.utils.model_loader import ModelLoader
from project.prompts.prompt_template import RAG_PROMPT
from project.logger.logging import get_logger
logger = get_logger(__name__)
class RAGPipeline:
def __init__(self, config_path: str = None):
self.config_path = config_path
self.model_loader = ModelLoader(config_path)
self.llm = self.model_loader.load_llm()
self.data_prep = DataPreparation()
self.retriever_module = DocumentRetriever(config_path)
self.reranker = DocumentReranker(config_path)
self.chain = None
self.retriever = None
logger.info("RAGPipeline initialized")
def setup(self, pdf_path: str = None, use_attention_paper: bool = True):
chunks = self.data_prep.prepare_documents(
pdf_path=pdf_path,
use_attention_paper=use_attention_paper
)
self.retriever_module.create_vectorstore(chunks)
self.retriever = self.retriever_module.get_base_retriever()
self._build_chain()
logger.info("RAG pipeline setup complete")
def _retrieve_and_rerank(self, query: str) -> List[Document]:
retrieved_docs = self.retriever.invoke(query)
reranked_docs = self.reranker.rerank(query, retrieved_docs)
return reranked_docs
def _format_docs(self, docs: List[Document]) -> str:
return "\n\n".join([
f"Document {i+1}:\n{doc.page_content}"
for i, doc in enumerate(docs)
])
def _build_chain(self):
self.chain = (
{
"context": lambda x: self._format_docs(
self._retrieve_and_rerank(x["question"])
),
"question": lambda x: x["question"]
}
| RAG_PROMPT
| self.llm
| StrOutputParser()
)
logger.info("RAG chain built successfully")
def invoke(self, query: str) -> str:
if self.chain is None:
raise ValueError("Pipeline not setup. Call setup() first.")
response = self.chain.invoke({"question": query})
logger.info(f"Query processed successfully")
return response
def get_retrieved_documents(self, query: str) -> List[Document]:
if self.retriever is None:
raise ValueError("Pipeline not setup. Call setup() first.")
return self._retrieve_and_rerank(query)