File size: 5,137 Bytes
9c37331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e712e61
 
9c37331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from src.utils import load_config
from dotenv import load_dotenv
from src.utils import get_pdf_from_url
from src.preprocess import Preprocessor
from src.embedding import EmbeddingModel
from src.utils import extract_text_from_pdf
from langchain.vectorstores import Chroma
from llm.answer_generator import GroqAnswerGenerator
from langchain.text_splitter import RecursiveCharacterTextSplitter
from llm.query_refiner import QueryRefiner


class ChatPipeline:
    def __init__(self, arxiv_id:str=None):

        self.arxiv_id = None
        self.config = load_config()
        self.chatbot_config = load_config("./configs/llm_producer.yaml")
        self.chunks = None
        self.retriever = None

    def _preprocess_docs(self, docs):
        """
        Preprocess the input text using the Preprocessor class.

        Args:
            text (str): The text to preprocess.

        Returns:
            str: The preprocessed text.
        """
        if not docs:
            raise ValueError("No documents provided for preprocessing.")
        if not isinstance(docs, list):
            raise TypeError("Expected a list of documents for preprocessing.")
        if not all(hasattr(doc, 'page_content') for doc in docs):
            raise ValueError("All documents must have a 'page_content' attribute.")
        

        preprocessor = Preprocessor()

        for i, doc in enumerate(docs):
            doc.page_content = preprocessor(doc.page_content)
        return docs
    
    def _create_chunks(self, docs):
        """
        Create chunks from the preprocessed documents.

        Args:
            docs (list): List of preprocessed documents.

        Returns:
            list: List of document chunks.
        """

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.config["text_splitter"]["chunk_size"], 
            chunk_overlap=self.config["text_splitter"]["chunk_overlap"]
        )

        return text_splitter.split_documents(docs)
    
    def _create_vector_store(self, chunks):
        """
        Create a vector store from the document chunks.

        Args:
            chunks (list): List of document chunks.

        Returns:
            VectorStore: The created vector store.
        """
        embedding_model = EmbeddingModel(model_type=self.config['embedding']['model_type'],
                                         model_name=self.config['embedding']['model_name'])
        vector_store = Chroma.from_documents(
            documents=chunks,
            embedding=embedding_model.model,
            persist_directory=self.config['vector_db']['path']
        )
        vector_store.persist()
        self.retriever = vector_store.as_retriever(search_kwargs=self.config['vector_db']['search_kwargs'])
        
    def setup(self, arxiv_id:str):
        """
        Setup the pipeline by loading necessary configurations and resources.
        """
        self.arxiv_id = arxiv_id
        if not self.arxiv_id:
            raise ValueError("arxiv_id must be provided to setup the pipeline.")
        
        self.query_refiner = QueryRefiner()

        get_pdf_from_url(self.arxiv_id, self.config['storage']['save_pdf_path'])

        documents = extract_text_from_pdf(f"{self.config['storage']['save_pdf_path']}/{self.arxiv_id}.pdf")

        preprocessed_docs = self._preprocess_docs(documents)

        self.chunks = self._create_chunks(preprocessed_docs)

        self._create_vector_store(self.chunks)

        self.chatbot = GroqAnswerGenerator(
            model_name=self.chatbot_config['model_name'],
            temperature=self.chatbot_config['temperature'],
            max_tokens=self.chatbot_config['max_tokens'],
            retriever=self.retriever
        )
    
    def setup_from_pdf(self, pdf_path: str):
        """
        Setup the pipeline using a local PDF file.
        """
        if not pdf_path:
            raise ValueError("pdf_path must be provided to setup the pipeline.")

        self.query_refiner = QueryRefiner()

        documents = extract_text_from_pdf(pdf_path)

        preprocessed_docs = self._preprocess_docs(documents)

        self.chunks = self._create_chunks(preprocessed_docs)

        self._create_vector_store(self.chunks)

        self.chatbot = GroqAnswerGenerator(
            model_name=self.chatbot_config['model_name'],
            temperature=self.chatbot_config['temperature'],
            max_tokens=self.chatbot_config['max_tokens'],
            retriever=self.retriever
        )
        
    def query(self, prompt: str, refine_query: bool = True):
        """
        Query the chatbot with a prompt.

        Args:
            prompt (str): The prompt to query the chatbot with.

        Returns:
            str: The response from the chatbot.
        """
        if not self.chatbot:
            raise ValueError("Chatbot is not initialized. Call setup() method first.")
        
        if refine_query:
            refined_query = self.query_refiner.refine(prompt)
            return self.chatbot.generate_answer(refined_query)
        else:
            return self.chatbot.generate_answer(prompt)