| | import os |
| | import sys |
| | import re |
| | import json |
| | sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
| | import tempfile |
| | from dotenv import load_dotenv, find_dotenv |
| | from embedding.call_embedding import get_embedding |
| | from langchain.document_loaders import UnstructuredFileLoader |
| | from langchain.document_loaders import UnstructuredMarkdownLoader |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter |
| | from langchain.document_loaders import PyMuPDFLoader |
| | from langchain.document_loaders import UnstructuredWordDocumentLoader |
| | from langchain.vectorstores import Chroma |
| | from langchain.schema import Document |
| |
|
| | |
| | |
| |
|
| | |
| | CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") |
| | |
| | |
| | os.environ['HF_HOME'] = CACHE_DIR |
| |
|
| | |
| | |
| | DEFAULT_DB_PATH = "./knowledge_db/sanguo_characters" |
| | DEFAULT_PERSIST_PATH = "./vector_db/chroma_sanguo" |
| |
|
| |
|
| |
|
| |
|
| | class CharacterTextSplitter(TextSplitter): |
| | """专门用于处理角色JSON数据的文本分割器""" |
| | |
| | def split_text(self, text: str) -> list[str]: |
| | |
| | pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' |
| | matches = re.finditer(pattern, text) |
| | |
| | |
| | chunks = [] |
| | for match in matches: |
| | try: |
| | |
| | char_data = json.loads(match.group()) |
| | |
| | |
| | if 'name' not in char_data: |
| | print(f"警告:发现缺少name字段的JSON数据: {match.group()[:100]}...") |
| | continue |
| | |
| | |
| | if 'skills' in char_data: |
| | for skill in char_data['skills']: |
| | if 'stamina_cost' in skill: |
| | skill['endurance_cost'] = skill.pop('stamina_cost') |
| | |
| | |
| | char_text = f"角色:{char_data['name']}\n" |
| | char_text += f"攻击力:{char_data['attack']}\n" |
| | char_text += f"防御力:{char_data['defense']}\n" |
| | char_text += f"体力:{char_data['stamina']}\n" |
| | char_text += f"耐力:{char_data['endurance']}\n" |
| | char_text += f"法力:{char_data['mana']}\n" |
| | char_text += f"闪避:{char_data['dodge']}\n" |
| | char_text += f"速度:{char_data['speed']}\n" |
| | char_text += "技能:\n" |
| | for skill in char_data['skills']: |
| | char_text += f"- {skill['name']}:{skill['effect']}\n" |
| | if 'endurance_cost' in skill and 'mana_cost' in skill: |
| | char_text += f" 耐力消耗:{skill['endurance_cost']},法力消耗:{skill['mana_cost']}\n" |
| | chunks.append(char_text) |
| | except json.JSONDecodeError as e: |
| | print(f"JSON解析错误: {e}") |
| | print(f"问题数据: {match.group()[:100]}...") |
| | continue |
| | except KeyError as e: |
| | print(f"缺少字段: {e}") |
| | print(f"问题数据: {match.group()[:100]}...") |
| | continue |
| | return chunks |
| |
|
| | def split_documents(self, documents: list[Document]) -> list[Document]: |
| | """分割文档列表""" |
| | texts = [] |
| | metadatas = [] |
| | for doc in documents: |
| | texts.extend(self.split_text(doc.page_content)) |
| | metadatas.extend([doc.metadata] * len(self.split_text(doc.page_content))) |
| | return [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] |
| |
|
| |
|
| | def get_files(dir_path): |
| | file_list = [] |
| | for filepath, dirnames, filenames in os.walk(dir_path): |
| | for filename in filenames: |
| | file_list.append(os.path.join(filepath, filename)) |
| | return file_list |
| |
|
| |
|
| | def file_loader(file, loaders): |
| | if isinstance(file, tempfile._TemporaryFileWrapper): |
| | file = file.name |
| | if not os.path.isfile(file): |
| | [file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)] |
| | return |
| | file_type = file.split('.')[-1].lower() |
| | if file_type == 'pdf': |
| | loaders.append(PyMuPDFLoader(file)) |
| | elif file_type == 'md': |
| | pattern = r"不存在|风控" |
| | match = re.search(pattern, file) |
| | if not match: |
| | loaders.append(UnstructuredMarkdownLoader(file)) |
| | elif file_type == 'txt': |
| | loaders.append(UnstructuredFileLoader(file)) |
| | elif file_type == 'docx': |
| | loaders.append(UnstructuredWordDocumentLoader(file)) |
| | return |
| |
|
| |
|
| | def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH): |
| | if embeddings == 'openai' or embeddings == 'm3e' or embeddings =='zhipuai': |
| | vectordb = create_db(files, persist_directory, embeddings) |
| | return "" |
| |
|
| |
|
| | def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="openai"): |
| | """ |
| | 该函数用于加载文件,切分文档,生成文档的嵌入向量,创建向量数据库。 |
| | |
| | 参数: |
| | file: 存放文件的路径。 |
| | embeddings: 用于生产 Embedding 的模型 |
| | |
| | 返回: |
| | vectordb: 创建的数据库。 |
| | """ |
| | if files == None: |
| | return "can't load empty file" |
| | if type(files) != list: |
| | files = [files] |
| | |
| | print(f"正在处理文件路径: {files}") |
| | |
| | loaders = [] |
| | [file_loader(file, loaders) for file in files] |
| | print(f"找到的加载器数量: {len(loaders)}") |
| | |
| | docs = [] |
| | for loader in loaders: |
| | if loader is not None: |
| | loaded_docs = loader.load() |
| | print(f"\n加载的文档数量: {len(loaded_docs)}") |
| | |
| | if loaded_docs: |
| | print("\n文档内容示例:") |
| | print("-" * 50) |
| | print(loaded_docs[0].page_content[:500]) |
| | print("-" * 50) |
| | print("\n文档元数据:") |
| | print(loaded_docs[0].metadata) |
| | print("-" * 50) |
| | docs.extend(loaded_docs) |
| | |
| | print(f"\n总文档数量: {len(docs)}") |
| | |
| | if len(docs) == 0: |
| | print("警告:没有找到任何文档!") |
| | return None |
| | |
| | |
| | text_splitter = CharacterTextSplitter() |
| | split_docs = text_splitter.split_documents(docs) |
| | print(f"\n分割后的文档数量: {len(split_docs)}") |
| | |
| | if len(split_docs) == 0: |
| | print("警告:分割后没有文档!") |
| | return None |
| | |
| | |
| | split_docs_dir = os.path.join(os.path.dirname(persist_directory), "split_docs") |
| | os.makedirs(split_docs_dir, exist_ok=True) |
| | split_docs_file = os.path.join(split_docs_dir, "split_documents.txt") |
| | |
| | with open(split_docs_file, "w", encoding="utf-8") as f: |
| | for i, doc in enumerate(split_docs, 1): |
| | f.write(f"\n文档 {i}:\n") |
| | f.write("-" * 50 + "\n") |
| | f.write(doc.page_content) |
| | f.write("\n" + "-" * 50 + "\n") |
| | |
| | print(f"\n分割后的文档已保存到: {split_docs_file}") |
| | |
| | if type(embeddings) == str: |
| | embeddings = get_embedding(embedding=embeddings) |
| | |
| | vectordb = Chroma.from_documents( |
| | documents=split_docs, |
| | embedding=embeddings, |
| | persist_directory=persist_directory, |
| | collection_metadata={"hnsw:space": "cosine"} |
| | ) |
| |
|
| | vectordb.persist() |
| | return vectordb |
| |
|
| |
|
| | def presit_knowledge_db(vectordb): |
| | """ |
| | 该函数用于持久化向量数据库。 |
| | |
| | 参数: |
| | vectordb: 要持久化的向量数据库。 |
| | """ |
| | vectordb.persist() |
| |
|
| |
|
| | def load_knowledge_db(path, embeddings): |
| | """ |
| | 该函数用于加载向量数据库。 |
| | |
| | 参数: |
| | path: 要加载的向量数据库路径。 |
| | embeddings: 向量数据库使用的 embedding 模型。 |
| | |
| | 返回: |
| | vectordb: 加载的数据库。 |
| | """ |
| | vectordb = Chroma( |
| | persist_directory=path, |
| | embedding_function=embeddings |
| | ) |
| | return vectordb |
| |
|
| |
|
| | if __name__ == "__main__": |
| | create_db(embeddings="m3e") |
| |
|