RAG_SEG / rag_chain.py
ohaiyo123's picture
Create rag_chain.py
1a4320d verified
from langchain.chains import RetrievalQA
from langchain.llms import LlamaCpp
from retriever import load_db
from huggingface_hub import hf_hub_download
from langchain.document_loaders import PyPDFLoader , DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter , CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain_community.llms import LlamaCpp
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate
)
from langchain.chains import RetrievalQA
# Tải model GGUF từ Hugging Face Hub
MODEL_PATH = hf_hub_download(
repo_id="ohaiyo123/SEG_Llama2Lora",
filename="llama 2 7b hf chat_Lora.gguf", # chính xác với tên file bạn đã upload
cache_dir="model_cache" # nơi lưu tạm trong container
)
# Khởi tạo LLaMA local
llm = LlamaCpp(
model_path= MODEL_PATH,
n_gpu_layers= -1,
n_batch=512,
n_ctx=2048,
f16_kv=True,
temperature=0.01,
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
verbose=False,
)
def create_db_from_text():
raw_text = """Ngân hàng là một tổ chức tài chính cung cấp các dịch vụ như gửi tiền, cho vay, chuyển khoản và thanh toán. Tại Việt Nam, các ngân hàng thương mại đóng vai trò quan trọng trong việc hỗ trợ doanh nghiệp và cá nhân tiếp cận nguồn vốn.
Một số ngân hàng lớn bao gồm Vietcombank, BIDV, VietinBank và Techcombank."""
text_splitter = CharacterTextSplitter(
separator = "\n",
chunk_size=500,
chunk_overlap=50,
length_function=len
)
chunks = text_splitter.split_text(raw_text)
#embeding
embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
db = FAISS.from_texts(chunks, embbeding_model)
db.save_local(vector_db_path)
return db
def create_db_from_file():
loader = DirectoryLoader(dpf_data_path,
glob="*.pdf",
loader_cls=PyPDFLoader)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
separators=[
"\n\n",
"\n",
" ",
".",
",",
"\u200b", # Zero-width space
"\uff0c", # Fullwidth comma
"\u3001", # Ideographic comma
"\uff0e", # Fullwidth full stop
"\u3002", # Ideographic full stop
"",
],
chunk_size=500,
chunk_overlap=50
)
chunks = text_splitter.split_documents(documents)
#embbeding
embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
db = FAISS.from_documents(chunks,embbeding_model)
db.save_local(vector_db_path)
return db
#load llm
def load_llm(model_file):
llm = LlamaCpp(
model_path= model_file,
n_gpu_layers= -1,
n_batch=512,
n_ctx=2048,
f16_kv=True,
temperature=0.01,
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
verbose=True,
)
return llm
#cấu trúc prompt
def create_prompt():
system_tm = SystemMessagePromptTemplate.from_template(
"Sử dụng thông tin sau đây để trả lời câu hỏi.\n"
"Nếu bạn không biết câu trả lời thì hãy nói rằng bạn không biết, đừng cố tạo ra câu trả lời\n\n"
"{context}"
)
human_tm = HumanMessagePromptTemplate.from_template("{question}")
return ChatPromptTemplate.from_messages([system_tm, human_tm])
def create_qna_chain(llm,db):
prompt = create_prompt()
llm_chain = RetrievalQA.from_chain_type(llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={"k":3}),
return_source_documents=True,
chain_type_kwargs={"prompt":prompt}
)
return llm_chain
#read from vector_data_base
def read_vectors_db():
#embbeding
embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
db = FAISS.load_local(vector_db_path,embbeding_model, allow_dangerous_deserialization=True)
return db
#test chain
#read vector db
db = read_vectors_db()
#load model
llm = load_llm(model_file)
# gop prompt vao llm
llm_chain = create_qna_chain(llm,db)