Spaces:
Configuration error
Configuration error
Upload 38 files
Browse files- .gitattributes +2 -0
- Chroma_db/readme.txt +0 -0
- Config/__pycache__/config.cpython-310.pyc +0 -0
- Config/config.py +14 -0
- Faiss_db/readme.txt +0 -0
- Faiss_db/sss1/index.faiss +0 -0
- Faiss_db/sss1/index.pkl +3 -0
- Neo4j/__pycache__/graph_extract.cpython-310.pyc +0 -0
- Neo4j/__pycache__/neo4j_op.cpython-310.pyc +0 -0
- Neo4j/graph_extract.py +69 -0
- Neo4j/neo4j_op.py +105 -0
- Ollama_api/__pycache__/ollama_api.cpython-310.pyc +0 -0
- Ollama_api/ollama_api.py +21 -0
- embeding/__pycache__/asr_utils.cpython-310.pyc +0 -0
- embeding/__pycache__/chromadb.cpython-310.pyc +0 -0
- embeding/__pycache__/elasticsearchStore.cpython-310.pyc +0 -0
- embeding/__pycache__/faissdb.cpython-310.pyc +0 -0
- embeding/asr_utils.py +17 -0
- embeding/chromadb.py +134 -0
- embeding/elasticsearchStore.py +147 -0
- embeding/faissdb.py +138 -0
- embeding/tmp.txt +2 -0
- img/graph-tool.png +3 -0
- img/readme.txt +1 -0
- img/zhu.png +3 -0
- img/zhuye.png +0 -0
- img//345/244/215/346/235/202/346/226/271/345/274/217.png +0 -0
- img//345/276/256/344/277/241/345/233/276/347/211/207_20240524180648.jpg +0 -0
- rag/__init__.py +0 -0
- rag/__pycache__/__init__.cpython-310.pyc +0 -0
- rag/__pycache__/config.cpython-310.pyc +0 -0
- rag/__pycache__/rag_class.cpython-310.pyc +0 -0
- rag/__pycache__/rerank.cpython-310.pyc +0 -0
- rag/__pycache__/rerank.cpython-39.pyc +0 -0
- rag/__pycache__/rerank_code.cpython-310.pyc +0 -0
- rag/rag_class.py +169 -0
- rag/rerank_code.py +21 -0
- test/__init__.py +0 -0
- test/graph2neo4j.py +25 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
img/graph-tool.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
img/zhu.png filter=lfs diff=lfs merge=lfs -text
|
Chroma_db/readme.txt
ADDED
|
File without changes
|
Config/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (362 Bytes). View file
|
|
|
Config/config.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 向量数据库选择 【chroma:1】 ,【faiss 2】,【ElasticsearchStore 3】
|
| 2 |
+
VECTOR_DB = 2
|
| 3 |
+
DB_directory = "./Chroma_db/"
|
| 4 |
+
if VECTOR_DB==2:
|
| 5 |
+
DB_directory ="./Faiss_db/"
|
| 6 |
+
elif VECTOR_DB==3:
|
| 7 |
+
DB_directory = "es"
|
| 8 |
+
|
| 9 |
+
# 配置neo4j
|
| 10 |
+
neo4j_host = "bolt://localhost:7687"
|
| 11 |
+
neo4j_name = "neo4j"
|
| 12 |
+
neo4j_pwd = "12345678"
|
| 13 |
+
# 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
|
| 14 |
+
neo4j_model = "qwen2:7b"
|
Faiss_db/readme.txt
ADDED
|
File without changes
|
Faiss_db/sss1/index.faiss
ADDED
|
Binary file (82 kB). View file
|
|
|
Faiss_db/sss1/index.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2bb588f4bd46218f42b045c42163bdcf3cc76a19e37458823ceaeaf8a1454e3b
|
| 3 |
+
size 9362
|
Neo4j/__pycache__/graph_extract.cpython-310.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
Neo4j/__pycache__/neo4j_op.cpython-310.pyc
ADDED
|
Binary file (3.89 kB). View file
|
|
|
Neo4j/graph_extract.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.llms import Ollama
|
| 2 |
+
from Config.config import neo4j_model
|
| 3 |
+
|
| 4 |
+
# 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
|
| 5 |
+
llm = Ollama(model=neo4j_model)
|
| 6 |
+
|
| 7 |
+
json_example = {'edges': [
|
| 8 |
+
{
|
| 9 |
+
'label': 'label 1',
|
| 10 |
+
'source': 'source 1',
|
| 11 |
+
'target': 'target 1'},
|
| 12 |
+
{
|
| 13 |
+
'label': 'label 1',
|
| 14 |
+
'source': 'source 1',
|
| 15 |
+
'target': 'target 1'}
|
| 16 |
+
],
|
| 17 |
+
'nodes': [{'name': 'label 1'},
|
| 18 |
+
{'name': 'label 2'},
|
| 19 |
+
{'name': 'label 3'}]
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
__retriever_prompt = f"""
|
| 23 |
+
您是一名专门从事知识图谱创建的人工智能专家,目标是根据给定的输入或请求捕获关系。
|
| 24 |
+
基于各种形式的用户输入,如段落、电子邮件、文本文件等。
|
| 25 |
+
你的任务是根据输入创建一个知识图谱。
|
| 26 |
+
nodes中每个元素只有一个name参数,name对应的值是一个实体,实体来自输入的词语或短语。
|
| 27 |
+
edges还必须有一个label参数,其中label是输入中的直接词语或短语,edges中的source和target取自nodes中的name。
|
| 28 |
+
|
| 29 |
+
仅使用JSON进行响应,其格式可以在python中进行jsonify,并直接输入cy.add(data),
|
| 30 |
+
您可以参考给定的示例:{json_example}。存储node和edge的数组中,最后一个元素后边不要有逗号,
|
| 31 |
+
确保边的目标和源与现有节点匹配。
|
| 32 |
+
不要在JSON的上方和下方包含markdown三引号,直接用花括号括起来。
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def generate_graph_info(raw_text: str) -> str | None:
|
| 37 |
+
"""
|
| 38 |
+
generate graph info from raw text
|
| 39 |
+
:param raw_text:
|
| 40 |
+
:return:
|
| 41 |
+
"""
|
| 42 |
+
messages = [
|
| 43 |
+
{"role": "system", "content": "你现在扮演信息抽取的角色,要求根据用户输入和AI的回答,正确提取出信息,记得不多对实体进行翻译。"},
|
| 44 |
+
{"role": "user", "content": raw_text},
|
| 45 |
+
{"role": "user", "content": __retriever_prompt}
|
| 46 |
+
]
|
| 47 |
+
print("解析中....")
|
| 48 |
+
for i in range(3):
|
| 49 |
+
graph_info_result = llm.invoke(messages)
|
| 50 |
+
if len(graph_info_result) < 10:
|
| 51 |
+
print("-------", i, "-------------------")
|
| 52 |
+
continue
|
| 53 |
+
else:
|
| 54 |
+
break
|
| 55 |
+
print(graph_info_result)
|
| 56 |
+
return graph_info_result
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def update_graph(raw_text):
|
| 60 |
+
# raw_text = request.json.get('text', '')
|
| 61 |
+
try:
|
| 62 |
+
result = generate_graph_info(raw_text)
|
| 63 |
+
if '```' in result:
|
| 64 |
+
graph_data = eval(result.split('```', 2)[1].replace("json", ''))
|
| 65 |
+
else:
|
| 66 |
+
graph_data = eval(str(result))
|
| 67 |
+
return graph_data
|
| 68 |
+
except Exception as e:
|
| 69 |
+
return {'error': f"Error parsing graph data: {str(e)}"}
|
Neo4j/neo4j_op.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from py2neo import Graph, Node, Relationship
|
| 2 |
+
from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
|
| 3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class KnowledgeGraph:
|
| 8 |
+
def __init__(self, uri, user, password):
|
| 9 |
+
self.graph = Graph(uri, auth=(user, password))
|
| 10 |
+
|
| 11 |
+
def parse_data(self,file):
|
| 12 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
| 13 |
+
try:
|
| 14 |
+
loaders = UnstructuredCSVLoader(file)
|
| 15 |
+
data = loaders.load()
|
| 16 |
+
except:
|
| 17 |
+
loaders = TextLoader(file,encoding="utf-8")
|
| 18 |
+
data = loaders.load()
|
| 19 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
| 20 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
| 21 |
+
data = loaders.load()
|
| 22 |
+
if "pdf" in file.lower():
|
| 23 |
+
loaders = UnstructuredPDFLoader(file)
|
| 24 |
+
data = loaders.load()
|
| 25 |
+
if ".xlsx" in file.lower():
|
| 26 |
+
loaders = UnstructuredExcelLoader(file)
|
| 27 |
+
data = loaders.load()
|
| 28 |
+
if ".md" in file.lower():
|
| 29 |
+
loaders = UnstructuredMarkdownLoader(file)
|
| 30 |
+
data = loaders.load()
|
| 31 |
+
return data
|
| 32 |
+
|
| 33 |
+
# 切分 数据
|
| 34 |
+
def split_files(self, files,chunk_size=500, chunk_overlap=100):
|
| 35 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 36 |
+
print("开始创建数据库 ....")
|
| 37 |
+
tmps = []
|
| 38 |
+
for file in files:
|
| 39 |
+
data = self.parse_data(file)
|
| 40 |
+
tmps.extend(data)
|
| 41 |
+
|
| 42 |
+
splits = text_splitter.split_documents(tmps)
|
| 43 |
+
|
| 44 |
+
return splits
|
| 45 |
+
|
| 46 |
+
def create_node(self, label, properties):
|
| 47 |
+
matcher = self.graph.nodes.match(label, **properties)
|
| 48 |
+
if matcher.first():
|
| 49 |
+
return matcher.first()
|
| 50 |
+
else:
|
| 51 |
+
node = Node(label, **properties)
|
| 52 |
+
self.graph.create(node)
|
| 53 |
+
return node
|
| 54 |
+
|
| 55 |
+
def create_relationship(self, label1, properties1, label2, properties2, relationship_type,
|
| 56 |
+
relationship_properties={}):
|
| 57 |
+
node1 = self.create_node(label1, properties1)
|
| 58 |
+
node2 = self.create_node(label2, properties2)
|
| 59 |
+
|
| 60 |
+
matcher = self.graph.match((node1, node2), r_type=relationship_type)
|
| 61 |
+
for rel in matcher:
|
| 62 |
+
if all(rel[key] == value for key, value in relationship_properties.items()):
|
| 63 |
+
return rel
|
| 64 |
+
|
| 65 |
+
relationship = Relationship(node1, relationship_type, node2, **relationship_properties)
|
| 66 |
+
self.graph.create(relationship)
|
| 67 |
+
return relationship
|
| 68 |
+
|
| 69 |
+
def delete_node(self, label, properties):
|
| 70 |
+
matcher = self.graph.nodes.match(label, **properties)
|
| 71 |
+
node = matcher.first()
|
| 72 |
+
if node:
|
| 73 |
+
self.graph.delete(node)
|
| 74 |
+
return True
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
def update_node(self, label, identifier, updates):
|
| 78 |
+
matcher = self.graph.nodes.match(label, **identifier)
|
| 79 |
+
node = matcher.first()
|
| 80 |
+
if node:
|
| 81 |
+
for key, value in updates.items():
|
| 82 |
+
node[key] = value
|
| 83 |
+
self.graph.push(node)
|
| 84 |
+
return node
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
def find_node(self, label, properties):
|
| 88 |
+
matcher = self.graph.nodes.match(label, **properties)
|
| 89 |
+
return list(matcher)
|
| 90 |
+
|
| 91 |
+
def create_nodes(self, label, properties_list):
|
| 92 |
+
nodes = []
|
| 93 |
+
for properties in properties_list:
|
| 94 |
+
node = self.create_node(label, properties)
|
| 95 |
+
nodes.append(node)
|
| 96 |
+
return nodes
|
| 97 |
+
|
| 98 |
+
def create_relationships(self, relationships):
|
| 99 |
+
created_relationships = []
|
| 100 |
+
for rel in relationships:
|
| 101 |
+
label1, properties1, label2, properties2, relationship_type = rel
|
| 102 |
+
relationship = self.create_relationship(label1, properties1, label2, properties2, relationship_type)
|
| 103 |
+
created_relationships.append(relationship)
|
| 104 |
+
return created_relationships
|
| 105 |
+
|
Ollama_api/__pycache__/ollama_api.cpython-310.pyc
ADDED
|
Binary file (721 Bytes). View file
|
|
|
Ollama_api/ollama_api.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
# 提供api获取ollama 模型列表
|
| 5 |
+
def get_llm():
|
| 6 |
+
respone = requests.get(url="http://localhost:11434/api/tags")
|
| 7 |
+
result = json.loads(respone.content)
|
| 8 |
+
llms = []
|
| 9 |
+
for llm in result["models"]:
|
| 10 |
+
if "code" not in llm["name"] and "embed" not in llm["name"]:
|
| 11 |
+
llms.append(llm["name"])
|
| 12 |
+
return llms
|
| 13 |
+
|
| 14 |
+
def get_embeding_model():
|
| 15 |
+
respone = requests.get(url="http://localhost:11434/api/tags")
|
| 16 |
+
result = json.loads(respone.content)
|
| 17 |
+
llms = []
|
| 18 |
+
for llm in result["models"]:
|
| 19 |
+
if "embed" in llm["name"]:
|
| 20 |
+
llms.append(llm["name"])
|
| 21 |
+
return llms
|
embeding/__pycache__/asr_utils.cpython-310.pyc
ADDED
|
Binary file (634 Bytes). View file
|
|
|
embeding/__pycache__/chromadb.cpython-310.pyc
ADDED
|
Binary file (3.91 kB). View file
|
|
|
embeding/__pycache__/elasticsearchStore.cpython-310.pyc
ADDED
|
Binary file (4.18 kB). View file
|
|
|
embeding/__pycache__/faissdb.cpython-310.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
embeding/asr_utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#coding:utf-8
|
| 2 |
+
from funasr import AutoModel
|
| 3 |
+
# paraformer-zh is a multi-functional asr model
|
| 4 |
+
# use vad, punc, spk or not as you need
|
| 5 |
+
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc",
|
| 6 |
+
# spk_model="cam++"
|
| 7 |
+
)
|
| 8 |
+
def get_spk_txt(file):
|
| 9 |
+
res = model.generate(input=file,
|
| 10 |
+
batch_size_s=300,
|
| 11 |
+
hotword='魔搭')
|
| 12 |
+
print(res[0]["text"])
|
| 13 |
+
fw = "embeding/tmp.txt"
|
| 14 |
+
f = open(fw,"w",encoding="utf-8")
|
| 15 |
+
f.write('"context"\n'+res[0]["text"])
|
| 16 |
+
f.close()
|
| 17 |
+
return fw
|
embeding/chromadb.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import Chroma
|
| 2 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 3 |
+
from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
|
| 4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 5 |
+
from .asr_utils import get_spk_txt
|
| 6 |
+
|
| 7 |
+
class ChromaDB():
|
| 8 |
+
def __init__(self, embedding="mofanke/acge_text_embedding:latest", persist_directory="./Chroma_db/"):
|
| 9 |
+
|
| 10 |
+
self.embedding = OllamaEmbeddings(model=embedding)
|
| 11 |
+
self.persist_directory = persist_directory
|
| 12 |
+
self.chromadb = Chroma(persist_directory=persist_directory)
|
| 13 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50)
|
| 14 |
+
|
| 15 |
+
def parse_data(self,file):
|
| 16 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
| 17 |
+
try:
|
| 18 |
+
loaders = UnstructuredCSVLoader(file)
|
| 19 |
+
data = loaders.load()
|
| 20 |
+
except:
|
| 21 |
+
loaders = TextLoader(file,encoding="utf-8")
|
| 22 |
+
data = loaders.load()
|
| 23 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
| 24 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
| 25 |
+
data = loaders.load()
|
| 26 |
+
if "pdf" in file.lower():
|
| 27 |
+
loaders = UnstructuredPDFLoader(file)
|
| 28 |
+
data = loaders.load()
|
| 29 |
+
if ".xlsx" in file.lower():
|
| 30 |
+
loaders = UnstructuredExcelLoader(file)
|
| 31 |
+
data = loaders.load()
|
| 32 |
+
if ".md" in file.lower():
|
| 33 |
+
loaders = UnstructuredMarkdownLoader(file)
|
| 34 |
+
data = loaders.load()
|
| 35 |
+
if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
|
| 36 |
+
# 语音解析成文字
|
| 37 |
+
fw = get_spk_txt(file)
|
| 38 |
+
loaders = UnstructuredCSVLoader(fw)
|
| 39 |
+
data = loaders.load()
|
| 40 |
+
tmp = []
|
| 41 |
+
for i in data:
|
| 42 |
+
i.metadata["source"] = file
|
| 43 |
+
tmp.append(i)
|
| 44 |
+
data = tmp
|
| 45 |
+
return data
|
| 46 |
+
|
| 47 |
+
# 创建 新的collection 并且初始化
|
| 48 |
+
def create_collection(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
| 49 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 50 |
+
print("开始创建数据库 ....")
|
| 51 |
+
tmps = []
|
| 52 |
+
for file in files:
|
| 53 |
+
data = self.parse_data(file)
|
| 54 |
+
tmps.extend(data)
|
| 55 |
+
|
| 56 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 57 |
+
|
| 58 |
+
vectorstore = self.chromadb.from_documents(documents=splits, collection_name=c_name,
|
| 59 |
+
embedding=self.embedding, persist_directory=self.persist_directory)
|
| 60 |
+
print("数据块总量:", vectorstore._collection.count())
|
| 61 |
+
|
| 62 |
+
return vectorstore
|
| 63 |
+
|
| 64 |
+
# 添加 数据到已有数据库
|
| 65 |
+
def add_chroma(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
| 66 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 67 |
+
print("开始添加文件...")
|
| 68 |
+
tmps = []
|
| 69 |
+
for file in files:
|
| 70 |
+
data = self.parse_data(file)
|
| 71 |
+
tmps.extend(data)
|
| 72 |
+
|
| 73 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 74 |
+
|
| 75 |
+
vectorstore = Chroma(persist_directory=self.persist_directory, collection_name=c_name,
|
| 76 |
+
embedding_function=self.embedding)
|
| 77 |
+
vectorstore.add_documents(splits)
|
| 78 |
+
print("数据块总量:", vectorstore._collection.count())
|
| 79 |
+
|
| 80 |
+
return vectorstore
|
| 81 |
+
|
| 82 |
+
# 删除 某个collection中的 某个文件
|
| 83 |
+
def del_files(self, del_files_name, c_name):
|
| 84 |
+
|
| 85 |
+
vectorstore = self.chromadb._client.get_collection(c_name)
|
| 86 |
+
del_ids = []
|
| 87 |
+
vec_dict = vectorstore.get()
|
| 88 |
+
for id, md in zip(vec_dict["ids"], vec_dict["metadatas"]):
|
| 89 |
+
for dl in del_files_name:
|
| 90 |
+
if dl in md["source"]:
|
| 91 |
+
del_ids.append(id)
|
| 92 |
+
vectorstore.delete(ids=del_ids)
|
| 93 |
+
print("数据块总量:", vectorstore.count())
|
| 94 |
+
|
| 95 |
+
return vectorstore
|
| 96 |
+
|
| 97 |
+
# 删除某个 知识库 collection
|
| 98 |
+
def delete_collection(self, c_name):
|
| 99 |
+
|
| 100 |
+
self.chromadb._client.delete_collection(c_name)
|
| 101 |
+
|
| 102 |
+
# 获取目前所有 collection
|
| 103 |
+
def get_all_collections_name(self):
|
| 104 |
+
cl_names = []
|
| 105 |
+
|
| 106 |
+
test = self.chromadb._client.list_collections()
|
| 107 |
+
for i in range(len(test)):
|
| 108 |
+
cl_names.append(test[i].name)
|
| 109 |
+
return cl_names
|
| 110 |
+
|
| 111 |
+
# 获取 collection中的所有文件
|
| 112 |
+
def get_collcetion_content_files(self, c_name):
|
| 113 |
+
vectorstore = self.chromadb._client.get_collection(c_name)
|
| 114 |
+
c_files = []
|
| 115 |
+
vec_dict = vectorstore.get()
|
| 116 |
+
for md in vec_dict["metadatas"]:
|
| 117 |
+
c_files.append(md["source"])
|
| 118 |
+
return list(set(c_files))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# if __name__ == "__main__":
|
| 122 |
+
# chromadb = ChromaDB()
|
| 123 |
+
# c_name = "sss3"
|
| 124 |
+
#
|
| 125 |
+
# print(chromadb.get_all_collections_name())
|
| 126 |
+
# chromadb.create_collection(["data/���内科学.txt", "data/jl.pdf"], c_name=c_name)
|
| 127 |
+
# print(chromadb.get_all_collections_name())
|
| 128 |
+
# chromadb.add_chroma(["data/儿科学.txt"], c_name=c_name)
|
| 129 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
| 130 |
+
# chromadb.del_files(["data/肾内科学.txt"], c_name=c_name)
|
| 131 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
| 132 |
+
# print(chromadb.get_all_collections_name())
|
| 133 |
+
# chromadb.delete_collection(c_name=c_name)
|
| 134 |
+
# print(chromadb.get_all_collections_name())
|
embeding/elasticsearchStore.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from elasticsearch import Elasticsearch
|
| 2 |
+
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
| 3 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 4 |
+
from langchain_community.document_loaders import TextLoader, UnstructuredCSVLoader, UnstructuredPDFLoader, \
|
| 5 |
+
UnstructuredWordDocumentLoader, UnstructuredExcelLoader, UnstructuredMarkdownLoader
|
| 6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 7 |
+
from .asr_utils import get_spk_txt
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ElsStore():
|
| 12 |
+
def __init__(self, embedding="mofanke/acge_text_embedding:latest", es_url="http://localhost:9200",
|
| 13 |
+
index_name='test_index'):
|
| 14 |
+
self.embedding = OllamaEmbeddings(model=embedding)
|
| 15 |
+
self.es_url = es_url
|
| 16 |
+
self.elastic_vector_search = ElasticsearchStore(
|
| 17 |
+
es_url=self.es_url,
|
| 18 |
+
index_name=index_name,
|
| 19 |
+
embedding=self.embedding
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def parse_data(self, file):
|
| 23 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
| 24 |
+
try:
|
| 25 |
+
loaders = UnstructuredCSVLoader(file)
|
| 26 |
+
data = loaders.load()
|
| 27 |
+
except:
|
| 28 |
+
loaders = TextLoader(file, encoding="utf-8")
|
| 29 |
+
data = loaders.load()
|
| 30 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
| 31 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
| 32 |
+
data = loaders.load()
|
| 33 |
+
if "pdf" in file.lower():
|
| 34 |
+
loaders = UnstructuredPDFLoader(file)
|
| 35 |
+
data = loaders.load()
|
| 36 |
+
if ".xlsx" in file.lower():
|
| 37 |
+
loaders = UnstructuredExcelLoader(file)
|
| 38 |
+
data = loaders.load()
|
| 39 |
+
if ".md" in file.lower():
|
| 40 |
+
loaders = UnstructuredMarkdownLoader(file)
|
| 41 |
+
data = loaders.load()
|
| 42 |
+
if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
|
| 43 |
+
# 语音解析成文字
|
| 44 |
+
fw = get_spk_txt(file)
|
| 45 |
+
loaders = UnstructuredCSVLoader(fw)
|
| 46 |
+
data = loaders.load()
|
| 47 |
+
tmp = []
|
| 48 |
+
for i in data:
|
| 49 |
+
i.metadata["source"] = file
|
| 50 |
+
tmp.append(i)
|
| 51 |
+
data = tmp
|
| 52 |
+
return data
|
| 53 |
+
|
| 54 |
+
def get_count(self, c_name):
|
| 55 |
+
# 获取index-anme中的数据块数
|
| 56 |
+
|
| 57 |
+
# 初始化 Elasticsearch 客户端
|
| 58 |
+
es = Elasticsearch([{
|
| 59 |
+
'host': self.es_url.split(":")[1][2:],
|
| 60 |
+
'port': int(self.es_url.split(":")[2]),
|
| 61 |
+
'scheme': 'http' # 指定使用的协议
|
| 62 |
+
}])
|
| 63 |
+
|
| 64 |
+
# 指定索引名称
|
| 65 |
+
index_name = c_name
|
| 66 |
+
|
| 67 |
+
# 获取文档总数
|
| 68 |
+
response = es.count(index=index_name)
|
| 69 |
+
|
| 70 |
+
# 输出文档总数
|
| 71 |
+
return response['count']
|
| 72 |
+
|
| 73 |
+
# 创建 新的index_name 并且初始化
|
| 74 |
+
def create_collection(self, files, c_name, chunk_size=200, chunk_overlap=50):
|
| 75 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 76 |
+
print("开始创建数据库 ....")
|
| 77 |
+
tmps = []
|
| 78 |
+
for file in files:
|
| 79 |
+
data = self.parse_data(file)
|
| 80 |
+
tmps.extend(data)
|
| 81 |
+
|
| 82 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 83 |
+
|
| 84 |
+
self.elastic_vector_search = ElasticsearchStore.from_documents(
|
| 85 |
+
documents=splits,
|
| 86 |
+
embedding=self.embedding,
|
| 87 |
+
es_url=self.es_url,
|
| 88 |
+
index_name=c_name,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.elastic_vector_search.client.indices.refresh(index=c_name)
|
| 92 |
+
|
| 93 |
+
print("数据块总量:", self.get_count(c_name))
|
| 94 |
+
|
| 95 |
+
return self.elastic_vector_search
|
| 96 |
+
|
| 97 |
+
# 添加 数据到已有数据库
|
| 98 |
+
def add_chroma(self, files, c_name, chunk_size=200, chunk_overlap=50):
|
| 99 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 100 |
+
print("开始添加文件...")
|
| 101 |
+
tmps = []
|
| 102 |
+
for file in files:
|
| 103 |
+
data = self.parse_data(file)
|
| 104 |
+
tmps.extend(data)
|
| 105 |
+
|
| 106 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 107 |
+
|
| 108 |
+
self.elastic_vector_search = ElasticsearchStore(
|
| 109 |
+
es_url=self.es_url,
|
| 110 |
+
index_name=c_name,
|
| 111 |
+
embedding=self.embedding
|
| 112 |
+
)
|
| 113 |
+
self.elastic_vector_search.add_documents(splits)
|
| 114 |
+
self.elastic_vector_search.client.indices.refresh(index=c_name)
|
| 115 |
+
print("数据块总量:", self.get_count(c_name))
|
| 116 |
+
|
| 117 |
+
return self.elastic_vector_search
|
| 118 |
+
|
| 119 |
+
# 删除某个 知识库 collection
|
| 120 |
+
def delete_collection(self, c_name):
|
| 121 |
+
url = self.es_url + "/" + c_name
|
| 122 |
+
# 发送 DELETE 请求
|
| 123 |
+
response = requests.delete(url)
|
| 124 |
+
|
| 125 |
+
# 检查响应状态码
|
| 126 |
+
if response.status_code == 200:
|
| 127 |
+
return f"索引 'test-basic1' 已成功删除。"
|
| 128 |
+
elif response.status_code == 404:
|
| 129 |
+
return f"索引 'test-basic1' 不存在。"
|
| 130 |
+
else:
|
| 131 |
+
return f"删除索引时出错: {response.status_code}, {response.text}"
|
| 132 |
+
|
| 133 |
+
# 获取目前所有 index_names
|
| 134 |
+
def get_all_collections_name(self):
|
| 135 |
+
indices = self.elastic_vector_search.client.indices.get_alias()
|
| 136 |
+
index_names = list(indices.keys())
|
| 137 |
+
|
| 138 |
+
return index_names
|
| 139 |
+
|
| 140 |
+
def get_collcetion_content_files(self,c_name):
|
| 141 |
+
return []
|
| 142 |
+
|
| 143 |
+
# 删除 某个collection中的 某个文件
|
| 144 |
+
def del_files(self, del_files_name, c_name):
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
|
embeding/faissdb.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import FAISS
|
| 2 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 3 |
+
from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
|
| 4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 5 |
+
import shutil
|
| 6 |
+
import os
|
| 7 |
+
from .asr_utils import get_spk_txt
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FaissDB():
|
| 11 |
+
def __init__(self, embedding="mofanke/acge_text_embedding:latest", persist_directory="./Faiss_db/"):
|
| 12 |
+
|
| 13 |
+
self.embedding = OllamaEmbeddings(model=embedding)
|
| 14 |
+
self.persist_directory = persist_directory
|
| 15 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50, add_start_index=True)
|
| 16 |
+
|
| 17 |
+
def parse_data(self,file):
|
| 18 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
| 19 |
+
try:
|
| 20 |
+
loaders = UnstructuredCSVLoader(file)
|
| 21 |
+
data = loaders.load()
|
| 22 |
+
except:
|
| 23 |
+
loaders = TextLoader(file,encoding="utf-8")
|
| 24 |
+
data = loaders.load()
|
| 25 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
| 26 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
| 27 |
+
data = loaders.load()
|
| 28 |
+
if "pdf" in file.lower():
|
| 29 |
+
loaders = UnstructuredPDFLoader(file)
|
| 30 |
+
data = loaders.load()
|
| 31 |
+
if ".xlsx" in file.lower():
|
| 32 |
+
loaders = UnstructuredExcelLoader(file)
|
| 33 |
+
data = loaders.load()
|
| 34 |
+
if ".md" in file.lower():
|
| 35 |
+
loaders = UnstructuredMarkdownLoader(file)
|
| 36 |
+
data = loaders.load()
|
| 37 |
+
if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
|
| 38 |
+
# 语音解析成文字
|
| 39 |
+
fw = get_spk_txt(file)
|
| 40 |
+
loaders = UnstructuredCSVLoader(fw)
|
| 41 |
+
data = loaders.load()
|
| 42 |
+
tmp = []
|
| 43 |
+
for i in data:
|
| 44 |
+
i.metadata["source"] = file
|
| 45 |
+
tmp.append(i)
|
| 46 |
+
data = tmp
|
| 47 |
+
return data
|
| 48 |
+
|
| 49 |
+
# 创建 新的collection 并且初始化
|
| 50 |
+
def create_collection(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
| 51 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 52 |
+
print("开始创建数据库 ....")
|
| 53 |
+
tmps = []
|
| 54 |
+
for file in files:
|
| 55 |
+
data = self.parse_data(file)
|
| 56 |
+
tmps.extend(data)
|
| 57 |
+
|
| 58 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 59 |
+
|
| 60 |
+
vectorstore = FAISS.from_documents(documents=splits,
|
| 61 |
+
embedding=self.embedding)
|
| 62 |
+
vectorstore.save_local(self.persist_directory + c_name)
|
| 63 |
+
print("数据块总量:", vectorstore.index.ntotal)
|
| 64 |
+
|
| 65 |
+
return vectorstore
|
| 66 |
+
|
| 67 |
+
# 添加 数据到已有数据库
|
| 68 |
+
def add_chroma(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
| 69 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 70 |
+
print("开始添加文件...")
|
| 71 |
+
tmps = []
|
| 72 |
+
for file in files:
|
| 73 |
+
data = self.parse_data(file)
|
| 74 |
+
tmps.extend(data)
|
| 75 |
+
|
| 76 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 77 |
+
|
| 78 |
+
vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
|
| 79 |
+
allow_dangerous_deserialization=True)
|
| 80 |
+
vectorstore.add_documents(documents=splits)
|
| 81 |
+
vectorstore.save_local("Faiss_db/" + c_name)
|
| 82 |
+
print("数据块总量:", vectorstore.index.ntotal)
|
| 83 |
+
|
| 84 |
+
return vectorstore
|
| 85 |
+
|
| 86 |
+
# 删除 某个collection中的 某个文件
|
| 87 |
+
def del_files(self, del_files_name, c_name):
|
| 88 |
+
|
| 89 |
+
vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
|
| 90 |
+
allow_dangerous_deserialization=True)
|
| 91 |
+
del_ids = []
|
| 92 |
+
vec_dict = vectorstore.docstore._dict
|
| 93 |
+
for id, md in vec_dict.items():
|
| 94 |
+
for dl in del_files_name:
|
| 95 |
+
if dl in md.metadata["source"]:
|
| 96 |
+
del_ids.append(id)
|
| 97 |
+
vectorstore.delete(ids=del_ids)
|
| 98 |
+
vectorstore.save_local(self.persist_directory + c_name)
|
| 99 |
+
print("数据块总量:", vectorstore.index.ntotal)
|
| 100 |
+
|
| 101 |
+
return vectorstore
|
| 102 |
+
|
| 103 |
+
# 删除某个 知识库 collection
|
| 104 |
+
def delete_collection(self, c_name):
|
| 105 |
+
shutil.rmtree(self.persist_directory + c_name)
|
| 106 |
+
|
| 107 |
+
# 获取目前所有 collection
|
| 108 |
+
def get_all_collections_name(self):
|
| 109 |
+
cl_names = [i for i in os.listdir(self.persist_directory) if os.path.isdir(self.persist_directory+i)]
|
| 110 |
+
|
| 111 |
+
return cl_names
|
| 112 |
+
|
| 113 |
+
# 获取 collection中的所有文件
|
| 114 |
+
def get_collcetion_content_files(self, c_name):
|
| 115 |
+
vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
|
| 116 |
+
allow_dangerous_deserialization=True)
|
| 117 |
+
c_files = []
|
| 118 |
+
vec_dict = vectorstore.docstore._dict
|
| 119 |
+
for _, md in vec_dict.items():
|
| 120 |
+
c_files.append(md.metadata["source"])
|
| 121 |
+
|
| 122 |
+
return list(set(c_files))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# if __name__ == "__main__":
|
| 126 |
+
# chromadb = FaissDB()
|
| 127 |
+
# c_name = "sss3"
|
| 128 |
+
#
|
| 129 |
+
# print(chromadb.get_all_collections_name())
|
| 130 |
+
# chromadb.create_collection(["data/jl.txt", "data/jl.pdf"], c_name=c_name)
|
| 131 |
+
# print(chromadb.get_all_collections_name())
|
| 132 |
+
# chromadb.add_chroma(["data/tmp.txt"], c_name=c_name)
|
| 133 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
| 134 |
+
# chromadb.del_files(["data/tmp.txt"], c_name=c_name)
|
| 135 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
| 136 |
+
# print(chromadb.get_all_collections_name())
|
| 137 |
+
# chromadb.delete_collection(c_name=c_name)
|
| 138 |
+
# print(chromadb.get_all_collections_name())
|
embeding/tmp.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"context"
|
| 2 |
+
你是不是觉得自己说话的声音直来直去呢?现在告诉你一个主持人吐字的小秘密,那就是每个字在口腔当中像是翻跟头一样打一圈再出来。比如说故人西辞黄鹤楼,而不是故人西辞黄鹤楼。再比如说乌衣巷口夕阳斜,而不是乌衣巷口夕阳斜,你也试试看抖音。
|
img/graph-tool.png
ADDED
|
Git LFS Details
|
img/readme.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1
|
img/zhu.png
ADDED
|
Git LFS Details
|
img/zhuye.png
ADDED
|
img//345/244/215/346/235/202/346/226/271/345/274/217.png
ADDED
|
img//345/276/256/344/277/241/345/233/276/347/211/207_20240524180648.jpg
ADDED
|
rag/__init__.py
ADDED
|
File without changes
|
rag/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (132 Bytes). View file
|
|
|
rag/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (364 Bytes). View file
|
|
|
rag/__pycache__/rag_class.cpython-310.pyc
ADDED
|
Binary file (5.39 kB). View file
|
|
|
rag/__pycache__/rerank.cpython-310.pyc
ADDED
|
Binary file (878 Bytes). View file
|
|
|
rag/__pycache__/rerank.cpython-39.pyc
ADDED
|
Binary file (869 Bytes). View file
|
|
|
rag/__pycache__/rerank_code.cpython-310.pyc
ADDED
|
Binary file (883 Bytes). View file
|
|
|
rag/rag_class.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import Chroma,FAISS
|
| 2 |
+
from langchain_community.llms import Ollama
|
| 3 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 4 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 5 |
+
from langchain_core.runnables import RunnablePassthrough
|
| 6 |
+
from operator import itemgetter
|
| 7 |
+
from langchain.prompts import ChatPromptTemplate
|
| 8 |
+
from rerank_code import rerank_topn
|
| 9 |
+
from Config.config import VECTOR_DB,DB_directory
|
| 10 |
+
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RAG_class:
|
| 14 |
+
def __init__(self, model="qwen2:7b", embed="milkey/dmeta-embedding-zh:f16", c_name="sss1",
|
| 15 |
+
persist_directory="E:/pycode/jupyter_code/langGraph/sss2/chroma.sqlite3/",es_url="http://localhost:9200"):
|
| 16 |
+
template = """
|
| 17 |
+
根据上下文回答以下问题,不要自己发挥,要根据以下参考内容总结答案,如果以下内容无法得到答案,就返回无法根据参考内容获取答案,
|
| 18 |
+
|
| 19 |
+
参考内容为:{context}
|
| 20 |
+
|
| 21 |
+
问题: {question}
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
self.prompts = ChatPromptTemplate.from_template(template)
|
| 25 |
+
|
| 26 |
+
# 使用 问题扩展+结果递归方式得到最终答案
|
| 27 |
+
template1 = """你是一个乐于助人的助手,可以生成与输入问题相关的多个子问题。
|
| 28 |
+
目标是将输入分解为一组可以单独回答的子问题/子问题。
|
| 29 |
+
生成多个与以下内容相关的搜索查询:{question}
|
| 30 |
+
输出4个相关问题,以换行符隔开:"""
|
| 31 |
+
self.prompt_questions = ChatPromptTemplate.from_template(template1)
|
| 32 |
+
|
| 33 |
+
# 构建 问答对
|
| 34 |
+
template2 = """
|
| 35 |
+
以下是您需要回答的问题:
|
| 36 |
+
|
| 37 |
+
\n--\n {question} \n---\n
|
| 38 |
+
|
| 39 |
+
以下是任何可用的背景问答对:
|
| 40 |
+
|
| 41 |
+
\n--\n {q_a_pairs} \n---\n
|
| 42 |
+
|
| 43 |
+
以下是与该问题相关的其他上下文:
|
| 44 |
+
|
| 45 |
+
\n--\n {context} \n---\n
|
| 46 |
+
|
| 47 |
+
使用以上上下文和背景问答对来回答问题,问题是:{question} ,答案是:
|
| 48 |
+
"""
|
| 49 |
+
self.decomposition_prompt = ChatPromptTemplate.from_template(template2)
|
| 50 |
+
|
| 51 |
+
self.llm = Ollama(model=model)
|
| 52 |
+
self.embeding = OllamaEmbeddings(model=embed)
|
| 53 |
+
if VECTOR_DB==1:
|
| 54 |
+
self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
|
| 55 |
+
persist_directory=persist_directory)
|
| 56 |
+
elif VECTOR_DB ==2:
|
| 57 |
+
self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
|
| 58 |
+
allow_dangerous_deserialization=True)
|
| 59 |
+
elif VECTOR_DB ==3:
|
| 60 |
+
self.vectstore = ElasticsearchStore(
|
| 61 |
+
es_url=es_url,
|
| 62 |
+
index_name=c_name,
|
| 63 |
+
embedding=self.embeding
|
| 64 |
+
)
|
| 65 |
+
self.retriever = self.vectstore.as_retriever()
|
| 66 |
+
try:
|
| 67 |
+
if VECTOR_DB==1:
|
| 68 |
+
self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
|
| 69 |
+
persist_directory=persist_directory)
|
| 70 |
+
elif VECTOR_DB ==2:
|
| 71 |
+
self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
|
| 72 |
+
allow_dangerous_deserialization=True)
|
| 73 |
+
elif VECTOR_DB ==3:
|
| 74 |
+
self.vectstore = ElasticsearchStore(
|
| 75 |
+
es_url=es_url,
|
| 76 |
+
index_name=c_name,
|
| 77 |
+
embedding=self.embeding
|
| 78 |
+
)
|
| 79 |
+
self.retriever = self.vectstore.as_retriever()
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print("仅模型时无需加载数据库",e)
|
| 82 |
+
#
|
| 83 |
+
# Post-processing
|
| 84 |
+
def format_docs(self,docs):
|
| 85 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
| 86 |
+
# 传统方式召回,单问题召回,然后llm总结答案回答
|
| 87 |
+
def simple_chain(self,question):
|
| 88 |
+
_chain = (
|
| 89 |
+
{"context": self.retriever|self.format_docs,"question":RunnablePassthrough()}
|
| 90 |
+
|self.prompts
|
| 91 |
+
|self.llm
|
| 92 |
+
|StrOutputParser()
|
| 93 |
+
)
|
| 94 |
+
answer = _chain.invoke({"question":question})
|
| 95 |
+
return answer
|
| 96 |
+
|
| 97 |
+
def rerank_chain(self,question):
|
| 98 |
+
retriever = self.vectstore.as_retriever(search_kwargs={"k": 10})
|
| 99 |
+
docs = retriever.invoke(question)
|
| 100 |
+
docs = rerank_topn(question,docs,N=5)
|
| 101 |
+
_chain = (
|
| 102 |
+
self.prompts
|
| 103 |
+
| self.llm
|
| 104 |
+
| StrOutputParser()
|
| 105 |
+
)
|
| 106 |
+
answer = _chain.invoke({"context":self.format_docs(docs),"question": question})
|
| 107 |
+
return answer
|
| 108 |
+
|
| 109 |
+
def format_qa_pairs(self, question, answer):
|
| 110 |
+
formatted_string = ""
|
| 111 |
+
formatted_string += f"Question: {question}\nAnswer:{answer}\n\n"
|
| 112 |
+
return formatted_string
|
| 113 |
+
|
| 114 |
+
# 获取问题的 扩展问题
|
| 115 |
+
def decomposition_chain(self, question):
|
| 116 |
+
_chain = (
|
| 117 |
+
{"question": RunnablePassthrough()}
|
| 118 |
+
| self.prompt_questions
|
| 119 |
+
| self.llm
|
| 120 |
+
| StrOutputParser()
|
| 121 |
+
| (lambda x: x.split("\n"))
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
questions = _chain.invoke({"question": question}) + [question]
|
| 125 |
+
|
| 126 |
+
return questions
|
| 127 |
+
# 多问题递归召回,每次召回后,问题和答案同时作为下一次召回的参考,再次用新问题召回
|
| 128 |
+
def rag_chain(self, questions):
|
| 129 |
+
q_a_pairs = ""
|
| 130 |
+
for q in questions:
|
| 131 |
+
_chain = (
|
| 132 |
+
{"context": itemgetter("question") | self.retriever,
|
| 133 |
+
"question": itemgetter("question"),
|
| 134 |
+
"q_a_pairs": itemgetter("q_a_paris")
|
| 135 |
+
}
|
| 136 |
+
| self.decomposition_prompt
|
| 137 |
+
| self.llm
|
| 138 |
+
| StrOutputParser()
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
answer = _chain.invoke({"question": q, "q_a_paris": q_a_pairs})
|
| 142 |
+
q_a_pairs = self.format_qa_pairs(q, answer)
|
| 143 |
+
q_a_pairs = q_a_pairs + "\n----\n" + q_a_pairs
|
| 144 |
+
return answer
|
| 145 |
+
|
| 146 |
+
# 将聊天历史格式化为一个字符串
|
| 147 |
+
def format_chat_history(self,history):
|
| 148 |
+
formatted_history = ""
|
| 149 |
+
for role,content in history:
|
| 150 |
+
formatted_history += f"{role}: {content}\n"
|
| 151 |
+
return formatted_history
|
| 152 |
+
# 基于ollama大模型的大模型 多轮对话,不使用知识库的
|
| 153 |
+
def mult_chat(self,chat_history):
|
| 154 |
+
# 格式化聊天历史
|
| 155 |
+
formatted_history = self.format_chat_history(chat_history)
|
| 156 |
+
|
| 157 |
+
# 调用模型生成回复
|
| 158 |
+
response = self.llm.invoke(formatted_history)
|
| 159 |
+
return response
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# if __name__ == "__main__":
|
| 164 |
+
# rag = RAG_class(model="deepseek-r1:14b")
|
| 165 |
+
# question = "人卫社官网网址是?"
|
| 166 |
+
# questions = rag.decomposition_chain(question)
|
| 167 |
+
# print(questions)
|
| 168 |
+
# answer = rag.rag_chain(questions)
|
| 169 |
+
# print(answer)
|
rag/rerank_code.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 3 |
+
|
| 4 |
+
tokenizer = AutoTokenizer.from_pretrained('E:\\model\\bge-reranker-large')
|
| 5 |
+
model = AutoModelForSequenceClassification.from_pretrained('E:\\model\\bge-reranker-large')
|
| 6 |
+
model.eval()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def rerank_topn(question,docs,N=5):
|
| 10 |
+
pairs = []
|
| 11 |
+
for i in docs:
|
| 12 |
+
pairs.append([question,i.page_content])
|
| 13 |
+
|
| 14 |
+
with torch.no_grad():
|
| 15 |
+
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
|
| 16 |
+
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
|
| 17 |
+
scores = scores.argsort().numpy()[::-1][:N]
|
| 18 |
+
bk = []
|
| 19 |
+
for i in scores:
|
| 20 |
+
bk.append(docs[i])
|
| 21 |
+
return bk
|
test/__init__.py
ADDED
|
File without changes
|
test/graph2neo4j.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(r"..//")#
|
| 4 |
+
from Neo4j.neo4j_op import KnowledgeGraph
|
| 5 |
+
from Neo4j.graph_extract import update_graph
|
| 6 |
+
from Config.config import neo4j_host,neo4j_name,neo4j_pwd
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
kg = KnowledgeGraph(neo4j_host,neo4j_name,neo4j_pwd)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
if __name__ == "__main__":
|
| 14 |
+
|
| 15 |
+
text = """范冰冰,1981年9月16日生于山东青岛,毕业于上海师范大学谢晋影视艺术学院,中国女演员,歌手。
|
| 16 |
+
1998年参演电视剧《还珠格格》成名。2004年主演电影《手机》获得第27届大众电影百花奖最佳女演员奖。"""
|
| 17 |
+
res = update_graph(text)
|
| 18 |
+
# 批量创建节点
|
| 19 |
+
nodes = kg.create_nodes("node", res["nodes"])
|
| 20 |
+
print(nodes)
|
| 21 |
+
# 批量创建关系
|
| 22 |
+
relationships = kg.create_relationships([
|
| 23 |
+
("node", {"name": edge["source"]}, "node", {"name": edge["target"]}, edge["label"]) for edge in res["edges"]
|
| 24 |
+
])
|
| 25 |
+
print(relationships)
|