Pudding48 commited on
Commit
d29e1bc
·
verified ·
1 Parent(s): 08b53ed

Update qabot.py

Browse files
Files changed (1) hide show
  1. qabot.py +60 -58
qabot.py CHANGED
@@ -1,58 +1,60 @@
1
- from langchain_community.llms import CTransformers
2
- from langchain.prompts import PromptTemplate
3
- from langchain_core.runnables import RunnableSequence
4
- from langchain.chains import RetrievalQA
5
- from langchain_community.embeddings import GPT4AllEmbeddings
6
- from langchain_community.vectorstores import FAISS
7
-
8
- # Cấu hình
9
- model_file = "model/tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
10
- vector_dp_path = "vectorstores/db_faiss"
11
-
12
- # Load LLM
13
- def load_llm(model_file):
14
- llm = CTransformers(
15
- model=model_file,
16
- model_type="llama",
17
- temperature=0.01,
18
- config={'gpu_layers': 0},
19
- max_new_tokens=128,
20
- context_length=512
21
- )
22
- return llm
23
-
24
- # Tạo prompt template
25
- def creat_prompt(template):
26
- prompt = PromptTemplate(template=template, input_variables=["context","question"])
27
- return prompt
28
-
29
- # Tạo pipeline chain (thay cho LLMChain)
30
- def create_qa_chain(prompt, llm, db):
31
- llm_chain = RetrievalQA.from_chain_type(
32
- llm = llm,
33
- chain_type = "stuff",
34
- retriever =db.as_retriever(search_kwargs = {"k":1}),
35
- return_source_documents = False,
36
- chain_type_kwargs={'prompt':prompt}
37
- )
38
- return llm_chain
39
-
40
- def read_vector_db():
41
- embedding_model = GPT4AllEmbeddings(model_file = "model/all-minilm-l6-v2-q4_0.gguf")
42
- db = FAISS.load_local(vector_dp_path, embedding_model,allow_dangerous_deserialization=True)
43
- return db
44
-
45
- db = read_vector_db()
46
- llm = load_llm(model_file)
47
- # Mẫu prompt
48
- template = """<|im_start|>system\nSử dụng thông tin sau đây để trả lời câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời\n
49
- {context}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant"""
50
-
51
- # Khởi tạo các thành phần
52
- prompt = creat_prompt(template)
53
- llm_chain =create_qa_chain(prompt, llm, db)
54
-
55
- # Chạy thử chain
56
- question = "Khoa công nghệ thông tin thành lập năm nào ?"
57
- response = llm_chain.invoke({"query": question})
58
- print(response)
 
 
 
1
+ from langchain_community.llms import CTransformers
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain_core.runnables import RunnableSequence
4
+ from langchain.chains import RetrievalQA
5
+ from langchain_community.embeddings import GPT4AllEmbeddings
6
+ from langchain_community.vectorstores import FAISS
7
+
8
+ # Cấu hình
9
+ model_file = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
10
+ vector_dp_path = "vectorstores/db_faiss"
11
+
12
+ os.makedirs(vector_dp_path, exist_ok=True)
13
+
14
+ # Load LLM
15
+ def load_llm(model_file):
16
+ llm = CTransformers(
17
+ model=model_file,
18
+ model_type="llama",
19
+ temperature=0.01,
20
+ config={'gpu_layers': 0},
21
+ max_new_tokens=128,
22
+ context_length=512
23
+ )
24
+ return llm
25
+
26
+ # Tạo prompt template
27
+ def creat_prompt(template):
28
+ prompt = PromptTemplate(template=template, input_variables=["context","question"])
29
+ return prompt
30
+
31
+ # Tạo pipeline chain (thay cho LLMChain)
32
+ def create_qa_chain(prompt, llm, db):
33
+ llm_chain = RetrievalQA.from_chain_type(
34
+ llm = llm,
35
+ chain_type = "stuff",
36
+ retriever =db.as_retriever(search_kwargs = {"k":1}),
37
+ return_source_documents = False,
38
+ chain_type_kwargs={'prompt':prompt}
39
+ )
40
+ return llm_chain
41
+
42
+ def read_vector_db():
43
+ embedding_model = GPT4AllEmbeddings(model_file = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf")
44
+ db = FAISS.load_local(vector_dp_path, embedding_model,allow_dangerous_deserialization=True)
45
+ return db
46
+
47
+ db = read_vector_db()
48
+ llm = load_llm(model_file)
49
+ # Mẫu prompt
50
+ template = """<|im_start|>system\nSử dụng thông tin sau đây để trả lời câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời\n
51
+ {context}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant"""
52
+
53
+ # Khởi tạo các thành phần
54
+ prompt = creat_prompt(template)
55
+ llm_chain =create_qa_chain(prompt, llm, db)
56
+
57
+ # Chạy thử chain
58
+ question = "Khoa công nghệ thông tin thành lập năm nào ?"
59
+ response = llm_chain.invoke({"query": question})
60
+ print(response)