from src.utils import load_config from dotenv import load_dotenv from src.utils import get_pdf_from_url from src.preprocess import Preprocessor from src.embedding import EmbeddingModel from src.utils import extract_text_from_pdf from langchain.vectorstores import Chroma from llm.answer_generator import GroqAnswerGenerator from langchain.text_splitter import RecursiveCharacterTextSplitter from llm.query_refiner import QueryRefiner class ChatPipeline: def __init__(self, arxiv_id:str=None): self.arxiv_id = None self.config = load_config() self.chatbot_config = load_config("./configs/llm_producer.yaml") self.chunks = None self.retriever = None def _preprocess_docs(self, docs): """ Preprocess the input text using the Preprocessor class. Args: text (str): The text to preprocess. Returns: str: The preprocessed text. """ if not docs: raise ValueError("No documents provided for preprocessing.") if not isinstance(docs, list): raise TypeError("Expected a list of documents for preprocessing.") if not all(hasattr(doc, 'page_content') for doc in docs): raise ValueError("All documents must have a 'page_content' attribute.") preprocessor = Preprocessor() for i, doc in enumerate(docs): doc.page_content = preprocessor(doc.page_content) return docs def _create_chunks(self, docs): """ Create chunks from the preprocessed documents. Args: docs (list): List of preprocessed documents. Returns: list: List of document chunks. """ text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.config["text_splitter"]["chunk_size"], chunk_overlap=self.config["text_splitter"]["chunk_overlap"] ) return text_splitter.split_documents(docs) def _create_vector_store(self, chunks): """ Create a vector store from the document chunks. Args: chunks (list): List of document chunks. Returns: VectorStore: The created vector store. """ embedding_model = EmbeddingModel(model_type=self.config['embedding']['model_type'], model_name=self.config['embedding']['model_name']) vector_store = Chroma.from_documents( documents=chunks, embedding=embedding_model.model, persist_directory=self.config['vector_db']['path'] ) vector_store.persist() self.retriever = vector_store.as_retriever(search_kwargs=self.config['vector_db']['search_kwargs']) def setup(self, arxiv_id:str): """ Setup the pipeline by loading necessary configurations and resources. """ self.arxiv_id = arxiv_id if not self.arxiv_id: raise ValueError("arxiv_id must be provided to setup the pipeline.") self.query_refiner = QueryRefiner() get_pdf_from_url(self.arxiv_id, self.config['storage']['save_pdf_path']) documents = extract_text_from_pdf(f"{self.config['storage']['save_pdf_path']}/{self.arxiv_id}.pdf") preprocessed_docs = self._preprocess_docs(documents) self.chunks = self._create_chunks(preprocessed_docs) self._create_vector_store(self.chunks) self.chatbot = GroqAnswerGenerator( model_name=self.chatbot_config['model_name'], temperature=self.chatbot_config['temperature'], max_tokens=self.chatbot_config['max_tokens'], retriever=self.retriever ) def setup_from_pdf(self, pdf_path: str): """ Setup the pipeline using a local PDF file. """ if not pdf_path: raise ValueError("pdf_path must be provided to setup the pipeline.") self.query_refiner = QueryRefiner() documents = extract_text_from_pdf(pdf_path) preprocessed_docs = self._preprocess_docs(documents) self.chunks = self._create_chunks(preprocessed_docs) self._create_vector_store(self.chunks) self.chatbot = GroqAnswerGenerator( model_name=self.chatbot_config['model_name'], temperature=self.chatbot_config['temperature'], max_tokens=self.chatbot_config['max_tokens'], retriever=self.retriever ) def query(self, prompt: str, refine_query: bool = True): """ Query the chatbot with a prompt. Args: prompt (str): The prompt to query the chatbot with. Returns: str: The response from the chatbot. """ if not self.chatbot: raise ValueError("Chatbot is not initialized. Call setup() method first.") if refine_query: refined_query = self.query_refiner.refine(prompt) return self.chatbot.generate_answer(refined_query) else: return self.chatbot.generate_answer(prompt)