farmerbot / deploy /create_retriever.py
Nelly43's picture
Update app
0c6d13f
import os
import glob
from langchain_community.document_loaders import Docx2txtLoader, TextLoader
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.retrievers import EnsembleRetriever
from ragatouille import RAGPretrainedModel
# Function to load and process documents
def docs_return(flag):
directory_path = 'rag_data/'
docx_file_pattern = '*.docx'
txt_file_pattern = '*.txt'
docx_file_paths = glob.glob(directory_path + docx_file_pattern)
txt_file_paths = glob.glob(directory_path + txt_file_pattern)
all_doc, all_doc2 = [], []
for x in docx_file_paths:
loader = Docx2txtLoader(x)
documents = loader.load()
all_doc.extend(documents)
all_doc2.append(str(documents[0].page_content))
for x in txt_file_paths:
loader = TextLoader(x)
documents = loader.load()
all_doc.extend(documents)
all_doc2.append(str(documents[0].page_content))
docs = '\n\n'.join(all_doc2)
return all_doc if flag == 0 else docs
# Function to get or download the embedding model
def get_embedding_model(model_name):
local_model_path = f"embedding_model/{model_name.replace('/', '_')}"
if os.path.exists(local_model_path):
print(f"Loading local model from {local_model_path}")
return HuggingFaceEmbeddings(model_name=local_model_path)
else:
print(f"Downloading model {model_name}")
return HuggingFaceEmbeddings(model_name=model_name)
# Function to return different types of text splitters
def get_text_splitter(splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000):
if splitter_type == 'character':
return CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator)
elif splitter_type == 'recursive':
return RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
elif splitter_type == 'token':
return TokenTextSplitter(chunk_size=max_tokens, chunk_overlap=chunk_overlap)
else:
raise ValueError("Unsupported splitter type. Choose from 'character', 'recursive', or 'token'.")
# Retriever using Chroma and HuggingFace embeddings
def retriever_chroma(flag, model_name="BAAI/bge-large-en-v1.5", splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000):
# Load or download the embedding model
embeddings = get_embedding_model(model_name)
if not flag:
# Load the documents
all_doc = docs_return(0)
# Use the splitter parameters
text_splitter = get_text_splitter(splitter_type=splitter_type, chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator, max_tokens=max_tokens)
# Split the documents using the text splitter
docs = text_splitter.split_documents(documents=all_doc)
# Create a Chroma vector database
vectordb = Chroma.from_documents(docs, embeddings, persist_directory="./chroma_db")
# Create the retriever
chroma_retriever = vectordb.as_retriever(
search_type="mmr", search_kwargs={"k": 4, "fetch_k": 10}
)
return chroma_retriever
else:
# Load a local Chroma vectorstore
vectordb = Chroma.load_local("vectorstore", embeddings)
chroma_retriever = vectordb.as_retriever(
search_type="mmr", search_kwargs={"k": 4, "fetch_k": 10}
)
return chroma_retriever
# ColBERT retriever
def colbert_retriever():
docs = docs_return(1)
RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
RAG.index(
collection=[docs],
index_name="ensemble_colbert",
max_document_length=256,
split_documents=True,
)
retriever = RAG.as_langchain_retriever(k=3)
return retriever
# Ensemble retriever
def ensemble_retriever(model_name="BAAI/bge-large-en-v1.5", splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000):
retriever1 = colbert_retriever()
retriever2 = retriever_chroma(False, model_name=model_name, splitter_type=splitter_type, chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator, max_tokens=max_tokens)
retriever = EnsembleRetriever(retrievers=[retriever1, retriever2], weights=[0.50, 0.50])
return retriever
# Example usage:
# dat = ensemble_retriever(model_name="sentence-transformers/all-mpnet-base-v2", splitter_type='token', chunk_size=500)
# data = dat.invoke("What is KUET?")
# context = ""
# for x in data[:2]:
# context += (x.page_content) + "\n"