| | |
| |
|
| | import threading |
| | from toolbox import Singleton |
| | import os |
| | import shutil |
| | import os |
| | import uuid |
| | import tqdm |
| | from langchain.vectorstores import FAISS |
| | from langchain.docstore.document import Document |
| | from typing import List, Tuple |
| | import numpy as np |
| | from crazy_functions.vector_fns.general_file_loader import load_file |
| |
|
| | embedding_model_dict = { |
| | "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", |
| | "ernie-base": "nghuyong/ernie-3.0-base-zh", |
| | "text2vec-base": "shibing624/text2vec-base-chinese", |
| | "text2vec": "GanymedeNil/text2vec-large-chinese", |
| | } |
| |
|
| | |
| | EMBEDDING_MODEL = "text2vec" |
| |
|
| | |
| | EMBEDDING_DEVICE = "cpu" |
| |
|
| | |
| | PROMPT_TEMPLATE = """已知信息: |
| | {context} |
| | |
| | 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}""" |
| |
|
| | |
| | SENTENCE_SIZE = 100 |
| |
|
| | |
| | CHUNK_SIZE = 250 |
| |
|
| | |
| | LLM_HISTORY_LEN = 3 |
| |
|
| | |
| | VECTOR_SEARCH_TOP_K = 5 |
| |
|
| | |
| | VECTOR_SEARCH_SCORE_THRESHOLD = 0 |
| |
|
| | NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") |
| |
|
| | FLAG_USER_NAME = uuid.uuid4().hex |
| |
|
| | |
| | |
| | OPEN_CROSS_DOMAIN = False |
| |
|
| | def similarity_search_with_score_by_vector( |
| | self, embedding: List[float], k: int = 4 |
| | ) -> List[Tuple[Document, float]]: |
| | |
| | def seperate_list(ls: List[int]) -> List[List[int]]: |
| | lists = [] |
| | ls1 = [ls[0]] |
| | for i in range(1, len(ls)): |
| | if ls[i - 1] + 1 == ls[i]: |
| | ls1.append(ls[i]) |
| | else: |
| | lists.append(ls1) |
| | ls1 = [ls[i]] |
| | lists.append(ls1) |
| | return lists |
| |
|
| | scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) |
| | docs = [] |
| | id_set = set() |
| | store_len = len(self.index_to_docstore_id) |
| | for j, i in enumerate(indices[0]): |
| | if i == -1 or 0 < self.score_threshold < scores[0][j]: |
| | |
| | continue |
| | _id = self.index_to_docstore_id[i] |
| | doc = self.docstore.search(_id) |
| | if not self.chunk_conent: |
| | if not isinstance(doc, Document): |
| | raise ValueError(f"Could not find document for id {_id}, got {doc}") |
| | doc.metadata["score"] = int(scores[0][j]) |
| | docs.append(doc) |
| | continue |
| | id_set.add(i) |
| | docs_len = len(doc.page_content) |
| | for k in range(1, max(i, store_len - i)): |
| | break_flag = False |
| | for l in [i + k, i - k]: |
| | if 0 <= l < len(self.index_to_docstore_id): |
| | _id0 = self.index_to_docstore_id[l] |
| | doc0 = self.docstore.search(_id0) |
| | if docs_len + len(doc0.page_content) > self.chunk_size: |
| | break_flag = True |
| | break |
| | elif doc0.metadata["source"] == doc.metadata["source"]: |
| | docs_len += len(doc0.page_content) |
| | id_set.add(l) |
| | if break_flag: |
| | break |
| | if not self.chunk_conent: |
| | return docs |
| | if len(id_set) == 0 and self.score_threshold > 0: |
| | return [] |
| | id_list = sorted(list(id_set)) |
| | id_lists = seperate_list(id_list) |
| | for id_seq in id_lists: |
| | for id in id_seq: |
| | if id == id_seq[0]: |
| | _id = self.index_to_docstore_id[id] |
| | doc = self.docstore.search(_id) |
| | else: |
| | _id0 = self.index_to_docstore_id[id] |
| | doc0 = self.docstore.search(_id0) |
| | doc.page_content += " " + doc0.page_content |
| | if not isinstance(doc, Document): |
| | raise ValueError(f"Could not find document for id {_id}, got {doc}") |
| | doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]]) |
| | doc.metadata["score"] = int(doc_score) |
| | docs.append(doc) |
| | return docs |
| |
|
| |
|
| | class LocalDocQA: |
| | llm: object = None |
| | embeddings: object = None |
| | top_k: int = VECTOR_SEARCH_TOP_K |
| | chunk_size: int = CHUNK_SIZE |
| | chunk_conent: bool = True |
| | score_threshold: int = VECTOR_SEARCH_SCORE_THRESHOLD |
| |
|
| | def init_cfg(self, |
| | top_k=VECTOR_SEARCH_TOP_K, |
| | ): |
| |
|
| | self.llm = None |
| | self.top_k = top_k |
| |
|
| | def init_knowledge_vector_store(self, |
| | filepath, |
| | vs_path: str or os.PathLike = None, |
| | sentence_size=SENTENCE_SIZE, |
| | text2vec=None): |
| | loaded_files = [] |
| | failed_files = [] |
| | if isinstance(filepath, str): |
| | if not os.path.exists(filepath): |
| | print("路径不存在") |
| | return None |
| | elif os.path.isfile(filepath): |
| | file = os.path.split(filepath)[-1] |
| | try: |
| | docs = load_file(filepath, SENTENCE_SIZE) |
| | print(f"{file} 已成功加载") |
| | loaded_files.append(filepath) |
| | except Exception as e: |
| | print(e) |
| | print(f"{file} 未能成功加载") |
| | return None |
| | elif os.path.isdir(filepath): |
| | docs = [] |
| | for file in tqdm(os.listdir(filepath), desc="加载文件"): |
| | fullfilepath = os.path.join(filepath, file) |
| | try: |
| | docs += load_file(fullfilepath, SENTENCE_SIZE) |
| | loaded_files.append(fullfilepath) |
| | except Exception as e: |
| | print(e) |
| | failed_files.append(file) |
| |
|
| | if len(failed_files) > 0: |
| | print("以下文件未能成功加载:") |
| | for file in failed_files: |
| | print(f"{file}\n") |
| |
|
| | else: |
| | docs = [] |
| | for file in filepath: |
| | docs += load_file(file, SENTENCE_SIZE) |
| | print(f"{file} 已成功加载") |
| | loaded_files.append(file) |
| |
|
| | if len(docs) > 0: |
| | print("文件加载完毕,正在生成向量库") |
| | if vs_path and os.path.isdir(vs_path): |
| | try: |
| | self.vector_store = FAISS.load_local(vs_path, text2vec) |
| | self.vector_store.add_documents(docs) |
| | except: |
| | self.vector_store = FAISS.from_documents(docs, text2vec) |
| | else: |
| | self.vector_store = FAISS.from_documents(docs, text2vec) |
| |
|
| | self.vector_store.save_local(vs_path) |
| | return vs_path, loaded_files |
| | else: |
| | raise RuntimeError("文件加载失败,请检查文件格式是否正确") |
| | |
| | def get_loaded_file(self, vs_path): |
| | ds = self.vector_store.docstore |
| | return set([ds._dict[k].metadata['source'].split(vs_path)[-1] for k in ds._dict]) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | def get_knowledge_based_conent_test(self, query, vs_path, chunk_conent, |
| | score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, |
| | vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_size=CHUNK_SIZE, |
| | text2vec=None): |
| | self.vector_store = FAISS.load_local(vs_path, text2vec) |
| | self.vector_store.chunk_conent = chunk_conent |
| | self.vector_store.score_threshold = score_threshold |
| | self.vector_store.chunk_size = chunk_size |
| |
|
| | embedding = self.vector_store.embedding_function.embed_query(query) |
| | related_docs_with_score = similarity_search_with_score_by_vector(self.vector_store, embedding, k=vector_search_top_k) |
| |
|
| | if not related_docs_with_score: |
| | response = {"query": query, |
| | "source_documents": []} |
| | return response, "" |
| | |
| | prompt = f"{query}. 你必须利用以下文档中包含的信息回答这个问题: \n\n---\n\n" |
| | prompt += "\n\n".join([f"({k}): " + doc.page_content for k, doc in enumerate(related_docs_with_score)]) |
| | prompt += "\n\n---\n\n" |
| | prompt = prompt.encode('utf-8', 'ignore').decode() |
| | |
| | response = {"query": query, "source_documents": related_docs_with_score} |
| | return response, prompt |
| |
|
| |
|
| |
|
| |
|
| | def construct_vector_store(vs_id, vs_path, files, sentence_size, history, one_conent, one_content_segmentation, text2vec): |
| | for file in files: |
| | assert os.path.exists(file), "输入文件不存在:" + file |
| | import nltk |
| | if NLTK_DATA_PATH not in nltk.data.path: nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path |
| | local_doc_qa = LocalDocQA() |
| | local_doc_qa.init_cfg() |
| | filelist = [] |
| | if not os.path.exists(os.path.join(vs_path, vs_id)): |
| | os.makedirs(os.path.join(vs_path, vs_id)) |
| | for file in files: |
| | file_name = file.name if not isinstance(file, str) else file |
| | filename = os.path.split(file_name)[-1] |
| | shutil.copyfile(file_name, os.path.join(vs_path, vs_id, filename)) |
| | filelist.append(os.path.join(vs_path, vs_id, filename)) |
| | vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, os.path.join(vs_path, vs_id), sentence_size, text2vec) |
| |
|
| | if len(loaded_files): |
| | file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" |
| | else: |
| | pass |
| | |
| | |
| | return local_doc_qa, vs_path |
| |
|
| | @Singleton |
| | class knowledge_archive_interface(): |
| | def __init__(self) -> None: |
| | self.threadLock = threading.Lock() |
| | self.current_id = "" |
| | self.kai_path = None |
| | self.qa_handle = None |
| | self.text2vec_large_chinese = None |
| |
|
| | def get_chinese_text2vec(self): |
| | if self.text2vec_large_chinese is None: |
| | |
| | from toolbox import ProxyNetworkActivate |
| | print('Checking Text2vec ...') |
| | from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
| | with ProxyNetworkActivate('Download_LLM'): |
| | self.text2vec_large_chinese = HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese") |
| |
|
| | return self.text2vec_large_chinese |
| |
|
| |
|
| | def feed_archive(self, file_manifest, vs_path, id="default"): |
| | self.threadLock.acquire() |
| | |
| | self.current_id = id |
| | self.qa_handle, self.kai_path = construct_vector_store( |
| | vs_id=self.current_id, |
| | vs_path=vs_path, |
| | files=file_manifest, |
| | sentence_size=100, |
| | history=[], |
| | one_conent="", |
| | one_content_segmentation="", |
| | text2vec = self.get_chinese_text2vec(), |
| | ) |
| | self.threadLock.release() |
| |
|
| | def get_current_archive_id(self): |
| | return self.current_id |
| | |
| | def get_loaded_file(self, vs_path): |
| | return self.qa_handle.get_loaded_file(vs_path) |
| |
|
| | def answer_with_archive_by_id(self, txt, id, vs_path): |
| | self.threadLock.acquire() |
| | if not self.current_id == id: |
| | self.current_id = id |
| | self.qa_handle, self.kai_path = construct_vector_store( |
| | vs_id=self.current_id, |
| | vs_path=vs_path, |
| | files=[], |
| | sentence_size=100, |
| | history=[], |
| | one_conent="", |
| | one_content_segmentation="", |
| | text2vec = self.get_chinese_text2vec(), |
| | ) |
| | VECTOR_SEARCH_SCORE_THRESHOLD = 0 |
| | VECTOR_SEARCH_TOP_K = 4 |
| | CHUNK_SIZE = 512 |
| | resp, prompt = self.qa_handle.get_knowledge_based_conent_test( |
| | query = txt, |
| | vs_path = self.kai_path, |
| | score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, |
| | vector_search_top_k=VECTOR_SEARCH_TOP_K, |
| | chunk_conent=True, |
| | chunk_size=CHUNK_SIZE, |
| | text2vec = self.get_chinese_text2vec(), |
| | ) |
| | self.threadLock.release() |
| | return resp, prompt |