ScholarBot / src /pipeline.py
vinny4's picture
added refiner in upload pdf
e712e61
from src.utils import load_config
from dotenv import load_dotenv
from src.utils import get_pdf_from_url
from src.preprocess import Preprocessor
from src.embedding import EmbeddingModel
from src.utils import extract_text_from_pdf
from langchain.vectorstores import Chroma
from llm.answer_generator import GroqAnswerGenerator
from langchain.text_splitter import RecursiveCharacterTextSplitter
from llm.query_refiner import QueryRefiner
class ChatPipeline:
def __init__(self, arxiv_id:str=None):
self.arxiv_id = None
self.config = load_config()
self.chatbot_config = load_config("./configs/llm_producer.yaml")
self.chunks = None
self.retriever = None
def _preprocess_docs(self, docs):
"""
Preprocess the input text using the Preprocessor class.
Args:
text (str): The text to preprocess.
Returns:
str: The preprocessed text.
"""
if not docs:
raise ValueError("No documents provided for preprocessing.")
if not isinstance(docs, list):
raise TypeError("Expected a list of documents for preprocessing.")
if not all(hasattr(doc, 'page_content') for doc in docs):
raise ValueError("All documents must have a 'page_content' attribute.")
preprocessor = Preprocessor()
for i, doc in enumerate(docs):
doc.page_content = preprocessor(doc.page_content)
return docs
def _create_chunks(self, docs):
"""
Create chunks from the preprocessed documents.
Args:
docs (list): List of preprocessed documents.
Returns:
list: List of document chunks.
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.config["text_splitter"]["chunk_size"],
chunk_overlap=self.config["text_splitter"]["chunk_overlap"]
)
return text_splitter.split_documents(docs)
def _create_vector_store(self, chunks):
"""
Create a vector store from the document chunks.
Args:
chunks (list): List of document chunks.
Returns:
VectorStore: The created vector store.
"""
embedding_model = EmbeddingModel(model_type=self.config['embedding']['model_type'],
model_name=self.config['embedding']['model_name'])
vector_store = Chroma.from_documents(
documents=chunks,
embedding=embedding_model.model,
persist_directory=self.config['vector_db']['path']
)
vector_store.persist()
self.retriever = vector_store.as_retriever(search_kwargs=self.config['vector_db']['search_kwargs'])
def setup(self, arxiv_id:str):
"""
Setup the pipeline by loading necessary configurations and resources.
"""
self.arxiv_id = arxiv_id
if not self.arxiv_id:
raise ValueError("arxiv_id must be provided to setup the pipeline.")
self.query_refiner = QueryRefiner()
get_pdf_from_url(self.arxiv_id, self.config['storage']['save_pdf_path'])
documents = extract_text_from_pdf(f"{self.config['storage']['save_pdf_path']}/{self.arxiv_id}.pdf")
preprocessed_docs = self._preprocess_docs(documents)
self.chunks = self._create_chunks(preprocessed_docs)
self._create_vector_store(self.chunks)
self.chatbot = GroqAnswerGenerator(
model_name=self.chatbot_config['model_name'],
temperature=self.chatbot_config['temperature'],
max_tokens=self.chatbot_config['max_tokens'],
retriever=self.retriever
)
def setup_from_pdf(self, pdf_path: str):
"""
Setup the pipeline using a local PDF file.
"""
if not pdf_path:
raise ValueError("pdf_path must be provided to setup the pipeline.")
self.query_refiner = QueryRefiner()
documents = extract_text_from_pdf(pdf_path)
preprocessed_docs = self._preprocess_docs(documents)
self.chunks = self._create_chunks(preprocessed_docs)
self._create_vector_store(self.chunks)
self.chatbot = GroqAnswerGenerator(
model_name=self.chatbot_config['model_name'],
temperature=self.chatbot_config['temperature'],
max_tokens=self.chatbot_config['max_tokens'],
retriever=self.retriever
)
def query(self, prompt: str, refine_query: bool = True):
"""
Query the chatbot with a prompt.
Args:
prompt (str): The prompt to query the chatbot with.
Returns:
str: The response from the chatbot.
"""
if not self.chatbot:
raise ValueError("Chatbot is not initialized. Call setup() method first.")
if refine_query:
refined_query = self.query_refiner.refine(prompt)
return self.chatbot.generate_answer(refined_query)
else:
return self.chatbot.generate_answer(prompt)