import os import chromadb from chromadb.config import Settings from typing import List, Dict, Optional from config import Config from core.openai_client import OpenAIClient class MemoryManager: """向量记忆管理器 - 存储和检索角色相关的文本片段""" def __init__(self, character_name: str): self.character_name = character_name self.client = OpenAIClient.get_client() # 创建向量数据库目录 os.makedirs(Config.VECTOR_DB_PATH, exist_ok=True) try: self.chroma_client = chromadb.Client(Settings( persist_directory=Config.VECTOR_DB_PATH, anonymized_telemetry=False )) except: # 如果上面的方式失败,尝试使用 PersistentClient self.chroma_client = chromadb.PersistentClient( path=Config.VECTOR_DB_PATH ) # 为每个角色创建独立的集合 collection_name = f"char_{character_name.replace(' ', '_').lower()}" collection_name = collection_name[:63] # ChromaDB 限制集合名长度 try: self.collection = self.chroma_client.get_or_create_collection( name=collection_name, metadata={"character": character_name} ) except Exception as e: print(f"创建集合时出错: {e}") # 如果创建失败,尝试使用更简单的名称 collection_name = f"char_{hash(character_name) % 10000}" self.collection = self.chroma_client.get_or_create_collection( name=collection_name, metadata={"character": character_name} ) def add_text_chunks(self, chunks: List[Dict], character_chunks: List[int]): """添加与角色相关的文本块 Args: chunks: 所有文本块 character_chunks: 角色出现的文本块ID列表 """ documents = [] metadatas = [] ids = [] for chunk_id in character_chunks: if chunk_id < len(chunks): chunk = chunks[chunk_id] documents.append(chunk['text']) metadatas.append({ 'chunk_id': chunk_id, 'position': chunk['start'] }) ids.append(f"chunk_{chunk_id}") if documents: try: # 批量添加,避免一次性添加太多 batch_size = 100 for i in range(0, len(documents), batch_size): batch_docs = documents[i:i+batch_size] batch_metas = metadatas[i:i+batch_size] batch_ids = ids[i:i+batch_size] self.collection.add( documents=batch_docs, metadatas=batch_metas, ids=batch_ids ) print(f"已为 {self.character_name} 添加 {len(documents)} 个文本块到向量库") except Exception as e: print(f"添加文本块到向量库失败: {e}") print("将继续运行,但不使用记忆功能") def search_relevant_context(self, query: str, n_results: int = None) -> List[str]: """检索与查询相关的上下文 Args: query: 查询文本 n_results: 返回结果数量 Returns: 相关文本片段列表 """ n_results = n_results or Config.MAX_MEMORY_RETRIEVAL try: collection_count = self.collection.count() if collection_count == 0: return [] actual_n_results = min(n_results, collection_count) results = self.collection.query( query_texts=[query], n_results=actual_n_results ) if results and results['documents']: return results['documents'][0] return [] except Exception as e: print(f"检索失败: {e}") return [] def get_embedding(self, text: str) -> List[float]: """获取文本嵌入向量 Args: text: 输入文本 Returns: 嵌入向量 """ try: response = self.client.embeddings.create( model=Config.EMBEDDING_MODEL, input=text ) return response.data[0].embedding except Exception as e: print(f"获取嵌入失败: {e}") return [] def get_statistics(self) -> Dict: """获取记忆库统计信息 Returns: 统计信息字典 """ try: count = self.collection.count() return { 'character': self.character_name, 'chunk_count': count, 'collection_name': self.collection.name } except: return { 'character': self.character_name, 'chunk_count': 0, 'collection_name': 'unknown' } def clear(self): """清空记忆库""" try: # 删除集合 self.chroma_client.delete_collection(self.collection.name) print(f"已清空 {self.character_name} 的记忆库") except Exception as e: print(f"清空记忆库失败: {e}")