ShynBui commited on
Commit
7d023d2
·
verified ·
1 Parent(s): 7323674

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +23 -94
utils.py CHANGED
@@ -1,108 +1,37 @@
1
- from langchain_community.document_loaders import TextLoader
2
- from langchain_community.docstore.document import Document
3
- from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
4
- from langchain_community.vectorstores import Chroma
5
- from langchain_community.embeddings import HuggingFaceEmbeddings
6
- from langchain_community.retrievers import BM25Retriever
7
- from langchain.llms import OpenAI
8
- from langchain_openai import ChatOpenAI
9
- from langchain.chains import RetrievalQA
10
  import os
 
 
 
 
 
 
 
11
 
12
- def split_with_source(text, source):
13
- splitter = CharacterTextSplitter(
14
- separator = "\n",
15
- chunk_size = 256,
16
- chunk_overlap = 0,
17
- length_function = len,
18
- add_start_index = True,
19
- )
20
- documents = splitter.create_documents([text])
21
- print(documents)
22
- for doc in documents:
23
- doc.metadata["source"] = source
24
- # print(doc.metadata)
25
 
26
- return documents
 
 
27
 
28
 
29
- def count_files_in_folder(folder_path):
30
- # Kiểm tra xem đường dẫn thư mục có tồn tại không
31
- if not os.path.isdir(folder_path):
32
- print("Đường dẫn không hợp lệ.")
33
- return None
34
 
35
- # Sử dụng os.listdir() để lấy danh sách các tập tin và thư mục trong thư mục
36
- files = os.listdir(folder_path)
37
 
38
- # Đếm số lượng tập tin trong danh sách
39
- file_count = len(files)
40
-
41
- return file_count
42
-
43
- def get_document_from_raw_text():
44
- documents = [Document(page_content="", metadata={'source': 0})]
45
- files = os.listdir(os.path.join(os.getcwd(), "raw_data"))
46
- # print(files)
47
- for i in files:
48
- file_path = i
49
- with open(os.path.join(os.path.join(os.getcwd(), "raw_data"),file_path), 'r', encoding="utf-8") as file:
50
- # Xử lý bằng text_spliter
51
- # Tiền xử lý văn bản
52
- content = file.read().replace('\n\n', "\n")
53
- # content = ''.join(content.split('.'))
54
- new_doc = content
55
- texts = split_with_source(new_doc, i)
56
- documents = documents + texts
57
-
58
- ##Xử lý mỗi khi xuống dòng
59
- # for line in file:
60
- # # Loại bỏ khoảng trắng thừa và ký tự xuống dòng ở đầu và cuối mỗi dòng
61
- # line = line.strip()
62
- # documents.append(Document(page_content=line, metadata={"source": i}))
63
- print(documents)
64
- return documents
65
-
66
- def load_the_embedding_retrieve(is_ready = False, k = 3, model= 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
67
- embeddings = HuggingFaceEmbeddings(model_name=model)
68
- if is_ready:
69
- retriever = Chroma(persist_directory=os.path.join(os.getcwd(), "Data"), embedding_function=embeddings).as_retriever(
70
- search_kwargs={"k": k}
71
- )
72
- else:
73
- documents = get_document_from_raw_text()
74
- print(type(documents))
75
- retriever = Chroma.from_documents(documents, embeddings).as_retriever(
76
- search_kwargs={"k": k}
77
- )
78
-
79
-
80
- return retriever
81
-
82
- def load_the_bm25_retrieve(k = 3):
83
- documents = get_document_from_raw_text()
84
- bm25_retriever = BM25Retriever.from_documents(documents)
85
- bm25_retriever.k = k
86
-
87
- return bm25_retriever
88
-
89
- def get_qachain(llm_name = "gpt-3.5-turbo-0125", chain_type = "stuff", retriever = None, return_source_documents = True):
90
- llm = ChatOpenAI(temperature=0,
91
- model_name=llm_name)
92
- return RetrievalQA.from_chain_type(llm=llm,
93
- chain_type=chain_type,
94
- retriever=retriever,
95
- return_source_documents=return_source_documents)
96
-
97
- def process_llm_response(llm_response):
98
- print(llm_response['result'])
99
- print('\n\nSources:')
100
- for source in llm_response["source_documents"]:
101
- print(source.metadata['source'])
102
 
 
103
 
 
104
 
 
105
 
 
106
 
107
 
 
 
108
 
 
 
 
1
+ import gradio as gr
 
 
 
 
 
 
 
 
2
  import os
3
+ from langchain.retrievers import EnsembleRetriever
4
+ from utils import *
5
+ import requests
6
+ from pyvi import ViTokenizer, ViPosTagger
7
+ import time
8
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
9
+ import torch
10
 
11
+ retriever = load_the_embedding_retrieve(is_ready=False, k=3)
12
+ bm25_retriever = load_the_bm25_retrieve(k=3)
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ ensemble_retriever = EnsembleRetriever(
15
+ retrievers=[bm25_retriever, retriever], weights=[0.5, 0.5]
16
+ )
17
 
18
 
 
 
 
 
 
19
 
20
+ def greet2(quote):
 
21
 
22
+ qa_chain = get_qachain(retriever=ensemble_retriever)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ prompt = os.environ['PROMPT']
25
 
26
+ qa_chain.combine_documents_chain.llm_chain.prompt.messages[0].prompt.template = prompt
27
 
28
+ llm_response = qa_chain(quote)
29
 
30
+ return llm_response['result']
31
 
32
 
33
+ if __name__ == "__main__":
34
+ quote = "Địa chỉ nhà trường?"
35
 
36
+ iface = gr.Interface(fn=greet2, inputs="text", outputs="text")
37
+ iface.launch()