Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| 虫群智能体系统 — 统一记忆核心 | |
| 合并 enhanced_memory_core / v2 / simple_memory_core 三者功能 | |
| SQLite存储 + 重要性评分 + 自动分类 + 过期清理 | |
| """ | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import sqlite3 | |
| import threading | |
| import time | |
| from collections import Counter | |
| from datetime import datetime, timedelta | |
| from typing import Dict, List, Optional | |
| from core.types import MemoryCategory, MemoryRecord | |
| logger = logging.getLogger(__name__) | |
| # 默认数据库路径 | |
| DEFAULT_DB_PATH = "/home/admin/swarm/data/memory.db" | |
| class MemoryCore: | |
| """统一记忆核心 — 单例""" | |
| _instance = None | |
| def __new__(cls, db_path: str = DEFAULT_DB_PATH): | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._db_path = db_path | |
| cls._instance._lock = threading.Lock() | |
| cls._instance._init_db() | |
| cls._instance._init_categories() | |
| # 启动后台清理 | |
| cls._instance._start_cleanup_thread() | |
| return cls._instance | |
| # ============================================================ | |
| # 初始化 | |
| # ============================================================ | |
| def _init_db(self): | |
| """初始化数据库表和索引""" | |
| # :memory: 模式不需要创建目录 | |
| db_dir = os.path.dirname(self._db_path) | |
| if db_dir: | |
| os.makedirs(db_dir, exist_ok=True) | |
| conn = sqlite3.connect(self._db_path) | |
| c = conn.cursor() | |
| c.execute(""" | |
| CREATE TABLE IF NOT EXISTS memories ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| memory_id TEXT UNIQUE, | |
| user_id TEXT, | |
| conversation_id TEXT, | |
| title TEXT, | |
| user_message TEXT, | |
| ai_response TEXT, | |
| category TEXT DEFAULT 'general', | |
| importance REAL DEFAULT 0.5, | |
| priority INTEGER DEFAULT 1, | |
| access_count INTEGER DEFAULT 0, | |
| tags TEXT, | |
| created_at DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| last_accessed DATETIME, | |
| expires_at DATETIME | |
| ) | |
| """) | |
| for idx in ["idx_user", "idx_category", "idx_importance", "idx_expires"]: | |
| try: | |
| c.execute(f"CREATE INDEX IF NOT EXISTS {idx} ON memories({idx.replace('idx_', '')})") | |
| except Exception: | |
| pass | |
| conn.commit() | |
| conn.close() | |
| def _init_categories(self): | |
| """初始化分类关键词映射""" | |
| self._category_keywords = { | |
| MemoryCategory.GREETING: ["你好", "hello", "hi", "早上好", "晚上好", "您好"], | |
| MemoryCategory.PERSONAL: ["名字", "年龄", "工作", "兴趣", "喜欢", "我", "我的"], | |
| MemoryCategory.INFORMATION: ["什么", "如何", "为什么", "怎么", "解释", "说明", "介绍"], | |
| MemoryCategory.TASK: ["做", "执行", "完成", "处理", "帮助", "需要", "请"], | |
| MemoryCategory.CREATION: ["写", "创作", "生成", "创建", "编写", "设计", "开发"], | |
| } | |
| def _start_cleanup_thread(self): | |
| """启动后台清理线程""" | |
| t = threading.Thread(target=self._cleanup_worker, daemon=True) | |
| t.start() | |
| # ============================================================ | |
| # 存储 | |
| # ============================================================ | |
| def store(self, user_id: str, conversation_id: str, title: str, | |
| user_message: str, ai_response: str, | |
| category: MemoryCategory = None, | |
| custom_tags: List[str] = None, | |
| retention_days: int = None) -> str: | |
| """ | |
| 存储一条记忆,返回 memory_id | |
| 自动完成:分类检测、重要性评分、标签生成、过期时间 | |
| """ | |
| # 自动分类 | |
| if category is None: | |
| category = self._detect_category(user_message) | |
| # 生成memory_id | |
| memory_id = hashlib.md5( | |
| f"{user_id}_{conversation_id}_{datetime.now().timestamp()}".encode() | |
| ).hexdigest()[:16] | |
| # 重要性评分 | |
| importance = self._calc_importance(user_message, ai_response, category) | |
| # 标签 | |
| tags = self._generate_tags(user_message, ai_response) | |
| if custom_tags: | |
| tags = list(set(tags + custom_tags)) | |
| # 过期时间 | |
| expires_at = None | |
| if retention_days is None: | |
| retention_days = self._default_retention(category) | |
| if retention_days > 0: | |
| expires_at = (datetime.now() + timedelta(days=retention_days)).isoformat() | |
| # 写入数据库 | |
| with self._lock: | |
| conn = sqlite3.connect(self._db_path) | |
| c = conn.cursor() | |
| c.execute(""" | |
| INSERT INTO memories | |
| (memory_id, user_id, conversation_id, title, | |
| user_message, ai_response, category, importance, | |
| priority, tags, created_at, last_accessed, expires_at) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| memory_id, user_id, conversation_id, title, | |
| user_message, ai_response, category.value, importance, | |
| 1, json.dumps(tags, ensure_ascii=False), | |
| datetime.now().isoformat(), datetime.now().isoformat(), expires_at | |
| )) | |
| conn.commit() | |
| conn.close() | |
| logger.debug(f"记忆存储: {memory_id} [{category.value}] 重要度={importance:.2f}") | |
| return memory_id | |
| # ============================================================ | |
| # 检索 | |
| # ============================================================ | |
| def retrieve(self, query: str, user_id: str = None, | |
| top_k: int = 10, category: MemoryCategory = None, | |
| min_importance: float = 0.0) -> List[Dict]: | |
| """检索相关记忆""" | |
| with self._lock: | |
| conn = sqlite3.connect(self._db_path) | |
| c = conn.cursor() | |
| conditions = ["expires_at IS NULL OR expires_at > ?"] | |
| params: list = [datetime.now().isoformat()] | |
| if user_id: | |
| conditions.append("user_id = ?") | |
| params.append(user_id) | |
| if category: | |
| conditions.append("category = ?") | |
| params.append(category.value) | |
| if min_importance > 0: | |
| conditions.append("importance >= ?") | |
| params.append(min_importance) | |
| if query: | |
| conditions.append("(user_message LIKE ? OR ai_response LIKE ?)") | |
| params.extend([f"%{query}%", f"%{query}%"]) | |
| where = " AND ".join(conditions) | |
| c.execute(f""" | |
| SELECT memory_id, user_id, conversation_id, title, | |
| user_message, ai_response, category, importance, | |
| access_count, tags, created_at | |
| FROM memories WHERE {where} | |
| ORDER BY importance DESC, last_accessed DESC | |
| LIMIT ? | |
| """, params + [top_k]) | |
| rows = c.fetchall() | |
| conn.close() | |
| results = [] | |
| for row in rows: | |
| results.append({ | |
| "memory_id": row[0], | |
| "user_id": row[1], | |
| "conversation_id": row[2], | |
| "title": row[3], | |
| "user_message": row[4], | |
| "ai_response": row[5], | |
| "category": row[6], | |
| "importance": row[7], | |
| "access_count": row[8], | |
| "tags": json.loads(row[9]) if row[9] else [], | |
| "created_at": row[10], | |
| }) | |
| # 更新访问计数 | |
| if results: | |
| self._increment_access([r["memory_id"] for r in results]) | |
| return results | |
| def get_relevant_context(self, query: str, user_id: str = None, | |
| top_k: int = 5) -> str: | |
| """获取与查询相关的上下文文本(供模型使用)""" | |
| memories = self.retrieve(query, user_id=user_id, top_k=top_k) | |
| if not memories: | |
| return "" | |
| parts = [] | |
| for m in memories: | |
| parts.append(f"用户: {m['user_message']}\n助手: {m['ai_response']}") | |
| return "\n---\n".join(parts) | |
| # ============================================================ | |
| # 统计与维护 | |
| # ============================================================ | |
| def get_stats(self) -> Dict: | |
| """获取记忆统计""" | |
| conn = sqlite3.connect(self._db_path) | |
| c = conn.cursor() | |
| c.execute("SELECT COUNT(*) FROM memories") | |
| total = c.fetchone()[0] | |
| c.execute("SELECT category, COUNT(*) FROM memories GROUP BY category") | |
| cat_dist = dict(c.fetchall()) | |
| c.execute("SELECT AVG(importance) FROM memories") | |
| avg_imp = c.fetchone()[0] or 0.0 | |
| conn.close() | |
| return { | |
| "total_memories": total, | |
| "category_distribution": cat_dist, | |
| "avg_importance": round(avg_imp, 3), | |
| "db_path": self._db_path, | |
| } | |
| def cleanup_expired(self) -> int: | |
| """清理过期记忆""" | |
| conn = sqlite3.connect(self._db_path) | |
| c = conn.cursor() | |
| c.execute("DELETE FROM memories WHERE expires_at <= ?", | |
| (datetime.now().isoformat(),)) | |
| deleted = c.rowcount | |
| conn.commit() | |
| conn.close() | |
| if deleted: | |
| logger.info(f"清理过期记忆: {deleted}条") | |
| return deleted | |
| # ============================================================ | |
| # 内部方法 | |
| # ============================================================ | |
| def _detect_category(self, text: str) -> MemoryCategory: | |
| """自动检测记忆分类""" | |
| text_lower = text.lower() | |
| for cat, keywords in self._category_keywords.items(): | |
| if any(kw in text_lower for kw in keywords): | |
| return cat | |
| return MemoryCategory.GENERAL | |
| def _calc_importance(self, user_msg: str, ai_resp: str, | |
| category: MemoryCategory) -> float: | |
| """计算重要性评分""" | |
| # 类别权重 | |
| cat_weights = { | |
| MemoryCategory.PERSONAL: 1.0, | |
| MemoryCategory.CREATION: 0.9, | |
| MemoryCategory.TASK: 0.8, | |
| MemoryCategory.INFORMATION: 0.6, | |
| MemoryCategory.GENERAL: 0.4, | |
| MemoryCategory.GREETING: 0.3, | |
| } | |
| cat_score = cat_weights.get(category, 0.4) | |
| # 长度因素 | |
| total_len = len(user_msg) + len(ai_resp) | |
| len_score = min(total_len / 500.0, 1.0) | |
| # 关键词因素 | |
| key_words = ["重要", "关键", "核心", "必须", "需要"] | |
| kw_score = min(sum(1 for w in key_words if w in user_msg) * 0.15, 0.3) | |
| # 加权 | |
| importance = cat_score * 0.4 + len_score * 0.3 + kw_score * 0.3 | |
| return round(min(max(importance, 0.0), 1.0), 3) | |
| def _generate_tags(self, user_msg: str, ai_resp: str) -> List[str]: | |
| """自动生成标签""" | |
| tags = [] | |
| text = (user_msg + " " + ai_resp).lower() | |
| tag_rules = { | |
| "question": ["什么", "如何", "为什么", "怎么"], | |
| "creation": ["写", "创作", "生成", "创建"], | |
| "technical": ["技术", "开发", "编程", "算法", "系统"], | |
| "business": ["商业", "市场", "营销", "客户"], | |
| "positive": ["好", "棒", "优秀", "喜欢", "满意"], | |
| "negative": ["不好", "糟糕", "失望", "问题", "错误"], | |
| } | |
| for tag, keywords in tag_rules.items(): | |
| if any(kw in text for kw in keywords): | |
| tags.append(tag) | |
| return list(set(tags)) | |
| def _default_retention(self, category: MemoryCategory) -> int: | |
| """各类别默认保留天数""" | |
| retentions = { | |
| MemoryCategory.GREETING: 30, | |
| MemoryCategory.INFORMATION: 90, | |
| MemoryCategory.TASK: 180, | |
| MemoryCategory.CREATION: 365, | |
| MemoryCategory.PERSONAL: 730, | |
| MemoryCategory.GENERAL: 60, | |
| } | |
| return retentions.get(category, 60) | |
| def _increment_access(self, memory_ids: List[str]): | |
| """更新访问计数""" | |
| conn = sqlite3.connect(self._db_path) | |
| c = conn.cursor() | |
| now = datetime.now().isoformat() | |
| for mid in memory_ids: | |
| c.execute("UPDATE memories SET access_count = access_count + 1, last_accessed = ? WHERE memory_id = ?", | |
| (now, mid)) | |
| conn.commit() | |
| conn.close() | |
| def _cleanup_worker(self): | |
| """后台清理线程 — 每小时检查一次""" | |
| while True: | |
| try: | |
| self.cleanup_expired() | |
| except Exception as e: | |
| logger.error(f"记忆清理异常: {e}") | |
| time.sleep(3600) | |