Create rag_chain.py
Browse files- rag_chain.py +163 -0
rag_chain.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.chains import RetrievalQA
|
| 2 |
+
from langchain.llms import LlamaCpp
|
| 3 |
+
from retriever import load_db
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
from langchain.document_loaders import PyPDFLoader , DirectoryLoader
|
| 6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter , CharacterTextSplitter
|
| 7 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 8 |
+
from langchain.vectorstores import FAISS
|
| 9 |
+
from langchain_community.llms import LlamaCpp
|
| 10 |
+
from langchain.chains import LLMChain
|
| 11 |
+
from langchain.prompts import PromptTemplate
|
| 12 |
+
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
|
| 13 |
+
from langchain.prompts.chat import (
|
| 14 |
+
ChatPromptTemplate,
|
| 15 |
+
SystemMessagePromptTemplate,
|
| 16 |
+
HumanMessagePromptTemplate
|
| 17 |
+
)
|
| 18 |
+
from langchain.chains import RetrievalQA
|
| 19 |
+
|
| 20 |
+
# Tải model GGUF từ Hugging Face Hub
|
| 21 |
+
MODEL_PATH = hf_hub_download(
|
| 22 |
+
repo_id="ohaiyo123/SEG_Llama2Lora",
|
| 23 |
+
filename="llama 2 7b hf chat_Lora.gguf", # chính xác với tên file bạn đã upload
|
| 24 |
+
cache_dir="model_cache" # nơi lưu tạm trong container
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Khởi tạo LLaMA local
|
| 28 |
+
llm = LlamaCpp(
|
| 29 |
+
model_path= MODEL_PATH,
|
| 30 |
+
n_gpu_layers= -1,
|
| 31 |
+
n_batch=512,
|
| 32 |
+
n_ctx=2048,
|
| 33 |
+
f16_kv=True,
|
| 34 |
+
temperature=0.01,
|
| 35 |
+
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
|
| 36 |
+
verbose=False,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def create_db_from_text():
|
| 40 |
+
raw_text = """Ngân hàng là một tổ chức tài chính cung cấp các dịch vụ như gửi tiền, cho vay, chuyển khoản và thanh toán. Tại Việt Nam, các ngân hàng thương mại đóng vai trò quan trọng trong việc hỗ trợ doanh nghiệp và cá nhân tiếp cận nguồn vốn.
|
| 41 |
+
Một số ngân hàng lớn bao gồm Vietcombank, BIDV, VietinBank và Techcombank."""
|
| 42 |
+
|
| 43 |
+
text_splitter = CharacterTextSplitter(
|
| 44 |
+
separator = "\n",
|
| 45 |
+
chunk_size=500,
|
| 46 |
+
chunk_overlap=50,
|
| 47 |
+
length_function=len
|
| 48 |
+
)
|
| 49 |
+
chunks = text_splitter.split_text(raw_text)
|
| 50 |
+
#embeding
|
| 51 |
+
embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
|
| 52 |
+
db = FAISS.from_texts(chunks, embbeding_model)
|
| 53 |
+
db.save_local(vector_db_path)
|
| 54 |
+
return db
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_db_from_file():
|
| 63 |
+
loader = DirectoryLoader(dpf_data_path,
|
| 64 |
+
glob="*.pdf",
|
| 65 |
+
loader_cls=PyPDFLoader)
|
| 66 |
+
documents = loader.load()
|
| 67 |
+
|
| 68 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 69 |
+
separators=[
|
| 70 |
+
"\n\n",
|
| 71 |
+
"\n",
|
| 72 |
+
" ",
|
| 73 |
+
".",
|
| 74 |
+
",",
|
| 75 |
+
"\u200b", # Zero-width space
|
| 76 |
+
"\uff0c", # Fullwidth comma
|
| 77 |
+
"\u3001", # Ideographic comma
|
| 78 |
+
"\uff0e", # Fullwidth full stop
|
| 79 |
+
"\u3002", # Ideographic full stop
|
| 80 |
+
"",
|
| 81 |
+
],
|
| 82 |
+
chunk_size=500,
|
| 83 |
+
chunk_overlap=50
|
| 84 |
+
)
|
| 85 |
+
chunks = text_splitter.split_documents(documents)
|
| 86 |
+
#embbeding
|
| 87 |
+
embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
|
| 88 |
+
db = FAISS.from_documents(chunks,embbeding_model)
|
| 89 |
+
db.save_local(vector_db_path)
|
| 90 |
+
return db
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
#load llm
|
| 96 |
+
def load_llm(model_file):
|
| 97 |
+
|
| 98 |
+
llm = LlamaCpp(
|
| 99 |
+
model_path= model_file,
|
| 100 |
+
n_gpu_layers= -1,
|
| 101 |
+
n_batch=512,
|
| 102 |
+
n_ctx=2048,
|
| 103 |
+
f16_kv=True,
|
| 104 |
+
temperature=0.01,
|
| 105 |
+
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
|
| 106 |
+
verbose=True,
|
| 107 |
+
|
| 108 |
+
)
|
| 109 |
+
return llm
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
#cấu trúc prompt
|
| 116 |
+
def create_prompt():
|
| 117 |
+
system_tm = SystemMessagePromptTemplate.from_template(
|
| 118 |
+
"Sử dụng thông tin sau đây để trả lời câu hỏi.\n"
|
| 119 |
+
"Nếu bạn không biết câu trả lời thì hãy nói rằng bạn không biết, đừng cố tạo ra câu trả lời\n\n"
|
| 120 |
+
"{context}"
|
| 121 |
+
)
|
| 122 |
+
human_tm = HumanMessagePromptTemplate.from_template("{question}")
|
| 123 |
+
return ChatPromptTemplate.from_messages([system_tm, human_tm])
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def create_qna_chain(llm,db):
|
| 131 |
+
prompt = create_prompt()
|
| 132 |
+
llm_chain = RetrievalQA.from_chain_type(llm=llm,
|
| 133 |
+
chain_type="stuff",
|
| 134 |
+
retriever=db.as_retriever(search_kwargs={"k":3}),
|
| 135 |
+
return_source_documents=True,
|
| 136 |
+
chain_type_kwargs={"prompt":prompt}
|
| 137 |
+
)
|
| 138 |
+
return llm_chain
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
#read from vector_data_base
|
| 145 |
+
def read_vectors_db():
|
| 146 |
+
#embbeding
|
| 147 |
+
embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
|
| 148 |
+
db = FAISS.load_local(vector_db_path,embbeding_model, allow_dangerous_deserialization=True)
|
| 149 |
+
return db
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
#test chain
|
| 155 |
+
#read vector db
|
| 156 |
+
db = read_vectors_db()
|
| 157 |
+
#load model
|
| 158 |
+
llm = load_llm(model_file)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# gop prompt vao llm
|
| 162 |
+
|
| 163 |
+
llm_chain = create_qna_chain(llm,db)
|