Spaces:
Sleeping
Sleeping
File size: 5,137 Bytes
9c37331 e712e61 9c37331 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | from src.utils import load_config
from dotenv import load_dotenv
from src.utils import get_pdf_from_url
from src.preprocess import Preprocessor
from src.embedding import EmbeddingModel
from src.utils import extract_text_from_pdf
from langchain.vectorstores import Chroma
from llm.answer_generator import GroqAnswerGenerator
from langchain.text_splitter import RecursiveCharacterTextSplitter
from llm.query_refiner import QueryRefiner
class ChatPipeline:
def __init__(self, arxiv_id:str=None):
self.arxiv_id = None
self.config = load_config()
self.chatbot_config = load_config("./configs/llm_producer.yaml")
self.chunks = None
self.retriever = None
def _preprocess_docs(self, docs):
"""
Preprocess the input text using the Preprocessor class.
Args:
text (str): The text to preprocess.
Returns:
str: The preprocessed text.
"""
if not docs:
raise ValueError("No documents provided for preprocessing.")
if not isinstance(docs, list):
raise TypeError("Expected a list of documents for preprocessing.")
if not all(hasattr(doc, 'page_content') for doc in docs):
raise ValueError("All documents must have a 'page_content' attribute.")
preprocessor = Preprocessor()
for i, doc in enumerate(docs):
doc.page_content = preprocessor(doc.page_content)
return docs
def _create_chunks(self, docs):
"""
Create chunks from the preprocessed documents.
Args:
docs (list): List of preprocessed documents.
Returns:
list: List of document chunks.
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.config["text_splitter"]["chunk_size"],
chunk_overlap=self.config["text_splitter"]["chunk_overlap"]
)
return text_splitter.split_documents(docs)
def _create_vector_store(self, chunks):
"""
Create a vector store from the document chunks.
Args:
chunks (list): List of document chunks.
Returns:
VectorStore: The created vector store.
"""
embedding_model = EmbeddingModel(model_type=self.config['embedding']['model_type'],
model_name=self.config['embedding']['model_name'])
vector_store = Chroma.from_documents(
documents=chunks,
embedding=embedding_model.model,
persist_directory=self.config['vector_db']['path']
)
vector_store.persist()
self.retriever = vector_store.as_retriever(search_kwargs=self.config['vector_db']['search_kwargs'])
def setup(self, arxiv_id:str):
"""
Setup the pipeline by loading necessary configurations and resources.
"""
self.arxiv_id = arxiv_id
if not self.arxiv_id:
raise ValueError("arxiv_id must be provided to setup the pipeline.")
self.query_refiner = QueryRefiner()
get_pdf_from_url(self.arxiv_id, self.config['storage']['save_pdf_path'])
documents = extract_text_from_pdf(f"{self.config['storage']['save_pdf_path']}/{self.arxiv_id}.pdf")
preprocessed_docs = self._preprocess_docs(documents)
self.chunks = self._create_chunks(preprocessed_docs)
self._create_vector_store(self.chunks)
self.chatbot = GroqAnswerGenerator(
model_name=self.chatbot_config['model_name'],
temperature=self.chatbot_config['temperature'],
max_tokens=self.chatbot_config['max_tokens'],
retriever=self.retriever
)
def setup_from_pdf(self, pdf_path: str):
"""
Setup the pipeline using a local PDF file.
"""
if not pdf_path:
raise ValueError("pdf_path must be provided to setup the pipeline.")
self.query_refiner = QueryRefiner()
documents = extract_text_from_pdf(pdf_path)
preprocessed_docs = self._preprocess_docs(documents)
self.chunks = self._create_chunks(preprocessed_docs)
self._create_vector_store(self.chunks)
self.chatbot = GroqAnswerGenerator(
model_name=self.chatbot_config['model_name'],
temperature=self.chatbot_config['temperature'],
max_tokens=self.chatbot_config['max_tokens'],
retriever=self.retriever
)
def query(self, prompt: str, refine_query: bool = True):
"""
Query the chatbot with a prompt.
Args:
prompt (str): The prompt to query the chatbot with.
Returns:
str: The response from the chatbot.
"""
if not self.chatbot:
raise ValueError("Chatbot is not initialized. Call setup() method first.")
if refine_query:
refined_query = self.query_refiner.refine(prompt)
return self.chatbot.generate_answer(refined_query)
else:
return self.chatbot.generate_answer(prompt)
|