ohaiyo123 commited on
Commit
1a4320d
·
verified ·
1 Parent(s): 80acdb5

Create rag_chain.py

Browse files
Files changed (1) hide show
  1. rag_chain.py +163 -0
rag_chain.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import RetrievalQA
2
+ from langchain.llms import LlamaCpp
3
+ from retriever import load_db
4
+ from huggingface_hub import hf_hub_download
5
+ from langchain.document_loaders import PyPDFLoader , DirectoryLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter , CharacterTextSplitter
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.vectorstores import FAISS
9
+ from langchain_community.llms import LlamaCpp
10
+ from langchain.chains import LLMChain
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
13
+ from langchain.prompts.chat import (
14
+ ChatPromptTemplate,
15
+ SystemMessagePromptTemplate,
16
+ HumanMessagePromptTemplate
17
+ )
18
+ from langchain.chains import RetrievalQA
19
+
20
+ # Tải model GGUF từ Hugging Face Hub
21
+ MODEL_PATH = hf_hub_download(
22
+ repo_id="ohaiyo123/SEG_Llama2Lora",
23
+ filename="llama 2 7b hf chat_Lora.gguf", # chính xác với tên file bạn đã upload
24
+ cache_dir="model_cache" # nơi lưu tạm trong container
25
+ )
26
+
27
+ # Khởi tạo LLaMA local
28
+ llm = LlamaCpp(
29
+ model_path= MODEL_PATH,
30
+ n_gpu_layers= -1,
31
+ n_batch=512,
32
+ n_ctx=2048,
33
+ f16_kv=True,
34
+ temperature=0.01,
35
+ callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
36
+ verbose=False,
37
+ )
38
+
39
+ def create_db_from_text():
40
+ raw_text = """Ngân hàng là một tổ chức tài chính cung cấp các dịch vụ như gửi tiền, cho vay, chuyển khoản và thanh toán. Tại Việt Nam, các ngân hàng thương mại đóng vai trò quan trọng trong việc hỗ trợ doanh nghiệp và cá nhân tiếp cận nguồn vốn.
41
+ Một số ngân hàng lớn bao gồm Vietcombank, BIDV, VietinBank và Techcombank."""
42
+
43
+ text_splitter = CharacterTextSplitter(
44
+ separator = "\n",
45
+ chunk_size=500,
46
+ chunk_overlap=50,
47
+ length_function=len
48
+ )
49
+ chunks = text_splitter.split_text(raw_text)
50
+ #embeding
51
+ embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
52
+ db = FAISS.from_texts(chunks, embbeding_model)
53
+ db.save_local(vector_db_path)
54
+ return db
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+ def create_db_from_file():
63
+ loader = DirectoryLoader(dpf_data_path,
64
+ glob="*.pdf",
65
+ loader_cls=PyPDFLoader)
66
+ documents = loader.load()
67
+
68
+ text_splitter = RecursiveCharacterTextSplitter(
69
+ separators=[
70
+ "\n\n",
71
+ "\n",
72
+ " ",
73
+ ".",
74
+ ",",
75
+ "\u200b", # Zero-width space
76
+ "\uff0c", # Fullwidth comma
77
+ "\u3001", # Ideographic comma
78
+ "\uff0e", # Fullwidth full stop
79
+ "\u3002", # Ideographic full stop
80
+ "",
81
+ ],
82
+ chunk_size=500,
83
+ chunk_overlap=50
84
+ )
85
+ chunks = text_splitter.split_documents(documents)
86
+ #embbeding
87
+ embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
88
+ db = FAISS.from_documents(chunks,embbeding_model)
89
+ db.save_local(vector_db_path)
90
+ return db
91
+
92
+
93
+
94
+
95
+ #load llm
96
+ def load_llm(model_file):
97
+
98
+ llm = LlamaCpp(
99
+ model_path= model_file,
100
+ n_gpu_layers= -1,
101
+ n_batch=512,
102
+ n_ctx=2048,
103
+ f16_kv=True,
104
+ temperature=0.01,
105
+ callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
106
+ verbose=True,
107
+
108
+ )
109
+ return llm
110
+
111
+
112
+
113
+
114
+
115
+ #cấu trúc prompt
116
+ def create_prompt():
117
+ system_tm = SystemMessagePromptTemplate.from_template(
118
+ "Sử dụng thông tin sau đây để trả lời câu hỏi.\n"
119
+ "Nếu bạn không biết câu trả lời thì hãy nói rằng bạn không biết, đừng cố tạo ra câu trả lời\n\n"
120
+ "{context}"
121
+ )
122
+ human_tm = HumanMessagePromptTemplate.from_template("{question}")
123
+ return ChatPromptTemplate.from_messages([system_tm, human_tm])
124
+
125
+
126
+
127
+
128
+
129
+
130
+ def create_qna_chain(llm,db):
131
+ prompt = create_prompt()
132
+ llm_chain = RetrievalQA.from_chain_type(llm=llm,
133
+ chain_type="stuff",
134
+ retriever=db.as_retriever(search_kwargs={"k":3}),
135
+ return_source_documents=True,
136
+ chain_type_kwargs={"prompt":prompt}
137
+ )
138
+ return llm_chain
139
+
140
+
141
+
142
+
143
+
144
+ #read from vector_data_base
145
+ def read_vectors_db():
146
+ #embbeding
147
+ embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
148
+ db = FAISS.load_local(vector_db_path,embbeding_model, allow_dangerous_deserialization=True)
149
+ return db
150
+
151
+
152
+
153
+
154
+ #test chain
155
+ #read vector db
156
+ db = read_vectors_db()
157
+ #load model
158
+ llm = load_llm(model_file)
159
+
160
+
161
+ # gop prompt vao llm
162
+
163
+ llm_chain = create_qna_chain(llm,db)