swarm-backend / core /memory_model.py
lk080424's picture
Upload folder using huggingface_hub
17fba62 verified
#!/usr/bin/env python3
"""
虫群v7 — 记忆模型(Memory Model)
参数化个人记忆系统:将交互记录编码为可检索的结构
设计思路:
- 传统方案:对话存数据库,检索靠关键词匹配
- 虫群方案:记忆以向量索引+时序结构存储,检索更精准
- 未来方向:将记忆编码为模型参数(模型即数据库)
- 当前实现:轻量级向量索引 + 时序衰减 + 关键词增强
"""
import hashlib
import json
import logging
import math
import os
import re
import time
from datetime import datetime
from typing import Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# ============================================================
# 记忆条目
# ============================================================
class MemoryEntry:
"""单条记忆"""
def __init__(self, user_message: str, ai_response: str,
intent: str = "", route: str = "",
timestamp: str = ""):
self.entry_id = hashlib.md5(
f"{user_message}{ai_response}{time.time()}".encode()
).hexdigest()[:12]
self.user_message = user_message
self.ai_response = ai_response
self.intent = intent
self.route = route
self.timestamp = timestamp or datetime.now().isoformat()
self.created_at = time.time()
# 关键词提取(简易版,替代分词)
self.keywords = self._extract_keywords(user_message)
# 访问计数(用于衰减/强化)
self.access_count = 0
self.last_access = self.created_at
def _extract_keywords(self, text: str) -> List[str]:
"""简易中文关键词提取:去停用词 + 长词优先"""
# 去标点
cleaned = re.sub(r'[,。!?、;:""''()\s]', ' ', text)
# 按空格和常见分隔符切分
words = re.split(r'[\s,.\-!?;:]+', cleaned)
# 过滤短词
keywords = [w for w in words if len(w) >= 2]
return keywords[:20] # 最多20个
def to_dict(self) -> Dict:
return {
"entry_id": self.entry_id,
"user_message": self.user_message[:100],
"ai_response": self.ai_response[:100],
"intent": self.intent,
"keywords": self.keywords[:5],
"timestamp": self.timestamp,
"access_count": self.access_count,
}
# ============================================================
# 记忆模型核心
# ============================================================
class MemoryModel:
"""
记忆模型 — 参数化个人记忆
特性:
1. retrieve_context(): 获取上下文摘要(给元模型用)
2. retrieve(): 精确检索记忆条目
3. store_interaction(): 存储新交互
4. 时间衰减: 越早的记忆权重越低
5. 关键词增强: 匹配关键词越多权重越高
"""
# 记忆存储路径
DATA_DIR = "/home/admin/swarm/data/memory"
MAX_ENTRIES = 5000 # 最大记忆条数
DECAY_HALF_LIFE = 86400 * 7 # 衰减半衰期: 7天
def __init__(self, user_id: str = "default"):
self.user_id = user_id
self._entries: List[MemoryEntry] = []
self._keyword_index: Dict[str, List[str]] = {} # keyword -> [entry_ids]
self._intent_index: Dict[str, List[str]] = {} # intent -> [entry_ids]
self._loaded = False
# 统计
self._store_count = 0
self._retrieve_count = 0
self._hit_count = 0
# 延迟加载
self._load()
# ============================================================
# 核心接口
# ============================================================
def retrieve_context(self, message: str) -> str:
"""
获取与当前消息相关的记忆上下文文本
返回格式: "之前你问过XXX,我回答了YYY"
"""
self._retrieve_count += 1
entries = self.retrieve(message, top_k=3)
if not entries:
return ""
self._hit_count += 1
parts = []
for e in entries:
msg = e.get("user_message", "") if isinstance(e, dict) else e.user_message
resp = e.get("ai_response", "") if isinstance(e, dict) else e.ai_response
parts.append(f"之前你问过「{msg[:50]}」,回答是「{resp[:50]}」")
return ";".join(parts)
def retrieve(self, query: str, top_k: int = 5) -> List[Dict]:
"""
检索相关记忆条目
算法: 关键词匹配 + 时间衰减 + 访问频次
"""
self._retrieve_count += 1
if not self._entries:
return []
# 提取查询关键词
query_words = set()
cleaned = re.sub(r'[,。!?、;:""''()\s]', ' ', query)
for w in re.split(r'[\s,.\-!?;:]+', cleaned):
if len(w) >= 2:
query_words.add(w)
# 计算每条记忆的得分
scored: List[Tuple[float, MemoryEntry]] = []
now = time.time()
for entry in self._entries:
score = 0.0
# 关键词匹配分 (0~0.5)
overlap = query_words & set(entry.keywords)
if overlap:
match_ratio = len(overlap) / max(len(query_words), 1)
score += match_ratio * 0.5
# 时间衰减分 (0~0.3): 越新越高
age_seconds = now - entry.created_at
decay = math.exp(-0.693 * age_seconds / self.DECAY_HALF_LIFE)
score += decay * 0.3
# 访问频次分 (0~0.2): 常访问的记忆更重要
freq_score = min(entry.access_count / 10.0, 1.0)
score += freq_score * 0.2
if score > 0.05: # 最低阈值
scored.append((score, entry))
# 排序取top_k
scored.sort(key=lambda x: x[0], reverse=True)
results = []
for score, entry in scored[:top_k]:
entry.access_count += 1
entry.last_access = now
d = entry.to_dict()
d["_score"] = round(score, 3)
results.append(d)
if results:
self._hit_count += 1
return results
def store_interaction(self, message: str, result: str,
analysis=None) -> str:
"""
存储一条交互记忆
Args:
message: 用户消息
result: AI回复
analysis: TaskAnalysis对象(可选)
Returns:
entry_id
"""
self._store_count += 1
intent = analysis.intent if analysis else ""
route = analysis.route if analysis else ""
entry = MemoryEntry(
user_message=message,
ai_response=result,
intent=intent,
route=route,
)
self._entries.append(entry)
# 更新关键词索引
for kw in entry.keywords:
if kw not in self._keyword_index:
self._keyword_index[kw] = []
self._keyword_index[kw].append(entry.entry_id)
# 更新意图索引
if intent:
if intent not in self._intent_index:
self._intent_index[intent] = []
self._intent_index[intent].append(entry.entry_id)
# 超过上限时淘汰最旧且最少访问的
if len(self._entries) > self.MAX_ENTRIES:
self._evict()
# 定期持久化(每10次存储)
if self._store_count % 10 == 0:
self._save()
return entry.entry_id
# ============================================================
# 淘汰策略
# ============================================================
def _evict(self, count: int = 100):
"""淘汰最旧且最少访问的记忆"""
now = time.time()
def evict_score(entry: MemoryEntry) -> float:
"""得分越低越先淘汰"""
age_days = (now - entry.created_at) / 86400
freq = entry.access_count
# 旧 + 低频 = 低分
return freq / (1 + age_days)
self._entries.sort(key=evict_score, reverse=True)
removed = self._entries[self.MAX_ENTRIES - count:]
self._entries = self._entries[:self.MAX_ENTRIES - count]
# 重建索引
self._rebuild_index()
logger.debug(f"记忆淘汰: 移除{len(removed)}条")
def _rebuild_index(self):
"""重建关键词索引"""
self._keyword_index.clear()
self._intent_index.clear()
for entry in self._entries:
for kw in entry.keywords:
if kw not in self._keyword_index:
self._keyword_index[kw] = []
self._keyword_index[kw].append(entry.entry_id)
if entry.intent:
if entry.intent not in self._intent_index:
self._intent_index[entry.intent] = []
self._intent_index[entry.intent].append(entry.entry_id)
# ============================================================
# 持久化
# ============================================================
def _load(self):
"""从磁盘加载记忆"""
if self._loaded:
return
filepath = os.path.join(self.DATA_DIR, f"{self.user_id}.json")
if not os.path.exists(filepath):
self._loaded = True
return
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
for item in data.get("entries", []):
entry = MemoryEntry(
user_message=item.get("user_message", ""),
ai_response=item.get("ai_response", ""),
intent=item.get("intent", ""),
route=item.get("route", ""),
timestamp=item.get("timestamp", ""),
)
entry.access_count = item.get("access_count", 0)
entry.created_at = item.get("created_at", time.time())
self._entries.append(entry)
self._rebuild_index()
logger.info(f"记忆加载: {len(self._entries)}条 (用户: {self.user_id})")
except Exception as e:
logger.warning(f"记忆加载失败: {e}")
self._loaded = True
def _save(self):
"""持久化记忆到磁盘"""
os.makedirs(self.DATA_DIR, exist_ok=True)
filepath = os.path.join(self.DATA_DIR, f"{self.user_id}.json")
try:
data = {
"user_id": self.user_id,
"version": 1,
"entries": [
{
"entry_id": e.entry_id,
"user_message": e.user_message,
"ai_response": e.ai_response,
"intent": e.intent,
"route": e.route,
"timestamp": e.timestamp,
"created_at": e.created_at,
"access_count": e.access_count,
"keywords": e.keywords,
}
for e in self._entries[-self.MAX_ENTRIES:]
]
}
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.debug(f"记忆保存: {len(self._entries)}条")
except Exception as e:
logger.warning(f"记忆保存失败: {e}")
# ============================================================
# 统计
# ============================================================
def get_stats(self) -> Dict:
"""获取记忆统计"""
return {
"user_id": self.user_id,
"total_entries": len(self._entries),
"keyword_index_size": len(self._keyword_index),
"intent_index_size": len(self._intent_index),
"store_count": self._store_count,
"retrieve_count": self._retrieve_count,
"hit_count": self._hit_count,
"hit_rate": round(
self._hit_count / max(self._retrieve_count, 1), 3
),
}
# ============================================================
# v7.1增强: 实时训练 + 记忆矩阵 + 超长对话
# ============================================================
def encode_realtime(self, message: str, result: str, analysis=None):
"""
实时训练: 每次交互都更新记忆参数
实现策略:
- 维护一个微型训练缓冲区(最近N条交互)
- 当缓冲区满时触发微调(模拟)
- 关键: 高频记忆权重增大,低频记忆权重衰减
"""
# 先正常存储
entry_id = self.store_interaction(message, result, analysis)
# 更新记忆强度参数
self._update_memory_params(message, result)
# 检查是否需要触发微调
if self._store_count % 20 == 0:
self._micro_finetune()
return entry_id
def _update_memory_params(self, message: str, result: str):
"""
更新记忆参数(模拟参数化存储)
思路: 不是真的训练模型,而是维护一个"参数字典"
key=概念, value=强度权重
频繁出现的概念权重增大 → 类似TF-IDF的反向操作
"""
if not hasattr(self, '_param_dict'):
self._param_dict: Dict[str, float] = {}
# 提取概念(用关键词替代)
concepts = self._extract_concepts(message + " " + result)
for concept in concepts:
if concept in self._param_dict:
# 已有概念: 增强权重(但不超过上限)
self._param_dict[concept] = min(
self._param_dict[concept] * 1.1, 2.0
)
else:
# 新概念: 初始权重
self._param_dict[concept] = 1.0
# 所有概念轻微衰减(防止无限增长)
for k in list(self._param_dict.keys()):
self._param_dict[k] *= 0.999
if self._param_dict[k] < 0.1:
del self._param_dict[k]
def _extract_concepts(self, text: str) -> List[str]:
"""提取概念(复用关键词提取)"""
cleaned = re.sub(r'[,。!?、;:""''()\s]', ' ', text)
words = re.split(r'[\s,.\-!?;:]+', cleaned)
return [w for w in words if len(w) >= 2][:10]
def _micro_finetune(self):
"""
微调(模拟): 基于近期交互微调记忆权重
真实实现: 用缓冲区数据做1-2步梯度下降
当前实现: 强化近期高频记忆、衰减低频记忆
"""
if not hasattr(self, '_param_dict') or not self._param_dict:
return
# 按权重排序,保留top 500概念
sorted_params = sorted(
self._param_dict.items(), key=lambda x: x[1], reverse=True
)
self._param_dict = dict(sorted_params[:500])
logger.debug(f"记忆微调: 保留{len(self._param_dict)}个概念参数")
def fork_for_task(self, task_name: str) -> 'MemoryModel':
"""
为特定任务复制一个专属记忆模型
用法: 物理课备课时fork一个physics记忆模型
课上只用物理相关记忆,课后合并回主记忆
Returns:
新的MemoryModel实例(副本)
"""
# 创建副本
forked = MemoryModel.__new__(MemoryModel)
forked.user_id = f"{self.user_id}__{task_name}"
forked._entries = [] # 不复制全部,只建空壳
forked._keyword_index = {}
forked._intent_index = {}
forked._loaded = True
forked._store_count = 0
forked._retrieve_count = 0
forked._hit_count = 0
# 复制相关记忆(按意图过滤)
task_keywords = self._extract_concepts(task_name)
for entry in self._entries:
overlap = set(entry.keywords) & set(task_keywords)
if overlap:
forked._entries.append(entry)
for kw in entry.keywords:
if kw not in forked._keyword_index:
forked._keyword_index[kw] = []
forked._keyword_index[kw].append(entry.entry_id)
# 复制参数字典
if hasattr(self, '_param_dict'):
forked._param_dict = {
k: v for k, v in self._param_dict.items()
if k in task_keywords or v > 1.0
}
else:
forked._param_dict = {}
logger.info(f"记忆分叉: {task_name}, 复制{len(forked._entries)}条相关记忆")
return forked
def merge_from(self, other: 'MemoryModel'):
"""
合并另一个记忆模型(从分叉回归)
用法: 课上积累的物理记忆,课后合并回主记忆
"""
merged = 0
for entry in other._entries:
# 检查是否已存在
existing_ids = {e.entry_id for e in self._entries}
if entry.entry_id not in existing_ids:
self._entries.append(entry)
for kw in entry.keywords:
if kw not in self._keyword_index:
self._keyword_index[kw] = []
self._keyword_index[kw].append(entry.entry_id)
merged += 1
# 合并参数
if hasattr(other, '_param_dict') and hasattr(self, '_param_dict'):
for k, v in other._param_dict.items():
if k in self._param_dict:
self._param_dict[k] = max(self._param_dict[k], v)
else:
self._param_dict[k] = v
self._save()
logger.info(f"记忆合并: +{merged}条新记忆")
def get_unlimited_context(self, query: str, max_tokens: int = 4000) -> str:
"""
超长对话支持: 通过记忆压缩实现无限上下文
原理:
- 传统: 把所有历史拼上去(受token限制)
- 虫群: 只检索最相关的记忆+最近N轮(滑动窗口)
- 压缩: 旧记忆自动摘要,保持信息密度
"""
parts = []
char_count = 0
# 1. 最近N轮对话(滑动窗口,最近10条)
recent = self._entries[-10:]
for e in reversed(recent):
line = f"用户: {e.user_message[:100]}\n助手: {e.ai_response[:100]}"
if char_count + len(line) > max_tokens:
break
parts.insert(0, line)
char_count += len(line)
# 2. 相关历史记忆(按相关性)
relevant = self.retrieve(query, top_k=5)
seen_ids = {e.entry_id for e in recent}
for r in relevant:
rid = r.get("entry_id", "")
if rid in seen_ids:
continue
line = f"[历史] 用户: {r['user_message'][:80]}\n[历史] 助手: {r['ai_response'][:80]}"
if char_count + len(line) > max_tokens:
break
parts.insert(0, line)
char_count += len(line)
seen_ids.add(rid)
# 3. 概念参数摘要
if hasattr(self, '_param_dict') and self._param_dict:
top_concepts = sorted(
self._param_dict.items(), key=lambda x: x[1], reverse=True
)[:10]
concept_str = "核心概念: " + ", ".join(
f"{k}({v:.1f})" for k, v in top_concepts
)
parts.insert(0, concept_str)
return "\n---\n".join(parts)
def get_param_stats(self) -> Dict:
"""获取记忆参数统计"""
params = getattr(self, '_param_dict', {})
return {
"concept_count": len(params),
"top_concepts": sorted(
params.items(), key=lambda x: x[1], reverse=True
)[:5] if params else [],
"avg_weight": sum(params.values()) / max(len(params), 1) if params else 0,
}
# ============================================================
# 记忆矩阵: 多任务记忆的统一管理
# ============================================================
class MemoryMatrix:
"""
记忆矩阵 — 管理用户的所有任务记忆
结构:
主记忆(General) — 日常交互
├── 物理记忆(Physics) — 备课专用
├── 编程记忆(Coding) — 开发专用
└── ...按需创建
用法:
matrix = MemoryMatrix("user_001")
matrix.activate("physics") # 激活物理记忆
matrix.store("...", "...") # 存入当前激活记忆
matrix.deactivate() # 回到主记忆
"""
def __init__(self, user_id: str = "default"):
self.user_id = user_id
self._main = MemoryModel(user_id)
self._forks: Dict[str, MemoryModel] = {}
self._active: Optional[str] = None # 当前激活的任务名
def activate(self, task_name: str):
"""激活一个任务记忆(不存在则创建)"""
if task_name not in self._forks:
# 从主记忆分叉
self._forks[task_name] = self._main.fork_for_task(task_name)
logger.info(f"创建任务记忆: {task_name}")
self._active = task_name
def deactivate(self):
"""停用任务记忆,合并回主记忆"""
if self._active and self._active in self._forks:
self._main.merge_from(self._forks[self._active])
# 保留分叉不删除(下次可复用)
self._active = None
@property
def current(self) -> MemoryModel:
"""获取当前活跃的记忆模型"""
if self._active and self._active in self._forks:
return self._forks[self._active]
return self._main
def store(self, message: str, result: str, analysis=None):
"""存入当前活跃记忆"""
self.current.store_interaction(message, result, analysis)
def retrieve(self, query: str, top_k: int = 5) -> List[Dict]:
"""从当前活跃记忆检索"""
return self.current.retrieve(query, top_k)
def get_context(self, query: str) -> str:
"""获取上下文(优先当前任务+补充主记忆)"""
task_ctx = self.current.retrieve_context(query)
if self._active:
main_ctx = self._main.retrieve_context(query)
if task_ctx and main_ctx:
return f"[任务:{self._active}] {task_ctx}\n[通用] {main_ctx}"
return task_ctx
def list_tasks(self) -> List[str]:
"""列出所有任务记忆"""
return list(self._forks.keys())
def get_status(self) -> Dict:
"""获取矩阵状态"""
tasks = {}
for name, mem in self._forks.items():
tasks[name] = mem.get_stats()
return {
"user_id": self.user_id,
"active_task": self._active,
"main_stats": self._main.get_stats(),
"task_count": len(self._forks),
"tasks": tasks,
}