Spaces:
Running
Running
| from typing import List, Dict, Any | |
| from langchain.schema import Document | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from project.source.data_preparation import DataPreparation | |
| from project.model.retriever import DocumentRetriever | |
| from project.model.reranking import DocumentReranker | |
| from project.utils.model_loader import ModelLoader | |
| from project.prompts.prompt_template import RAG_PROMPT | |
| from project.logger.logging import get_logger | |
| logger = get_logger(__name__) | |
| class RAGPipeline: | |
| def __init__(self, config_path: str = None): | |
| self.config_path = config_path | |
| self.model_loader = ModelLoader(config_path) | |
| self.llm = self.model_loader.load_llm() | |
| self.data_prep = DataPreparation() | |
| self.retriever_module = DocumentRetriever(config_path) | |
| self.reranker = DocumentReranker(config_path) | |
| self.chain = None | |
| self.retriever = None | |
| logger.info("RAGPipeline initialized") | |
| def setup(self, pdf_path: str = None, use_attention_paper: bool = True): | |
| chunks = self.data_prep.prepare_documents( | |
| pdf_path=pdf_path, | |
| use_attention_paper=use_attention_paper | |
| ) | |
| self.retriever_module.create_vectorstore(chunks) | |
| self.retriever = self.retriever_module.get_base_retriever() | |
| self._build_chain() | |
| logger.info("RAG pipeline setup complete") | |
| def _retrieve_and_rerank(self, query: str) -> List[Document]: | |
| retrieved_docs = self.retriever.invoke(query) | |
| reranked_docs = self.reranker.rerank(query, retrieved_docs) | |
| return reranked_docs | |
| def _format_docs(self, docs: List[Document]) -> str: | |
| return "\n\n".join([ | |
| f"Document {i+1}:\n{doc.page_content}" | |
| for i, doc in enumerate(docs) | |
| ]) | |
| def _build_chain(self): | |
| self.chain = ( | |
| { | |
| "context": lambda x: self._format_docs( | |
| self._retrieve_and_rerank(x["question"]) | |
| ), | |
| "question": lambda x: x["question"] | |
| } | |
| | RAG_PROMPT | |
| | self.llm | |
| | StrOutputParser() | |
| ) | |
| logger.info("RAG chain built successfully") | |
| def invoke(self, query: str) -> str: | |
| if self.chain is None: | |
| raise ValueError("Pipeline not setup. Call setup() first.") | |
| response = self.chain.invoke({"question": query}) | |
| logger.info(f"Query processed successfully") | |
| return response | |
| def get_retrieved_documents(self, query: str) -> List[Document]: | |
| if self.retriever is None: | |
| raise ValueError("Pipeline not setup. Call setup() first.") | |
| return self._retrieve_and_rerank(query) | |