Spaces:
Sleeping
Sleeping
| from src.utils import load_config | |
| from dotenv import load_dotenv | |
| from src.utils import get_pdf_from_url | |
| from src.preprocess import Preprocessor | |
| from src.embedding import EmbeddingModel | |
| from src.utils import extract_text_from_pdf | |
| from langchain.vectorstores import Chroma | |
| from llm.answer_generator import GroqAnswerGenerator | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from llm.query_refiner import QueryRefiner | |
| class ChatPipeline: | |
| def __init__(self, arxiv_id:str=None): | |
| self.arxiv_id = None | |
| self.config = load_config() | |
| self.chatbot_config = load_config("./configs/llm_producer.yaml") | |
| self.chunks = None | |
| self.retriever = None | |
| def _preprocess_docs(self, docs): | |
| """ | |
| Preprocess the input text using the Preprocessor class. | |
| Args: | |
| text (str): The text to preprocess. | |
| Returns: | |
| str: The preprocessed text. | |
| """ | |
| if not docs: | |
| raise ValueError("No documents provided for preprocessing.") | |
| if not isinstance(docs, list): | |
| raise TypeError("Expected a list of documents for preprocessing.") | |
| if not all(hasattr(doc, 'page_content') for doc in docs): | |
| raise ValueError("All documents must have a 'page_content' attribute.") | |
| preprocessor = Preprocessor() | |
| for i, doc in enumerate(docs): | |
| doc.page_content = preprocessor(doc.page_content) | |
| return docs | |
| def _create_chunks(self, docs): | |
| """ | |
| Create chunks from the preprocessed documents. | |
| Args: | |
| docs (list): List of preprocessed documents. | |
| Returns: | |
| list: List of document chunks. | |
| """ | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=self.config["text_splitter"]["chunk_size"], | |
| chunk_overlap=self.config["text_splitter"]["chunk_overlap"] | |
| ) | |
| return text_splitter.split_documents(docs) | |
| def _create_vector_store(self, chunks): | |
| """ | |
| Create a vector store from the document chunks. | |
| Args: | |
| chunks (list): List of document chunks. | |
| Returns: | |
| VectorStore: The created vector store. | |
| """ | |
| embedding_model = EmbeddingModel(model_type=self.config['embedding']['model_type'], | |
| model_name=self.config['embedding']['model_name']) | |
| vector_store = Chroma.from_documents( | |
| documents=chunks, | |
| embedding=embedding_model.model, | |
| persist_directory=self.config['vector_db']['path'] | |
| ) | |
| vector_store.persist() | |
| self.retriever = vector_store.as_retriever(search_kwargs=self.config['vector_db']['search_kwargs']) | |
| def setup(self, arxiv_id:str): | |
| """ | |
| Setup the pipeline by loading necessary configurations and resources. | |
| """ | |
| self.arxiv_id = arxiv_id | |
| if not self.arxiv_id: | |
| raise ValueError("arxiv_id must be provided to setup the pipeline.") | |
| self.query_refiner = QueryRefiner() | |
| get_pdf_from_url(self.arxiv_id, self.config['storage']['save_pdf_path']) | |
| documents = extract_text_from_pdf(f"{self.config['storage']['save_pdf_path']}/{self.arxiv_id}.pdf") | |
| preprocessed_docs = self._preprocess_docs(documents) | |
| self.chunks = self._create_chunks(preprocessed_docs) | |
| self._create_vector_store(self.chunks) | |
| self.chatbot = GroqAnswerGenerator( | |
| model_name=self.chatbot_config['model_name'], | |
| temperature=self.chatbot_config['temperature'], | |
| max_tokens=self.chatbot_config['max_tokens'], | |
| retriever=self.retriever | |
| ) | |
| def setup_from_pdf(self, pdf_path: str): | |
| """ | |
| Setup the pipeline using a local PDF file. | |
| """ | |
| if not pdf_path: | |
| raise ValueError("pdf_path must be provided to setup the pipeline.") | |
| self.query_refiner = QueryRefiner() | |
| documents = extract_text_from_pdf(pdf_path) | |
| preprocessed_docs = self._preprocess_docs(documents) | |
| self.chunks = self._create_chunks(preprocessed_docs) | |
| self._create_vector_store(self.chunks) | |
| self.chatbot = GroqAnswerGenerator( | |
| model_name=self.chatbot_config['model_name'], | |
| temperature=self.chatbot_config['temperature'], | |
| max_tokens=self.chatbot_config['max_tokens'], | |
| retriever=self.retriever | |
| ) | |
| def query(self, prompt: str, refine_query: bool = True): | |
| """ | |
| Query the chatbot with a prompt. | |
| Args: | |
| prompt (str): The prompt to query the chatbot with. | |
| Returns: | |
| str: The response from the chatbot. | |
| """ | |
| if not self.chatbot: | |
| raise ValueError("Chatbot is not initialized. Call setup() method first.") | |
| if refine_query: | |
| refined_query = self.query_refiner.refine(prompt) | |
| return self.chatbot.generate_answer(refined_query) | |
| else: | |
| return self.chatbot.generate_answer(prompt) | |