File size: 5,706 Bytes
a226682 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
import chromadb
from chromadb.config import Settings
from typing import List, Dict, Optional
from config import Config
from core.openai_client import OpenAIClient
class MemoryManager:
"""向量记忆管理器 - 存储和检索角色相关的文本片段"""
def __init__(self, character_name: str):
self.character_name = character_name
self.client = OpenAIClient.get_client()
# 创建向量数据库目录
os.makedirs(Config.VECTOR_DB_PATH, exist_ok=True)
try:
self.chroma_client = chromadb.Client(Settings(
persist_directory=Config.VECTOR_DB_PATH,
anonymized_telemetry=False
))
except:
# 如果上面的方式失败,尝试使用 PersistentClient
self.chroma_client = chromadb.PersistentClient(
path=Config.VECTOR_DB_PATH
)
# 为每个角色创建独立的集合
collection_name = f"char_{character_name.replace(' ', '_').lower()}"
collection_name = collection_name[:63] # ChromaDB 限制集合名长度
try:
self.collection = self.chroma_client.get_or_create_collection(
name=collection_name,
metadata={"character": character_name}
)
except Exception as e:
print(f"创建集合时出错: {e}")
# 如果创建失败,尝试使用更简单的名称
collection_name = f"char_{hash(character_name) % 10000}"
self.collection = self.chroma_client.get_or_create_collection(
name=collection_name,
metadata={"character": character_name}
)
def add_text_chunks(self, chunks: List[Dict], character_chunks: List[int]):
"""添加与角色相关的文本块
Args:
chunks: 所有文本块
character_chunks: 角色出现的文本块ID列表
"""
documents = []
metadatas = []
ids = []
for chunk_id in character_chunks:
if chunk_id < len(chunks):
chunk = chunks[chunk_id]
documents.append(chunk['text'])
metadatas.append({
'chunk_id': chunk_id,
'position': chunk['start']
})
ids.append(f"chunk_{chunk_id}")
if documents:
try:
# 批量添加,避免一次性添加太多
batch_size = 100
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i+batch_size]
batch_metas = metadatas[i:i+batch_size]
batch_ids = ids[i:i+batch_size]
self.collection.add(
documents=batch_docs,
metadatas=batch_metas,
ids=batch_ids
)
print(f"已为 {self.character_name} 添加 {len(documents)} 个文本块到向量库")
except Exception as e:
print(f"添加文本块到向量库失败: {e}")
print("将继续运行,但不使用记忆功能")
def search_relevant_context(self, query: str, n_results: int = None) -> List[str]:
"""检索与查询相关的上下文
Args:
query: 查询文本
n_results: 返回结果数量
Returns:
相关文本片段列表
"""
n_results = n_results or Config.MAX_MEMORY_RETRIEVAL
try:
collection_count = self.collection.count()
if collection_count == 0:
return []
actual_n_results = min(n_results, collection_count)
results = self.collection.query(
query_texts=[query],
n_results=actual_n_results
)
if results and results['documents']:
return results['documents'][0]
return []
except Exception as e:
print(f"检索失败: {e}")
return []
def get_embedding(self, text: str) -> List[float]:
"""获取文本嵌入向量
Args:
text: 输入文本
Returns:
嵌入向量
"""
try:
response = self.client.embeddings.create(
model=Config.EMBEDDING_MODEL,
input=text
)
return response.data[0].embedding
except Exception as e:
print(f"获取嵌入失败: {e}")
return []
def get_statistics(self) -> Dict:
"""获取记忆库统计信息
Returns:
统计信息字典
"""
try:
count = self.collection.count()
return {
'character': self.character_name,
'chunk_count': count,
'collection_name': self.collection.name
}
except:
return {
'character': self.character_name,
'chunk_count': 0,
'collection_name': 'unknown'
}
def clear(self):
"""清空记忆库"""
try:
# 删除集合
self.chroma_client.delete_collection(self.collection.name)
print(f"已清空 {self.character_name} 的记忆库")
except Exception as e:
print(f"清空记忆库失败: {e}") |